Spaces:
Sleeping
Sleeping
| import spaces | |
| import torch | |
| import numpy as np | |
| from typing import Generator | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE | |
| # Global variables to store the model and tokenizer | |
| tokenizer = None | |
| model = None | |
| def initialize_model(): | |
| """Initializes and loads the model and tokenizer once onto the GPU.""" | |
| global tokenizer, model | |
| if model is None: | |
| try: | |
| print(f"Loading model {MODEL_NAME}...") | |
| # Use bfloat16 for efficiency on modern GPUs | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=dtype, | |
| device_map="auto" | |
| ) | |
| model.eval() | |
| # Set padding token if not defined | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| raise | |
| return tokenizer, model | |
| # Call initialization | |
| try: | |
| initialize_model() | |
| except Exception as e: | |
| print(f"Warning: Global model initialization failed: {e}") | |
| def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]: | |
| """ | |
| Generates a response from the KAT model with proper streaming. | |
| """ | |
| global tokenizer, model | |
| # Fallback initialization | |
| if model is None or tokenizer is None: | |
| initialize_model() | |
| # Convert Gradio history format to the model's chat template format | |
| messages = [] | |
| for human, bot in history: | |
| if human: | |
| messages.append({"role": "user", "content": human}) | |
| if bot: | |
| messages.append({"role": "assistant", "content": bot}) | |
| # Add the current prompt | |
| messages.append({"role": "user", "content": prompt}) | |
| # Apply chat template | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| # Tokenize with attention mask | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| input_ids = inputs.input_ids.to(model.device) | |
| attention_mask = inputs.attention_mask.to(model.device) | |
| # Store initial input length | |
| initial_length = input_ids.shape[-1] | |
| # Generate with streaming using yield-based approach | |
| accumulated_text = "" | |
| generated_tokens = 0 | |
| # Generate tokens incrementally | |
| while generated_tokens < MAX_NEW_TOKENS: | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| # Get next token probabilities | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # Apply temperature | |
| if TEMPERATURE > 0: | |
| next_token_logits = next_token_logits / TEMPERATURE | |
| # Apply softmax and sample | |
| probs = torch.softmax(next_token_logits, dim=-1) | |
| if DO_SAMPLE: | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| else: | |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) | |
| # Check for EOS token | |
| if next_token.item() == tokenizer.eos_token_id: | |
| break | |
| # Decode the new token | |
| new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True) | |
| # Update accumulated text | |
| accumulated_text += new_token_text | |
| # Yield the current accumulated text | |
| yield accumulated_text | |
| # Prepare for next iteration | |
| input_ids = torch.cat([input_ids, next_token], dim=-1) | |
| attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1) | |
| # Increment generated tokens counter | |
| generated_tokens += 1 | |
| # Final yield to ensure complete text | |
| if accumulated_text: | |
| yield accumulated_text.strip() |