Spaces:
Runtime error
Runtime error
| import os | |
| import gc | |
| from string import Template | |
| from threading import Thread | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "PY007/LiteChat-Preview", | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "PY007/LiteChat-Preview", | |
| trust_remote_code=True, | |
| device_map="auto", | |
| ) | |
| model.eval() | |
| max_context_length = model.config.max_position_embeddings | |
| max_new_tokens = 1024 | |
| prompt_template = Template("""\ | |
| ### Instruction: $human | |
| ### Response: $bot\ | |
| """) | |
| system_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
| system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt") | |
| max_sys_tokens = system_prompt_tokens['input_ids'].size(-1) | |
| def bot(history): | |
| history = history or [] | |
| # Inject prompt formatting into the history | |
| prompt_history = [] | |
| for human, bot in history: | |
| if bot is not None: | |
| bot = bot.replace("<br>", "\n") | |
| bot = bot.rstrip() | |
| prompt_history.append( | |
| prompt_template.substitute( | |
| human=human, bot=bot if bot is not None else "") | |
| ) | |
| msg_tokens = tokenizer( | |
| "\n\n".join(prompt_history).strip(), | |
| return_tensors="pt", | |
| add_special_tokens=False # Use <BOS> from the system prompt | |
| ) | |
| # Take only the most recent context up to the max context length and prepend the | |
| # system prompt with the messages | |
| max_tokens = -max_context_length + max_new_tokens + max_sys_tokens | |
| # inputs = BatchEncoding({ | |
| # k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1) | |
| # for k in msg_tokens | |
| # }).to('cuda') | |
| inputs = BatchEncoding({ | |
| k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1) | |
| for k in msg_tokens | |
| }) | |
| # Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models | |
| if inputs.get("token_type_ids", None) is not None: | |
| inputs.pop("token_type_ids") | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50, | |
| temperature=0.7, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| partial_text = "" | |
| for new_text in streamer: | |
| # Process out the prompt separator | |
| new_text = new_text.replace("<br>", "\n") | |
| if "###" in new_text: | |
| new_text = new_text.split("###")[0] | |
| partial_text += new_text.strip() | |
| history[-1][1] = partial_text | |
| break | |
| else: | |
| # Filter empty trailing new lines | |
| if new_text == "\n": | |
| new_text = new_text.strip() | |
| partial_text += new_text | |
| history[-1][1] = partial_text | |
| yield history | |
| return partial_text | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# LiteChat by StatNLP") | |
| gr.Markdown("The model is currently running on free-tier CPU, which has limited speed.") | |
| gr.Markdown("Paper and code will be released soon.") | |
| chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500) | |
| state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| msg = gr.Textbox( | |
| label="Send a message", | |
| placeholder="Send a message", | |
| show_label=False | |
| ).style(container=False) | |
| with gr.Column(): | |
| with gr.Row(): | |
| submit = gr.Button("Send") | |
| stop = gr.Button("Stop") | |
| clear = gr.Button("Clear History") | |
| submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | |
| fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True) | |
| submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | |
| fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True) | |
| stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False) | |
| clear.click(lambda: None, None, [chatbot], queue=True) | |
| demo.queue(max_size=32) | |
| demo.launch() |