Spaces:
Runtime error
Runtime error
| import time | |
| from abc import ABC, abstractmethod | |
| from typing import List, Tuple | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| from transformers import LogitsProcessor, LogitsProcessorList | |
| from .singleton import Singleton | |
| def parse_codeblock(text): | |
| lines = text.split("\n") | |
| for i, line in enumerate(lines): | |
| if "```" in line: | |
| if line != "```": | |
| lines[i] = f'<pre><code class="{lines[i][3:]}">' | |
| else: | |
| lines[i] = '</code></pre>' | |
| else: | |
| if i > 0: | |
| lines[i] = "<br/>" + line.replace("<", "<").replace(">", ">") | |
| return "".join(lines) | |
| class BasePredictor(ABC): | |
| def __init__(self, model_name): | |
| self.model = None | |
| self.tokenizer = None | |
| def stream_chat_continue(self, *args, **kwargs): | |
| raise NotImplementedError | |
| def predict_continue(self, query, latest_message, max_length, top_p, | |
| temperature, allow_generate, history, *args, | |
| **kwargs): | |
| if history is None: | |
| history = [] | |
| allow_generate[0] = True | |
| history.append((query, latest_message)) | |
| for response in self.stream_chat_continue( | |
| self.model, | |
| self.tokenizer, | |
| query=query, | |
| history=history, | |
| max_length=max_length, | |
| top_p=top_p, | |
| temperature=temperature): | |
| history[-1] = (history[-1][0], response) | |
| yield history, '', '' | |
| if not allow_generate[0]: | |
| break | |
| class InvalidScoreLogitsProcessor(LogitsProcessor): | |
| def __init__(self, start_pos=20005): | |
| self.start_pos = start_pos | |
| def __call__(self, input_ids: torch.LongTensor, | |
| scores: torch.FloatTensor) -> torch.FloatTensor: | |
| if torch.isnan(scores).any() or torch.isinf(scores).any(): | |
| scores.zero_() | |
| scores[..., self.start_pos] = 5e4 | |
| return scores | |
| class ChatGLM(BasePredictor): | |
| def __init__(self, model_name="THUDM/chatglm-6b-int4"): | |
| print(f'Loading model {model_name}') | |
| start = time.perf_counter() | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| resume_download=True | |
| ) | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| resume_download=True | |
| ).half().to(self.device) | |
| model = model.eval() | |
| self.model = model | |
| self.model_name = model_name | |
| end = time.perf_counter() | |
| print( | |
| f'Successfully loaded model {model_name}, time cost: {end - start:.2f}s' | |
| ) | |
| def generator_image_text(self, text): | |
| response, history = self.model.chat(self.tokenizer, "描述画面:{}".format(text), history=[]) | |
| return response | |
| def stream_chat_continue(self, | |
| model, | |
| tokenizer, | |
| query: str, | |
| history: List[Tuple[str, str]] = None, | |
| max_length: int = 2048, | |
| do_sample=True, | |
| top_p=0.7, | |
| temperature=0.95, | |
| logits_processor=None, | |
| **kwargs): | |
| if history is None: | |
| history = [] | |
| if logits_processor is None: | |
| logits_processor = LogitsProcessorList() | |
| if len(history) > 0: | |
| answer = history[-1][1] | |
| else: | |
| answer = '' | |
| logits_processor.append( | |
| InvalidScoreLogitsProcessor( | |
| start_pos=20005 if 'slim' not in self.model_name else 5)) | |
| gen_kwargs = { | |
| "max_length": max_length, | |
| "do_sample": do_sample, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "logits_processor": logits_processor, | |
| **kwargs | |
| } | |
| if not history: | |
| prompt = query | |
| else: | |
| prompt = "" | |
| for i, (old_query, response) in enumerate(history): | |
| if i != len(history) - 1: | |
| prompt += "[Round {}]\n问:{}\n答:{}\n".format( | |
| i, old_query, response) | |
| else: | |
| prompt += "[Round {}]\n问:{}\n答:".format(i, old_query) | |
| batch_input = tokenizer([prompt], return_tensors="pt", padding=True) | |
| batch_input = batch_input.to(model.device) | |
| batch_answer = tokenizer(answer, return_tensors="pt") | |
| batch_answer = batch_answer.to(model.device) | |
| input_length = len(batch_input['input_ids'][0]) | |
| final_input_ids = torch.cat( | |
| [batch_input['input_ids'], batch_answer['input_ids'][:, :-2]], | |
| dim=-1).cuda() | |
| attention_mask = model.get_masks( | |
| final_input_ids, device=final_input_ids.device) | |
| batch_input['input_ids'] = final_input_ids | |
| batch_input['attention_mask'] = attention_mask | |
| input_ids = final_input_ids | |
| MASK, gMASK = self.model.config.bos_token_id - 4, self.model.config.bos_token_id - 3 | |
| mask_token = MASK if MASK in input_ids else gMASK | |
| mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] | |
| batch_input['position_ids'] = self.model.get_position_ids( | |
| input_ids, mask_positions, device=input_ids.device) | |
| for outputs in model.stream_generate(**batch_input, **gen_kwargs): | |
| outputs = outputs.tolist()[0][input_length:] | |
| response = tokenizer.decode(outputs) | |
| response = model.process_response(response) | |
| yield parse_codeblock(response) | |
| class Models(object): | |
| def __getattr__(self, item): | |
| if item in self.__dict__: | |
| return getattr(self, item) | |
| if item == 'chatglm': | |
| self.chatglm = ChatGLM("THUDM/chatglm-6b-int4") | |
| return getattr(self, item) | |
| models = Models.instance() | |
| def chat2text(text: str) -> str: | |
| return models.chatglm.generator_image_text(text) | |