|
|
import os |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, logging as hf_logging |
|
|
from threading import Thread |
|
|
import gradio as gr |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
if HF_TOKEN is None: |
|
|
raise ValueError("Please set the HF_TOKEN environment variable.") |
|
|
login(token=HF_TOKEN) |
|
|
|
|
|
hf_logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
model_id = "motionlabs/NEWT-1.7B-QWEN-PREVIEW" |
|
|
|
|
|
|
|
|
log_messages = [] |
|
|
|
|
|
def log(msg): |
|
|
log_messages.append(msg) |
|
|
print(msg) |
|
|
return "\n".join(log_messages) |
|
|
|
|
|
log("Initializing tokenizer and model…") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN) |
|
|
log("Tokenizer loaded.") |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
use_auth_token=HF_TOKEN |
|
|
) |
|
|
log("Model loaded.") |
|
|
|
|
|
|
|
|
def stream_chat(history, message): |
|
|
messages = [] |
|
|
for user, bot in history: |
|
|
messages.append({"role": "user", "content": user}) |
|
|
if bot: |
|
|
messages.append({"role": "assistant", "content": bot}) |
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
**inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=1024, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
temperature=0.7, |
|
|
) |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=gen_kwargs) |
|
|
thread.start() |
|
|
|
|
|
output_text = "" |
|
|
for token in streamer: |
|
|
output_text += token |
|
|
yield history + [(message, output_text)] |
|
|
|
|
|
|
|
|
with gr.Blocks(title=f"Chat with {model_id}") as demo: |
|
|
gr.Markdown(f"# Chat with {model_id}") |
|
|
|
|
|
chatbot = gr.Chatbot() |
|
|
msg = gr.Textbox(placeholder="Type your message here…") |
|
|
clear = gr.Button("Clear") |
|
|
logs = gr.Textbox(label="Logs", value="\n".join(log_messages), interactive=False) |
|
|
|
|
|
def user_submit(user_message, history): |
|
|
return "", history + [(user_message, None)] |
|
|
|
|
|
msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then( |
|
|
stream_chat, [chatbot, msg], chatbot |
|
|
) |
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
demo.queue() |
|
|
demo.launch() |
|
|
|