File size: 2,653 Bytes
0c0a8bd
86c17bb
 
 
5a1c3b8
86c17bb
 
5a1c3b8
86c17bb
 
 
 
5a1c3b8
 
86c17bb
 
5a1c3b8
86c17bb
33553e1
b88fc91
33553e1
 
 
 
 
86c17bb
b88fc91
 
33553e1
 
 
b88fc91
 
5a1c3b8
b88fc91
 
86c17bb
b88fc91
86c17bb
 
 
 
 
5a1c3b8
86c17bb
 
b88fc91
33553e1
b88fc91
33553e1
 
 
 
 
86c17bb
33553e1
 
b88fc91
33553e1
86c17bb
5a1c3b8
 
e6c2261
5a1c3b8
86c17bb
0c0a8bd
86c17bb
0c0a8bd
86c17bb
 
5a1c3b8
 
 
86c17bb
0c0a8bd
86c17bb
 
 
0c0a8bd
e6c2261
 
0c0a8bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# Choose your model repo (from NextGLab)
MODEL_NAME = "NextGLab/ORANSight_Gemma_2_2B_Instruct"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",   # lets HF decide (fp16/bf16/fp32 depending on GPU)
    device_map="auto"     # put on GPU if available
)

# --- Chat function ---
def chat(message, history, max_new_tokens=128, temperature=0.7):
    try:
        # Convert conversation history into messages
        messages = []
        for user_msg, bot_msg in history:
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": bot_msg})
        messages.append({"role": "user", "content": message})

        # Apply chat template -> returns tensor of input_ids
        input_ids = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt"
        ).to(model.device)

        # Wrap as dict so generate(**inputs) works
        inputs = {"input_ids": input_ids}

        # Generate output
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

        # Decode new tokens only
        response = tokenizer.decode(
            outputs[0][input_ids.shape[-1]:],
            skip_special_tokens=True
        ).strip()

        history.append((message, response))
        return history, history, ""

    except Exception as e:
        import traceback
        traceback.print_exc()
        return history + [(message, f"⚠️ Error: {str(e)}")], history, ""


# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("# 🤖 ORANSight Gemma Chat (2B Instruct)")

    chatbot = gr.Chatbot()
    msg = gr.Textbox(show_label=False, placeholder="Type a message...")
    send = gr.Button("Send")
    clear = gr.Button("Clear Chat")

    with gr.Row():
        max_tokens = gr.Slider(50, 512, step=10, value=128, label="Max tokens")
        temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")

    state = gr.State([])

    msg.submit(chat, [msg, state, max_tokens, temperature], [chatbot, state, msg])
    send.click(chat, [msg, state, max_tokens, temperature], [chatbot, state, msg])
    clear.click(lambda: ([], []), None, [chatbot, state])

if __name__ == "__main__":
    demo.launch()