Spaces:
Runtime error
Runtime error
| import gc | |
| import re | |
| import time | |
| import uuid | |
| from typing import List, Union, Dict, Any, Iterator | |
| import torch | |
| from loguru import logger | |
| from openai.types.chat import ChatCompletionMessageParam | |
| from transformers import PreTrainedTokenizer, PreTrainedModel | |
| from transformers.generation.logits_process import LogitsProcessor | |
| from api.generation.utils import apply_stopping_strings | |
| from api.utils.protocol import Role | |
| class InvalidScoreLogitsProcessor(LogitsProcessor): | |
| def __call__( | |
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| if torch.isnan(scores).any() or torch.isinf(scores).any(): | |
| scores.zero_() | |
| scores[..., 5] = 5e4 | |
| return scores | |
| def process_response(response: str) -> str: | |
| """ | |
| Process the response by stripping leading and trailing whitespace, | |
| replacing the placeholder for training time, and normalizing punctuation. | |
| Args: | |
| response: The input response string. | |
| Returns: | |
| The processed response string. | |
| """ | |
| response = response.strip() | |
| response = response.replace("[[训练时间]]", "2023年") | |
| punkts = [ | |
| [",", ","], | |
| ["!", "!"], | |
| [":", ":"], | |
| [";", ";"], | |
| ["\?", "?"], | |
| ] | |
| for item in punkts: | |
| response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) | |
| response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) | |
| return response | |
| def check_is_chatglm(model) -> bool: | |
| """ | |
| Checks if the given model is a ChatGLM model. | |
| Args: | |
| model: The model to be checked. | |
| Returns: | |
| bool: True if the model is a ChatGLM model, False otherwise. | |
| """ | |
| return "GLMBlock" in getattr(model, "_no_split_modules", []) | |
| def generate_stream_chatglm( | |
| model: PreTrainedModel, | |
| tokenizer: PreTrainedTokenizer, | |
| params: Dict[str, Any], | |
| ) -> Iterator: | |
| """ | |
| Generates text in a streaming manner using the ChatGLM model. | |
| Args: | |
| model: The pre-trained ChatGLM model. | |
| tokenizer: The tokenizer used for tokenizing the input. | |
| params: A dictionary containing the input parameters. | |
| Yields: | |
| A dictionary representing each generated text completion. | |
| """ | |
| inputs = params["inputs"] | |
| model_name = params.get("model", "llm") | |
| temperature = float(params.get("temperature", 1.0)) | |
| repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| max_new_tokens = int(params.get("max_tokens", 256)) | |
| echo = params.get("echo", True) | |
| input_echo_len = len(inputs["input_ids"][0]) | |
| if input_echo_len >= model.config.seq_length: | |
| logger.warning(f"Input length larger than {model.config.seq_length}") | |
| inputs = {k: v[:, -model.config.seq_length:].to(model.device) for k, v in inputs.items()} | |
| gen_kwargs = { | |
| "max_length": min(max_new_tokens + input_echo_len, model.config.seq_length), | |
| "do_sample": temperature > 1e-5, | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty, | |
| "logits_processor": [InvalidScoreLogitsProcessor()], | |
| } | |
| if temperature > 1e-5: | |
| gen_kwargs["temperature"] = temperature | |
| total_len, previous_text = 0, "" | |
| completion_id: str = f"cmpl-{str(uuid.uuid4())}" | |
| created: int = int(time.time()) | |
| for total_ids in model.stream_generate(**inputs, **gen_kwargs): | |
| total_ids = total_ids.tolist()[0] | |
| total_len = len(total_ids) | |
| output_ids = total_ids if echo else total_ids[input_echo_len:] | |
| response = tokenizer.decode(output_ids) | |
| response = process_response(response) | |
| delta_text = response[len(previous_text):] | |
| previous_text = response | |
| yield { | |
| "id": completion_id, | |
| "object": "text_completion", | |
| "created": created, | |
| "model": model_name, | |
| "delta": delta_text, | |
| "text": response, | |
| "logprobs": None, | |
| "finish_reason": None, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| } | |
| # Only last stream result contains finish_reason, we set finish_reason as stop | |
| yield { | |
| "id": completion_id, | |
| "object": "text_completion", | |
| "created": created, | |
| "model": model_name, | |
| "delta": "", | |
| "text": response, | |
| "logprobs": None, | |
| "finish_reason": "stop", | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| } | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def generate_stream_chatglm_v3( | |
| model: PreTrainedModel, | |
| tokenizer: PreTrainedTokenizer, | |
| params: Dict[str, Any], | |
| ) -> Iterator: | |
| """ | |
| Generates text in a streaming manner using the ChatGLM model. | |
| Args: | |
| model: The pre-trained ChatGLM model. | |
| tokenizer: The tokenizer used for tokenizing the input. | |
| params: A dictionary containing the input parameters. | |
| Yields: | |
| A dictionary representing each generated text completion. | |
| """ | |
| inputs = params["inputs"] | |
| model_name = params.get("model", "llm") | |
| temperature = float(params.get("temperature", 1.0)) | |
| repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| max_new_tokens = int(params.get("max_tokens", 256)) | |
| echo = params.get("echo", True) | |
| input_echo_len = len(inputs["input_ids"][0]) | |
| if input_echo_len >= model.config.seq_length: | |
| logger.warning(f"Input length larger than {model.config.seq_length}") | |
| inputs = {k: v[:, -model.config.seq_length:].to(model.device) for k, v in inputs.items()} | |
| eos_token_id = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.get_command("<|user|>"), | |
| ] | |
| gen_kwargs = { | |
| "max_length": min(max_new_tokens + input_echo_len, model.config.seq_length), | |
| "do_sample": temperature > 1e-5, | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty, | |
| "logits_processor": [InvalidScoreLogitsProcessor()], | |
| } | |
| if temperature > 1e-5: | |
| gen_kwargs["temperature"] = temperature | |
| total_len, previous_text = 0, "" | |
| completion_id: str = f"cmpl-{str(uuid.uuid4())}" | |
| created: int = int(time.time()) | |
| for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs): | |
| total_ids = total_ids.tolist()[0] | |
| total_len = len(total_ids) | |
| output_ids = total_ids[:-1] if echo else total_ids[input_echo_len:-1] | |
| response = tokenizer.decode(output_ids) | |
| if response and response[-1] != "�": | |
| response, stop_found = apply_stopping_strings(response, ["<|observation|>"]) | |
| delta_text = response[len(previous_text):] | |
| previous_text = response | |
| yield { | |
| "id": completion_id, | |
| "object": "text_completion", | |
| "created": created, | |
| "model": model_name, | |
| "delta": delta_text, | |
| "text": response, | |
| "logprobs": None, | |
| "finish_reason": "function_call" if stop_found else None, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| } | |
| if stop_found: | |
| break | |
| # Only last stream result contains finish_reason, we set finish_reason as stop | |
| yield { | |
| "id": completion_id, | |
| "object": "text_completion", | |
| "created": created, | |
| "model": model_name, | |
| "delta": "", | |
| "text": response, | |
| "logprobs": None, | |
| "finish_reason": "stop", | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| } | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def process_chatglm_messages( | |
| messages: List[ChatCompletionMessageParam], | |
| functions: Union[dict, List[dict]] = None, | |
| ) -> List[dict]: | |
| """ | |
| Processes a list of chat messages and returns a modified list of messages. | |
| Args: | |
| messages: A list of chat messages to be processed. | |
| functions: Optional. A dictionary or list of dictionaries representing the available tools. | |
| Returns: | |
| A modified list of chat messages. | |
| """ | |
| _messages = messages | |
| messages = [] | |
| if functions: | |
| messages.append( | |
| { | |
| "role": Role.SYSTEM, | |
| "content": "Answer the following questions as best as you can. You have access to the following tools:", | |
| "tools": functions | |
| } | |
| ) | |
| for m in _messages: | |
| role, content = m["role"], m["content"] | |
| if role == Role.FUNCTION: | |
| messages.append({"role": "observation", "content": content}) | |
| elif role == Role.ASSISTANT: | |
| for response in content.split("<|assistant|>"): | |
| if "\n" in response: | |
| metadata, sub_content = response.split("\n", maxsplit=1) | |
| else: | |
| metadata, sub_content = "", response | |
| messages.append({"role": role, "metadata": metadata, "content": sub_content.strip()}) | |
| else: | |
| messages.append({"role": role, "content": content}) | |
| return messages | |