MogensR commited on
Commit
724ecba
·
1 Parent(s): b76f5a3

Create tests/test_models.py

Browse files
Files changed (1) hide show
  1. tests/test_models.py +322 -295
tests/test_models.py CHANGED
@@ -1,349 +1,376 @@
1
  """
2
- Tests for the processing pipeline.
3
  """
4
 
5
  import pytest
6
- import numpy as np
7
- import cv2
8
- from unittest.mock import Mock, patch, MagicMock
9
  from pathlib import Path
 
 
10
 
11
- from api.pipeline import (
12
- ProcessingPipeline,
13
- PipelineConfig,
14
- PipelineResult,
15
- ProcessingMode,
16
- PipelineStage
 
 
 
17
  )
18
 
19
 
20
- class TestPipelineConfig:
21
- """Test pipeline configuration."""
 
 
 
 
 
 
22
 
23
- def test_default_config(self):
24
- """Test default configuration values."""
25
- config = PipelineConfig()
26
- assert config.mode == ProcessingMode.PHOTO
27
- assert config.quality_preset == "high"
28
- assert config.use_gpu == True
29
- assert config.enable_cache == True
30
 
31
- def test_custom_config(self):
32
- """Test custom configuration."""
33
- config = PipelineConfig(
34
- mode=ProcessingMode.VIDEO,
35
- quality_preset="ultra",
36
- use_gpu=False,
37
- batch_size=4
 
 
 
 
38
  )
39
- assert config.mode == ProcessingMode.VIDEO
40
- assert config.quality_preset == "ultra"
41
- assert config.use_gpu == False
42
- assert config.batch_size == 4
43
-
44
-
45
- class TestProcessingPipeline:
46
- """Test the main processing pipeline."""
47
 
48
- @pytest.fixture
49
- def mock_pipeline(self, pipeline_config):
50
- """Create a pipeline with mocked components."""
51
- with patch('api.pipeline.ModelFactory') as mock_factory:
52
- with patch('api.pipeline.DeviceManager') as mock_device:
53
- mock_device.return_value.get_device.return_value = 'cpu'
54
- mock_factory.return_value.load_model.return_value = Mock()
55
-
56
- pipeline = ProcessingPipeline(pipeline_config)
57
- return pipeline
58
 
59
- def test_pipeline_initialization(self, mock_pipeline):
60
- """Test pipeline initialization."""
61
- assert mock_pipeline is not None
62
- assert mock_pipeline.config is not None
63
- assert mock_pipeline.current_stage == PipelineStage.INITIALIZATION
64
 
65
- def test_process_image_success(self, mock_pipeline, sample_image, sample_background):
66
- """Test successful image processing."""
67
- # Mock the processing methods
68
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
69
- mock_pipeline.alpha_matting.process = Mock(return_value={
70
- 'alpha': np.ones((512, 512), dtype=np.float32),
71
- 'confidence': 0.95
72
- })
73
-
74
- result = mock_pipeline.process_image(sample_image, sample_background)
75
-
76
- assert result is not None
77
- assert isinstance(result, PipelineResult)
78
- assert result.success == True
79
- assert result.output_image is not None
80
 
81
- def test_process_image_with_effects(self, mock_pipeline, sample_image):
82
- """Test image processing with effects."""
83
- mock_pipeline.config.apply_effects = ['bokeh', 'vignette']
 
 
 
 
 
84
 
85
- # Mock processing
86
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
87
- mock_pipeline.alpha_matting.process = Mock(return_value={
88
- 'alpha': np.ones((512, 512), dtype=np.float32),
89
- 'confidence': 0.95
90
- })
 
 
 
 
 
91
 
92
- result = mock_pipeline.process_image(sample_image, None)
93
 
94
- assert result is not None
95
- assert result.success == True
 
 
 
 
 
 
 
 
 
96
 
97
- def test_process_image_failure(self, mock_pipeline, sample_image):
98
- """Test image processing failure handling."""
99
- # Mock segmentation to fail
100
- mock_pipeline._segment_image = Mock(side_effect=Exception("Segmentation failed"))
101
 
102
- result = mock_pipeline.process_image(sample_image, None)
103
 
104
- assert result is not None
105
- assert result.success == False
106
- assert len(result.errors) > 0
 
 
 
 
 
107
 
