Spaces:
Runtime error
Runtime error
| import os | |
| from threading import Thread | |
| import platform | |
| from typing import Union | |
| import torch | |
| from transformers import TextIteratorStreamer,PreTrainedTokenizerFast | |
| from safetensors.torch import load_model | |
| from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
| # import 自定义类和函数 | |
| from model.chat_model import TextToTextModel | |
| from utils.functions import get_T5_config | |
| from config import InferConfig, T5ModelConfig | |
| class ChatBot: | |
| def __init__(self, infer_config: InferConfig) -> None: | |
| ''' | |
| ''' | |
| self.infer_config = infer_config | |
| # 初始化tokenizer | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(infer_config.model_dir) | |
| self.tokenizer = tokenizer | |
| self.encode = tokenizer.encode_plus | |
| self.batch_decode = tokenizer.batch_decode | |
| self.batch_encode_plus = tokenizer.batch_encode_plus | |
| t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) | |
| try: | |
| model = TextToTextModel(t5_config) | |
| if os.path.isdir(infer_config.model_dir): | |
| # from_pretrained | |
| model = model.from_pretrained(infer_config.model_dir) | |
| elif infer_config.model_dir.endswith('.safetensors'): | |
| # load safetensors | |
| load_model(model, infer_config.model_dir) | |
| else: | |
| # load torch checkpoint | |
| model.load_state_dict(torch.load(infer_config.model_dir)) | |
| self.model = model | |
| except Exception as e: | |
| print(str(e), 'transformers and pytorch load fail, try accelerate load function.') | |
| empty_model = None | |
| with init_empty_weights(): | |
| empty_model = TextToTextModel(t5_config) | |
| self.model = load_checkpoint_and_dispatch( | |
| model=empty_model, | |
| checkpoint=infer_config.model_dir, | |
| device_map='auto', | |
| dtype=torch.float16, | |
| ) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model.to(self.device) | |
| self.streamer = TextIteratorStreamer(tokenizer=tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
| def stream_chat(self, input_txt: str) -> TextIteratorStreamer: | |
| ''' | |
| 流式对话,线程启动后可返回,通过迭代streamer获取生成的文字,仅支持greedy search | |
| ''' | |
| encoded = self.encode(input_txt + '[EOS]') | |
| input_ids = torch.LongTensor([encoded.input_ids]).to(self.device) | |
| attention_mask = torch.LongTensor([encoded.attention_mask]).to(self.device) | |
| generation_kwargs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'max_seq_len': self.infer_config.max_seq_len, | |
| 'streamer': self.streamer, | |
| 'search_type': 'greedy', | |
| } | |
| thread = Thread(target=self.model.my_generate, kwargs=generation_kwargs) | |
| thread.start() | |
| return self.streamer | |
| def chat(self, input_txt: Union[str, list[str]] ) -> Union[str, list[str]]: | |
| ''' | |
| 非流式生成,可以使用beam search、beam sample等方法生成文本。 | |
| ''' | |
| if isinstance(input_txt, str): | |
| input_txt = [input_txt] | |
| elif not isinstance(input_txt, list): | |
| raise Exception('input_txt mast be a str or list[str]') | |
| # add EOS token | |
| input_txts = [f"{txt}[EOS]" for txt in input_txt] | |
| encoded = self.batch_encode_plus(input_txts, padding=True) | |
| input_ids = torch.LongTensor(encoded.input_ids).to(self.device) | |
| attention_mask = torch.LongTensor(encoded.attention_mask).to(self.device) | |
| outputs = self.model.my_generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_seq_len=self.infer_config.max_seq_len, | |
| search_type='greedy', | |
| ) | |
| outputs = self.batch_decode(outputs.cpu().numpy(), clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
| note = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋" | |
| outputs = [item if len(item) != 0 else note for item in outputs] | |
| return outputs[0] if len(outputs) == 1 else outputs | |