Spaces:
Runtime error
Runtime error
File size: 3,232 Bytes
36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf 279f350 36b0eaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
'''
This script creates a Gradio chatbot interface for the ibm-granite/granite-3.3-8b-instruct model.
Key Features:
- Loads the model and tokenizer from Hugging Face Hub.
- Uses a chat interface for interactive conversations.
- Manages chat history to maintain context.
- Handles API key management through Hugging Face Spaces secrets.
'''
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import os
# --- Configuration ---
MODEL_ID = "ibm-granite/granite-3.3-8b-instruct"
# --- Model and Tokenizer Loading ---
def load_model_and_tokenizer():
'''Load the model and tokenizer, handling potential errors.'''
try:
# Securely get the Hugging Face token from secrets
hf_token = os.getenv("HUGGINGFACE_TOKEN")
if not hf_token:
raise ValueError("HUGGINGFACE_TOKEN secret not found. Please add it to your Space settings.")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map=device,
torch_dtype=torch.bfloat16,
token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token)
return model, tokenizer, device
except Exception as e:
# Provide a user-friendly error message
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
model, tokenizer, device = load_model_and_tokenizer()
# --- Chatbot Logic ---
def chat_function(message, history):
'''
This function processes the user's message and returns the model's response.
'''
# Set seed for reproducibility
set_seed(42)
# Format the conversation history for the model
conv = []
for user_msg, model_msg in history:
conv.append({"role": "user", "content": user_msg})
conv.append({"role": "assistant", "content": model_msg})
conv.append({"role": "user", "content": message})
# Tokenize the input
input_ids = tokenizer.apply_chat_template(
conv,
return_tensors="pt",
thinking=False, # Set to False for direct response
add_generation_prompt=True
).to(device)
# Generate the response
output = model.generate(
input_ids,
max_new_tokens=1024,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
)
# Decode the prediction
prediction = tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True)
return prediction
# --- Gradio Interface ---
def create_gradio_interface():
'''Create and return the Gradio ChatInterface.'''
return gr.ChatInterface(
fn=chat_function,
title="Granite 3.3 8B Chatbot",
description="A chatbot powered by the ibm-granite/granite-3.3-8b-instruct model. Ask any question!",
theme="soft",
examples=[
["Hello, who are you?"],
["What is the capital of France?"],
["Explain the theory of relativity in simple terms."]
]
)
# --- Main Execution ---
if __name__ == "__main__":
chatbot_interface = create_gradio_interface()
chatbot_interface.launch()
|