108
- @pytest.mark.parametrize("quality", ["low", "medium", "high", "ultra"])
109
- def test_quality_presets(self, mock_pipeline, sample_image, quality):
110
- """Test different quality presets."""
111
- mock_pipeline.config.quality_preset = quality
112
-
113
- # Mock processing
114
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
115
- mock_pipeline.alpha_matting.process = Mock(return_value={
116
- 'alpha': np.ones((512, 512), dtype=np.float32),
117
- 'confidence': 0.95
118
- })
119
-
120
- result = mock_pipeline.process_image(sample_image, None)
121
-
122
- assert result is not None
123
- assert result.success == True
124
 
125
- def test_batch_processing(self, mock_pipeline, sample_image):
126
- """Test batch processing of multiple images."""
127
- images = [sample_image] * 3
128
-
129
- # Mock processing
130
- mock_pipeline.process_image = Mock(return_value=PipelineResult(
131
- success=True,
132
- output_image=sample_image,
133
- quality_score=0.9
134
- ))
 
135
 
136
- results = mock_pipeline.process_batch(images)
 
137
 
138
- assert len(results) == 3
139
- assert all(r.success for r in results)
140
 
141
- def test_progress_callback(self, mock_pipeline, sample_image):
142
- """Test progress callback functionality."""
143
  progress_values = []
144
 
145
- def progress_callback(value, message):
146
- progress_values.append(value)
147
 
148
- mock_pipeline.config.progress_callback = progress_callback
 
 
 
 
 
149
 
150
- # Mock processing
151
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
152
- mock_pipeline.alpha_matting.process = Mock(return_value={
153
- 'alpha': np.ones((512, 512), dtype=np.float32),
154
- 'confidence': 0.95
155
- })
 
156
 
157
- result = mock_pipeline.process_image(sample_image, None)
158
 
159
- assert len(progress_values) > 0
160
- assert 0.0 <= max(progress_values) <= 1.0
161
 
162
- def test_cache_functionality(self, mock_pipeline, sample_image):
163
- """Test caching functionality."""
164
- mock_pipeline.config.enable_cache = True
165
-
166
- # Mock processing
167
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
168
- mock_pipeline.alpha_matting.process = Mock(return_value={
169
- 'alpha': np.ones((512, 512), dtype=np.float32),
170
- 'confidence': 0.95
171
- })
172
-
173
- # First call
174
- result1 = mock_pipeline.process_image(sample_image, None)
175
-
176
- # Second call (should use cache)
177
- result2 = mock_pipeline.process_image(sample_image, None)
178
-
179
- assert result1.success == result2.success
180
- # Verify segmentation was only called once (cache hit on second call)
181
- assert mock_pipeline._segment_image.call_count == 1
182
 
183
- def test_memory_management(self, mock_pipeline):
184
- """Test memory management and cleanup."""
185
- initial_cache_size = len(mock_pipeline.cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- # Process multiple images to fill cache
188
- for i in range(10):
189
- image = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8)
190
- mock_pipeline.cache[f"test_{i}"] = PipelineResult(success=True)
191
 
192
- # Clear cache
193
- mock_pipeline.clear_cache()
194
 
195
- assert len(mock_pipeline.cache) == 0
 
196
 
197
- def test_statistics_tracking(self, mock_pipeline, sample_image):
198
- """Test statistics tracking."""
199
- # Mock processing
200
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
201
- mock_pipeline.alpha_matting.process = Mock(return_value={
202
- 'alpha': np.ones((512, 512), dtype=np.float32),
203
- 'confidence': 0.95
204
- })
205
-
206
- # Process image
207
- result = mock_pipeline.process_image(sample_image, None)
208
-
209
- # Get statistics
210
- stats = mock_pipeline.get_statistics()
211
-
212
- assert 'total_processed' in stats
213
- assert stats['total_processed'] > 0
214
- assert 'avg_time' in stats
215
-
216
-
217
- class TestPipelineIntegration:
218
- """Integration tests for the pipeline."""
219
 
220
- @pytest.mark.integration
221
- @pytest.mark.slow
222
- def test_end_to_end_processing(self, sample_image, sample_background, temp_dir):
223
- """Test end-to-end processing pipeline."""
224
- config = PipelineConfig(
225
- use_gpu=False,
226
- quality_preset="medium",
227
- enable_cache=False
228
  )
 
229
 
230
- # Create pipeline (will use real components if available)
231
- try:
232
- pipeline = ProcessingPipeline(config)
233
- except Exception:
234
- pytest.skip("Models not available for integration test")
235
-
236
- # Process image
237
- result = pipeline.process_image(sample_image, sample_background)
238
-
239
- if result.success:
240
- assert result.output_image is not None
241
- assert result.output_image.shape == sample_image.shape
242
- assert result.quality_score > 0
243
-
244
- # Save output
245
- output_path = temp_dir / "test_output.png"
246
- cv2.imwrite(str(output_path), result.output_image)
247
- assert output_path.exists()
248
 
