| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class ChatState: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, model, system="", chat_template="auto"): | 
					
					
						
						| 
							 | 
						        chat_template = ( | 
					
					
						
						| 
							 | 
						            type(model).__name__ if chat_template == "auto" else chat_template | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if chat_template == "Llama3CausalLM": | 
					
					
						
						| 
							 | 
						            self.__START_TURN_SYSTEM__ = ( | 
					
					
						
						| 
							 | 
						                "<|start_header_id|>system<|end_header_id|>\n\n" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            self.__START_TURN_USER__ = ( | 
					
					
						
						| 
							 | 
						                "<|start_header_id|>user<|end_header_id|>\n\n" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            self.__START_TURN_MODEL__ = ( | 
					
					
						
						| 
							 | 
						                "<|start_header_id|>assistant<|end_header_id|>\n\n" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            self.__END_TURN_SYSTEM__ = "<|eot_id|>" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_USER__ = "<|eot_id|>" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_MODEL__ = "<|eot_id|>" | 
					
					
						
						| 
							 | 
						            print("Using chat template for: Llama") | 
					
					
						
						| 
							 | 
						        elif chat_template == "GemmaCausalLM": | 
					
					
						
						| 
							 | 
						            self.__START_TURN_SYSTEM__ = "" | 
					
					
						
						| 
							 | 
						            self.__START_TURN_USER__ = "<start_of_turn>user\n" | 
					
					
						
						| 
							 | 
						            self.__START_TURN_MODEL__ = "<start_of_turn>model\n" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_SYSTEM__ = "\n" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_USER__ = "<end_of_turn>\n" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_MODEL__ = "<end_of_turn>\n" | 
					
					
						
						| 
							 | 
						            print("Using chat template for: Gemma") | 
					
					
						
						| 
							 | 
						        elif chat_template == "MistralCausalLM": | 
					
					
						
						| 
							 | 
						            self.__START_TURN_SYSTEM__ = "" | 
					
					
						
						| 
							 | 
						            self.__START_TURN_USER__ = "[INST]" | 
					
					
						
						| 
							 | 
						            self.__START_TURN_MODEL__ = "" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_SYSTEM__ = "<s>" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_USER__ = "[/INST]" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_MODEL__ = "</s>" | 
					
					
						
						| 
							 | 
						            print("Using chat template for: Mistral") | 
					
					
						
						| 
							 | 
						        elif chat_template == "Vicuna": | 
					
					
						
						| 
							 | 
						            self.__START_TURN_SYSTEM__ = "" | 
					
					
						
						| 
							 | 
						            self.__START_TURN_USER__ = "USER: " | 
					
					
						
						| 
							 | 
						            self.__START_TURN_MODEL__ = "ASSISTANT: " | 
					
					
						
						| 
							 | 
						            self.__END_TURN_SYSTEM__ = "\n\n" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_USER__ = "\n" | 
					
					
						
						| 
							 | 
						            self.__END_TURN_MODEL__ = "</s>\n" | 
					
					
						
						| 
							 | 
						            print("Using chat template for : Vicuna") | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            assert (0, "Unknown turn tags for this model class") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.model = model | 
					
					
						
						| 
							 | 
						        self.system = system | 
					
					
						
						| 
							 | 
						        self.history = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def add_to_history_as_user(self, message): | 
					
					
						
						| 
							 | 
						        self.history.append( | 
					
					
						
						| 
							 | 
						            self.__START_TURN_USER__ + message + self.__END_TURN_USER__ | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def add_to_history_as_model(self, message): | 
					
					
						
						| 
							 | 
						        self.history.append( | 
					
					
						
						| 
							 | 
						            self.__START_TURN_MODEL__ + message + self.__END_TURN_MODEL__ | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_history(self): | 
					
					
						
						| 
							 | 
						        return "".join([*self.history]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_full_prompt(self): | 
					
					
						
						| 
							 | 
						        prompt = self.get_history() + self.__START_TURN_MODEL__ | 
					
					
						
						| 
							 | 
						        if len(self.system) > 0: | 
					
					
						
						| 
							 | 
						            prompt = ( | 
					
					
						
						| 
							 | 
						                self.__START_TURN_SYSTEM__ | 
					
					
						
						| 
							 | 
						                + self.system | 
					
					
						
						| 
							 | 
						                + self.__END_TURN_SYSTEM__ | 
					
					
						
						| 
							 | 
						                + prompt | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        return prompt | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def send_message(self, message): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Handles sending a user message and getting a model response. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            message: The user's message. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            The model's response. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.add_to_history_as_user(message) | 
					
					
						
						| 
							 | 
						        prompt = self.get_full_prompt() | 
					
					
						
						| 
							 | 
						        response = self.model.generate( | 
					
					
						
						| 
							 | 
						            prompt, max_length=2048, strip_prompt=True | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.add_to_history_as_model(response) | 
					
					
						
						| 
							 | 
						        return (message, response) | 
					
					
						
						| 
							 | 
						
 |