import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM # Choose your model repo (from NextGLab) MODEL_NAME = "NextGLab/ORANSight_Gemma_2_2B_Instruct" # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype="auto", # lets HF decide (fp16/bf16/fp32 depending on GPU) device_map="auto" # put on GPU if available ) # --- Chat function --- def chat(message, history, max_new_tokens=128, temperature=0.7): try: # Convert conversation history into messages messages = [] for user_msg, bot_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": message}) # Apply chat template -> returns tensor of input_ids input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" ).to(model.device) # Wrap as dict so generate(**inputs) works inputs = {"input_ids": input_ids} # Generate output outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode new tokens only response = tokenizer.decode( outputs[0][input_ids.shape[-1]:], skip_special_tokens=True ).strip() history.append((message, response)) return history, history, "" except Exception as e: import traceback traceback.print_exc() return history + [(message, f"⚠️ Error: {str(e)}")], history, "" # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("# 🤖 ORANSight Gemma Chat (2B Instruct)") chatbot = gr.Chatbot() msg = gr.Textbox(show_label=False, placeholder="Type a message...") send = gr.Button("Send") clear = gr.Button("Clear Chat") with gr.Row(): max_tokens = gr.Slider(50, 512, step=10, value=128, label="Max tokens") temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature") state = gr.State([]) msg.submit(chat, [msg, state, max_tokens, temperature], [chatbot, state, msg]) send.click(chat, [msg, state, max_tokens, temperature], [chatbot, state, msg]) clear.click(lambda: ([], []), None, [chatbot, state]) if __name__ == "__main__": demo.launch()