Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import sys | |
| import html | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from threading import Thread | |
| model_name_or_path = 'TencentARC/LLaMA-Pro-8B-Instruct' | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False) | |
| model = AutoModelForCausalLM.from_pretrained(model_name_or_path) | |
| model.half().cuda() | |
| def convert_message(message): | |
| message_text = "" | |
| if message["content"] is None and message["role"] == "assistant": | |
| message_text += "<|assistant|>\n" # final msg | |
| elif message["role"] == "system": | |
| message_text += "<|system|>\n" + message["content"].strip() + "\n" | |
| elif message["role"] == "user": | |
| message_text += "<|user|>\n" + message["content"].strip() + "\n" | |
| elif message["role"] == "assistant": | |
| message_text += "<|assistant|>\n" + message["content"].strip() + "\n" | |
| else: | |
| raise ValueError("Invalid role: {}".format(message["role"])) | |
| # gradio cleaning - it converts stuff to html entities | |
| # we would need special handling for where we want to keep the html... | |
| message_text = html.unescape(message_text) | |
| # it also converts newlines to <br>, undo this. | |
| message_text = message_text.replace("<br>", "\n") | |
| return message_text | |
| def convert_history(chat_history, max_input_length=1024): | |
| history_text = "" | |
| idx = len(chat_history) - 1 | |
| # add messages in reverse order until we hit max_input_length | |
| while len(tokenizer(history_text).input_ids) < max_input_length and idx >= 0: | |
| user_message, chatbot_message = chat_history[idx] | |
| user_message = convert_message({"role": "user", "content": user_message}) | |
| chatbot_message = convert_message({"role": "assistant", "content": chatbot_message}) | |
| history_text = user_message + chatbot_message + history_text | |
| idx = idx - 1 | |
| # if nothing was added, add <|assistant|> to start generation. | |
| if history_text == "": | |
| history_text = "<|assistant|>\n" | |
| return history_text | |
| def instruct(instruction, max_token_output=1024): | |
| input_text = instruction | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
| input_ids = tokenizer(input_text, return_tensors='pt', truncation=False) | |
| input_ids["input_ids"] = input_ids["input_ids"].cuda() | |
| input_ids["attention_mask"] = input_ids["attention_mask"].cuda() | |
| generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=max_token_output, do_sample=False) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| return streamer | |
| with gr.Blocks() as demo: | |
| # chatbot-style model | |
| with gr.Tab("Chatbot"): | |
| chatbot = gr.Chatbot([], elem_id="chatbot") | |
| msg = gr.Textbox() | |
| clear = gr.Button("Clear") | |
| # fn to add user message to history | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| def bot(history): | |
| prompt = convert_history(history) | |
| streaming_out = instruct(prompt) | |
| history[-1][1] = "" | |
| for new_token in streaming_out: | |
| history[-1][1] += new_token | |
| yield history | |
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| bot, chatbot, chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| if __name__ == "__main__": | |
| demo.queue().launch(share=True) |