Spaces:
Sleeping
Sleeping
File size: 2,653 Bytes
0c0a8bd 86c17bb 5a1c3b8 86c17bb 5a1c3b8 86c17bb 5a1c3b8 86c17bb 5a1c3b8 86c17bb 33553e1 b88fc91 33553e1 86c17bb b88fc91 33553e1 b88fc91 5a1c3b8 b88fc91 86c17bb b88fc91 86c17bb 5a1c3b8 86c17bb b88fc91 33553e1 b88fc91 33553e1 86c17bb 33553e1 b88fc91 33553e1 86c17bb 5a1c3b8 e6c2261 5a1c3b8 86c17bb 0c0a8bd 86c17bb 0c0a8bd 86c17bb 5a1c3b8 86c17bb 0c0a8bd 86c17bb 0c0a8bd e6c2261 0c0a8bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# Choose your model repo (from NextGLab)
MODEL_NAME = "NextGLab/ORANSight_Gemma_2_2B_Instruct"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto", # lets HF decide (fp16/bf16/fp32 depending on GPU)
device_map="auto" # put on GPU if available
)
# --- Chat function ---
def chat(message, history, max_new_tokens=128, temperature=0.7):
try:
# Convert conversation history into messages
messages = []
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
# Apply chat template -> returns tensor of input_ids
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt"
).to(model.device)
# Wrap as dict so generate(**inputs) works
inputs = {"input_ids": input_ids}
# Generate output
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode new tokens only
response = tokenizer.decode(
outputs[0][input_ids.shape[-1]:],
skip_special_tokens=True
).strip()
history.append((message, response))
return history, history, ""
except Exception as e:
import traceback
traceback.print_exc()
return history + [(message, f"⚠️ Error: {str(e)}")], history, ""
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# 🤖 ORANSight Gemma Chat (2B Instruct)")
chatbot = gr.Chatbot()
msg = gr.Textbox(show_label=False, placeholder="Type a message...")
send = gr.Button("Send")
clear = gr.Button("Clear Chat")
with gr.Row():
max_tokens = gr.Slider(50, 512, step=10, value=128, label="Max tokens")
temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
state = gr.State([])
msg.submit(chat, [msg, state, max_tokens, temperature], [chatbot, state, msg])
send.click(chat, [msg, state, max_tokens, temperature], [chatbot, state, msg])
clear.click(lambda: ([], []), None, [chatbot, state])
if __name__ == "__main__":
demo.launch() |