Spaces:
Running
Running
| """ | |
| Tokenizer Service - Handles tokenizer loading, caching, and management | |
| """ | |
| import time | |
| from typing import Dict, Tuple, Optional, Any | |
| from transformers import AutoTokenizer | |
| from flask import current_app | |
| class TokenizerService: | |
| """Service for managing tokenizer loading and caching.""" | |
| # Predefined tokenizer models with aliases | |
| TOKENIZER_MODELS = { | |
| 'qwen3': { | |
| 'name': 'Qwen/Qwen3-0.6B', | |
| 'alias': 'Qwen 3' | |
| }, | |
| 'gemma3-27b': { | |
| 'name': 'google/gemma-3-27b-it', | |
| 'alias': 'Gemma 3 27B' | |
| }, | |
| 'glm4': { | |
| 'name': 'THUDM/GLM-4-32B-0414', | |
| 'alias': 'GLM 4' | |
| }, | |
| 'mistral-small': { | |
| 'name': 'mistralai/Mistral-Small-3.1-24B-Instruct-2503', | |
| 'alias': 'Mistral Small 3.1' | |
| }, | |
| 'llama4': { | |
| 'name': 'meta-llama/Llama-4-Scout-17B-16E-Instruct', | |
| 'alias': 'Llama 4' | |
| }, | |
| 'deepseek-r1': { | |
| 'name': 'deepseek-ai/DeepSeek-R1', | |
| 'alias': 'Deepseek R1' | |
| }, | |
| 'qwen_25_72b': { | |
| 'name': 'Qwen/Qwen2.5-72B-Instruct', | |
| 'alias': 'QWQ 32B' | |
| }, | |
| 'llama_33': { | |
| 'name': 'unsloth/Llama-3.3-70B-Instruct-bnb-4bit', | |
| 'alias': 'Llama 3.3 70B' | |
| }, | |
| 'gemma2_2b': { | |
| 'name': 'google/gemma-2-2b-it', | |
| 'alias': 'Gemma 2 2B' | |
| }, | |
| 'bert-large-uncased': { | |
| 'name': 'google-bert/bert-large-uncased', | |
| 'alias': 'Bert Large Uncased' | |
| }, | |
| 'gpt2': { | |
| 'name': 'openai-community/gpt2', | |
| 'alias': 'GPT-2' | |
| } | |
| } | |
| def __init__(self): | |
| """Initialize the tokenizer service with empty caches.""" | |
| self.tokenizers: Dict[str, Any] = {} | |
| self.custom_tokenizers: Dict[str, Tuple[Any, float]] = {} | |
| self.tokenizer_info_cache: Dict[str, Dict] = {} | |
| self.custom_model_errors: Dict[str, str] = {} | |
| def get_tokenizer_info(self, tokenizer) -> Dict: | |
| """Extract useful information from a tokenizer.""" | |
| info = {} | |
| try: | |
| # Get vocabulary size (dictionary size) | |
| if hasattr(tokenizer, 'vocab_size'): | |
| info['vocab_size'] = tokenizer.vocab_size | |
| elif hasattr(tokenizer, 'get_vocab'): | |
| info['vocab_size'] = len(tokenizer.get_vocab()) | |
| # Get model max length if available | |
| if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length < 1000000: | |
| info['model_max_length'] = tokenizer.model_max_length | |
| # Check tokenizer type | |
| info['tokenizer_type'] = tokenizer.__class__.__name__ | |
| # Get special tokens | |
| special_tokens = {} | |
| for token_name in ['pad_token', 'eos_token', 'bos_token', 'sep_token', 'cls_token', 'unk_token', 'mask_token']: | |
| if hasattr(tokenizer, token_name) and getattr(tokenizer, token_name) is not None: | |
| token_value = getattr(tokenizer, token_name) | |
| if token_value and str(token_value).strip(): | |
| special_tokens[token_name] = str(token_value) | |
| info['special_tokens'] = special_tokens | |
| except Exception as e: | |
| info['error'] = f"Error extracting tokenizer info: {str(e)}" | |
| return info | |
| def load_tokenizer(self, model_id_or_name: str) -> Tuple[Optional[Any], Dict, Optional[str]]: | |
| """ | |
| Load tokenizer if not already loaded. | |
| Returns: | |
| Tuple of (tokenizer, tokenizer_info, error_message) | |
| """ | |
| error_message = None | |
| tokenizer_info = {} | |
| # Check if we have cached tokenizer info | |
| if model_id_or_name in self.tokenizer_info_cache: | |
| tokenizer_info = self.tokenizer_info_cache[model_id_or_name] | |
| try: | |
| # Check if it's a predefined model ID | |
| if model_id_or_name in self.TOKENIZER_MODELS: | |
| model_name = self.TOKENIZER_MODELS[model_id_or_name]['name'] | |
| if model_id_or_name not in self.tokenizers: | |
| self.tokenizers[model_id_or_name] = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer = self.tokenizers[model_id_or_name] | |
| # Get tokenizer info if not already cached | |
| if model_id_or_name not in self.tokenizer_info_cache: | |
| tokenizer_info = self.get_tokenizer_info(tokenizer) | |
| self.tokenizer_info_cache[model_id_or_name] = tokenizer_info | |
| return tokenizer, tokenizer_info, None | |
| # It's a custom model path | |
| # Check if we have it in the custom cache and it's not expired | |
| current_time = time.time() | |
| cache_expiration = current_app.config.get('CACHE_EXPIRATION', 3600) | |
| if model_id_or_name in self.custom_tokenizers: | |
| cached_tokenizer, timestamp = self.custom_tokenizers[model_id_or_name] | |
| if current_time - timestamp < cache_expiration: | |
| # Get tokenizer info if not already cached | |
| if model_id_or_name not in self.tokenizer_info_cache: | |
| tokenizer_info = self.get_tokenizer_info(cached_tokenizer) | |
| self.tokenizer_info_cache[model_id_or_name] = tokenizer_info | |
| return cached_tokenizer, tokenizer_info, None | |
| # Not in cache or expired, load it | |
| tokenizer = AutoTokenizer.from_pretrained(model_id_or_name) | |
| # Store in cache with timestamp | |
| self.custom_tokenizers[model_id_or_name] = (tokenizer, current_time) | |
| # Clear any previous errors for this model | |
| if model_id_or_name in self.custom_model_errors: | |
| del self.custom_model_errors[model_id_or_name] | |
| # Get tokenizer info | |
| tokenizer_info = self.get_tokenizer_info(tokenizer) | |
| self.tokenizer_info_cache[model_id_or_name] = tokenizer_info | |
| return tokenizer, tokenizer_info, None | |
| except Exception as e: | |
| error_message = f"Failed to load tokenizer: {str(e)}" | |
| # Store error for future reference | |
| self.custom_model_errors[model_id_or_name] = error_message | |
| return None, tokenizer_info, error_message | |
| def get_model_alias(self, model_id: str) -> str: | |
| """Get the display alias for a model ID.""" | |
| if model_id in self.TOKENIZER_MODELS: | |
| return self.TOKENIZER_MODELS[model_id]['alias'] | |
| return model_id | |
| def is_predefined_model(self, model_id: str) -> bool: | |
| """Check if a model ID is a predefined model.""" | |
| return model_id in self.TOKENIZER_MODELS | |
| def clear_cache(self): | |
| """Clear all caches.""" | |
| self.tokenizers.clear() | |
| self.custom_tokenizers.clear() | |
| self.tokenizer_info_cache.clear() | |
| self.custom_model_errors.clear() | |
| # Global instance | |
| tokenizer_service = TokenizerService() |