fllay commited on
Commit
86c17bb
·
verified ·
1 Parent(s): f3cdb01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -52
app.py CHANGED
@@ -1,62 +1,75 @@
1
- import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
-
5
- # Pick one of the models from your collection:
6
- MODEL_NAME = "NextGLab/oransight-20-gemma-2b" # <-- edit this to whichever in your collection
7
-
8
- @torch.inference_mode()
9
- def load_model():
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
- model = AutoModelForCausalLM.from_pretrained(
12
- MODEL_NAME,
13
- device_map="auto",
14
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
15
- )
16
- pipe = pipeline(
17
- "text-generation",
18
- model=model,
19
- tokenizer=tokenizer,
20
- device=model.device,
21
- )
22
- return pipe
23
-
24
- pipe = load_model()
25
-
26
- # Simple chatbot fn
27
- def chat(message, history, max_new_tokens=256, temperature=0.7):
28
- prompt = message
29
- outputs = pipe(
30
- prompt,
31
- max_new_tokens=max_new_tokens,
32
- temperature=temperature,
33
- do_sample=True,
34
- pad_token_id=pipe.tokenizer.eos_token_id,
35
- )
36
- text = outputs[0]["generated_text"]
37
- # Hugging Face pipeline often echos input text → remove
38
- return text[len(prompt):].strip()
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  with gr.Blocks() as demo:
41
- gr.Markdown(f"# 🤖 Oransight‑20‑Gemma Demo\nModel: **{MODEL_NAME}**")
42
-
43
  chatbot = gr.Chatbot()
44
- msg = gr.Textbox(label="Your message")
45
  send = gr.Button("Send")
46
- clear = gr.Button("Clear")
47
-
48
- max_tokens = gr.Slider(50, 600, value=256, step=10, label="Max new tokens")
49
  temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
50
-
51
  state = gr.State([])
52
-
53
- def respond(message, history, max_tokens, temperature):
54
- response = chat(message, history, max_tokens, temperature)
55
- history.append((message, response))
56
- return history, history, ""
57
-
58
- send.click(respond, [msg, state, max_tokens, temperature], [chatbot, state, msg])
59
- msg.submit(respond, [msg, state, max_tokens, temperature], [chatbot, state, msg])
60
  clear.click(lambda: ([], []), None, [chatbot, state])
61
 
62
  if __name__ == "__main__":
 
 
1
  import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # Hugging Face repo ID (from the model page)
6
+ MODEL_NAME = "NextGLab/ORANSight_Gemma_2_2B_Instruct"
7
+
8
+ # Load tokenizer & model
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ MODEL_NAME,
13
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
14
+ device_map="auto"
15
+ )
16
+
17
+ # --- Helper function ---
18
+ def chat(message, history, max_new_tokens=128, temperature=0.7):
19
+ """
20
+ message: user input
21
+ history: running chat history (list of [user, assistant])
22
+ """
23
+
24
+ # Convert Gradio-style history into chat template
25
+ messages = []
26
+ for user_msg, bot_msg in history:
27
+ messages.append({"role": "user", "content": user_msg})
28
+ messages.append({"role": "assistant", "content": bot_msg})
29
+ messages.append({"role": "user", "content": message})
30
+
31
+ # Prepare input using Gemma chat template
32
+ inputs = tokenizer.apply_chat_template(
33
+ messages,
34
+ add_generation_prompt=True,
35
+ tokenize=True,
36
+ return_tensors="pt",
37
+ ).to(model.device)
38
 
39
+ with torch.no_grad():
40
+ outputs = model.generate(
41
+ **inputs,
42
+ max_new_tokens=max_new_tokens,
43
+ temperature=temperature,
44
+ do_sample=True,
45
+ pad_token_id=tokenizer.eos_token_id
46
+ )
47
+
48
+ # Decode only new tokens (avoid echoing input)
49
+ response = tokenizer.decode(
50
+ outputs[0][inputs["input_ids"].shape[-1]:],
51
+ skip_special_tokens=True
52
+ ).strip()
53
+
54
+ history.append((message, response))
55
+ return history, history, ""
56
+
57
+ # --- Gradio App ---
58
  with gr.Blocks() as demo:
59
+ gr.Markdown("# 🤖 ORANSight Gemma 2 2B Instruct")
60
+
61
  chatbot = gr.Chatbot()
62
+ msg = gr.Textbox(show_label=False, placeholder="Type a message...")
63
  send = gr.Button("Send")
64
+ clear = gr.Button("Clear Chat")
65
+
66
+ max_tokens = gr.Slider(50, 512, value=128, step=10, label="Max new tokens")
67
  temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
68
+
69
  state = gr.State([])
70
+
71
+ msg.submit(chat, [msg, state, max_tokens, temperature], [chatbot, state, msg])
72
+ send.click(chat, [msg, state, max_tokens, temperature], [chatbot, state, msg])
 
 
 
 
 
73
  clear.click(lambda: ([], []), None, [chatbot, state])
74
 
75
  if __name__ == "__main__":