""" Tests for model management functionality. """ import pytest import tempfile from pathlib import Path from unittest.mock import Mock, patch, MagicMock import json from models import ( ModelRegistry, ModelInfo, ModelStatus, ModelTask, ModelFramework, ModelDownloader, ModelLoader, ModelOptimizer ) class TestModelRegistry: """Test model registry functionality.""" @pytest.fixture def registry(self): """Create a test registry.""" temp_dir = tempfile.mkdtemp() return ModelRegistry(models_dir=Path(temp_dir)) def test_registry_initialization(self, registry): """Test registry initialization.""" assert registry is not None assert len(registry.models) > 0 # Should have default models assert registry.models_dir.exists() def test_register_model(self, registry): """Test registering a new model.""" model = ModelInfo( model_id="test-model", name="Test Model", version="1.0", task=ModelTask.SEGMENTATION, framework=ModelFramework.PYTORCH, url="http://example.com/model.pth", filename="test.pth", file_size=1000000 ) success = registry.register_model(model) assert success == True assert "test-model" in registry.models def test_get_model(self, registry): """Test getting a model by ID.""" model = registry.get_model("rmbg-1.4") assert model is not None assert model.model_id == "rmbg-1.4" assert model.task == ModelTask.SEGMENTATION def test_list_models_by_task(self, registry): """Test listing models by task.""" segmentation_models = registry.list_models(task=ModelTask.SEGMENTATION) assert len(segmentation_models) > 0 assert all(m.task == ModelTask.SEGMENTATION for m in segmentation_models) def test_list_models_by_framework(self, registry): """Test listing models by framework.""" pytorch_models = registry.list_models(framework=ModelFramework.PYTORCH) onnx_models = registry.list_models(framework=ModelFramework.ONNX) assert all(m.framework == ModelFramework.PYTORCH for m in pytorch_models) assert all(m.framework == ModelFramework.ONNX for m in onnx_models) def test_get_best_model(self, registry): """Test getting best model for a task.""" # Best for accuracy best_accuracy = registry.get_best_model( ModelTask.SEGMENTATION, prefer_speed=False ) assert best_accuracy is not None # Best for speed best_speed = registry.get_best_model( ModelTask.SEGMENTATION, prefer_speed=True ) assert best_speed is not None def test_update_model_usage(self, registry): """Test updating model usage statistics.""" model_id = "rmbg-1.4" initial_count = registry.models[model_id].use_count registry.update_model_usage(model_id) assert registry.models[model_id].use_count == initial_count + 1 assert registry.models[model_id].last_used is not None def test_get_total_size(self, registry): """Test calculating total model size.""" total_size = registry.get_total_size() assert total_size > 0 # Size of available models should be 0 initially available_size = registry.get_total_size(status=ModelStatus.AVAILABLE) assert available_size == 0 def test_export_registry(self, registry, temp_dir): """Test exporting registry to file.""" export_path = temp_dir / "registry_export.json" registry.export_registry(export_path) assert export_path.exists() with open(export_path) as f: data = json.load(f) assert "models" in data assert len(data["models"]) > 0 class TestModelDownloader: """Test model downloading functionality.""" @pytest.fixture def downloader(self, mock_registry): """Create a test downloader.""" return ModelDownloader(mock_registry) @patch('requests.get') def test_download_model(self, mock_get, downloader): """Test downloading a model.""" # Mock HTTP response mock_response = MagicMock() mock_response.headers = {'content-length': '1000000'} mock_response.iter_content = MagicMock( return_value=[b'data' * 1000] ) mock_response.raise_for_status = MagicMock() mock_get.return_value = mock_response # Test download success = downloader.download_model("test-model", force=True) assert mock_get.called # Note: Full download test would require more mocking def test_download_progress_tracking(self, downloader): """Test download progress tracking.""" progress_values = [] def progress_callback(progress): progress_values.append(progress.progress) # Start a download (will fail but we can test progress initialization) with patch.object(downloader, '_download_model_task', return_value=True): downloader.download_model( "test-model", progress_callback=progress_callback ) assert "test-model" in downloader.downloads def test_cancel_download(self, downloader): """Test cancelling a download.""" # Start a mock download downloader.downloads["test-model"] = Mock() downloader._stop_events["test-model"] = Mock() success = downloader.cancel_download("test-model") assert success == True assert downloader._stop_events["test-model"].set.called def test_download_with_resume(self, downloader, temp_dir): """Test download with resume support.""" # Create a partial file partial_file = temp_dir / "test.pth.part" partial_file.write_bytes(b"partial_data") # Mock download would check for partial file assert partial_file.exists() assert partial_file.stat().st_size > 0 class TestModelLoader: """Test model loading functionality.""" @pytest.fixture def loader(self, mock_registry): """Create a test loader.""" return ModelLoader(mock_registry, device='cpu') def test_loader_initialization(self, loader): """Test loader initialization.""" assert loader is not None assert loader.device == 'cpu' assert loader.max_memory_bytes > 0 @patch('torch.load') def test_load_pytorch_model(self, mock_torch_load, loader): """Test loading a PyTorch model.""" mock_model = MagicMock() mock_torch_load.return_value = mock_model # Mock model info model_info = ModelInfo( model_id="test-pytorch", name="Test PyTorch Model", version="1.0", task=ModelTask.SEGMENTATION, framework=ModelFramework.PYTORCH, url="", filename="model.pth", local_path="/tmp/model.pth", status=ModelStatus.AVAILABLE ) loader.registry.get_model = Mock(return_value=model_info) with patch.object(Path, 'exists', return_value=True): loaded = loader.load_model("test-pytorch") # Note: Full test would require more setup assert mock_torch_load.called def test_memory_management(self, loader): """Test memory management during model loading.""" # Add mock models to loaded cache for i in range(5): loader.loaded_models[f"model_{i}"] = Mock( memory_usage=100 * 1024 * 1024 # 100MB each ) loader.current_memory_usage = 500 * 1024 * 1024 # 500MB # Free memory loader._free_memory(200 * 1024 * 1024) # Need 200MB # Should have freed at least 2 models assert len(loader.loaded_models) < 5 def test_unload_model(self, loader): """Test unloading a model.""" # Add a mock model loader.loaded_models["test"] = Mock( model=Mock(), memory_usage=100 * 1024 * 1024 ) loader.current_memory_usage = 100 * 1024 * 1024 success = loader.unload_model("test") assert success == True assert "test" not in loader.loaded_models assert loader.current_memory_usage == 0 def test_get_memory_usage(self, loader): """Test getting memory usage statistics.""" # Add mock models loader.loaded_models["model1"] = Mock(memory_usage=100 * 1024 * 1024) loader.loaded_models["model2"] = Mock(memory_usage=200 * 1024 * 1024) loader.current_memory_usage = 300 * 1024 * 1024 usage = loader.get_memory_usage() assert usage["current_usage_mb"] == 300 assert usage["loaded_models"] == 2 assert "model1" in usage["models"] assert "model2" in usage["models"] class TestModelOptimizer: """Test model optimization functionality.""" @pytest.fixture def optimizer(self, mock_registry): """Create a test optimizer.""" loader = ModelLoader(mock_registry, device='cpu') return ModelOptimizer(loader) @patch('torch.quantization.quantize_dynamic') def test_quantize_pytorch_model(self, mock_quantize, optimizer): """Test PyTorch model quantization.""" # Create mock model mock_model = MagicMock() mock_quantize.return_value = mock_model loaded = Mock( model_id="test", model=mock_model, framework=ModelFramework.PYTORCH, metadata={'input_size': (1, 3, 512, 512)} ) with patch.object(optimizer, '_get_model_size', return_value=1000000): with patch.object(optimizer, '_benchmark_model', return_value=0.1): result = optimizer._quantize_pytorch( loaded, Path("/tmp"), "dynamic" ) assert mock_quantize.called # Note: Full test would require more setup def test_optimization_result(self, optimizer): """Test optimization result structure.""" from models.optimizer import OptimizationResult result = OptimizationResult( original_size_mb=100, optimized_size_mb=25, compression_ratio=4.0, original_speed_ms=100, optimized_speed_ms=50, speedup=2.0, accuracy_loss=0.01, optimization_time=10.0, output_path="/tmp/optimized.pth" ) assert result.compression_ratio == 4.0 assert result.speedup == 2.0 assert result.accuracy_loss == 0.01 class TestModelIntegration: """Integration tests for model management.""" @pytest.mark.integration @pytest.mark.slow def test_model_registry_persistence(self, temp_dir): """Test registry persistence across instances.""" # Create registry and add model registry1 = ModelRegistry(models_dir=temp_dir) test_model = ModelInfo( model_id="persistence-test", name="Persistence Test", version="1.0", task=ModelTask.SEGMENTATION, framework=ModelFramework.PYTORCH, url="http://example.com/model.pth", filename="persist.pth" ) registry1.register_model(test_model) # Create new registry instance registry2 = ModelRegistry(models_dir=temp_dir) # Check if model persisted loaded_model = registry2.get_model("persistence-test") assert loaded_model is not None assert loaded_model.name == "Persistence Test" @pytest.mark.integration def test_model_manager_workflow(self): """Test complete model manager workflow.""" from models import create_model_manager manager = create_model_manager() # Test model discovery stats = manager.get_stats() assert "registry" in stats assert stats["registry"]["total_models"] > 0 # Test benchmark (without actual model loading) with patch.object(manager.loader, 'load_model', return_value=Mock()): benchmarks = manager.benchmark() # Would return empty without real models assert isinstance(benchmarks, dict)