File size: 3,232 Bytes
36b0eaf
 
279f350
36b0eaf
 
 
 
 
 
279f350
36b0eaf
 
 
 
279f350
36b0eaf
 
279f350
36b0eaf
 
 
 
 
 
 
 
279f350
36b0eaf
 
 
 
 
 
 
 
 
 
 
 
 
 
279f350
36b0eaf
279f350
36b0eaf
 
 
 
 
 
 
279f350
36b0eaf
 
 
 
 
 
279f350
36b0eaf
 
 
 
 
 
 
279f350
36b0eaf
 
 
 
 
 
 
 
 
279f350
36b0eaf
 
 
 
279f350
36b0eaf
 
 
 
 
 
 
 
 
 
 
 
 
 
279f350
36b0eaf
279f350
36b0eaf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
'''
This script creates a Gradio chatbot interface for the ibm-granite/granite-3.3-8b-instruct model.

Key Features:
- Loads the model and tokenizer from Hugging Face Hub.
- Uses a chat interface for interactive conversations.
- Manages chat history to maintain context.
- Handles API key management through Hugging Face Spaces secrets.
'''

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import os

# --- Configuration ---
MODEL_ID = "ibm-granite/granite-3.3-8b-instruct"

# --- Model and Tokenizer Loading ---
def load_model_and_tokenizer():
    '''Load the model and tokenizer, handling potential errors.'''
    try:
        # Securely get the Hugging Face token from secrets
        hf_token = os.getenv("HUGGINGFACE_TOKEN")
        if not hf_token:
            raise ValueError("HUGGINGFACE_TOKEN secret not found. Please add it to your Space settings.")

        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            device_map=device,
            torch_dtype=torch.bfloat16,
            token=hf_token
        )
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token)
        
        return model, tokenizer, device
    except Exception as e:
        # Provide a user-friendly error message
        raise RuntimeError(f"Failed to load model or tokenizer: {e}")

model, tokenizer, device = load_model_and_tokenizer()

# --- Chatbot Logic ---
def chat_function(message, history):
    '''
    This function processes the user's message and returns the model's response.
    '''
    # Set seed for reproducibility
    set_seed(42)

    # Format the conversation history for the model
    conv = []
    for user_msg, model_msg in history:
        conv.append({"role": "user", "content": user_msg})
        conv.append({"role": "assistant", "content": model_msg})
    conv.append({"role": "user", "content": message})

    # Tokenize the input
    input_ids = tokenizer.apply_chat_template(
        conv, 
        return_tensors="pt", 
        thinking=False,  # Set to False for direct response
        add_generation_prompt=True
    ).to(device)

    # Generate the response
    output = model.generate(
        input_ids,
        max_new_tokens=1024,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=0.7,
    )

    # Decode the prediction
    prediction = tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True)
    
    return prediction

# --- Gradio Interface ---
def create_gradio_interface():
    '''Create and return the Gradio ChatInterface.'''
    return gr.ChatInterface(
        fn=chat_function,
        title="Granite 3.3 8B Chatbot",
        description="A chatbot powered by the ibm-granite/granite-3.3-8b-instruct model. Ask any question!",
        theme="soft",
        examples=[
            ["Hello, who are you?"],
            ["What is the capital of France?"],
            ["Explain the theory of relativity in simple terms."]
        ]
    )

# --- Main Execution ---
if __name__ == "__main__":
    chatbot_interface = create_gradio_interface()
    chatbot_interface.launch()