saneowl commited on
Commit
84e4eb6
·
verified ·
1 Parent(s): 787bf13

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, logging as hf_logging
4
+ from threading import Thread
5
+ import gradio as gr
6
+ from huggingface_hub import login
7
+
8
+ # --- Hugging Face authentication ---
9
+ HF_TOKEN = os.environ.get("HF_TOKEN")
10
+ if HF_TOKEN is None:
11
+ raise ValueError("Please set the HF_TOKEN environment variable.")
12
+ login(token=HF_TOKEN)
13
+
14
+ hf_logging.set_verbosity_error() # suppress warnings
15
+
16
+ # --- Model ID ---
17
+ model_id = "motionlabs/NEWT-1.7B-QWEN-PREVIEW"
18
+
19
+ # --- Logs helper ---
20
+ log_messages = []
21
+
22
+ def log(msg):
23
+ log_messages.append(msg)
24
+ print(msg)
25
+ return "\n".join(log_messages)
26
+
27
+ log("Initializing tokenizer and model…")
28
+
29
+ # Load tokenizer
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
31
+ log("Tokenizer loaded.")
32
+
33
+ # Load model
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ model_id,
36
+ torch_dtype=torch.float16,
37
+ device_map="auto",
38
+ use_auth_token=HF_TOKEN
39
+ )
40
+ log("Model loaded.")
41
+
42
+ # --- Chat streaming ---
43
+ def stream_chat(history, message):
44
+ messages = []
45
+ for user, bot in history:
46
+ messages.append({"role": "user", "content": user})
47
+ if bot:
48
+ messages.append({"role": "assistant", "content": bot})
49
+ messages.append({"role": "user", "content": message})
50
+
51
+ prompt = tokenizer.apply_chat_template(
52
+ messages,
53
+ tokenize=False,
54
+ add_generation_prompt=True
55
+ )
56
+
57
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
58
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
59
+
60
+ gen_kwargs = dict(
61
+ **inputs,
62
+ streamer=streamer,
63
+ max_new_tokens=1024,
64
+ do_sample=True,
65
+ top_p=0.9,
66
+ temperature=0.7,
67
+ )
68
+
69
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
70
+ thread.start()
71
+
72
+ output_text = ""
73
+ for token in streamer:
74
+ output_text += token
75
+ yield history + [(message, output_text)]
76
+
77
+ # --- Gradio UI ---
78
+ with gr.Blocks(title=f"Chat with {model_id}") as demo:
79
+ gr.Markdown(f"# Chat with {model_id}")
80
+
81
+ chatbot = gr.Chatbot()
82
+ msg = gr.Textbox(placeholder="Type your message here…")
83
+ clear = gr.Button("Clear")
84
+ logs = gr.Textbox(label="Logs", value="\n".join(log_messages), interactive=False)
85
+
86
+ def user_submit(user_message, history):
87
+ return "", history + [(user_message, None)]
88
+
89
+ msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
90
+ stream_chat, [chatbot, msg], chatbot
91
+ )
92
+ clear.click(lambda: None, None, chatbot, queue=False)
93
+
94
+ demo.queue()
95
+ demo.launch()