249
- @pytest.mark.integration
250
- @pytest.mark.slow
251
- def test_video_frame_processing(self, sample_video, temp_dir):
252
- """Test processing video frames."""
253
- config = PipelineConfig(
254
- mode=ProcessingMode.VIDEO,
255
- use_gpu=False,
256
- quality_preset="low"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  )
258
 
259
- try:
260
- pipeline = ProcessingPipeline(config)
261
- except Exception:
262
- pytest.skip("Models not available for integration test")
 
 
 
263
 
264
- # Open video
265
- cap = cv2.VideoCapture(sample_video)
266
- processed_frames = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
- # Process first 5 frames
269
- for i in range(5):
270
- ret, frame = cap.read()
271
- if not ret:
272
- break
273
-
274
- result = pipeline.process_image(frame, None)
275
- if result.success:
276
- processed_frames.append(result.output_image)
277
-
278
- cap.release()
279
-
280
- assert len(processed_frames) > 0
281
-
282
- # Save as video
283
- if processed_frames:
284
- output_path = temp_dir / "test_video_out.mp4"
285
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
286
- out = cv2.VideoWriter(str(output_path), fourcc, 30.0,
287
- (processed_frames[0].shape[1], processed_frames[0].shape[0]))
288
-
289
- for frame in processed_frames:
290
- out.write(frame)
291
-
292
- out.release()
293
- assert output_path.exists()
294
 
295
 
296
- class TestPipelinePerformance:
297
- """Performance tests for the pipeline."""
298
-
299
- @pytest.mark.slow
300
- def test_processing_speed(self, mock_pipeline, sample_image, performance_timer):
301
- """Test processing speed."""
302
- # Mock processing
303
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
304
- mock_pipeline.alpha_matting.process = Mock(return_value={
305
- 'alpha': np.ones((512, 512), dtype=np.float32),
306
- 'confidence': 0.95
307
- })
308
-
309
- with performance_timer as timer:
310
- result = mock_pipeline.process_image(sample_image, None)
311
-
312
- assert result.success == True
313
- assert timer.elapsed < 1.0 # Should process in under 1 second
314
 
 
315
  @pytest.mark.slow
316
- def test_batch_processing_speed(self, mock_pipeline, sample_image, performance_timer):
317
- """Test batch processing speed."""
318
- images = [sample_image] * 10
319
-
320
- # Mock processing
321
- mock_pipeline.process_image = Mock(return_value=PipelineResult(
322
- success=True,
323
- output_image=sample_image,
324
- quality_score=0.9
325
- ))
326
-
327
- with performance_timer as timer:
328
- results = mock_pipeline.process_batch(images)
329
-
330
- assert len(results) == 10
331
- assert timer.elapsed < 5.0 # Should process 10 images in under 5 seconds
332
-
333
- def test_memory_usage(self, mock_pipeline, sample_image):
334
- """Test memory usage during processing."""
335
- import psutil
336
- import os
337
-
338
- process = psutil.Process(os.getpid())
339
- initial_memory = process.memory_info().rss / 1024 / 1024 # MB
340
 
341
- # Process multiple images
342
- for _ in range(10):
343
- mock_pipeline.process_image(sample_image, None)
344
 
345
- final_memory = process.memory_info().rss / 1024 / 1024 # MB
346
- memory_increase = final_memory - initial_memory
347
 
