saneowl's picture
Create app.py
84e4eb6 verified
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, logging as hf_logging
from threading import Thread
import gradio as gr
from huggingface_hub import login
# --- Hugging Face authentication ---
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("Please set the HF_TOKEN environment variable.")
login(token=HF_TOKEN)
hf_logging.set_verbosity_error() # suppress warnings
# --- Model ID ---
model_id = "motionlabs/NEWT-1.7B-QWEN-PREVIEW"
# --- Logs helper ---
log_messages = []
def log(msg):
log_messages.append(msg)
print(msg)
return "\n".join(log_messages)
log("Initializing tokenizer and model…")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
log("Tokenizer loaded.")
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=HF_TOKEN
)
log("Model loaded.")
# --- Chat streaming ---
def stream_chat(history, message):
messages = []
for user, bot in history:
messages.append({"role": "user", "content": user})
if bot:
messages.append({"role": "assistant", "content": bot})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.9,
temperature=0.7,
)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
output_text = ""
for token in streamer:
output_text += token
yield history + [(message, output_text)]
# --- Gradio UI ---
with gr.Blocks(title=f"Chat with {model_id}") as demo:
gr.Markdown(f"# Chat with {model_id}")
chatbot = gr.Chatbot()
msg = gr.Textbox(placeholder="Type your message here…")
clear = gr.Button("Clear")
logs = gr.Textbox(label="Logs", value="\n".join(log_messages), interactive=False)
def user_submit(user_message, history):
return "", history + [(user_message, None)]
msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
stream_chat, [chatbot, msg], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch()