from fix_int8 import fix_pytorch_int8 fix_pytorch_int8() # import subprocess # result = subprocess.run(['git', 'clone', 'https://huggingface.co/KumaTea/twitter-int8', 'model'], capture_output=True, text=True) # print(result.stdout) # Credit: # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py import torch import logging import gradio as gr from transformers import AutoTokenizer, GenerationConfig, AutoModel gr_title = """
采用 INT4 量化,速度很慢,仅作备用
GitHub Repo: KumaTea/ChatGLM
""" gr_footer = """本项目基于 ljsabc/Fujisaki ,模型采用 THUDM/chatglm-6b 。
每天起床第一句!
""" default_start = ["你是谁?", "我是 kuma"] # device = torch.device('cpu') # torch.cuda.current_device = lambda : device logging.basicConfig( format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%m/%d %H:%M:%S') model = AutoModel.from_pretrained( "KumaTea/twitter-int4", trust_remote_code=True, revision="e2aecb2" ).float() # .to(device) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="4de8efe") # dump a log to ensure everything works well # print(model.peft_config) # We have to use full precision, as some tokens are >65535 model.eval() # print(model) torch.set_default_tensor_type(torch.FloatTensor) def evaluate(context, temperature, top_p, top_k=None): generation_config = GenerationConfig( temperature=temperature, top_p=top_p, # top_k=top_k, #repetition_penalty=1.1, num_beams=1, do_sample=True, ) with torch.no_grad(): # input_text = f"Context: {context}Answer: " input_text = '||'.join(default_start) + '||' input_text += context + '||' logging.info('[API] Incoming request: ' + input_text) ids = tokenizer([input_text], return_tensors="pt") inputs = ids.to("cpu") out = model.generate( **inputs, max_length=224, generation_config=generation_config ) out = out.tolist()[0] decoder_output = tokenizer.decode(out) # out_text = decoder_output.split("Answer: ")[1] out_text = decoder_output logging.info('[API] Result: ' + out_text) return out_text def evaluate_stream(msg, history, temperature, top_p): generation_config = GenerationConfig( temperature=temperature, top_p=top_p, #repetition_penalty=1.1, num_beams=1, do_sample=True, ) if not msg: msg = '……' history.append([msg, ""]) context = '||'.join(default_start) + '||' if len(history) > 4: history.pop(0) for j in range(len(history)): history[j][0] = history[j][0].replace("