Spaces:
Runtime error
Runtime error
| import json | |
| import re | |
| from copy import deepcopy | |
| from typing import List, Union, Optional, Dict, Any, Tuple | |
| from fastapi import HTTPException | |
| from loguru import logger | |
| from openai.types.chat import ( | |
| ChatCompletionMessageParam, | |
| ChatCompletionUserMessageParam, | |
| ChatCompletionAssistantMessageParam, | |
| ) | |
| from transformers import PreTrainedTokenizer | |
| from api.generation.utils import parse_messages | |
| from api.utils.protocol import Role | |
| TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" | |
| REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs: | |
| {tools_text} | |
| Use the following format: | |
| Question: the input question you must answer | |
| Thought: you should always think about what to do | |
| Action: the action to take, should be one of [{tools_name_text}] | |
| Action Input: the input to the action | |
| Observation: the result of the action | |
| (this Thought/Action/Action Input/Observation can be repeated zero or more times) | |
| Thought: I now know the final answer | |
| Final Answer: the final answer to the original input question | |
| Begin!""" | |
| _TEXT_COMPLETION_CMD = object() | |
| def build_qwen_chat_input( | |
| tokenizer: PreTrainedTokenizer, | |
| messages: List[ChatCompletionMessageParam], | |
| context_len: int = 8192, | |
| max_new_tokens: int = 256, | |
| functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, | |
| tools: Optional[List[Dict[str, Any]]] = None, | |
| ) -> List[int]: | |
| """ | |
| Builds the input tokens for Qwen chat generation. | |
| Refs: | |
| https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py | |
| Args: | |
| tokenizer: The tokenizer used to encode the input tokens. | |
| messages: The list of chat messages. | |
| context_len: The maximum length of the context. | |
| max_new_tokens: The maximum number of new tokens to add. | |
| functions: Optional dictionary or list of dictionaries representing the functions. | |
| tools: Optional list of dictionaries representing the tools. | |
| Returns: | |
| The list of input tokens. | |
| """ | |
| query, history = process_qwen_messages(messages, functions, tools) | |
| if query is _TEXT_COMPLETION_CMD: | |
| return build_last_message_input(tokenizer, history) | |
| messages = [] | |
| for q, r in history: | |
| messages.extend( | |
| [ | |
| ChatCompletionUserMessageParam(role="user", content=q), | |
| ChatCompletionAssistantMessageParam(role="assistant", content=r) | |
| ] | |
| ) | |
| messages.append(ChatCompletionUserMessageParam(role="user", content=query)) | |
| max_input_tokens = context_len - max_new_tokens | |
| system, rounds = parse_messages(messages) | |
| system = f"You are a helpful assistant.{system}" | |
| im_start_tokens, im_end_tokens = [tokenizer.im_start_id], [tokenizer.im_end_id] | |
| nl_tokens = tokenizer.encode("\n") | |
| def _tokenize_str(role, content): | |
| return tokenizer.encode( | |
| role, allowed_special=set() | |
| ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) | |
| system_tokens_part = _tokenize_str("system", system) | |
| system_tokens = im_start_tokens + system_tokens_part + im_end_tokens | |
| max_history_tokens = max_input_tokens - len(system_tokens) | |
| history_tokens = [] | |
| for r in rounds[::-1]: | |
| round_tokens = [] | |
| for message in r: | |
| if round_tokens: | |
| round_tokens += nl_tokens | |
| if message["role"] == Role.USER: | |
| content_tokens = im_start_tokens + _tokenize_str("user", message["content"]) + im_end_tokens | |
| else: | |
| content_tokens = im_start_tokens + _tokenize_str("assistant", message["content"]) + im_end_tokens | |
| round_tokens.extend(content_tokens) | |
| if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: | |
| if history_tokens: | |
| history_tokens = nl_tokens + history_tokens | |
| history_tokens = round_tokens + history_tokens # concat left | |
| if len(history_tokens) < max_history_tokens: | |
| continue | |
| break | |
| input_tokens = system_tokens + nl_tokens + history_tokens | |
| if messages[-1]["role"] != Role.ASSISTANT: | |
| input_tokens += nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens | |
| return input_tokens[-max_input_tokens:] # truncate left | |
| def check_is_qwen(model) -> bool: | |
| """ | |
| Checks if the given model is a Qwen model. | |
| Args: | |
| model: The model to be checked. | |
| Returns: | |
| bool: True if the model is a Qwen model, False otherwise. | |
| """ | |
| return "QWenBlock" in getattr(model, "_no_split_modules", []) | |
| def process_qwen_messages( | |
| messages: List[ChatCompletionMessageParam], | |
| functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, | |
| tools: Optional[List[Dict[str, Any]]] = None, | |
| ) -> Tuple[str, List[List[str]]]: | |
| """ | |
| Process the Qwen messages and generate a query and history. | |
| Args: | |
| messages (List[ChatCompletionMessageParam]): The list of chat completion messages. | |
| functions (Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]): The functions to be used. | |
| tools (Optional[List[Dict[str, Any]]]): The tools to be used. | |
| Returns: | |
| Tuple[str, List[List[str]]]: The generated query and history. | |
| """ | |
| if all(m["role"] != Role.USER for m in messages): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid request: Expecting at least one user message.", | |
| ) | |
| messages = deepcopy(messages) | |
| default_system = "You are a helpful assistant." | |
| system = "" | |
| if messages[0]["role"] == Role.SYSTEM: | |
| system = messages.pop(0)["content"].lstrip("\n").rstrip() | |
| if system == default_system: | |
| system = "" | |
| if tools: | |
| functions = [t["function"] for t in tools] | |
| if functions: | |
| tools_text = [] | |
| tools_name_text = [] | |
| for func_info in functions: | |
| name = func_info.get("name", "") | |
| name_m = func_info.get("name_for_model", name) | |
| name_h = func_info.get("name_for_human", name) | |
| desc = func_info.get("description", "") | |
| desc_m = func_info.get("description_for_model", desc) | |
| tool = TOOL_DESC.format( | |
| name_for_model=name_m, | |
| name_for_human=name_h, | |
| # Hint: You can add the following format requirements in description: | |
| # "Format the arguments as a JSON object." | |
| # "Enclose the code within triple backticks (`) at the beginning and end of the code." | |
| description_for_model=desc_m, | |
| parameters=json.dumps(func_info["parameters"], ensure_ascii=False), | |
| ) | |
| tools_text.append(tool) | |
| tools_name_text.append(name_m) | |
| tools_text = "\n\n".join(tools_text) | |
| tools_name_text = ", ".join(tools_name_text) | |
| system += "\n\n" + REACT_INSTRUCTION.format( | |
| tools_text=tools_text, | |
| tools_name_text=tools_name_text, | |
| ) | |
| system = system.lstrip("\n").rstrip() | |
| dummy_thought = { | |
| "en": "\nThought: I now know the final answer.\nFinal answer: ", | |
| "zh": "\nThought: 我会作答了。\nFinal answer: ", | |
| } | |
| _messages = messages | |
| messages = [] | |
| for m_idx, m in enumerate(_messages): | |
| role, content = m["role"], m["content"] | |
| func_call, tool_calls = m.get("function_call", None), m.get("tool_calls", None) | |
| if content: | |
| content = content.lstrip("\n").rstrip() | |
| if role in [Role.FUNCTION, Role.TOOL]: | |
| if (len(messages) == 0) or (messages[-1]["role"] != Role.ASSISTANT): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid request: Expecting role assistant before role function.", | |
| ) | |
| messages[-1]["content"] += f"\nObservation: {content}" | |
| if m_idx == len(_messages) - 1: | |
| messages[-1]["content"] += "\nThought:" | |
| elif role == Role.ASSISTANT: | |
| if len(messages) == 0: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid request: Expecting role user before role assistant.", | |
| ) | |
| last_msg = messages[-1]["content"] | |
| last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 | |
| if func_call is None and tool_calls is None: | |
| if functions or tool_calls: | |
| content = dummy_thought["zh" if last_msg_has_zh else "en"] + content | |
| else: | |
| if func_call: | |
| f_name, f_args = func_call.get("name"), func_call.get("arguments") | |
| else: | |
| f_name, f_args = tool_calls[0]["function"]["name"], tool_calls[0]["function"]["arguments"] | |
| if not content: | |
| if last_msg_has_zh: | |
| content = f"Thought: 我可以使用 {f_name} API。" | |
| else: | |
| content = f"Thought: I can use {f_name}." | |
| if messages[-1]["role"] == Role.USER: | |
| messages.append( | |
| ChatCompletionAssistantMessageParam(role="assistant", content=content.lstrip("\n").rstrip()) | |
| ) | |
| else: | |
| messages[-1]["content"] += content | |
| elif role == Role.USER: | |
| messages.append( | |
| ChatCompletionUserMessageParam(role="user", content=content.lstrip("\n").rstrip()) | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=400, detail=f"Invalid request: Incorrect role {role}." | |
| ) | |
| query = _TEXT_COMPLETION_CMD | |
| if messages[-1]["role"] == Role.USER: | |
| query = messages[-1]["content"] | |
| messages = messages[:-1] | |
| if len(messages) % 2 != 0: | |
| raise HTTPException(status_code=400, detail="Invalid request") | |
| history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] | |
| for i in range(0, len(messages), 2): | |
| if messages[i]["role"] == Role.USER and messages[i + 1]["role"] == Role.ASSISTANT: | |
| usr_msg = messages[i]["content"].lstrip("\n").rstrip() | |
| bot_msg = messages[i + 1]["content"].lstrip("\n").rstrip() | |
| if system and (i == len(messages) - 2): | |
| usr_msg = f"{system}\n\nQuestion: {usr_msg}" | |
| system = "" | |
| for t in dummy_thought.values(): | |
| t = t.lstrip("\n") | |
| if bot_msg.startswith(t) and ("\nAction: " in bot_msg): | |
| bot_msg = bot_msg[len(t):] | |
| history.append([usr_msg, bot_msg]) | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", | |
| ) | |
| if system: | |
| assert query is not _TEXT_COMPLETION_CMD | |
| query = f"{system}\n\nQuestion: {query}" | |
| return query, history | |
| def build_last_message_input(tokenizer: PreTrainedTokenizer, history: list): | |
| im_start = "<|im_start|>" | |
| im_end = "<|im_end|>" | |
| prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" | |
| for i, (query, response) in enumerate(history): | |
| query = query.lstrip("\n").rstrip() | |
| response = response.lstrip("\n").rstrip() | |
| prompt += f"\n{im_start}user\n{query}{im_end}" | |
| prompt += f"\n{im_start}assistant\n{response}{im_end}" | |
| prompt = prompt[:-len(im_end)] | |
| logger.debug(f"==== Prompt with tools ====\n{prompt}") | |
| return tokenizer.encode(prompt) | |