|
|
import asyncio |
|
|
import base64 |
|
|
import os |
|
|
import re |
|
|
from dataclasses import asdict, dataclass |
|
|
from math import ceil |
|
|
|
|
|
import jsonlines |
|
|
import requests |
|
|
import tiktoken |
|
|
import yaml |
|
|
from FlagEmbedding import BGEM3FlagModel |
|
|
from jinja2 import Environment, Template |
|
|
from oaib import Auto |
|
|
from openai import OpenAI |
|
|
from PIL import Image |
|
|
from torch import Tensor, cosine_similarity |
|
|
|
|
|
from src.model_utils import get_text_embedding |
|
|
from src.utils import get_json_from_response, pexists, pjoin, print, tenacity |
|
|
|
|
|
ENCODING = tiktoken.encoding_for_model("gpt-4o") |
|
|
|
|
|
|
|
|
def run_async(coroutine): |
|
|
""" |
|
|
Run an asynchronous coroutine in a non-async environment. |
|
|
|
|
|
Args: |
|
|
coroutine: The coroutine to run. |
|
|
|
|
|
Returns: |
|
|
The result of the coroutine. |
|
|
""" |
|
|
try: |
|
|
loop = asyncio.get_event_loop() |
|
|
except RuntimeError: |
|
|
loop = asyncio.new_event_loop() |
|
|
asyncio.set_event_loop(loop) |
|
|
job = loop.run_until_complete(coroutine) |
|
|
return job |
|
|
|
|
|
|
|
|
def calc_image_tokens(images: list[str]): |
|
|
""" |
|
|
Calculate the number of tokens for a list of images. |
|
|
""" |
|
|
tokens = 0 |
|
|
for image in images: |
|
|
with open(image, "rb") as f: |
|
|
width, height = Image.open(f).size |
|
|
if width > 1024 or height > 1024: |
|
|
if width > height: |
|
|
height = int(height * 1024 / width) |
|
|
width = 1024 |
|
|
else: |
|
|
width = int(width * 1024 / height) |
|
|
height = 1024 |
|
|
h = ceil(height / 512) |
|
|
w = ceil(width / 512) |
|
|
tokens += 85 + 170 * h * w |
|
|
return tokens |
|
|
|
|
|
|
|
|
class LLM: |
|
|
""" |
|
|
A wrapper class to interact with a language model. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: str = "gpt-4o-2024-08-06", |
|
|
api_base: str = None, |
|
|
use_openai: bool = True, |
|
|
use_batch: bool = False, |
|
|
) -> None: |
|
|
""" |
|
|
Initialize the LLM. |
|
|
|
|
|
Args: |
|
|
model (str): The model name. |
|
|
api_base (str): The base URL for the API. |
|
|
use_openai (bool): Whether to use OpenAI. |
|
|
use_batch (bool): Whether to use OpenAI's Batch API, which is single thread only. |
|
|
""" |
|
|
if use_openai and "OPENAI_API_KEY" in os.environ: |
|
|
self.client = OpenAI(base_url=api_base) |
|
|
if use_batch and "OPENAI_API_KEY" in os.environ: |
|
|
assert use_openai, "use_batch must be used with use_openai" |
|
|
self.oai_batch = Auto(loglevel=0) |
|
|
if "OPENAI_API_KEY" not in os.environ: |
|
|
print("Warning: no API key found") |
|
|
self.model = model |
|
|
self.api_base = api_base |
|
|
self._use_openai = use_openai |
|
|
self._use_batch = use_batch |
|
|
|
|
|
@tenacity |
|
|
def __call__( |
|
|
self, |
|
|
content: str, |
|
|
images: list[str] = None, |
|
|
system_message: str = None, |
|
|
history: list = None, |
|
|
delay_batch: bool = False, |
|
|
return_json: bool = False, |
|
|
return_message: bool = False, |
|
|
) -> str | dict | list: |
|
|
""" |
|
|
Call the language model with a prompt and optional images. |
|
|
|
|
|
Args: |
|
|
content (str): The prompt content. |
|
|
images (list[str]): A list of image file paths. |
|
|
system_message (str): The system message. |
|
|
history (list): The conversation history. |
|
|
delay_batch (bool): Whether to delay return of response. |
|
|
return_json (bool): Whether to return the response as JSON. |
|
|
return_message (bool): Whether to return the message. |
|
|
|
|
|
Returns: |
|
|
str | dict | list: The response from the model. |
|
|
""" |
|
|
if content.startswith("You are"): |
|
|
system_message, content = content.split("\n", 1) |
|
|
if history is None: |
|
|
history = [] |
|
|
if isinstance(images, str): |
|
|
images = [images] |
|
|
system, message = self.format_message(content, images, system_message) |
|
|
if self._use_batch: |
|
|
result = run_async(self._run_batch(system + history + message, delay_batch)) |
|
|
if delay_batch: |
|
|
return |
|
|
try: |
|
|
response = result.to_dict()["result"][0]["choices"][0]["message"][ |
|
|
"content" |
|
|
] |
|
|
except Exception as e: |
|
|
print("Failed to get response from batch") |
|
|
raise e |
|
|
elif self._use_openai: |
|
|
completion = self.client.chat.completions.create( |
|
|
model=self.model, messages=system + history + message |
|
|
) |
|
|
response = completion.choices[0].message.content |
|
|
else: |
|
|
response = requests.post( |
|
|
self.api_base, |
|
|
json={ |
|
|
"system": system_message, |
|
|
"prompt": content, |
|
|
"image": [ |
|
|
i["image_url"]["url"] |
|
|
for i in message[-1]["content"] |
|
|
if i["type"] == "image_url" |
|
|
], |
|
|
}, |
|
|
) |
|
|
response.raise_for_status() |
|
|
response = response.text |
|
|
message.append({"role": "assistant", "content": response}) |
|
|
if return_json: |
|
|
response = get_json_from_response(response) |
|
|
if return_message: |
|
|
response = (response, message) |
|
|
return response |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"LLM(model={self.model}, api_base={self.api_base})" |
|
|
|
|
|
async def _run_batch(self, messages: list, delay_batch: bool = False): |
|
|
await self.oai_batch.add( |
|
|
"chat.completions.create", |
|
|
model=self.model, |
|
|
messages=messages, |
|
|
) |
|
|
if delay_batch: |
|
|
return |
|
|
return await self.oai_batch.run() |
|
|
|
|
|
def format_message( |
|
|
self, |
|
|
content: str, |
|
|
images: list[str] = None, |
|
|
system_message: str = None, |
|
|
): |
|
|
""" |
|
|
Message formatter for OpenAI server call. |
|
|
""" |
|
|
if system_message is None: |
|
|
system_message = "You are a helpful assistant" |
|
|
system = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": [{"type": "text", "text": system_message}], |
|
|
} |
|
|
] |
|
|
message = [{"role": "user", "content": [{"type": "text", "text": content}]}] |
|
|
if images is not None: |
|
|
if not isinstance(images, list): |
|
|
images = [images] |
|
|
for image in images: |
|
|
with open(image, "rb") as f: |
|
|
message[0]["content"].append( |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": { |
|
|
"url": f"data:image/jpeg;base64,{base64.b64encode(f.read()).decode('utf-8')}" |
|
|
}, |
|
|
} |
|
|
) |
|
|
return system, message |
|
|
|
|
|
def get_batch_result(self): |
|
|
""" |
|
|
Get responses from delayed batch calls. |
|
|
""" |
|
|
results = run_async(self.oai_batch.run()) |
|
|
return [ |
|
|
r["choices"][0]["message"]["content"] |
|
|
for r in results.to_dict()["result"].values() |
|
|
] |
|
|
|
|
|
def clear_history(self): |
|
|
self.history = [] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Turn: |
|
|
""" |
|
|
A class to represent a turn in a conversation. |
|
|
""" |
|
|
|
|
|
id: int |
|
|
prompt: str |
|
|
response: str |
|
|
message: list |
|
|
images: list[str] = None |
|
|
input_tokens: int = 0 |
|
|
output_tokens: int = 0 |
|
|
embedding: Tensor = None |
|
|
|
|
|
def to_dict(self): |
|
|
return {k: v for k, v in asdict(self).items() if k != "embedding"} |
|
|
|
|
|
def calc_token(self): |
|
|
""" |
|
|
Calculate the number of tokens for the turn. |
|
|
""" |
|
|
if self.images is not None: |
|
|
self.input_tokens += calc_image_tokens(self.images) |
|
|
self.input_tokens += len(ENCODING.encode(self.prompt)) |
|
|
self.output_tokens = len(ENCODING.encode(self.response)) |
|
|
|
|
|
def __eq__(self, other): |
|
|
return self is other |
|
|
|
|
|
|
|
|
class Role: |
|
|
""" |
|
|
An agent, defined by its instruction template and model. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
env: Environment, |
|
|
record_cost: bool, |
|
|
llm: LLM = None, |
|
|
config: dict = None, |
|
|
text_model: BGEM3FlagModel = None, |
|
|
): |
|
|
""" |
|
|
Initialize the Agent. |
|
|
|
|
|
Args: |
|
|
name (str): The name of the role. |
|
|
env (Environment): The Jinja2 environment. |
|
|
record_cost (bool): Whether to record the token cost. |
|
|
llm (LLM): The language model. |
|
|
config (dict): The configuration. |
|
|
text_model (BGEM3FlagModel): The text model. |
|
|
""" |
|
|
self.name = name |
|
|
if config is None: |
|
|
with open(f"roles/{name}.yaml", "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
if llm is None: |
|
|
llm = globals()[config["use_model"] + "_model"] |
|
|
self.llm = llm |
|
|
self.model = llm.model |
|
|
self.record_cost = record_cost |
|
|
self.text_model = text_model |
|
|
self.return_json = config["return_json"] |
|
|
self.system_message = config["system_prompt"] |
|
|
self.prompt_args = set(config["jinja_args"]) |
|
|
self.template = env.from_string(config["template"]) |
|
|
self.retry_template = Template( |
|
|
"""The previous output is invalid, please carefully analyze the traceback and feedback information, correct errors happened before. |
|
|
feedback: |
|
|
{{feedback}} |
|
|
traceback: |
|
|
{{traceback}} |
|
|
Give your corrected output in the same format without including the previous output: |
|
|
""" |
|
|
) |
|
|
self.system_tokens = len(ENCODING.encode(self.system_message)) |
|
|
self.input_tokens = 0 |
|
|
self.output_tokens = 0 |
|
|
self.history: list[Turn] = [] |
|
|
|
|
|
def calc_cost(self, turns: list[Turn]): |
|
|
""" |
|
|
Calculate the cost of a list of turns. |
|
|
""" |
|
|
for turn in turns: |
|
|
self.input_tokens += turn.input_tokens |
|
|
self.output_tokens += turn.output_tokens |
|
|
self.input_tokens += self.system_tokens |
|
|
self.output_tokens += 3 |
|
|
|
|
|
def get_history(self, similar: int, recent: int, prompt: str): |
|
|
""" |
|
|
Get the conversation history. |
|
|
""" |
|
|
history = self.history[-recent:] if recent > 0 else [] |
|
|
if similar > 0: |
|
|
embedding = get_text_embedding(prompt, self.text_model) |
|
|
history.sort(key=lambda x: cosine_similarity(embedding, x.embedding)) |
|
|
for turn in history: |
|
|
if len(history) > similar + recent: |
|
|
break |
|
|
if turn not in history: |
|
|
history.append(turn) |
|
|
history.sort(key=lambda x: x.id) |
|
|
return history |
|
|
|
|
|
def save_history(self, output_dir: str): |
|
|
""" |
|
|
Save the conversation history to a file. |
|
|
""" |
|
|
history_file = pjoin(output_dir, f"{self.name}.jsonl") |
|
|
if pexists(history_file) and len(self.history) == 0: |
|
|
return |
|
|
with jsonlines.open(history_file, "w") as writer: |
|
|
writer.write( |
|
|
{ |
|
|
"input_tokens": self.input_tokens, |
|
|
"output_tokens": self.output_tokens, |
|
|
} |
|
|
) |
|
|
for turn in self.history: |
|
|
writer.write(turn.to_dict()) |
|
|
|
|
|
def retry(self, feedback: str, traceback: str, error_idx: int): |
|
|
""" |
|
|
Retry a failed turn with feedback and traceback. |
|
|
""" |
|
|
assert error_idx > 0, "error_idx must be greater than 0" |
|
|
prompt = self.retry_template.render(feedback=feedback, traceback=traceback) |
|
|
history = [] |
|
|
for turn in self.history[-error_idx:]: |
|
|
history.extend(turn.message) |
|
|
response, message = self.llm( |
|
|
prompt, |
|
|
history=history, |
|
|
return_message=True, |
|
|
) |
|
|
turn = Turn( |
|
|
id=len(self.history), |
|
|
prompt=prompt, |
|
|
response=response, |
|
|
message=message, |
|
|
) |
|
|
return self.__post_process__(response, self.history[-error_idx:], turn) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"Role(name={self.name}, model={self.model})" |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: list[str] = None, |
|
|
recent: int = 0, |
|
|
similar: int = 0, |
|
|
**jinja_args, |
|
|
): |
|
|
""" |
|
|
Call the agent with prompt arguments. |
|
|
|
|
|
Args: |
|
|
images (list[str]): A list of image file paths. |
|
|
recent (int): The number of recent turns to include. |
|
|
similar (int): The number of similar turns to include. |
|
|
**jinja_args: Additional arguments for the Jinja2 template. |
|
|
|
|
|
Returns: |
|
|
The response from the role. |
|
|
""" |
|
|
if isinstance(images, str): |
|
|
images = [images] |
|
|
assert self.prompt_args == set(jinja_args.keys()), "Invalid arguments" |
|
|
prompt = self.template.render(**jinja_args) |
|
|
history = self.get_history(similar, recent, prompt) |
|
|
history_msg = [] |
|
|
for turn in history: |
|
|
history_msg.extend(turn.message) |
|
|
|
|
|
response, message = self.llm( |
|
|
prompt, |
|
|
system_message=self.system_message, |
|
|
history=history_msg, |
|
|
images=images, |
|
|
return_message=True, |
|
|
) |
|
|
turn = Turn( |
|
|
id=len(self.history), |
|
|
prompt=prompt, |
|
|
response=response, |
|
|
message=message, |
|
|
images=images, |
|
|
) |
|
|
return self.__post_process__(response, history, turn, similar) |
|
|
|
|
|
def __post_process__( |
|
|
self, response: str, history: list[Turn], turn: Turn, similar: int = 0 |
|
|
): |
|
|
""" |
|
|
Post-process the response from the agent. |
|
|
""" |
|
|
self.history.append(turn) |
|
|
if similar > 0: |
|
|
turn.embedding = get_text_embedding(turn.prompt, self.text_model) |
|
|
if self.record_cost: |
|
|
turn.calc_token() |
|
|
self.calc_cost(history + [turn]) |
|
|
if self.return_json: |
|
|
response = get_json_from_response(response) |
|
|
return response |
|
|
|
|
|
|
|
|
def get_simple_modelname(llms: list[LLM]): |
|
|
""" |
|
|
Get a abbreviation from a list of LLMs. |
|
|
""" |
|
|
if isinstance(llms, LLM): |
|
|
llms = [llms] |
|
|
return "+".join(re.search(r"^(.*?)-\d{2}", llm.model).group(1) for llm in llms) |
|
|
|
|
|
|
|
|
gpt4o = LLM(model="gpt-4o-2024-08-06", use_batch=True) |
|
|
gpt4omini = LLM(model="gpt-4o-mini-2024-07-18", use_batch=True) |
|
|
qwen2_5 = LLM( |
|
|
model="Qwen2.5-72B-Instruct-GPTQ-Int4", api_base="http://124.16.138.143:7812/v1" |
|
|
) |
|
|
|
|
|
qwen_vl = LLM(model="Qwen2-VL-72B-Instruct", api_base="http://124.16.138.144:7999/v1") |
|
|
qwen_coder = LLM( |
|
|
model="Qwen2.5-Coder-32B-Instruct", api_base="http://127.0.0.1:8008/v1" |
|
|
) |
|
|
intern_vl = LLM(model="InternVL2_5-78B", api_base="http://124.16.138.144:8009/v1") |
|
|
|
|
|
language_model = gpt4o |
|
|
vision_model = gpt4o |
|
|
|
|
|
if __name__ == "__main__": |
|
|
gpt4o = LLM(model="gpt-4o-2024-08-06") |
|
|
print( |
|
|
gpt4o( |
|
|
"who r u", |
|
|
) |
|
|
) |
|
|
|