348
- # Memory increase should be reasonable (less than 500MB for 10 images)
349
- assert memory_increase < 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Tests for model management functionality.
3
  """
4
 
5
  import pytest
6
+ import tempfile
 
 
7
  from pathlib import Path
8
+ from unittest.mock import Mock, patch, MagicMock
9
+ import json
10
 
11
+ from models import (
12
+ ModelRegistry,
13
+ ModelInfo,
14
+ ModelStatus,
15
+ ModelTask,
16
+ ModelFramework,
17
+ ModelDownloader,
18
+ ModelLoader,
19
+ ModelOptimizer
20
  )
21
 
22
 
23
+ class TestModelRegistry:
24
+ """Test model registry functionality."""
25
+
26
+ @pytest.fixture
27
+ def registry(self):
28
+ """Create a test registry."""
29
+ temp_dir = tempfile.mkdtemp()
30
+ return ModelRegistry(models_dir=Path(temp_dir))
31
 
32
+ def test_registry_initialization(self, registry):
33
+ """Test registry initialization."""
34
+ assert registry is not None
35
+ assert len(registry.models) > 0 # Should have default models
36
+ assert registry.models_dir.exists()
 
 
37
 
38
+ def test_register_model(self, registry):
39
+ """Test registering a new model."""
40
+ model = ModelInfo(
41
+ model_id="test-model",
42
+ name="Test Model",
43
+ version="1.0",
44
+ task=ModelTask.SEGMENTATION,
45
+ framework=ModelFramework.PYTORCH,
46
+ url="http://example.com/model.pth",
47
+ filename="test.pth",
48
+ file_size=1000000
49
  )
50
+
51
+ success = registry.register_model(model)
52
+ assert success == True
53
+ assert "test-model" in registry.models
 
 
 
 
54
 
55
+ def test_get_model(self, registry):
56
+ """Test getting a model by ID."""
57
+ model = registry.get_model("rmbg-1.4")
58
+ assert model is not None
59
+ assert model.model_id == "rmbg-1.4"
60
+ assert model.task == ModelTask.SEGMENTATION
 
 
 
 
61
 
62
+ def test_list_models_by_task(self, registry):
63
+ """Test listing models by task."""
64
+ segmentation_models = registry.list_models(task=ModelTask.SEGMENTATION)
65
+ assert len(segmentation_models) > 0
66
+ assert all(m.task == ModelTask.SEGMENTATION for m in segmentation_models)
67
 
68
+ def test_list_models_by_framework(self, registry):
69
+ """Test listing models by framework."""
70
+ pytorch_models = registry.list_models(framework=ModelFramework.PYTORCH)
71
+ onnx_models = registry.list_models(framework=ModelFramework.ONNX)
72
+
73
+ assert all(m.framework == ModelFramework.PYTORCH for m in pytorch_models)
74
+ assert all(m.framework == ModelFramework.ONNX for m in onnx_models)
 
 
 
 
 
 
 
 
75
 
76
+ def test_get_best_model(self, registry):
77
+ """Test getting best model for a task."""
78
+ # Best for accuracy
79
+ best_accuracy = registry.get_best_model(
80
+ ModelTask.SEGMENTATION,
81
+ prefer_speed=False
82
+ )
83
+ assert best_accuracy is not None
84
 
85
+ # Best for speed
86
+ best_speed = registry.get_best_model(
87
+ ModelTask.SEGMENTATION,
88
+ prefer_speed=True
89
+ )
90
+ assert best_speed is not None
91
+
92
+ def test_update_model_usage(self, registry):
93
+ """Test updating model usage statistics."""
94
+ model_id = "rmbg-1.4"
95
+ initial_count = registry.models[model_id].use_count
96
 
97
+ registry.update_model_usage(model_id)
98
 
99
+ assert registry.models[model_id].use_count == initial_count + 1
100
+ assert registry.models[model_id].last_used is not None
101
+
102
+ def test_get_total_size(self, registry):
103
+ """Test calculating total model size."""
104
+ total_size = registry.get_total_size()
105
+ assert total_size > 0
106
+
107
+ # Size of available models should be 0 initially
108
+ available_size = registry.get_total_size(status=ModelStatus.AVAILABLE)
109
+ assert available_size == 0
110
 
111
+ def test_export_registry(self, registry, temp_dir):
112
+ """Test exporting registry to file."""
113
+ export_path = temp_dir / "registry_export.json"
114
+ registry.export_registry(export_path)
115
 
116
+ assert export_path.exists()
117
 
118
+ with open(export_path) as f:
119
+ data = json.load(f)
120
+ assert "models" in data
121
+ assert len(data["models"]) > 0
122
+
123
+
124
+ class TestModelDownloader:
125
+ """Test model downloading functionality."""
126
 
127
+ @pytest.fixture
128
+ def downloader(self, mock_registry):
129
+ """Create a test downloader."""
130
+ return ModelDownloader(mock_registry)
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ @patch('requests.get')
133
+ def test_download_model(self, mock_get, downloader):
134
+ """Test downloading a model."""
135
+ # Mock HTTP response
136
+ mock_response = MagicMock()
137
+ mock_response.headers = {'content-length': '1000000'}
138
+ mock_response.iter_content = MagicMock(
139
+ return_value=[b'data' * 1000]
140
+ )
141
+ mock_response.raise_for_status = MagicMock()
142
+ mock_get.return_value = mock_response
143
 
144
+ # Test download
145
+ success = downloader.download_model("test-model", force=True)
146
 
147
+ assert mock_get.called
148
+ # Note: Full download test would require more mocking
149
 
150
+ def test_download_progress_tracking(self, downloader):
151
+ """Test download progress tracking."""
152
  progress_values = []
153
 
154
+ def progress_callback(progress):
155
+ progress_values.append(progress.progress)
156
 
157
+ # Start a download (will fail but we can test progress initialization)
158
+ with patch.object(downloader, '_download_model_task', return_value=True):
159
+ downloader.download_model(
160
+ "test-model",
161
+ progress_callback=progress_callback
162
+ )
163
 
164
+ assert "test-model" in downloader.downloads
165
+
166
+ def test_cancel_download(self, downloader):
167
+ """Test cancelling a download."""
168
+ # Start a mock download
169
+ downloader.downloads["test-model"] = Mock()
170
+ downloader._stop_events["test-model"] = Mock()
171
 
172
+ success = downloader.cancel_download("test-model")
173
 
174
+ assert success == True
175
+ assert downloader._stop_events["test-model"].set.called
176
 
177
+ def test_download_with_resume(self, downloader, temp_dir):
178
+ """Test download with resume support."""
179
+ # Create a partial file
180
+ partial_file = temp_dir / "test.pth.part"
181
+ partial_file.write_bytes(b"partial_data")
182
+
183
+ # Mock download would check for partial file
184
+ assert partial_file.exists()
185
+ assert partial_file.stat().st_size > 0
186
+
187
+
188
+ class TestModelLoader:
189
+ """Test model loading functionality."""
190
+
191
+ @pytest.fixture
192
+ def loader(self, mock_registry):
193
+ """Create a test loader."""
194
+ return ModelLoader(mock_registry, device='cpu')
 
 
195
 
196
+ def test_loader_initialization(self, loader):
197
+ """Test loader initialization."""
198
+ assert loader is not None
199
+ assert loader.device == 'cpu'
200
+ assert loader.max_memory_bytes > 0
201
+
202
+ @patch('torch.load')
203
+ def test_load_pytorch_model(self, mock_torch_load, loader):
204
+ """Test loading a PyTorch model."""
205
+ mock_model = MagicMock()
206
+ mock_torch_load.return_value = mock_model
207
+
208
+ # Mock model info
209
+ model_info = ModelInfo(
210
+ model_id="test-pytorch",
211
+ name="Test PyTorch Model",
212
+ version="1.0",
213
+ task=ModelTask.SEGMENTATION,
214
+ framework=ModelFramework.PYTORCH,
215
+ url="",
216
+ filename="model.pth",
217
+ local_path="/tmp/model.pth",
218
+ status=ModelStatus.AVAILABLE
219
+ )
220
 
221
+ loader.registry.get_model = Mock(return_value=model_info)
 
 
 
222
 
223
+ with patch.object(Path, 'exists', return_value=True):
224
+ loaded = loader.load_model("test-pytorch")
225
 
226
+ # Note: Full test would require more setup
227
+ assert mock_torch_load.called
228
 
229
+ def test_memory_management(self, loader):
230
+ """Test memory management during model loading."""
231
+ # Add mock models to loaded cache
232
+ for i in range(5):
233
+ loader.loaded_models[f"model_{i}"] = Mock(
234
+ memory_usage=100 * 1024 * 1024 # 100MB each
235
+ )
236
+
237
+ loader.current_memory_usage = 500 * 1024 * 1024 # 500MB
238
+
239
+ # Free memory
240
+ loader._free_memory(200 * 1024 * 1024) # Need 200MB
241
+
242
+ # Should have freed at least 2 models
243
+ assert len(loader.loaded_models) < 5
 
 
 
 
 
 
 
244
 
245
+ def test_unload_model(self, loader):
246
+ """Test unloading a model."""
247
+ # Add a mock model
248
+ loader.loaded_models["test"] = Mock(
249
+ model=Mock(),
250
+ memory_usage=100 * 1024 * 1024
 
 
251
  )
252
+ loader.current_memory_usage = 100 * 1024 * 1024
253
 
254
+ success = loader.unload_model("test")
255
+
256
+ assert success == True
257
+ assert "test" not in loader.loaded_models
258
+ assert loader.current_memory_usage == 0
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ def test_get_memory_usage(self, loader):
261
+ """Test getting memory usage statistics."""
262
+ # Add mock models
263
+ loader.loaded_models["model1"] = Mock(memory_usage=100 * 1024 * 1024)
264
+ loader.loaded_models["model2"] = Mock(memory_usage=200 * 1024 * 1024)
265
+ loader.current_memory_usage = 300 * 1024 * 1024
266
+
267
+ usage = loader.get_memory_usage()
268
+
269
+ assert usage["current_usage_mb"] == 300
270
+ assert usage["loaded_models"] == 2
271
+ assert "model1" in usage["models"]
272
+ assert "model2" in usage["models"]
273
+
274
+
275
+ class TestModelOptimizer:
276
+ """Test model optimization functionality."""
277
+
278
+ @pytest.fixture
279
+ def optimizer(self, mock_registry):
280
+ """Create a test optimizer."""
281
+ loader = ModelLoader(mock_registry, device='cpu')
282
+ return ModelOptimizer(loader)
283
+
284
+ @patch('torch.quantization.quantize_dynamic')
285
+ def test_quantize_pytorch_model(self, mock_quantize, optimizer):
286
+ """Test PyTorch model quantization."""
287
+ # Create mock model
288
+ mock_model = MagicMock()
289
+ mock_quantize.return_value = mock_model
290
+
291
+ loaded = Mock(
292
+ model_id="test",
293
+ model=mock_model,
294
+ framework=ModelFramework.PYTORCH,
295
+ metadata={'input_size': (1, 3, 512, 512)}
296
  )
297
 
298
+ with patch.object(optimizer, '_get_model_size', return_value=1000000):
299
+ with patch.object(optimizer, '_benchmark_model', return_value=0.1):
300
+ result = optimizer._quantize_pytorch(
301
+ loaded,
302
+ Path("/tmp"),
303
+ "dynamic"
304
+ )
305
 
306
+ assert mock_quantize.called
307
+ # Note: Full test would require more setup
308
+
309
+ def test_optimization_result(self, optimizer):
310
+ """Test optimization result structure."""
311
+ from models.optimizer import OptimizationResult
312
+
313
+ result = OptimizationResult(
314
+ original_size_mb=100,
315
+ optimized_size_mb=25,
316
+ compression_ratio=4.0,
317
+ original_speed_ms=100,
318
+ optimized_speed_ms=50,
319
+ speedup=2.0,
320
+ accuracy_loss=0.01,
321
+ optimization_time=10.0,
322
+ output_path="/tmp/optimized.pth"
323
+ )
324
 
325
+ assert result.compression_ratio == 4.0
326
+ assert result.speedup == 2.0
327
+ assert result.accuracy_loss == 0.01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
 
330
+ class TestModelIntegration:
331
+ """Integration tests for model management."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
+ @pytest.mark.integration
334
  @pytest.mark.slow
