Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, BitsAndBytesConfig | |
| from threading import Thread | |
| import time | |
| import logging | |
| import gc | |
| from pathlib import Path | |
| import re | |
| from huggingface_hub import HfApi, list_models | |
| import os | |
| import queue | |
| import threading | |
| from collections import deque | |
| # Set PyTorch memory management environment variables | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('gradio-chat-ui.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Log memory management settings | |
| logger.info(f"PyTorch CUDA allocation config: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}") | |
| logger.info(f"CUDA device count: {torch.cuda.device_count() if torch.cuda.is_available() else 'N/A'}") | |
| # Model parameters | |
| MODEL_NAME = "No Model Loaded" | |
| MAX_LENGTH = 16384 | |
| DEFAULT_TEMPERATURE = 0.15 | |
| DEFAULT_TOP_P = 0.93 | |
| DEFAULT_TOP_K = 50 | |
| DEFAULT_REP_PENALTY = 1.15 | |
| # Base location for local models | |
| LOCAL_MODELS_BASE = "/home/llm-models/" | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| hf_api = HfApi() | |
| # Generation metadata storage with automatic cleanup | |
| generation_metadata = deque(maxlen=100) # Fixed size deque to prevent unlimited growth | |
| class RAMSavingIteratorStreamer: | |
| """ | |
| Custom streamer that saves VRAM by moving tokens to CPU and provides iteration interface for Gradio. | |
| Combines the benefits of TextStreamer (RAM saving) with TextIteratorStreamer (iteration). | |
| """ | |
| def __init__(self, tokenizer, skip_special_tokens=True, skip_prompt=True, timeout=None): | |
| self.tokenizer = tokenizer | |
| self.skip_special_tokens = skip_special_tokens | |
| self.skip_prompt = skip_prompt | |
| self.timeout = timeout | |
| # Token and text storage (CPU-based) | |
| self.generated_tokens = [] | |
| self.generated_text = "" | |
| self.token_cache = "" | |
| # Queue for streaming interface | |
| self.text_queue = queue.Queue() | |
| self.stop_signal = threading.Event() | |
| # Track prompt tokens to skip them | |
| self.prompt_length = 0 | |
| self.tokens_processed = 0 | |
| # Decoding state | |
| self.print_len = 0 | |
| def put(self, value): | |
| """ | |
| Receive new token(s) and process them for streaming. | |
| This method is called by the model during generation. | |
| """ | |
| try: | |
| # Handle different input types | |
| if isinstance(value, torch.Tensor): | |
| if value.dim() > 1: | |
| value = value[0] # Remove batch dimension if present | |
| token_ids = value.tolist() | |
| # Store CPU version to save VRAM | |
| self.generated_tokens.append(value.detach().cpu()) | |
| else: | |
| token_ids = value if isinstance(value, list) else [value] | |
| self.generated_tokens.append(torch.tensor(token_ids, dtype=torch.long)) | |
| # Track tokens processed | |
| if isinstance(token_ids, list): | |
| self.tokens_processed += len(token_ids) | |
| else: | |
| self.tokens_processed += 1 | |
| # Skip prompt tokens if requested | |
| if self.skip_prompt and self.tokens_processed <= self.prompt_length: | |
| return | |
| # Decode incrementally for real-time streaming | |
| try: | |
| # Get all generated tokens so far | |
| if self.generated_tokens: | |
| all_tokens = [] | |
| for tokens in self.generated_tokens: | |
| if isinstance(tokens, torch.Tensor): | |
| if tokens.dim() == 0: | |
| all_tokens.append(tokens.item()) | |
| else: | |
| all_tokens.extend(tokens.tolist()) | |
| elif isinstance(tokens, list): | |
| all_tokens.extend(tokens) | |
| else: | |
| all_tokens.append(tokens) | |
| # Decode the full sequence | |
| full_text = self.tokenizer.decode( | |
| all_tokens, | |
| skip_special_tokens=self.skip_special_tokens | |
| ) | |
| # Get new text since last update | |
| if len(full_text) > self.print_len: | |
| new_text = full_text[self.print_len:] | |
| self.print_len = len(full_text) | |
| self.generated_text = full_text | |
| # Put new text in queue for iteration | |
| if new_text: | |
| self.text_queue.put(new_text) | |
| except Exception as decode_error: | |
| logger.warning(f"Decoding error in streamer: {decode_error}") | |
| except Exception as e: | |
| logger.error(f"Error in RAMSavingIteratorStreamer.put: {e}") | |
| def end(self): | |
| """Signal end of generation.""" | |
| self.text_queue.put(None) # Sentinel value | |
| def __iter__(self): | |
| """Make this streamer iterable for Gradio compatibility.""" | |
| return self | |
| def __next__(self): | |
| """Get next chunk of text for streaming.""" | |
| try: | |
| value = self.text_queue.get(timeout=self.timeout) | |
| if value is None: # End signal | |
| raise StopIteration | |
| return value | |
| except queue.Empty: | |
| raise StopIteration | |
| def set_prompt_length(self, prompt_length): | |
| """Set the length of prompt tokens to skip.""" | |
| self.prompt_length = prompt_length | |
| def get_generated_text(self): | |
| """Get the complete generated text.""" | |
| return self.generated_text | |
| def get_generated_tokens(self): | |
| """Get all generated tokens as a single tensor.""" | |
| if not self.generated_tokens: | |
| return torch.tensor([]) | |
| # Combine all tokens | |
| all_tokens = [] | |
| for tokens in self.generated_tokens: | |
| if isinstance(tokens, torch.Tensor): | |
| if tokens.dim() == 0: | |
| all_tokens.append(tokens.item()) | |
| else: | |
| all_tokens.extend(tokens.tolist()) | |
| elif isinstance(tokens, list): | |
| all_tokens.extend(tokens) | |
| else: | |
| all_tokens.append(tokens) | |
| return torch.tensor(all_tokens, dtype=torch.long) | |
| def cleanup(self): | |
| """Clean up resources.""" | |
| self.generated_tokens.clear() | |
| self.generated_text = "" | |
| self.token_cache = "" | |
| # Clear queue | |
| while not self.text_queue.empty(): | |
| try: | |
| self.text_queue.get_nowait() | |
| except queue.Empty: | |
| break | |
| self.stop_signal.set() | |
| def scan_local_models(base_path=LOCAL_MODELS_BASE): | |
| """Scan for valid models in the local models directory""" | |
| try: | |
| base_path = Path(base_path) | |
| if not base_path.exists(): | |
| logger.warning(f"Base path does not exist: {base_path}") | |
| return [] | |
| valid_models = [] | |
| # Scan subdirectories (depth 1 only) | |
| for item in base_path.iterdir(): | |
| if item.is_dir(): | |
| # Check if directory contains required model files | |
| config_file = item / "config.json" | |
| # Look for model weight files (safetensors or bin) | |
| safetensors_files = list(item.glob("*.safetensors")) | |
| bin_files = list(item.glob("*.bin")) | |
| # Check if it's a valid model directory | |
| if config_file.exists() and (safetensors_files or bin_files): | |
| valid_models.append(str(item)) | |
| logger.info(f"Found valid model: {item}") | |
| # Sort models for consistent ordering | |
| valid_models.sort() | |
| logger.info(f"Found {len(valid_models)} valid models in {base_path}") | |
| return valid_models | |
| except Exception as e: | |
| logger.error(f"Error scanning local models: {e}") | |
| return [] | |
| def update_local_models_dropdown(base_path): | |
| """Update the local models dropdown based on base path""" | |
| if not base_path or not base_path.strip(): | |
| return gr.Dropdown(choices=[], value=None, interactive=True) | |
| models = scan_local_models(base_path) | |
| model_choices = [Path(model).name for model in models] # Show just the model name | |
| model_paths = models # Keep full paths for internal use | |
| # Create a mapping for display name to full path | |
| if model_choices: | |
| return gr.Dropdown( | |
| choices=list(zip(model_choices, model_paths)), | |
| value=model_paths[0] if model_paths else None, | |
| label="๐ Available Local Models", | |
| interactive=True, | |
| allow_custom_value=False, # Don't allow custom for local models | |
| filterable=True | |
| ) | |
| else: | |
| return gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| label="๐ Available Local Models (None found)", | |
| interactive=True, | |
| allow_custom_value=False, | |
| filterable=True | |
| ) | |
| def search_hf_models(query, limit=20): | |
| """Enhanced search for models on Hugging Face Hub with better coverage""" | |
| if not query or len(query.strip()) < 2: | |
| return [] | |
| try: | |
| query = query.strip() | |
| model_choices = [] | |
| # Strategy 1: Direct model ID search (if query looks like a model ID) | |
| if '/' in query: | |
| try: | |
| # Try to get the specific model | |
| model_info = hf_api.model_info(query) | |
| if model_info and hasattr(model_info, 'id'): | |
| model_choices.append(model_info.id) | |
| logger.info(f"Found direct model: {model_info.id}") | |
| except Exception as direct_error: | |
| logger.debug(f"Direct model search failed: {direct_error}") | |
| # Strategy 2: Search with different parameters | |
| search_strategies = [ | |
| # Exact search | |
| {"search": query, "sort": "downloads", "direction": -1, "limit": limit//2}, | |
| # Author search (if query contains /) | |
| {"author": query.split('/')[0] if '/' in query else query, "sort": "downloads", "direction": -1, "limit": limit//4} if '/' in query else None, | |
| # Broader search | |
| {"search": query, "sort": "trending", "direction": -1, "limit": limit//4}, | |
| ] | |
| for strategy in search_strategies: | |
| if strategy is None: | |
| continue | |
| try: | |
| models = list_models( | |
| task="text-generation", | |
| **strategy | |
| ) | |
| for model in models: | |
| if model.id not in model_choices: | |
| model_choices.append(model.id) | |
| except Exception as strategy_error: | |
| logger.debug(f"Search strategy failed: {strategy_error}") | |
| # Remove duplicates while preserving order | |
| seen = set() | |
| unique_choices = [] | |
| for choice in model_choices: | |
| if choice not in seen: | |
| seen.add(choice) | |
| unique_choices.append(choice) | |
| # Limit results | |
| final_choices = unique_choices[:limit] | |
| logger.info(f"HF search for '{query}' returned {len(final_choices)} models") | |
| return final_choices | |
| except Exception as e: | |
| logger.error(f"Error searching models: {str(e)}") | |
| return [] | |
| def update_model_dropdown(query): | |
| """Update dropdown with enhanced search results""" | |
| if not query or len(query.strip()) < 2: | |
| return gr.Dropdown(choices=[], value=None, interactive=True) | |
| choices = search_hf_models(query, limit=20) | |
| return gr.Dropdown( | |
| choices=choices, | |
| value=choices[0] if choices else None, | |
| interactive=True, | |
| allow_custom_value=True, # Allow manual typing | |
| filterable=True | |
| ) | |
| def load_model_with_progress(model_source, hf_model, local_path, local_model_selection, quantization, memory_optimization): | |
| """Load model with progress tracking and memory optimization""" | |
| global model, tokenizer, MODEL_NAME | |
| # Determine model path based on source | |
| if model_source == "Hugging Face Model": | |
| if not hf_model: | |
| return "โ Error: Please select a model from the dropdown" | |
| model_path = hf_model | |
| else: | |
| # Use selected local model if available, otherwise use manual path | |
| if local_model_selection: | |
| model_path = local_model_selection | |
| else: | |
| model_path = local_path | |
| if not Path(model_path).exists(): | |
| logger.error(f"Local path does not exist: {model_path}") | |
| return f"โ Error: Local path does not exist: {model_path}" | |
| MODEL_NAME = model_path.split("/")[-1] if "/" in model_path else model_path | |
| logger.info(f"Loading model from {model_path} with memory optimization: {memory_optimization}") | |
| try: | |
| # Yield progress updates | |
| yield "๐ Initializing model loading..." | |
| # Setup memory configuration (GPU-only, generous allocation) | |
| if torch.cuda.is_available(): | |
| device_properties = torch.cuda.get_device_properties(0) | |
| total_memory_gb = device_properties.total_memory / (1024**3) | |
| # Set max memory to 11GB as requested (GPU-bound) | |
| max_memory_val = 11.5 # Fixed 11GB allocation | |
| max_memory = f"{max_memory_val}GB" | |
| logger.info(f"Setting max GPU memory to {max_memory} (Total available: {total_memory_gb:.2f}GB)") | |
| else: | |
| max_memory = "11GB" | |
| logger.info("CUDA not available. Using CPU fallback.") | |
| yield "๐ Configuring quantization settings..." | |
| # Configure quantization (removed CPU offloading) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=quantization == "4bit", | |
| load_in_8bit=quantization == "8bit", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| yield "๐ Loading tokenizer..." | |
| # Load tokenizer | |
| if model_source == "Local Path": | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| local_files_only=True | |
| ) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True | |
| ) | |
| yield "๐ Cleaning memory cache..." | |
| # Clean memory | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Determine torch dtype | |
| if quantization in ["4bit", "8bit"]: | |
| torch_dtype = torch.bfloat16 | |
| elif quantization == "f16": | |
| torch_dtype = torch.float16 | |
| else: # bf16 | |
| torch_dtype = torch.bfloat16 | |
| yield "๐ Loading model weights (this may take a while)..." | |
| # Simple GPU-only model loading parameters | |
| model_kwargs = { | |
| "device_map": "auto", | |
| "max_memory": {0: max_memory} if torch.cuda.is_available() else None, | |
| "torch_dtype": torch_dtype, | |
| "quantization_config": bnb_config if quantization in ["4bit", "8bit"] else None, | |
| "trust_remote_code": True, | |
| } | |
| # Memory optimization specific settings (GPU-only) | |
| if memory_optimization: | |
| model_kwargs.update({ | |
| "attn_implementation": "flash_attention_2" if torch.cuda.is_available() else "sdpa", | |
| "use_cache": False, # Disable cache by default for memory optimization | |
| }) | |
| else: | |
| model_kwargs.update({ | |
| "attn_implementation": "flash_attention_2" if torch.cuda.is_available() else "sdpa", | |
| #"use_cache": True, # Enable cache for performance | |
| }) | |
| # Add local files only for local models | |
| if model_source == "Local Path": | |
| model_kwargs["local_files_only"] = True | |
| # Load model | |
| model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) | |
| # Post-loading memory optimization | |
| if memory_optimization: | |
| yield "๐ Applying memory optimizations..." | |
| # Additional memory cleanup after loading | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| logger.info("Model loaded successfully with memory optimization") | |
| yield "โ Model loaded successfully with memory optimization!" if memory_optimization else "โ Model loaded successfully!" | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}", exc_info=True) | |
| yield f"โ Error loading model: {str(e)}" | |
| def unload_model(): | |
| """Unload the model and free memory with aggressive cleanup""" | |
| global model, tokenizer, MODEL_NAME | |
| if model is None: | |
| return "No model loaded" | |
| try: | |
| logger.info("Unloading model with aggressive memory cleanup...") | |
| # Step 1: Move model to CPU first (if it was on GPU) | |
| if torch.cuda.is_available() and hasattr(model, 'device'): | |
| try: | |
| model.cpu() | |
| logger.info("Model moved to CPU") | |
| except Exception as cpu_error: | |
| logger.warning(f"Could not move model to CPU: {cpu_error}") | |
| # Step 2: Clear model cache if available | |
| if hasattr(model, 'clear_cache'): | |
| model.clear_cache() | |
| # Step 3: Delete model and tokenizer references | |
| del model | |
| del tokenizer | |
| model = None | |
| tokenizer = None | |
| # Step 4: Reset model name | |
| MODEL_NAME = "No Model Loaded" | |
| # Step 5: Clear metadata deque | |
| generation_metadata.clear() | |
| # Step 6: Aggressive garbage collection (multiple rounds) | |
| for i in range(5): # More aggressive - 5 rounds | |
| gc.collect() | |
| time.sleep(0.1) # Small delay between rounds | |
| # Step 7: Aggressive CUDA cleanup | |
| if torch.cuda.is_available(): | |
| logger.info("Performing aggressive CUDA cleanup...") | |
| # Multiple rounds of cache clearing | |
| for i in range(5): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| # Additional PyTorch CUDA cleanup | |
| if hasattr(torch.cuda, 'ipc_collect'): | |
| torch.cuda.ipc_collect() | |
| # Reset memory stats | |
| if hasattr(torch.cuda, 'reset_peak_memory_stats'): | |
| torch.cuda.reset_peak_memory_stats() | |
| if hasattr(torch.cuda, 'reset_accumulated_memory_stats'): | |
| torch.cuda.reset_accumulated_memory_stats() | |
| time.sleep(0.1) | |
| # Step 8: Force PyTorch to release all unused memory | |
| if torch.cuda.is_available(): | |
| try: | |
| # Try to trigger the memory pool cleanup | |
| torch.cuda.empty_cache() | |
| # Force a small allocation and deallocation to trigger cleanup | |
| dummy_tensor = torch.zeros(1, device='cuda') | |
| del dummy_tensor | |
| torch.cuda.empty_cache() | |
| logger.info("Forced memory pool cleanup") | |
| except Exception as cleanup_error: | |
| logger.warning(f"Advanced cleanup failed: {cleanup_error}") | |
| # Step 9: Final garbage collection | |
| gc.collect() | |
| logger.info("Model unloaded successfully with aggressive cleanup") | |
| return "โ Model unloaded with aggressive memory cleanup" | |
| except Exception as e: | |
| logger.error(f"Error unloading model: {str(e)}", exc_info=True) | |
| # Emergency cleanup even if unload fails | |
| model = None | |
| tokenizer = None | |
| MODEL_NAME = "No Model Loaded" | |
| generation_metadata.clear() | |
| # Emergency memory cleanup | |
| for _ in range(3): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return f"โ Error unloading model: {str(e)} (Emergency cleanup performed)" | |
| def cleanup_memory(): | |
| """Enhanced memory cleanup function with PyTorch optimizations""" | |
| try: | |
| # Clear Python garbage | |
| gc.collect() | |
| # Clear CUDA cache if available | |
| if torch.cuda.is_available(): | |
| # Multiple aggressive cleanup rounds | |
| for i in range(3): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| if hasattr(torch.cuda, 'ipc_collect'): | |
| torch.cuda.ipc_collect() | |
| # PyTorch specific memory management | |
| if hasattr(torch.cuda, 'reset_peak_memory_stats'): | |
| torch.cuda.reset_peak_memory_stats() | |
| if hasattr(torch.cuda, 'reset_accumulated_memory_stats'): | |
| torch.cuda.reset_accumulated_memory_stats() | |
| # Brief pause between cleanup rounds | |
| time.sleep(0.1) | |
| # Clear metadata deque | |
| generation_metadata.clear() | |
| # Force garbage collection again | |
| gc.collect() | |
| logger.info("Enhanced memory cleanup completed") | |
| return "๐งน Enhanced memory cleanup completed" | |
| except Exception as e: | |
| logger.error(f"Memory cleanup error: {e}") | |
| return f"Memory cleanup error: {e}" | |
| def nuclear_memory_cleanup(): | |
| """Nuclear option: Complete VRAM reset (use if normal unload doesn't work)""" | |
| global model, tokenizer, MODEL_NAME | |
| try: | |
| logger.info("Performing nuclear memory cleanup...") | |
| # Force unload everything | |
| model = None | |
| tokenizer = None | |
| MODEL_NAME = "No Model Loaded" | |
| generation_metadata.clear() | |
| # Import PyTorch again to reset some internal states | |
| import torch | |
| # Multiple aggressive cleanup rounds | |
| for round_num in range(10): # Very aggressive - 10 rounds | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| # Multiple types of CUDA cleanup | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| # Try to reset CUDA context | |
| try: | |
| if hasattr(torch.cuda, 'ipc_collect'): | |
| torch.cuda.ipc_collect() | |
| if hasattr(torch.cuda, 'memory_summary'): | |
| logger.info(f"Round {round_num + 1}: {torch.cuda.memory_summary()}") | |
| except Exception: | |
| pass | |
| # Reset memory stats | |
| try: | |
| if hasattr(torch.cuda, 'reset_peak_memory_stats'): | |
| torch.cuda.reset_peak_memory_stats() | |
| if hasattr(torch.cuda, 'reset_accumulated_memory_stats'): | |
| torch.cuda.reset_accumulated_memory_stats() | |
| except Exception: | |
| pass | |
| time.sleep(0.1) | |
| # Final attempt: allocate and free a small tensor to trigger cleanup | |
| if torch.cuda.is_available(): | |
| try: | |
| for _ in range(5): | |
| dummy = torch.zeros(1024, 1024, device='cuda') # 4MB tensor | |
| del dummy | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| except Exception as nuclear_error: | |
| logger.warning(f"Nuclear tensor cleanup failed: {nuclear_error}") | |
| logger.info("Nuclear memory cleanup completed") | |
| return "โข๏ธ Nuclear memory cleanup completed! VRAM should be minimal now." | |
| except Exception as e: | |
| logger.error(f"Nuclear cleanup error: {e}") | |
| return f"โข๏ธ Nuclear cleanup error: {e}" | |
| def get_memory_stats(): | |
| """Get comprehensive VRAM usage information""" | |
| if not torch.cuda.is_available(): | |
| return """ | |
| <div style="text-align: center; padding: 15px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white;"> | |
| <h3 style="margin: 0; font-size: 16px;">๐ป CPU Mode</h3> | |
| <p style="margin: 5px 0; opacity: 0.9;">GPU not available</p> | |
| </div> | |
| """ | |
| try: | |
| torch.cuda.synchronize() | |
| total = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| allocated = torch.cuda.memory_allocated(0) / (1024**3) | |
| reserved = torch.cuda.memory_reserved(0) / (1024**3) | |
| free = total - reserved | |
| usage_percent = (reserved/total)*100 | |
| # Get peak memory if available | |
| peak_allocated = 0 | |
| if hasattr(torch.cuda, 'max_memory_allocated'): | |
| peak_allocated = torch.cuda.max_memory_allocated(0) / (1024**3) | |
| # Dynamic color based on usage | |
| if usage_percent < 50: | |
| color = "#10b981" # Green | |
| elif usage_percent < 80: | |
| color = "#f59e0b" # Orange | |
| else: | |
| color = "#ef4444" # Red | |
| return f""" | |
| <div style="text-align: center; padding: 15px; background: linear-gradient(135deg, {color}22 0%, {color}44 100%); border: 2px solid {color}; border-radius: 10px;"> | |
| <h3 style="margin: 0; font-size: 16px; color: {color};">๐ฎ VRAM Usage</h3> | |
| <div style="margin: 10px 0;"> | |
| <div style="background: #f3f4f6; border-radius: 8px; height: 8px; overflow: hidden;"> | |
| <div style="width: {usage_percent}%; height: 100%; background: {color}; transition: width 0.3s ease;"></div> | |
| </div> | |
| </div> | |
| <p style="margin: 5px 0; font-weight: 600;">Total: {total:.2f} GB</p> | |
| <p style="margin: 5px 0;">Allocated: {allocated:.2f} GB ({usage_percent:.1f}%)</p> | |
| <p style="margin: 5px 0;">Reserved: {reserved:.2f} GB</p> | |
| <p style="margin: 5px 0;">Free: {free:.2f} GB</p> | |
| <p style="margin: 5px 0; font-size: 12px; opacity: 0.8;">Peak: {peak_allocated:.2f} GB</p> | |
| <p style="margin: 5px 0; font-size: 10px; opacity: 0.6;">RAM-Saving Streamer Active</p> | |
| </div> | |
| """ | |
| except Exception as e: | |
| logger.error(f"Error getting memory stats: {str(e)}") | |
| return f""" | |
| <div style="text-align: center; padding: 15px; background: #fee2e2; border: 2px solid #ef4444; border-radius: 10px;"> | |
| <h3 style="margin: 0; color: #ef4444;">โ Error</h3> | |
| <p style="margin: 5px 0;">{str(e)}</p> | |
| </div> | |
| """ | |
| def process_latex_content(text): | |
| """Enhanced LaTeX processing for streaming without UI glitches""" | |
| # Don't process LaTeX here - let Gradio handle it natively | |
| # Just return the text as-is for now | |
| return text | |
| def process_think_tags(text): | |
| """Process thinking tags with progressive streaming support""" | |
| # Check if we're in the middle of generating a think section | |
| if '<think>' in text and '</think>' not in text: | |
| # We're currently generating inside a think section | |
| parts = text.split('<think>') | |
| if len(parts) == 2: | |
| before_think = parts[0] | |
| thinking_content = parts[1] | |
| # Create a progressive thinking display | |
| formatted_thinking = f""" | |
| <div style="background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); border-left: 4px solid #6366f1; padding: 12px; margin: 8px 0; border-radius: 8px;"> | |
| <div style="display: flex; align-items: center; margin-bottom: 8px;"> | |
| <span style="font-size: 16px; margin-right: 8px;">๐ค</span> | |
| <strong style="color: #4338ca;">Thinking...</strong> | |
| </div> | |
| <div style="color: #475569; font-style: italic;">{thinking_content}</div> | |
| </div> | |
| """ | |
| return before_think + formatted_thinking | |
| # Handle completed think sections | |
| think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL) | |
| def replace_think(match): | |
| think_content = match.group(1).strip() | |
| return f""" | |
| <div style="background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); border-left: 4px solid #6366f1; padding: 12px; margin: 8px 0; border-radius: 8px;"> | |
| <div style="display: flex; align-items: center; margin-bottom: 8px;"> | |
| <span style="font-size: 16px; margin-right: 8px;">๐ค</span> | |
| <strong style="color: #4338ca;">Thinking...</strong> | |
| </div> | |
| <div style="color: #475569; font-style: italic;">{think_content}</div> | |
| </div> | |
| """ | |
| # Replace completed <think> tags with formatted version | |
| processed_text = think_pattern.sub(replace_think, text) | |
| return processed_text | |
| def calculate_generation_metrics(start_time, total_tokens): | |
| """Calculate generation metrics""" | |
| end_time = time.time() | |
| generation_time = end_time - start_time | |
| tokens_per_second = total_tokens / generation_time if generation_time > 0 else 0 | |
| return { | |
| "generation_time": generation_time, | |
| "total_tokens": total_tokens, | |
| "tokens_per_second": tokens_per_second, | |
| "model_name": MODEL_NAME | |
| } | |
| def format_metadata_tooltip(metadata): | |
| """Format metadata for tooltip display""" | |
| return f"""Model: {metadata['model_name']} | |
| Tokens: {metadata['total_tokens']} | |
| Speed: {metadata['tokens_per_second']:.2f} tok/s | |
| Time: {metadata['generation_time']:.2f}s""" | |
| def add_metadata_to_response(response_text, metadata): | |
| """Add metadata icon with tooltip to the response""" | |
| tooltip_content = format_metadata_tooltip(metadata) | |
| # Create a metadata icon with tooltip using HTML | |
| metadata_html = f""" | |
| <div style="position: relative; display: inline-block; margin-left: 8px;"> | |
| <span class="metadata-icon" style="cursor: help; opacity: 0.6; font-size: 14px;" title="{tooltip_content}">โน๏ธ</span> | |
| </div> | |
| """ | |
| # Add metadata icon at the end of the response | |
| return response_text + "\n\n" + metadata_html | |
| def chat_with_model(message, history, system_prompt, temp, top_p_val, top_k_val, rep_penalty_val, memory_opt): | |
| """ | |
| Enhanced chat function with RAM-saving streamer and improved memory management. | |
| Uses direct generation approach for better memory control and VRAM efficiency. | |
| """ | |
| global model, tokenizer, generation_metadata | |
| # Check if model is loaded | |
| if model is None or tokenizer is None: | |
| return "โ Model not loaded. Please load the model first." | |
| # Initialize variables for cleanup | |
| input_ids = None | |
| streamer = None | |
| try: | |
| # Record start time for metrics | |
| start_time = time.time() | |
| token_count = 0 | |
| # Format conversation for model | |
| messages = [{"role": "system", "content": system_prompt}] | |
| # Add chat history - HANDLE BOTH FORMATS (tuples from original and dicts from new) | |
| for h in history: | |
| if isinstance(h, dict): | |
| # New dict format | |
| if h.get("role") == "user": | |
| messages.append({"role": "user", "content": h["content"]}) | |
| elif h.get("role") == "assistant": | |
| messages.append({"role": "assistant", "content": h["content"]}) | |
| else: | |
| # Original tuple format (user_msg, bot_msg) | |
| if len(h) >= 2: | |
| messages.append({"role": "user", "content": h[0]}) | |
| if h[1] is not None: | |
| messages.append({"role": "assistant", "content": h[1]}) | |
| # Add the current message | |
| messages.append({"role": "user", "content": message}) | |
| # Wrap generation in torch.no_grad() to prevent gradient accumulation | |
| with torch.no_grad(): | |
| # Create model input with memory-efficient approach | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| # Handle edge case | |
| if input_ids.ndim == 1: | |
| input_ids = input_ids.unsqueeze(0) | |
| # Move to device | |
| input_ids = input_ids.to(model.device) | |
| # Setup RAM-saving streamer | |
| streamer = RAMSavingIteratorStreamer( | |
| tokenizer, | |
| skip_special_tokens=True, | |
| skip_prompt=True, | |
| timeout=1.0 | |
| ) | |
| # Set prompt length for the streamer | |
| streamer.set_prompt_length(input_ids.shape[1]) | |
| # Pre-generation memory cleanup (only if memory optimization is on) | |
| if memory_opt: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Conditional generation parameters based on memory optimization | |
| gen_kwargs = { | |
| "input_ids": input_ids, | |
| "max_new_tokens": MAX_LENGTH, | |
| "temperature": temp, | |
| "top_p": top_p_val, | |
| "top_k": top_k_val, | |
| "repetition_penalty": rep_penalty_val, | |
| "do_sample": temp > 0, | |
| "streamer": streamer, | |
| "use_cache": not memory_opt, # Disable cache only if memory optimization is on | |
| } | |
| # Generate in a thread for real-time streaming | |
| thread = Thread( | |
| target=model.generate, | |
| kwargs=gen_kwargs, | |
| daemon=True | |
| ) | |
| thread.start() | |
| # Stream the response with conditional memory management | |
| partial_text = "" | |
| try: | |
| for new_text in streamer: | |
| partial_text += new_text | |
| token_count += 1 | |
| # Process the text to handle think tags while preserving LaTeX | |
| processed_text = process_think_tags(partial_text) | |
| yield processed_text | |
| # Conditional cleanup based on memory optimization setting (less frequent) | |
| if memory_opt and token_count % 150 == 0: # Reduced frequency for performance | |
| gc.collect() # Only light cleanup if memory optimization is on | |
| except StopIteration: | |
| # Normal end of generation | |
| pass | |
| except Exception as stream_error: | |
| logger.error(f"Streaming error: {stream_error}") | |
| yield f"โ Streaming error: {stream_error}" | |
| return | |
| finally: | |
| # Add metadata to final response | |
| try: | |
| metrics = calculate_generation_metrics(start_time, token_count) | |
| partial_text = add_metadata_to_response(partial_text, metrics) | |
| except Exception as e: | |
| logger.warning(f"Couldn't add metadata: {str(e)}") | |
| yield partial_text | |
| # Ensure thread completion | |
| if thread.is_alive(): | |
| thread.join(timeout=5.0) | |
| if thread.is_alive(): | |
| logger.warning("Generation thread did not complete in time") | |
| # Calculate generation metrics | |
| try: | |
| metrics = calculate_generation_metrics(start_time, token_count) | |
| # Store metadata (using deque with max size to prevent memory leaks) | |
| generation_metadata.append(metrics) | |
| # Log the metrics | |
| logger.info(f"Generation metrics - Tokens: {metrics['total_tokens']}, Speed: {metrics['tokens_per_second']:.2f} tok/s, Time: {metrics['generation_time']:.2f}s") | |
| except Exception as metrics_error: | |
| logger.warning(f"Error calculating metrics: {metrics_error}") | |
| # Final cleanup | |
| try: | |
| # Clean up streamer | |
| if streamer: | |
| streamer.cleanup() | |
| del streamer | |
| streamer = None | |
| # Clean up input tensors | |
| if input_ids is not None: | |
| del input_ids | |
| input_ids = None | |
| # Conditional cleanup based on memory optimization setting | |
| if memory_opt: | |
| # Aggressive cleanup only if memory optimization is enabled | |
| if torch.cuda.is_available(): | |
| for _ in range(2): # Reduced rounds for performance | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| # Force garbage collection | |
| for _ in range(2): | |
| gc.collect() | |
| else: | |
| # Light cleanup for performance mode | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info(f"Generation completed, {token_count} tokens, memory_opt: {memory_opt}, VRAM saved with RAM-saving streamer") | |
| except Exception as cleanup_error: | |
| logger.warning(f"Final cleanup warning: {cleanup_error}") | |
| except Exception as e: | |
| logger.error(f"Error in chat_with_model: {str(e)}", exc_info=True) | |
| # Emergency cleanup | |
| try: | |
| if streamer: | |
| streamer.cleanup() | |
| del streamer | |
| if input_ids is not None: | |
| del input_ids | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception as emergency_cleanup_error: | |
| logger.error(f"Emergency cleanup failed: {emergency_cleanup_error}") | |
| yield f"โ Error: {str(e)}" | |
| def update_model_name(): | |
| """Update the displayed model name""" | |
| return f"๐ฎ AI Chat Assistant ({MODEL_NAME})" | |
| def add_page_refresh_warning(): | |
| """Add JavaScript to warn about page refresh when model is loaded""" | |
| return """ | |
| <script> | |
| window.addEventListener('beforeunload', function (e) { | |
| // Check if model is loaded by looking for specific text in the page | |
| const statusElements = document.querySelectorAll('input[type="text"], textarea'); | |
| let modelLoaded = false; | |
| statusElements.forEach(element => { | |
| if (element.value && element.value.includes('Model loaded successfully')) { | |
| modelLoaded = true; | |
| } | |
| }); | |
| if (modelLoaded) { | |
| e.preventDefault(); | |
| e.returnValue = 'A model is currently loaded. Are you sure you want to leave?'; | |
| return 'A model is currently loaded. Are you sure you want to leave?'; | |
| } | |
| }); | |
| </script> | |
| """ | |
| # Custom CSS for elegant styling with fixed dropdown behavior | |
| custom_css = """ | |
| /* Main container styling */ | |
| .gradio-container { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| min-height: 100vh; | |
| } | |
| /* Header styling */ | |
| .header-text { | |
| background: rgba(255, 255, 255, 0.95); | |
| backdrop-filter: blur(10px); | |
| border-radius: 15px; | |
| padding: 20px; | |
| margin: 20px 0; | |
| text-align: center; | |
| box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1); | |
| border: 1px solid rgba(255, 255, 255, 0.2); | |
| } | |
| /* Chat interface styling */ | |
| .chat-container { | |
| background: rgba(255, 255, 255, 0.95) !important; | |
| border-radius: 20px !important; | |
| box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1) !important; | |
| border: 1px solid rgba(255, 255, 255, 0.2) !important; | |
| backdrop-filter: blur(10px) !important; | |
| } | |
| /* Control panel styling */ | |
| .control-panel { | |
| background: rgba(255, 255, 255, 0.9) !important; | |
| border-radius: 15px !important; | |
| padding: 20px !important; | |
| box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1) !important; | |
| border: 1px solid rgba(255, 255, 255, 0.3) !important; | |
| backdrop-filter: blur(10px) !important; | |
| overflow: visible !important; /* Allow dropdowns to overflow */ | |
| } | |
| /* Button styling */ | |
| .btn-primary { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| border-radius: 10px !important; | |
| color: white !important; | |
| font-weight: 600 !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; | |
| } | |
| .btn-primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 8px 25px rgba(102, 126, 234, 0.6) !important; | |
| } | |
| .btn-secondary { | |
| background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; | |
| border: none !important; | |
| border-radius: 10px !important; | |
| color: white !important; | |
| font-weight: 600 !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| /* Input field styling */ | |
| .input-field { | |
| border-radius: 10px !important; | |
| border: 2px solid rgba(102, 126, 234, 0.2) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .input-field:focus { | |
| border-color: #667eea !important; | |
| box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important; | |
| } | |
| /* Dropdown fixes */ | |
| .dropdown-container { | |
| position: relative !important; | |
| z-index: 1000 !important; | |
| overflow: visible !important; | |
| } | |
| /* Fix dropdown menu positioning and styling */ | |
| .dropdown select, | |
| .dropdown-menu, | |
| .svelte-select, | |
| .svelte-select-list { | |
| position: relative !important; | |
| z-index: 1001 !important; | |
| background: white !important; | |
| border: 2px solid rgba(102, 126, 234, 0.2) !important; | |
| border-radius: 10px !important; | |
| box-shadow: 0 4px 20px rgba(0, 0, 0, 0.15) !important; | |
| max-height: 200px !important; | |
| overflow-y: auto !important; | |
| } | |
| /* Fix dropdown option styling */ | |
| .dropdown option, | |
| .svelte-select-option { | |
| padding: 8px 12px !important; | |
| background: white !important; | |
| color: #333 !important; | |
| border: none !important; | |
| } | |
| .dropdown option:hover, | |
| .svelte-select-option:hover { | |
| background: #f0f0f0 !important; | |
| color: #667eea !important; | |
| } | |
| /* Ensure dropdown arrow is clickable */ | |
| .dropdown::after, | |
| .dropdown-arrow { | |
| pointer-events: none !important; | |
| z-index: 1002 !important; | |
| } | |
| /* Fix any overflow issues in parent containers */ | |
| .gradio-group, | |
| .gradio-column { | |
| overflow: visible !important; | |
| } | |
| /* Accordion styling */ | |
| .accordion { | |
| border-radius: 10px !important; | |
| border: 1px solid rgba(102, 126, 234, 0.2) !important; | |
| overflow: visible !important; /* Allow dropdowns to overflow accordion */ | |
| } | |
| /* Status indicators */ | |
| .status-success { | |
| color: #10b981 !important; | |
| font-weight: 600 !important; | |
| } | |
| .status-error { | |
| color: #ef4444 !important; | |
| font-weight: 600 !important; | |
| } | |
| /* Reduced transition frequency to avoid conflicts */ | |
| .gradio-container * { | |
| transition: background-color 0.3s ease, border-color 0.3s ease !important; | |
| } | |
| /* Chat bubble styling */ | |
| .message { | |
| border-radius: 18px !important; | |
| padding: 12px 16px !important; | |
| margin: 8px 0 !important; | |
| max-width: 80% !important; | |
| } | |
| .user-message { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| margin-left: auto !important; | |
| } | |
| .bot-message { | |
| background: #f8fafc !important; | |
| border: 1px solid #e2e8f0 !important; | |
| } | |
| /* Metadata tooltip styling - Enhanced */ | |
| .metadata-icon { | |
| display: inline-block; | |
| margin-left: 8px; | |
| cursor: help; | |
| opacity: 0.6; | |
| transition: opacity 0.3s ease, transform 0.2s ease; | |
| font-size: 14px; | |
| user-select: none; | |
| vertical-align: middle; | |
| } | |
| .metadata-icon:hover { | |
| opacity: 1; | |
| transform: scale(1.1); | |
| } | |
| /* Enhanced tooltip styling */ | |
| .metadata-icon[title]:hover::after { | |
| content: attr(title); | |
| position: absolute; | |
| bottom: 100%; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| background: rgba(0, 0, 0, 0.9); | |
| color: white; | |
| padding: 8px 12px; | |
| border-radius: 6px; | |
| font-size: 12px; | |
| white-space: pre-line; | |
| z-index: 1000; | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); | |
| margin-bottom: 5px; | |
| min-width: 200px; | |
| text-align: left; | |
| } | |
| .metadata-icon[title]:hover::before { | |
| content: ''; | |
| position: absolute; | |
| bottom: 100%; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| border: 5px solid transparent; | |
| border-top-color: rgba(0, 0, 0, 0.9); | |
| z-index: 1001; | |
| } | |
| /* Compact system prompt */ | |
| .compact-prompt { | |
| min-height: 40px !important; | |
| transition: min-height 0.3s ease !important; | |
| } | |
| .compact-prompt:focus { | |
| min-height: 80px !important; | |
| } | |
| """ | |
| # Main application | |
| with gr.Blocks(css=custom_css, title="๐ฎ AI Chat Assistant") as demo: | |
| # Add page refresh warning script | |
| gr.HTML(add_page_refresh_warning()) | |
| # Header | |
| with gr.Row(): | |
| title = gr.Markdown("# ๐ฎ AI Chat Assistant (No Model Loaded)", elem_classes="header-text") | |
| with gr.Row(equal_height=True): | |
| # Main chat area (left side - 70% width) | |
| with gr.Column(scale=7, elem_classes="chat-container"): | |
| # Compact system prompt (changed from 4 lines to 1) | |
| system_prompt = gr.Textbox( | |
| label="๐ฏ System Prompt", | |
| value="You are a helpful AI assistant.", | |
| lines=1, # Changed from 4 to 1 | |
| elem_classes="input-field compact-prompt" | |
| ) | |
| # Generation settings in accordion | |
| with gr.Accordion("โ๏ธ Generation Settings", open=False, elem_classes="accordion"): | |
| with gr.Row(): | |
| temperature = gr.Slider(0.0, 2.0, DEFAULT_TEMPERATURE, step=0.05, label="๐ก๏ธ Temperature") | |
| top_p = gr.Slider(0.0, 1.0, DEFAULT_TOP_P, step=0.01, label="๐ฏ Top-p") | |
| with gr.Row(): | |
| top_k = gr.Slider(1, 200, DEFAULT_TOP_K, step=1, label="๐ Top-k") | |
| rep_penalty = gr.Slider(1.0, 2.0, DEFAULT_REP_PENALTY, step=0.01, label="๐ Repetition Penalty") | |
| # Memory optimization for chat (moved here to be defined before use) | |
| memory_opt_chat = gr.Checkbox( | |
| label="๐ง Memory Optimization for Chat", | |
| value=True, | |
| info="Use memory optimization during chat generation (disables KV cache)" | |
| ) | |
| # Chat interface using original gr.ChatInterface for fast streaming and stop button | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| latex_delimiters=[ | |
| {"left": "$", "right": "$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| {"left": "\\(", "right": "\\)", "display": False}, | |
| {"left": "\\[", "right": "\\]", "display": True} | |
| ], | |
| show_copy_button=True, | |
| avatar_images=("๐ค", "๐ค"), | |
| type="messages", | |
| render_markdown=True | |
| ) | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_with_model, | |
| chatbot=chatbot, | |
| additional_inputs=[system_prompt, temperature, top_p, top_k, rep_penalty, memory_opt_chat], | |
| type="messages", | |
| submit_btn="Send ๐ค", | |
| stop_btn="โน๏ธ Stop" | |
| ) | |
| # Control panel (right side - 30% width) | |
| with gr.Column(scale=3, elem_classes="control-panel"): | |
| # Model status and controls | |
| with gr.Group(): | |
| gr.Markdown("### ๐ Model Controls") | |
| with gr.Row(): | |
| load_btn = gr.Button("๐ Load Model", variant="primary", elem_classes="btn-primary") | |
| unload_btn = gr.Button("๐๏ธ Unload", variant="secondary", elem_classes="btn-secondary") | |
| model_status = gr.Textbox( | |
| label="๐ Status", | |
| value="Model not loaded", | |
| interactive=False, | |
| elem_classes="input-field" | |
| ) | |
| progress_display = gr.Textbox( | |
| label="๐ Progress", | |
| value="Ready to load model", | |
| interactive=False, | |
| elem_classes="input-field" | |
| ) | |
| # Model selection | |
| with gr.Group(): | |
| gr.Markdown("### ๐๏ธ Model Selection") | |
| model_source = gr.Radio( | |
| choices=["Hugging Face Model", "Local Path"], | |
| value="Local Path", # Changed default to Local Path | |
| label="๐ Model Source" | |
| ) | |
| # HF Model search and selection (initially hidden) | |
| with gr.Group(visible=False) as hf_group: | |
| model_search = gr.Textbox( | |
| label="๐ Search Models", | |
| placeholder="e.g., microsoft/Phi-3, meta-llama/Llama-3, ykarout/your-model", | |
| elem_classes="input-field" | |
| ) | |
| hf_model = gr.Dropdown( | |
| label="๐ Select Model", | |
| choices=[], | |
| interactive=True, | |
| elem_classes="input-field dropdown-container", | |
| allow_custom_value=True, # Allow typing custom model names | |
| filterable=True # Enable filtering | |
| ) | |
| # Local path group (visible by default) | |
| with gr.Group(visible=True) as local_group: | |
| local_path = gr.Textbox( | |
| value=LOCAL_MODELS_BASE, # Changed default to new base location | |
| label="๐ Local Models Base Path", | |
| elem_classes="input-field" | |
| ) | |
| # Button to refresh local models | |
| refresh_local_btn = gr.Button("๐ Scan Local Models", elem_classes="btn-secondary") | |
| # Dropdown for local models with better configuration | |
| local_models_dropdown = gr.Dropdown( | |
| label="๐ Available Local Models", | |
| choices=[], | |
| interactive=True, | |
| elem_classes="input-field dropdown-container", | |
| allow_custom_value=False, # Don't allow custom for local models | |
| filterable=True # Enable filtering | |
| ) | |
| quantization = gr.Radio( | |
| choices=["4bit", "8bit", "bf16", "f16"], | |
| value="4bit", | |
| label="โก Quantization" | |
| ) | |
| # Advanced memory optimization toggle | |
| memory_optimization = gr.Checkbox( | |
| label="๐ง Advanced Memory Optimization", | |
| value=True, | |
| info="Reduces VRAM usage but may slightly impact speed" | |
| ) | |
| # Note: Memory optimization for chat is now in Generation Settings | |
| # Memory stats with cleanup buttons | |
| with gr.Group(): | |
| gr.Markdown("### ๐พ System Status") | |
| memory_info = gr.HTML() | |
| with gr.Row(): | |
| refresh_btn = gr.Button("โป Refresh Stats", elem_classes="btn-secondary") | |
| cleanup_btn = gr.Button("๐งน Clean Memory", elem_classes="btn-secondary") | |
| with gr.Row(): | |
| nuclear_btn = gr.Button("โข๏ธ Nuclear Cleanup", elem_classes="btn-secondary", variant="stop") | |
| # Event handlers | |
| # Model search functionality for HF | |
| model_search.change( | |
| update_model_dropdown, | |
| inputs=[model_search], | |
| outputs=[hf_model] | |
| ) | |
| # Show/hide model selection based on source | |
| def toggle_model_source(choice): | |
| return ( | |
| gr.Group(visible=choice == "Hugging Face Model"), | |
| gr.Group(visible=choice == "Local Path") | |
| ) | |
| model_source.change( | |
| toggle_model_source, | |
| inputs=[model_source], | |
| outputs=[hf_group, local_group] | |
| ) | |
| # Local model scanning | |
| refresh_local_btn.click( | |
| update_local_models_dropdown, | |
| inputs=[local_path], | |
| outputs=[local_models_dropdown] | |
| ) | |
| # Auto-scan on path change | |
| local_path.change( | |
| update_local_models_dropdown, | |
| inputs=[local_path], | |
| outputs=[local_models_dropdown] | |
| ) | |
| # Model loading with progress | |
| load_btn.click( | |
| load_model_with_progress, | |
| inputs=[model_source, hf_model, local_path, local_models_dropdown, quantization, memory_optimization], | |
| outputs=[progress_display] | |
| ).then( | |
| lambda: "โ Model loaded successfully!" if model is not None else "โ Model loading failed", | |
| outputs=[model_status] | |
| ).then( | |
| get_memory_stats, | |
| outputs=[memory_info] | |
| ).then( | |
| update_model_name, | |
| outputs=[title] | |
| ) | |
| # Model unloading | |
| unload_btn.click( | |
| unload_model, | |
| outputs=[model_status] | |
| ).then( | |
| lambda: "Ready to load model", | |
| outputs=[progress_display] | |
| ).then( | |
| get_memory_stats, | |
| outputs=[memory_info] | |
| ).then( | |
| lambda: "# ๐ฎ AI Chat Assistant (No Model Loaded)", | |
| outputs=[title] | |
| ) | |
| # Refresh memory stats | |
| refresh_btn.click(get_memory_stats, outputs=[memory_info]) | |
| # Manual memory cleanup | |
| cleanup_btn.click(cleanup_memory, outputs=[]).then( | |
| get_memory_stats, outputs=[memory_info] | |
| ) | |
| # Nuclear memory cleanup | |
| nuclear_btn.click(nuclear_memory_cleanup, outputs=[]).then( | |
| get_memory_stats, outputs=[memory_info] | |
| ) | |
| # Initialize on load | |
| demo.load(get_memory_stats, outputs=[memory_info]) | |
| demo.load( | |
| lambda: update_local_models_dropdown(LOCAL_MODELS_BASE), | |
| outputs=[local_models_dropdown] | |
| ) | |
| # Enable queue for streaming | |
| demo.queue() |