Spaces:
Sleeping
Sleeping
devjas1
(FEAT/REFAC)[Expand Registry & Metadata]: Enhance model registry with new models, richer metadata, and utility functions.
0be85e4
| # models/registry.py | |
| from typing import Callable, Dict, List, Any | |
| from models.figure2_cnn import Figure2CNN | |
| from models.resnet_cnn import ResNet1D | |
| from models.resnet18_vision import ResNet18Vision | |
| from models.enhanced_cnn import EnhancedCNN, EfficientSpectralCNN, HybridSpectralNet | |
| # Internal registry of model builders keyed by short name. | |
| _REGISTRY: Dict[str, Callable[[int], object]] = { | |
| "figure2": lambda L: Figure2CNN(input_length=L), | |
| "resnet": lambda L: ResNet1D(input_length=L), | |
| "resnet18vision": lambda L: ResNet18Vision(input_length=L), | |
| "enhanced_cnn": lambda L: EnhancedCNN(input_length=L), | |
| "efficient_cnn": lambda L: EfficientSpectralCNN(input_length=L), | |
| "hybrid_net": lambda L: HybridSpectralNet(input_length=L), | |
| } | |
| # Model specifications with metadata for enhanced features | |
| _MODEL_SPECS: Dict[str, Dict[str, Any]] = { | |
| "figure2": { | |
| "input_length": 500, | |
| "num_classes": 2, | |
| "description": "Figure 2 baseline custom implementation", | |
| "modalities": ["raman", "ftir"], | |
| "citation": "Neo et al., 2023, Resour. Conserv. Recycl., 188, 106718", | |
| "performance": {"accuracy": 0.948, "f1_score": 0.943}, | |
| "parameters": "~500K", | |
| "speed": "fast", | |
| }, | |
| "resnet": { | |
| "input_length": 500, | |
| "num_classes": 2, | |
| "description": "(Residual Network) uses skip connections to train much deeper networks", | |
| "modalities": ["raman", "ftir"], | |
| "citation": "Custom ResNet implementation", | |
| "performance": {"accuracy": 0.962, "f1_score": 0.959}, | |
| "parameters": "~100K", | |
| "speed": "very_fast", | |
| }, | |
| "resnet18vision": { | |
| "input_length": 500, | |
| "num_classes": 2, | |
| "description": "excels at image recognition tasks by using 'residual blocks' to train more efficiently", | |
| "modalities": ["raman", "ftir"], | |
| "citation": "ResNet18 Vision adaptation", | |
| "performance": {"accuracy": 0.945, "f1_score": 0.940}, | |
| "parameters": "~11M", | |
| "speed": "medium", | |
| }, | |
| "enhanced_cnn": { | |
| "input_length": 500, | |
| "num_classes": 2, | |
| "description": "Enhanced CNN with attention mechanisms and multi-scale feature extraction", | |
| "modalities": ["raman", "ftir"], | |
| "citation": "Custom enhanced architecture with attention", | |
| "performance": {"accuracy": 0.975, "f1_score": 0.973}, | |
| "parameters": "~800K", | |
| "speed": "medium", | |
| "features": ["attention", "multi_scale", "batch_norm", "dropout"], | |
| }, | |
| "efficient_cnn": { | |
| "input_length": 500, | |
| "num_classes": 2, | |
| "description": "Efficient CNN optimized for real-time inference with depthwise separable convolutions", | |
| "modalities": ["raman", "ftir"], | |
| "citation": "Custom efficient architecture", | |
| "performance": {"accuracy": 0.955, "f1_score": 0.952}, | |
| "parameters": "~200K", | |
| "speed": "very_fast", | |
| "features": ["depthwise_separable", "lightweight", "real_time"], | |
| }, | |
| "hybrid_net": { | |
| "input_length": 500, | |
| "num_classes": 2, | |
| "description": "Hybrid network combining CNN backbone with self-attention mechanisms", | |
| "modalities": ["raman", "ftir"], | |
| "citation": "Custom hybrid CNN-Transformer architecture", | |
| "performance": {"accuracy": 0.968, "f1_score": 0.965}, | |
| "parameters": "~1.2M", | |
| "speed": "medium", | |
| "features": ["self_attention", "cnn_backbone", "transformer_head"], | |
| }, | |
| } | |
| # Placeholder for future model expansions | |
| _FUTURE_MODELS = { | |
| "densenet1d": { | |
| "description": "DenseNet1D for spectroscopy with dense connections", | |
| "status": "planned", | |
| "modalities": ["raman", "ftir"], | |
| "features": ["dense_connections", "parameter_efficient"], | |
| }, | |
| "ensemble_cnn": { | |
| "description": "Ensemble of multiple CNN variants for robust predictions", | |
| "status": "planned", | |
| "modalities": ["raman", "ftir"], | |
| "features": ["ensemble", "robust", "high_accuracy"], | |
| }, | |
| "vision_transformer": { | |
| "description": "Vision Transformer adapted for 1D spectral data", | |
| "status": "planned", | |
| "modalities": ["raman", "ftir"], | |
| "features": ["transformer", "attention", "state_of_art"], | |
| }, | |
| "autoencoder_cnn": { | |
| "description": "CNN with autoencoder for unsupervised feature learning", | |
| "status": "planned", | |
| "modalities": ["raman", "ftir"], | |
| "features": ["autoencoder", "unsupervised", "feature_learning"], | |
| }, | |
| } | |
| def choices(): | |
| """Return the list of available model keys.""" | |
| return list(_REGISTRY.keys()) | |
| def planned_models(): | |
| """Return the list of planned future model keys.""" | |
| return list(_FUTURE_MODELS.keys()) | |
| def build(name: str, input_length: int): | |
| """Instantiate a model by short name with the given input length.""" | |
| if name not in _REGISTRY: | |
| raise ValueError(f"Unknown model '{name}'. Choices: {choices()}") | |
| return _REGISTRY[name](input_length) | |
| def build_multiple(names: List[str], input_length: int) -> Dict[str, Any]: | |
| """Nuild multiple models for comparison.""" | |
| models = {} | |
| for name in names: | |
| if name in _REGISTRY: | |
| models[name] = build(name, input_length) | |
| else: | |
| raise ValueError(f"Unknown model '{name}'. Available: {choices()}") | |
| return models | |
| def register_model( | |
| name: str, builder: Callable[[int], object], spec: Dict[str, Any] | |
| ) -> None: | |
| """Dynamically register a new model.""" | |
| if name in _REGISTRY: | |
| raise ValueError(f"Model '{name}' already registered.") | |
| if not callable(builder): | |
| raise TypeError("Builder must be a callable that accepts an integer argument.") | |
| _REGISTRY[name] = builder | |
| _MODEL_SPECS[name] = spec | |
| def spec(name: str): | |
| """Return expected input length and number of classes for a model key.""" | |
| if name in _MODEL_SPECS: | |
| return _MODEL_SPECS[name].copy() | |
| raise KeyError(f"Unknown model '{name}'. Available: {choices()}") | |
| def get_model_info(name: str) -> Dict[str, Any]: | |
| """Get comprehensive model information including metadata.""" | |
| if name in _MODEL_SPECS: | |
| return _MODEL_SPECS[name].copy() | |
| elif name in _FUTURE_MODELS: | |
| return _FUTURE_MODELS[name].copy() | |
| else: | |
| raise KeyError(f"Unknown model '{name}'") | |
| def models_for_modality(modality: str) -> List[str]: | |
| """Get list of models that support a specific modality.""" | |
| compatible = [] | |
| for name, spec_info in _MODEL_SPECS.items(): | |
| if modality in spec_info.get("modalities", []): | |
| compatible.append(name) | |
| return compatible | |
| def validate_model_list(names: List[str]) -> List[str]: | |
| """Validate and return list of available models from input list.""" | |
| available = choices() | |
| valid_models = [] | |
| for name in names: | |
| if name in available: # Fixed: was using 'is' instead of 'in' | |
| valid_models.append(name) | |
| return valid_models | |
| def get_models_metadata() -> Dict[str, Dict[str, Any]]: | |
| """Get metadata for all registered models.""" | |
| return {name: _MODEL_SPECS[name].copy() for name in _MODEL_SPECS} | |
| def is_model_compatible(name: str, modality: str) -> bool: | |
| """Check if a model is compatible with a specific modality.""" | |
| if name not in _MODEL_SPECS: | |
| return False | |
| return modality in _MODEL_SPECS[name].get("modalities", []) | |
| def get_model_capabilities(name: str) -> Dict[str, Any]: | |
| """Get detailed capabilities of a model.""" | |
| if name not in _MODEL_SPECS: | |
| raise KeyError(f"Unknown model '{name}'") | |
| spec = _MODEL_SPECS[name].copy() | |
| spec.update( | |
| { | |
| "available": True, | |
| "status": "active", | |
| "supported_tasks": ["binary_classification"], | |
| "performance_metrics": { | |
| "supports_confidence": True, | |
| "supports_batch": True, | |
| "memory_efficient": spec.get("description", "").lower().find("resnet") | |
| != -1, | |
| }, | |
| } | |
| ) | |
| return spec | |
| __all__ = [ | |
| "choices", | |
| "build", | |
| "spec", | |
| "build_multiple", | |
| "register_model", | |
| "get_model_info", | |
| "models_for_modality", | |
| "validate_model_list", | |
| "planned_models", | |
| "get_models_metadata", | |
| "is_model_compatible", | |
| "get_model_capabilities", | |
| ] | |