335
+ def test_model_registry_persistence(self, temp_dir):
336
+ """Test registry persistence across instances."""
337
+ # Create registry and add model
338
+ registry1 = ModelRegistry(models_dir=temp_dir)
339
+
340
+ test_model = ModelInfo(
341
+ model_id="persistence-test",
342
+ name="Persistence Test",
343
+ version="1.0",
344
+ task=ModelTask.SEGMENTATION,
345
+ framework=ModelFramework.PYTORCH,
346
+ url="http://example.com/model.pth",
347
+ filename="persist.pth"
348
+ )
 
 
 
 
 
 
 
 
 
 
349
 
350
+ registry1.register_model(test_model)
 
 
351
 
352
+ # Create new registry instance
353
+ registry2 = ModelRegistry(models_dir=temp_dir)
354
 
355
+ # Check if model persisted
356
+ loaded_model = registry2.get_model("persistence-test")
357
+ assert loaded_model is not None
358
+ assert loaded_model.name == "Persistence Test"
359
+
360
+ @pytest.mark.integration
361
+ def test_model_manager_workflow(self):
362
+ """Test complete model manager workflow."""
363
+ from models import create_model_manager
364
+
365
+ manager = create_model_manager()
366
+
367
+ # Test model discovery
368
+ stats = manager.get_stats()
369
+ assert "registry" in stats
370
+ assert stats["registry"]["total_models"] > 0
371
+
372
+ # Test benchmark (without actual model loading)
373
+ with patch.object(manager.loader, 'load_model', return_value=Mock()):
374
+ benchmarks = manager.benchmark()
375
+ # Would return empty without real models
376
+ assert isinstance(benchmarks, dict)