Spaces:
Sleeping
Sleeping
File size: 4,304 Bytes
f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 f36fe6f e71ab18 9eb416b e71ab18 9eb416b f36fe6f e71ab18 9eb416b e71ab18 eae8d97 e71ab18 eae8d97 e71ab18 eae8d97 e71ab18 9eb416b e71ab18 9eb416b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import spaces
import torch
import numpy as np
from typing import Generator
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
# Global variables to store the model and tokenizer
tokenizer = None
model = None
def initialize_model():
"""Initializes and loads the model and tokenizer once onto the GPU."""
global tokenizer, model
if model is None:
try:
print(f"Loading model {MODEL_NAME}...")
# Use bfloat16 for efficiency on modern GPUs
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto"
)
model.eval()
# Set padding token if not defined
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model: {e}")
raise
return tokenizer, model
# Call initialization
try:
initialize_model()
except Exception as e:
print(f"Warning: Global model initialization failed: {e}")
@spaces.GPU(duration=120)
def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
"""
Generates a response from the KAT model with proper streaming.
"""
global tokenizer, model
# Fallback initialization
if model is None or tokenizer is None:
initialize_model()
# Convert Gradio history format to the model's chat template format
messages = []
for human, bot in history:
if human:
messages.append({"role": "user", "content": human})
if bot:
messages.append({"role": "assistant", "content": bot})
# Add the current prompt
messages.append({"role": "user", "content": prompt})
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize with attention mask
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
# Store initial input length
initial_length = input_ids.shape[-1]
# Generate with streaming using yield-based approach
accumulated_text = ""
generated_tokens = 0
# Generate tokens incrementally
while generated_tokens < MAX_NEW_TOKENS:
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Get next token probabilities
next_token_logits = outputs.logits[:, -1, :]
# Apply temperature
if TEMPERATURE > 0:
next_token_logits = next_token_logits / TEMPERATURE
# Apply softmax and sample
probs = torch.softmax(next_token_logits, dim=-1)
if DO_SAMPLE:
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Check for EOS token
if next_token.item() == tokenizer.eos_token_id:
break
# Decode the new token
new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
# Update accumulated text
accumulated_text += new_token_text
# Yield the current accumulated text
yield accumulated_text
# Prepare for next iteration
input_ids = torch.cat([input_ids, next_token], dim=-1)
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
# Increment generated tokens counter
generated_tokens += 1
# Final yield to ensure complete text
if accumulated_text:
yield accumulated_text.strip() |