import os import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # Set cache directory for HF Spaces persistent storage os.environ.setdefault("HF_HOME", "/data/.huggingface") os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.huggingface/transformers") # Define available base models (for local inference) model_list = { "SafeLM 1.7B": "locuslab/safelm-1.7b", "SmolLM2 1.7B": "HuggingFaceTB/SmolLM2-1.7B", "Llama 3.2 1B": "meta-llama/Llama-3.2-1B", } # Use token from environment variables (HF Spaces) or keys.py (local) HF_TOKEN_FROM_ENV = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN") HF_TOKEN = HF_TOKEN_FROM_ENV # Model cache for loaded models model_cache = {} def load_model(model_name): """Load model and tokenizer, cache them for reuse""" if model_name not in model_cache: print(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, # Use float32 for CPU device_map="cpu", low_cpu_mem_usage=True ) # Add padding token if it doesn't exist if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model_cache[model_name] = { 'tokenizer': tokenizer, 'model': model } print(f"Model {model_name} loaded successfully") return model_cache[model_name] def respond(message, history, max_tokens, temperature, top_p, selected_model): try: # Get the model ID from the model list model_id = model_list.get(selected_model, "locuslab/safelm-1.7b") # Load the model and tokenizer try: model_data = load_model(model_id) tokenizer = model_data['tokenizer'] model = model_data['model'] except Exception as e: yield f"❌ Error loading model '{model_id}': {str(e)}" return # Build conversation context for base model conversation = "" for u, a in history: if u: u_clean = u[2:].strip() if u.startswith("👤 ") else u conversation += f"User: {u_clean}\n" if a: a_clean = a[2:].strip() if a.startswith("🛡️ ") else a conversation += f"Assistant: {a_clean}\n" # Add current message conversation += f"User: {message}\nAssistant:" # Tokenize input inputs = tokenizer.encode(conversation, return_tensors="pt") # Limit input length to prevent memory issues max_input_length = 1024 if inputs.shape[1] > max_input_length: inputs = inputs[:, -max_input_length:] # Generate response with torch.no_grad(): outputs = model.generate( inputs, max_new_tokens=min(max_tokens, 150), temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, no_repeat_ngram_size=3 ) # Decode only the new tokens new_tokens = outputs[0][inputs.shape[1]:] response = tokenizer.decode(new_tokens, skip_special_tokens=True) # Clean up the response response = response.strip() # Stop at natural break points stop_sequences = ["\nUser:", "\nHuman:", "\n\n"] for stop_seq in stop_sequences: if stop_seq in response: response = response.split(stop_seq)[0] yield response if response else "I'm not sure how to respond to that." except Exception as e: yield f"❌ Error generating response: {str(e)}" # Custom CSS for styling (your beautiful design!) css = """ body { background-color: #f0f5fb; /* Light pastel blue background */ } .gradio-container { background-color: white; border-radius: 16px; box-shadow: 0 2px 10px rgba(0,0,0,0.05); max-width: 90%; margin: 15px auto; padding-bottom: 20px; } /* Header styling with diagonal shield */ .app-header { position: relative; overflow: hidden; } .app-header::before { content: "🛡️"; position: absolute; font-size: 100px; opacity: 0.1; right: -20px; top: -30px; transform: rotate(15deg); pointer-events: none; } /* Simple styling for buttons */ #send-btn { background-color: white !important; color: #333 !important; border: 2px solid #e6c200 !important; } #send-btn:hover { background-color: #fff9e6 !important; } #clear-btn { background-color: white !important; color: #333 !important; border: 2px solid #e6c200 !important; } #clear-btn:hover { background-color: #fff9e6 !important; } /* Hide elements */ footer { display: none !important; } .footer { display: none !important; } """ with gr.Blocks(css=css, analytics_enabled=False) as demo: # Custom header with branding gr.HTML("""