Spaces:
Runtime error
Runtime error
| # chat helper | |
| class ChatState: | |
| def __init__(self, model, system="", chat_template="auto"): | |
| chat_template = ( | |
| type(model).__name__ if chat_template == "auto" else chat_template | |
| ) | |
| if chat_template == "Llama3CausalLM": | |
| self.__START_TURN_SYSTEM__ = ( | |
| "<|start_header_id|>system<|end_header_id|>\n\n" | |
| ) | |
| self.__START_TURN_USER__ = ( | |
| "<|start_header_id|>user<|end_header_id|>\n\n" | |
| ) | |
| self.__START_TURN_MODEL__ = ( | |
| "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
| ) | |
| self.__END_TURN_SYSTEM__ = "<|eot_id|>" | |
| self.__END_TURN_USER__ = "<|eot_id|>" | |
| self.__END_TURN_MODEL__ = "<|eot_id|>" | |
| print("Using chat template for: Llama") | |
| elif chat_template == "GemmaCausalLM": | |
| self.__START_TURN_SYSTEM__ = "" | |
| self.__START_TURN_USER__ = "<start_of_turn>user\n" | |
| self.__START_TURN_MODEL__ = "<start_of_turn>model\n" | |
| self.__END_TURN_SYSTEM__ = "\n" | |
| self.__END_TURN_USER__ = "<end_of_turn>\n" | |
| self.__END_TURN_MODEL__ = "<end_of_turn>\n" | |
| print("Using chat template for: Gemma") | |
| elif chat_template == "MistralCausalLM": | |
| self.__START_TURN_SYSTEM__ = "" | |
| self.__START_TURN_USER__ = "[INST]" | |
| self.__START_TURN_MODEL__ = "" | |
| self.__END_TURN_SYSTEM__ = "<s>" | |
| self.__END_TURN_USER__ = "[/INST]" | |
| self.__END_TURN_MODEL__ = "</s>" | |
| print("Using chat template for: Mistral") | |
| elif chat_template == "Vicuna": | |
| self.__START_TURN_SYSTEM__ = "" | |
| self.__START_TURN_USER__ = "USER: " | |
| self.__START_TURN_MODEL__ = "ASSISTANT: " | |
| self.__END_TURN_SYSTEM__ = "\n\n" | |
| self.__END_TURN_USER__ = "\n" | |
| self.__END_TURN_MODEL__ = "</s>\n" | |
| print("Using chat template for : Vicuna") | |
| else: | |
| assert (0, "Unknown turn tags for this model class") | |
| self.model = model | |
| self.system = system | |
| self.history = [] | |
| def add_to_history_as_user(self, message): | |
| self.history.append( | |
| self.__START_TURN_USER__ + message + self.__END_TURN_USER__ | |
| ) | |
| def add_to_history_as_model(self, message): | |
| self.history.append( | |
| self.__START_TURN_MODEL__ + message + self.__END_TURN_MODEL__ | |
| ) | |
| def get_history(self): | |
| return "".join([*self.history]) | |
| def get_full_prompt(self): | |
| prompt = self.get_history() + self.__START_TURN_MODEL__ | |
| if len(self.system) > 0: | |
| prompt = ( | |
| self.__START_TURN_SYSTEM__ | |
| + self.system | |
| + self.__END_TURN_SYSTEM__ | |
| + prompt | |
| ) | |
| return prompt | |
| def send_message(self, message): | |
| """ | |
| Handles sending a user message and getting a model response. | |
| Args: | |
| message: The user's message. | |
| Returns: | |
| The model's response. | |
| """ | |
| self.add_to_history_as_user(message) | |
| prompt = self.get_full_prompt() | |
| response = self.model.generate( | |
| prompt, max_length=2048, strip_prompt=True | |
| ) | |
| self.add_to_history_as_model(response) | |
| return (message, response) | |