Spaces:
Running
Running
Make ActionRunner accept custom LLMs (#13)
Browse files- autoagents/agents/search.py +5 -12
- autoagents/models/custom.py +33 -0
- autoagents/spaces/app.py +16 -11
- autoagents/tools/tools.py +2 -7
- autoagents/utils/utils.py +0 -8
- test.py +11 -6
autoagents/agents/search.py
CHANGED
|
@@ -14,11 +14,10 @@ from langchain.schema import AgentAction, AgentFinish
|
|
| 14 |
from langchain.callbacks import get_openai_callback
|
| 15 |
from langchain.callbacks.base import AsyncCallbackHandler
|
| 16 |
from langchain.callbacks.manager import AsyncCallbackManager
|
| 17 |
-
|
| 18 |
|
| 19 |
from autoagents.tools.tools import search_tool, note_tool, rewrite_search_query
|
| 20 |
from autoagents.utils.logger import InteractionsLogger
|
| 21 |
-
from autoagents.utils.utils import OpenAICred
|
| 22 |
|
| 23 |
|
| 24 |
# Set up the base template
|
|
@@ -124,9 +123,8 @@ class CustomOutputParser(AgentOutputParser):
|
|
| 124 |
class Config:
|
| 125 |
arbitrary_types_allowed = True
|
| 126 |
ialogger: InteractionsLogger
|
| 127 |
-
|
| 128 |
new_action_input: Optional[str]
|
| 129 |
-
|
| 130 |
action_history = defaultdict(set)
|
| 131 |
|
| 132 |
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
|
|
@@ -154,7 +152,7 @@ class CustomOutputParser(AgentOutputParser):
|
|
| 154 |
if action_input in self.action_history[action]:
|
| 155 |
new_action_input = rewrite_search_query(action_input,
|
| 156 |
self.action_history[action],
|
| 157 |
-
|
| 158 |
self.ialogger.add_message({"query_rewrite": True})
|
| 159 |
self.new_action_input = new_action_input
|
| 160 |
self.action_history[action].add(new_action_input)
|
|
@@ -168,8 +166,7 @@ class CustomOutputParser(AgentOutputParser):
|
|
| 168 |
class ActionRunner:
|
| 169 |
def __init__(self,
|
| 170 |
outputq,
|
| 171 |
-
|
| 172 |
-
model_name: str,
|
| 173 |
persist_logs: bool = False):
|
| 174 |
self.ialogger = InteractionsLogger(name=f"{uuid.uuid4().hex[:6]}", persist=persist_logs)
|
| 175 |
tools = [search_tool, note_tool]
|
|
@@ -179,7 +176,7 @@ class ActionRunner:
|
|
| 179 |
input_variables=["input", "intermediate_steps"],
|
| 180 |
ialogger=self.ialogger)
|
| 181 |
|
| 182 |
-
output_parser = CustomOutputParser(ialogger=self.ialogger,
|
| 183 |
|
| 184 |
class MyCustomHandler(AsyncCallbackHandler):
|
| 185 |
def __init__(self):
|
|
@@ -225,10 +222,6 @@ class ActionRunner:
|
|
| 225 |
|
| 226 |
handler = MyCustomHandler()
|
| 227 |
|
| 228 |
-
llm = ChatOpenAI(openai_api_key=cred.key,
|
| 229 |
-
openai_organization=cred.org,
|
| 230 |
-
temperature=0,
|
| 231 |
-
model_name=model_name)
|
| 232 |
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])
|
| 233 |
tool_names = [tool.name for tool in tools]
|
| 234 |
for tool in tools:
|
|
|
|
| 14 |
from langchain.callbacks import get_openai_callback
|
| 15 |
from langchain.callbacks.base import AsyncCallbackHandler
|
| 16 |
from langchain.callbacks.manager import AsyncCallbackManager
|
| 17 |
+
from langchain.base_language import BaseLanguageModel
|
| 18 |
|
| 19 |
from autoagents.tools.tools import search_tool, note_tool, rewrite_search_query
|
| 20 |
from autoagents.utils.logger import InteractionsLogger
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
# Set up the base template
|
|
|
|
| 123 |
class Config:
|
| 124 |
arbitrary_types_allowed = True
|
| 125 |
ialogger: InteractionsLogger
|
| 126 |
+
llm: BaseLanguageModel
|
| 127 |
new_action_input: Optional[str]
|
|
|
|
| 128 |
action_history = defaultdict(set)
|
| 129 |
|
| 130 |
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
|
|
|
|
| 152 |
if action_input in self.action_history[action]:
|
| 153 |
new_action_input = rewrite_search_query(action_input,
|
| 154 |
self.action_history[action],
|
| 155 |
+
self.llm)
|
| 156 |
self.ialogger.add_message({"query_rewrite": True})
|
| 157 |
self.new_action_input = new_action_input
|
| 158 |
self.action_history[action].add(new_action_input)
|
|
|
|
| 166 |
class ActionRunner:
|
| 167 |
def __init__(self,
|
| 168 |
outputq,
|
| 169 |
+
llm: BaseLanguageModel,
|
|
|
|
| 170 |
persist_logs: bool = False):
|
| 171 |
self.ialogger = InteractionsLogger(name=f"{uuid.uuid4().hex[:6]}", persist=persist_logs)
|
| 172 |
tools = [search_tool, note_tool]
|
|
|
|
| 176 |
input_variables=["input", "intermediate_steps"],
|
| 177 |
ialogger=self.ialogger)
|
| 178 |
|
| 179 |
+
output_parser = CustomOutputParser(ialogger=self.ialogger, llm=llm)
|
| 180 |
|
| 181 |
class MyCustomHandler(AsyncCallbackHandler):
|
| 182 |
def __init__(self):
|
|
|
|
| 222 |
|
| 223 |
handler = MyCustomHandler()
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])
|
| 226 |
tool_names = [tool.name for tool in tools]
|
| 227 |
for tool in tools:
|
autoagents/models/custom.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
|
| 3 |
+
from langchain.llms.base import LLM
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CustomLLM(LLM):
|
| 7 |
+
@property
|
| 8 |
+
def _llm_type(self) -> str:
|
| 9 |
+
return "custom"
|
| 10 |
+
|
| 11 |
+
def _call(self, prompt: str, stop=None) -> str:
|
| 12 |
+
r = requests.post(
|
| 13 |
+
"http://localhost:8000/v1/chat/completions",
|
| 14 |
+
json={
|
| 15 |
+
"model": "283-vicuna-7b",
|
| 16 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 17 |
+
"stop": stop
|
| 18 |
+
},
|
| 19 |
+
)
|
| 20 |
+
result = r.json()
|
| 21 |
+
return result["choices"][0]["message"]["content"]
|
| 22 |
+
|
| 23 |
+
async def _acall(self, prompt: str, stop=None) -> str:
|
| 24 |
+
r = requests.post(
|
| 25 |
+
"http://localhost:8000/v1/chat/completions",
|
| 26 |
+
json={
|
| 27 |
+
"model": "283-vicuna-7b",
|
| 28 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 29 |
+
"stop": stop
|
| 30 |
+
},
|
| 31 |
+
)
|
| 32 |
+
result = r.json()
|
| 33 |
+
return result["choices"][0]["message"]["content"]
|
autoagents/spaces/app.py
CHANGED
|
@@ -9,7 +9,9 @@ import openai
|
|
| 9 |
|
| 10 |
from autoagents.utils.constants import MAIN_HEADER, MAIN_CAPTION, SAMPLE_QUESTIONS
|
| 11 |
from autoagents.agents.search import ActionRunner
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
|
| 14 |
async def run():
|
| 15 |
output_acc = ""
|
|
@@ -44,10 +46,9 @@ async def run():
|
|
| 44 |
|
| 45 |
# Ask the user to enter their OpenAI API key
|
| 46 |
if (api_key := st.sidebar.text_input("OpenAI api-key", type="password")):
|
| 47 |
-
|
| 48 |
else:
|
| 49 |
-
|
| 50 |
-
os.getenv("OPENAI_API_ORG"))
|
| 51 |
with st.sidebar:
|
| 52 |
model_dict = {
|
| 53 |
"gpt-3.5-turbo": "GPT-3.5-turbo",
|
|
@@ -67,18 +68,22 @@ async def run():
|
|
| 67 |
for q in SAMPLE_QUESTIONS:
|
| 68 |
st.markdown(f"*{q}*")
|
| 69 |
|
| 70 |
-
if not
|
| 71 |
st.warning(
|
| 72 |
"API key required to try this app. The API key is not stored in any form. [This](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key) might help."
|
| 73 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
outputq = asyncio.Queue()
|
| 76 |
-
runner = ActionRunner(
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
|
| 83 |
async def cleanup(e):
|
| 84 |
st.error(e)
|
|
|
|
| 9 |
|
| 10 |
from autoagents.utils.constants import MAIN_HEADER, MAIN_CAPTION, SAMPLE_QUESTIONS
|
| 11 |
from autoagents.agents.search import ActionRunner
|
| 12 |
+
|
| 13 |
+
from langchain.chat_models import ChatOpenAI
|
| 14 |
+
|
| 15 |
|
| 16 |
async def run():
|
| 17 |
output_acc = ""
|
|
|
|
| 46 |
|
| 47 |
# Ask the user to enter their OpenAI API key
|
| 48 |
if (api_key := st.sidebar.text_input("OpenAI api-key", type="password")):
|
| 49 |
+
api_org = None
|
| 50 |
else:
|
| 51 |
+
api_key, api_org = os.getenv("OPENAI_API_KEY"), os.getenv("OPENAI_API_ORG")
|
|
|
|
| 52 |
with st.sidebar:
|
| 53 |
model_dict = {
|
| 54 |
"gpt-3.5-turbo": "GPT-3.5-turbo",
|
|
|
|
| 68 |
for q in SAMPLE_QUESTIONS:
|
| 69 |
st.markdown(f"*{q}*")
|
| 70 |
|
| 71 |
+
if not api_key:
|
| 72 |
st.warning(
|
| 73 |
"API key required to try this app. The API key is not stored in any form. [This](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key) might help."
|
| 74 |
)
|
| 75 |
+
elif api_org and st.session_state.model_name == "gpt-4":
|
| 76 |
+
st.warning(
|
| 77 |
+
"The free API key does not support GPT-4. Please switch to GPT-3.5-turbo or input your own API key."
|
| 78 |
+
)
|
| 79 |
else:
|
| 80 |
outputq = asyncio.Queue()
|
| 81 |
+
runner = ActionRunner(outputq,
|
| 82 |
+
ChatOpenAI(openai_api_key=api_key,
|
| 83 |
+
openai_organization=api_org,
|
| 84 |
+
temperature=0,
|
| 85 |
+
model_name=st.session_state.model_name),
|
| 86 |
+
persist_logs=True) # log to HF-dataset
|
| 87 |
|
| 88 |
async def cleanup(e):
|
| 89 |
st.error(e)
|
autoagents/tools/tools.py
CHANGED
|
@@ -3,9 +3,7 @@ import os
|
|
| 3 |
from duckpy import Client
|
| 4 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
| 5 |
from langchain.agents import Tool
|
| 6 |
-
from langchain.
|
| 7 |
-
|
| 8 |
-
from autoagents.utils.utils import OpenAICred
|
| 9 |
|
| 10 |
|
| 11 |
MAX_SEARCH_RESULTS = 20 # Number of search results to observe at a time
|
|
@@ -55,15 +53,12 @@ note_tool = Tool(name="Notepad",
|
|
| 55 |
description=notepad_description)
|
| 56 |
|
| 57 |
|
| 58 |
-
def rewrite_search_query(q: str, search_history,
|
| 59 |
history_string = '\n'.join(search_history)
|
| 60 |
template ="""We are using the Search tool.
|
| 61 |
# Previous queries:
|
| 62 |
{history_string}. \n\n Rewrite query {action_input} to be
|
| 63 |
different from the previous ones."""
|
| 64 |
-
llm = ChatOpenAI(temperature=0,
|
| 65 |
-
openai_api_key=cred.key,
|
| 66 |
-
openai_organization=cred.org)
|
| 67 |
prompt = PromptTemplate(template=template,
|
| 68 |
input_variables=["action_input", "history_string"])
|
| 69 |
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
|
|
|
| 3 |
from duckpy import Client
|
| 4 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
| 5 |
from langchain.agents import Tool
|
| 6 |
+
from langchain.base_language import BaseLanguageModel
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
MAX_SEARCH_RESULTS = 20 # Number of search results to observe at a time
|
|
|
|
| 53 |
description=notepad_description)
|
| 54 |
|
| 55 |
|
| 56 |
+
def rewrite_search_query(q: str, search_history, llm: BaseLanguageModel) -> str:
|
| 57 |
history_string = '\n'.join(search_history)
|
| 58 |
template ="""We are using the Search tool.
|
| 59 |
# Previous queries:
|
| 60 |
{history_string}. \n\n Rewrite query {action_input} to be
|
| 61 |
different from the previous ones."""
|
|
|
|
|
|
|
|
|
|
| 62 |
prompt = PromptTemplate(template=template,
|
| 63 |
input_variables=["action_input", "history_string"])
|
| 64 |
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
autoagents/utils/utils.py
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import Optional
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
@dataclass
|
| 6 |
-
class OpenAICred:
|
| 7 |
-
key: str
|
| 8 |
-
org: Optional[str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test.py
CHANGED
|
@@ -1,17 +1,22 @@
|
|
| 1 |
import os
|
| 2 |
import asyncio
|
| 3 |
-
|
| 4 |
-
from langchain.callbacks import get_openai_callback
|
| 5 |
from pprint import pprint
|
| 6 |
-
import pdb
|
| 7 |
from ast import literal_eval
|
| 8 |
from multiprocessing import Pool, TimeoutError
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
async def work(user_input):
|
| 11 |
outputq = asyncio.Queue()
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
task = asyncio.create_task(runner.run(user_input, outputq))
|
| 16 |
|
| 17 |
while True:
|
|
|
|
| 1 |
import os
|
| 2 |
import asyncio
|
| 3 |
+
|
|
|
|
| 4 |
from pprint import pprint
|
|
|
|
| 5 |
from ast import literal_eval
|
| 6 |
from multiprocessing import Pool, TimeoutError
|
| 7 |
|
| 8 |
+
from autoagents.agents.search import ActionRunner
|
| 9 |
+
from langchain.callbacks import get_openai_callback
|
| 10 |
+
from langchain.chat_models import ChatOpenAI
|
| 11 |
+
|
| 12 |
+
|
| 13 |
async def work(user_input):
|
| 14 |
outputq = asyncio.Queue()
|
| 15 |
+
llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"),
|
| 16 |
+
openai_organization=os.getenv("OPENAI_API_ORG"),
|
| 17 |
+
temperature=0,
|
| 18 |
+
model_name="gpt-3.5-turbo")
|
| 19 |
+
runner = ActionRunner(outputq, llm=llm)
|
| 20 |
task = asyncio.create_task(runner.run(user_input, outputq))
|
| 21 |
|
| 22 |
while True:
|