Spaces:
Build error
Build error
| from threading import Thread | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import subprocess | |
| #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| BANNER_HTML = """ | |
| <p align="center"> | |
| <a href="https://github.com/ymcui/Chinese-LLaMA-Alpaca-3"> | |
| <img src="https://ymcui.com/images/chinese-llama-alpaca-3-banner.png" width="600"/> | |
| </a> | |
| </p> | |
| <h3> | |
| <center>Check our <a href='https://github.com/ymcui/Chinese-LLaMA-Alpaca-3' target='_blank'>Chinese-LLaMA-Alpaca-3 GitHub Project</a> for more information. | |
| </center> | |
| </h3> | |
| <p> | |
| <center><em>The demo is mainly for academic purposes. Illegal usages are prohibited. Default model: <a href="https://huggingface.co/hfl/llama-3-chinese-8b-instruct-v3">hfl/llama-3-chinese-8b-instruct-v3</a></em></center> | |
| </p> | |
| """ | |
| DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant. 你是一个乐于助人的助手。" | |
| # Load different instruct models based on the selected version | |
| def load_model(version): | |
| global tokenizer, model | |
| if version == "v1": | |
| model_name = "hfl/llama-3-chinese-8b-instruct" | |
| elif version == "v2": | |
| model_name = "hfl/llama-3-chinese-8b-instruct-v2" | |
| elif version == "v3": | |
| model_name = "hfl/llama-3-chinese-8b-instruct-v3" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2") | |
| return f"Model {model_name} loaded." | |
| def stream_chat(message: str, history: list, system_prompt: str, model_version: str, temperature: float, max_new_tokens: int): | |
| conversation = [{"role": "system", "content": system_prompt or DEFAULT_SYSTEM_PROMPT}] | |
| for prompt, answer in history: | |
| conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) | |
| conversation.append({"role": "user", "content": message}) | |
| input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] | |
| generate_kwargs = { | |
| "input_ids": input_ids, | |
| "streamer": streamer, | |
| "eos_token_id": terminators, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_k": 40, | |
| "top_p": 0.9, | |
| "num_beams": 1, | |
| "repetition_penalty": 1.1, | |
| "do_sample": temperature != 0, | |
| } | |
| generation_thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
| generation_thread.start() | |
| output = "" | |
| for new_token in streamer: | |
| output += new_token | |
| yield output | |
| chatbot = gr.Chatbot(height=500) | |
| with gr.Blocks() as demo: | |
| gr.HTML(BANNER_HTML) | |
| gr.ChatInterface( | |
| fn=stream_chat, | |
| chatbot=chatbot, | |
| fill_height=True, | |
| additional_inputs_accordion=gr.Accordion(label="Parameters / 参数设置", open=False, render=False), | |
| additional_inputs=[ | |
| gr.Text(value=DEFAULT_SYSTEM_PROMPT, label="System Prompt / 系统提示词", render=False), | |
| gr.Radio(choices=["v1", "v2", "v3"], label="Model Version / 模型版本", value="v3", interactive=False, render=False), | |
| gr.Slider(minimum=0, maximum=1.5, step=0.1, value=0.6, label="Temperature / 温度系数", render=False), | |
| gr.Slider(minimum=128, maximum=2048, step=1, value=512, label="Max new tokens / 最大生成长度", render=False), | |
| ], | |
| cache_examples=False, | |
| submit_btn="Send / 发送", | |
| stop_btn="Stop / 停止", | |
| retry_btn="🔄 Retry / 重试", | |
| undo_btn="↩️ Undo / 撤销", | |
| clear_btn="🗑️ Clear / 清空", | |
| ) | |
| if __name__ == "__main__": | |
| load_model("v3") # Load the default model | |
| demo.launch() |