KAT-Dev / models.py
akhaliq's picture
akhaliq HF Staff
Update Gradio app with multiple files
9eb416b verified
import spaces
import torch
import numpy as np
from typing import Generator
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
# Global variables to store the model and tokenizer
tokenizer = None
model = None
def initialize_model():
"""Initializes and loads the model and tokenizer once onto the GPU."""
global tokenizer, model
if model is None:
try:
print(f"Loading model {MODEL_NAME}...")
# Use bfloat16 for efficiency on modern GPUs
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto"
)
model.eval()
# Set padding token if not defined
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model: {e}")
raise
return tokenizer, model
# Call initialization
try:
initialize_model()
except Exception as e:
print(f"Warning: Global model initialization failed: {e}")
@spaces.GPU(duration=120)
def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
"""
Generates a response from the KAT model with proper streaming.
"""
global tokenizer, model
# Fallback initialization
if model is None or tokenizer is None:
initialize_model()
# Convert Gradio history format to the model's chat template format
messages = []
for human, bot in history:
if human:
messages.append({"role": "user", "content": human})
if bot:
messages.append({"role": "assistant", "content": bot})
# Add the current prompt
messages.append({"role": "user", "content": prompt})
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize with attention mask
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
# Store initial input length
initial_length = input_ids.shape[-1]
# Generate with streaming using yield-based approach
accumulated_text = ""
generated_tokens = 0
# Generate tokens incrementally
while generated_tokens < MAX_NEW_TOKENS:
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Get next token probabilities
next_token_logits = outputs.logits[:, -1, :]
# Apply temperature
if TEMPERATURE > 0:
next_token_logits = next_token_logits / TEMPERATURE
# Apply softmax and sample
probs = torch.softmax(next_token_logits, dim=-1)
if DO_SAMPLE:
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Check for EOS token
if next_token.item() == tokenizer.eos_token_id:
break
# Decode the new token
new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
# Update accumulated text
accumulated_text += new_token_text
# Yield the current accumulated text
yield accumulated_text
# Prepare for next iteration
input_ids = torch.cat([input_ids, next_token], dim=-1)
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
# Increment generated tokens counter
generated_tokens += 1
# Final yield to ensure complete text
if accumulated_text:
yield accumulated_text.strip()