Create tests/test_models.py
Browse files- tests/test_models.py +322 -295
tests/test_models.py
CHANGED
|
@@ -1,349 +1,376 @@
|
|
| 1 |
"""
|
| 2 |
-
Tests for
|
| 3 |
"""
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
-
import
|
| 7 |
-
import cv2
|
| 8 |
-
from unittest.mock import Mock, patch, MagicMock
|
| 9 |
from pathlib import Path
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
from
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
|
| 20 |
-
class
|
| 21 |
-
"""Test
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
def
|
| 24 |
-
"""Test
|
| 25 |
-
|
| 26 |
-
assert
|
| 27 |
-
assert
|
| 28 |
-
assert config.use_gpu == True
|
| 29 |
-
assert config.enable_cache == True
|
| 30 |
|
| 31 |
-
def
|
| 32 |
-
"""Test
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
assert
|
| 42 |
-
assert
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class TestProcessingPipeline:
|
| 46 |
-
"""Test the main processing pipeline."""
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
mock_factory.return_value.load_model.return_value = Mock()
|
| 55 |
-
|
| 56 |
-
pipeline = ProcessingPipeline(pipeline_config)
|
| 57 |
-
return pipeline
|
| 58 |
|
| 59 |
-
def
|
| 60 |
-
"""Test
|
| 61 |
-
|
| 62 |
-
assert
|
| 63 |
-
assert
|
| 64 |
|
| 65 |
-
def
|
| 66 |
-
"""Test
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
})
|
| 73 |
-
|
| 74 |
-
result = mock_pipeline.process_image(sample_image, sample_background)
|
| 75 |
-
|
| 76 |
-
assert result is not None
|
| 77 |
-
assert isinstance(result, PipelineResult)
|
| 78 |
-
assert result.success == True
|
| 79 |
-
assert result.output_image is not None
|
| 80 |
|
| 81 |
-
def
|
| 82 |
-
"""Test
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
#
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
|
| 93 |
|
| 94 |
-
assert
|
| 95 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
def
|
| 98 |
-
"""Test
|
| 99 |
-
|
| 100 |
-
|
| 101 |
|
| 102 |
-
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
@pytest.
|
| 109 |
-
def
|
| 110 |
-
"""
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
# Mock processing
|
| 114 |
-
mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
|
| 115 |
-
mock_pipeline.alpha_matting.process = Mock(return_value={
|
| 116 |
-
'alpha': np.ones((512, 512), dtype=np.float32),
|
| 117 |
-
'confidence': 0.95
|
| 118 |
-
})
|
| 119 |
-
|
| 120 |
-
result = mock_pipeline.process_image(sample_image, None)
|
| 121 |
-
|
| 122 |
-
assert result is not None
|
| 123 |
-
assert result.success == True
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
)
|
|
|
|
| 135 |
|
| 136 |
-
|
|
|
|
| 137 |
|
| 138 |
-
assert
|
| 139 |
-
|
| 140 |
|
| 141 |
-
def
|
| 142 |
-
"""Test progress
|
| 143 |
progress_values = []
|
| 144 |
|
| 145 |
-
def progress_callback(
|
| 146 |
-
progress_values.append(
|
| 147 |
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
| 156 |
|
| 157 |
-
|
| 158 |
|
| 159 |
-
assert
|
| 160 |
-
assert
|
| 161 |
|
| 162 |
-
def
|
| 163 |
-
"""Test
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
# Verify segmentation was only called once (cache hit on second call)
|
| 181 |
-
assert mock_pipeline._segment_image.call_count == 1
|
| 182 |
|
| 183 |
-
def
|
| 184 |
-
"""Test
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
|
| 188 |
-
for i in range(10):
|
| 189 |
-
image = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8)
|
| 190 |
-
mock_pipeline.cache[f"test_{i}"] = PipelineResult(success=True)
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
|
| 195 |
-
|
|
|
|
| 196 |
|
| 197 |
-
def
|
| 198 |
-
"""Test
|
| 199 |
-
#
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
assert 'total_processed' in stats
|
| 213 |
-
assert stats['total_processed'] > 0
|
| 214 |
-
assert 'avg_time' in stats
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
class TestPipelineIntegration:
|
| 218 |
-
"""Integration tests for the pipeline."""
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
""
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
quality_preset="medium",
|
| 227 |
-
enable_cache=False
|
| 228 |
)
|
|
|
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
# Process image
|
| 237 |
-
result = pipeline.process_image(sample_image, sample_background)
|
| 238 |
-
|
| 239 |
-
if result.success:
|
| 240 |
-
assert result.output_image is not None
|
| 241 |
-
assert result.output_image.shape == sample_image.shape
|
| 242 |
-
assert result.quality_score > 0
|
| 243 |
-
|
| 244 |
-
# Save output
|
| 245 |
-
output_path = temp_dir / "test_output.png"
|
| 246 |
-
cv2.imwrite(str(output_path), result.output_image)
|
| 247 |
-
assert output_path.exists()
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
""
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
)
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
if not ret:
|
| 272 |
-
break
|
| 273 |
-
|
| 274 |
-
result = pipeline.process_image(frame, None)
|
| 275 |
-
if result.success:
|
| 276 |
-
processed_frames.append(result.output_image)
|
| 277 |
-
|
| 278 |
-
cap.release()
|
| 279 |
-
|
| 280 |
-
assert len(processed_frames) > 0
|
| 281 |
-
|
| 282 |
-
# Save as video
|
| 283 |
-
if processed_frames:
|
| 284 |
-
output_path = temp_dir / "test_video_out.mp4"
|
| 285 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 286 |
-
out = cv2.VideoWriter(str(output_path), fourcc, 30.0,
|
| 287 |
-
(processed_frames[0].shape[1], processed_frames[0].shape[0]))
|
| 288 |
-
|
| 289 |
-
for frame in processed_frames:
|
| 290 |
-
out.write(frame)
|
| 291 |
-
|
| 292 |
-
out.release()
|
| 293 |
-
assert output_path.exists()
|
| 294 |
|
| 295 |
|
| 296 |
-
class
|
| 297 |
-
"""
|
| 298 |
-
|
| 299 |
-
@pytest.mark.slow
|
| 300 |
-
def test_processing_speed(self, mock_pipeline, sample_image, performance_timer):
|
| 301 |
-
"""Test processing speed."""
|
| 302 |
-
# Mock processing
|
| 303 |
-
mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
|
| 304 |
-
mock_pipeline.alpha_matting.process = Mock(return_value={
|
| 305 |
-
'alpha': np.ones((512, 512), dtype=np.float32),
|
| 306 |
-
'confidence': 0.95
|
| 307 |
-
})
|
| 308 |
-
|
| 309 |
-
with performance_timer as timer:
|
| 310 |
-
result = mock_pipeline.process_image(sample_image, None)
|
| 311 |
-
|
| 312 |
-
assert result.success == True
|
| 313 |
-
assert timer.elapsed < 1.0 # Should process in under 1 second
|
| 314 |
|
|
|
|
| 315 |
@pytest.mark.slow
|
| 316 |
-
def
|
| 317 |
-
"""Test
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
assert len(results) == 10
|
| 331 |
-
assert timer.elapsed < 5.0 # Should process 10 images in under 5 seconds
|
| 332 |
-
|
| 333 |
-
def test_memory_usage(self, mock_pipeline, sample_image):
|
| 334 |
-
"""Test memory usage during processing."""
|
| 335 |
-
import psutil
|
| 336 |
-
import os
|
| 337 |
-
|
| 338 |
-
process = psutil.Process(os.getpid())
|
| 339 |
-
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
|
| 340 |
|
| 341 |
-
|
| 342 |
-
for _ in range(10):
|
| 343 |
-
mock_pipeline.process_image(sample_image, None)
|
| 344 |
|
| 345 |
-
|
| 346 |
-
|
| 347 |
|
| 348 |
-
#
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Tests for model management functionality.
|
| 3 |
"""
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
+
import tempfile
|
|
|
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
+
from unittest.mock import Mock, patch, MagicMock
|
| 9 |
+
import json
|
| 10 |
|
| 11 |
+
from models import (
|
| 12 |
+
ModelRegistry,
|
| 13 |
+
ModelInfo,
|
| 14 |
+
ModelStatus,
|
| 15 |
+
ModelTask,
|
| 16 |
+
ModelFramework,
|
| 17 |
+
ModelDownloader,
|
| 18 |
+
ModelLoader,
|
| 19 |
+
ModelOptimizer
|
| 20 |
)
|
| 21 |
|
| 22 |
|
| 23 |
+
class TestModelRegistry:
|
| 24 |
+
"""Test model registry functionality."""
|
| 25 |
+
|
| 26 |
+
@pytest.fixture
|
| 27 |
+
def registry(self):
|
| 28 |
+
"""Create a test registry."""
|
| 29 |
+
temp_dir = tempfile.mkdtemp()
|
| 30 |
+
return ModelRegistry(models_dir=Path(temp_dir))
|
| 31 |
|
| 32 |
+
def test_registry_initialization(self, registry):
|
| 33 |
+
"""Test registry initialization."""
|
| 34 |
+
assert registry is not None
|
| 35 |
+
assert len(registry.models) > 0 # Should have default models
|
| 36 |
+
assert registry.models_dir.exists()
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
def test_register_model(self, registry):
|
| 39 |
+
"""Test registering a new model."""
|
| 40 |
+
model = ModelInfo(
|
| 41 |
+
model_id="test-model",
|
| 42 |
+
name="Test Model",
|
| 43 |
+
version="1.0",
|
| 44 |
+
task=ModelTask.SEGMENTATION,
|
| 45 |
+
framework=ModelFramework.PYTORCH,
|
| 46 |
+
url="http://example.com/model.pth",
|
| 47 |
+
filename="test.pth",
|
| 48 |
+
file_size=1000000
|
| 49 |
)
|
| 50 |
+
|
| 51 |
+
success = registry.register_model(model)
|
| 52 |
+
assert success == True
|
| 53 |
+
assert "test-model" in registry.models
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
def test_get_model(self, registry):
|
| 56 |
+
"""Test getting a model by ID."""
|
| 57 |
+
model = registry.get_model("rmbg-1.4")
|
| 58 |
+
assert model is not None
|
| 59 |
+
assert model.model_id == "rmbg-1.4"
|
| 60 |
+
assert model.task == ModelTask.SEGMENTATION
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
def test_list_models_by_task(self, registry):
|
| 63 |
+
"""Test listing models by task."""
|
| 64 |
+
segmentation_models = registry.list_models(task=ModelTask.SEGMENTATION)
|
| 65 |
+
assert len(segmentation_models) > 0
|
| 66 |
+
assert all(m.task == ModelTask.SEGMENTATION for m in segmentation_models)
|
| 67 |
|
| 68 |
+
def test_list_models_by_framework(self, registry):
|
| 69 |
+
"""Test listing models by framework."""
|
| 70 |
+
pytorch_models = registry.list_models(framework=ModelFramework.PYTORCH)
|
| 71 |
+
onnx_models = registry.list_models(framework=ModelFramework.ONNX)
|
| 72 |
+
|
| 73 |
+
assert all(m.framework == ModelFramework.PYTORCH for m in pytorch_models)
|
| 74 |
+
assert all(m.framework == ModelFramework.ONNX for m in onnx_models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
def test_get_best_model(self, registry):
|
| 77 |
+
"""Test getting best model for a task."""
|
| 78 |
+
# Best for accuracy
|
| 79 |
+
best_accuracy = registry.get_best_model(
|
| 80 |
+
ModelTask.SEGMENTATION,
|
| 81 |
+
prefer_speed=False
|
| 82 |
+
)
|
| 83 |
+
assert best_accuracy is not None
|
| 84 |
|
| 85 |
+
# Best for speed
|
| 86 |
+
best_speed = registry.get_best_model(
|
| 87 |
+
ModelTask.SEGMENTATION,
|
| 88 |
+
prefer_speed=True
|
| 89 |
+
)
|
| 90 |
+
assert best_speed is not None
|
| 91 |
+
|
| 92 |
+
def test_update_model_usage(self, registry):
|
| 93 |
+
"""Test updating model usage statistics."""
|
| 94 |
+
model_id = "rmbg-1.4"
|
| 95 |
+
initial_count = registry.models[model_id].use_count
|
| 96 |
|
| 97 |
+
registry.update_model_usage(model_id)
|
| 98 |
|
| 99 |
+
assert registry.models[model_id].use_count == initial_count + 1
|
| 100 |
+
assert registry.models[model_id].last_used is not None
|
| 101 |
+
|
| 102 |
+
def test_get_total_size(self, registry):
|
| 103 |
+
"""Test calculating total model size."""
|
| 104 |
+
total_size = registry.get_total_size()
|
| 105 |
+
assert total_size > 0
|
| 106 |
+
|
| 107 |
+
# Size of available models should be 0 initially
|
| 108 |
+
available_size = registry.get_total_size(status=ModelStatus.AVAILABLE)
|
| 109 |
+
assert available_size == 0
|
| 110 |
|
| 111 |
+
def test_export_registry(self, registry, temp_dir):
|
| 112 |
+
"""Test exporting registry to file."""
|
| 113 |
+
export_path = temp_dir / "registry_export.json"
|
| 114 |
+
registry.export_registry(export_path)
|
| 115 |
|
| 116 |
+
assert export_path.exists()
|
| 117 |
|
| 118 |
+
with open(export_path) as f:
|
| 119 |
+
data = json.load(f)
|
| 120 |
+
assert "models" in data
|
| 121 |
+
assert len(data["models"]) > 0
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class TestModelDownloader:
|
| 125 |
+
"""Test model downloading functionality."""
|
| 126 |
|
| 127 |
+
@pytest.fixture
|
| 128 |
+
def downloader(self, mock_registry):
|
| 129 |
+
"""Create a test downloader."""
|
| 130 |
+
return ModelDownloader(mock_registry)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
@patch('requests.get')
|
| 133 |
+
def test_download_model(self, mock_get, downloader):
|
| 134 |
+
"""Test downloading a model."""
|
| 135 |
+
# Mock HTTP response
|
| 136 |
+
mock_response = MagicMock()
|
| 137 |
+
mock_response.headers = {'content-length': '1000000'}
|
| 138 |
+
mock_response.iter_content = MagicMock(
|
| 139 |
+
return_value=[b'data' * 1000]
|
| 140 |
+
)
|
| 141 |
+
mock_response.raise_for_status = MagicMock()
|
| 142 |
+
mock_get.return_value = mock_response
|
| 143 |
|
| 144 |
+
# Test download
|
| 145 |
+
success = downloader.download_model("test-model", force=True)
|
| 146 |
|
| 147 |
+
assert mock_get.called
|
| 148 |
+
# Note: Full download test would require more mocking
|
| 149 |
|
| 150 |
+
def test_download_progress_tracking(self, downloader):
|
| 151 |
+
"""Test download progress tracking."""
|
| 152 |
progress_values = []
|
| 153 |
|
| 154 |
+
def progress_callback(progress):
|
| 155 |
+
progress_values.append(progress.progress)
|
| 156 |
|
| 157 |
+
# Start a download (will fail but we can test progress initialization)
|
| 158 |
+
with patch.object(downloader, '_download_model_task', return_value=True):
|
| 159 |
+
downloader.download_model(
|
| 160 |
+
"test-model",
|
| 161 |
+
progress_callback=progress_callback
|
| 162 |
+
)
|
| 163 |
|
| 164 |
+
assert "test-model" in downloader.downloads
|
| 165 |
+
|
| 166 |
+
def test_cancel_download(self, downloader):
|
| 167 |
+
"""Test cancelling a download."""
|
| 168 |
+
# Start a mock download
|
| 169 |
+
downloader.downloads["test-model"] = Mock()
|
| 170 |
+
downloader._stop_events["test-model"] = Mock()
|
| 171 |
|
| 172 |
+
success = downloader.cancel_download("test-model")
|
| 173 |
|
| 174 |
+
assert success == True
|
| 175 |
+
assert downloader._stop_events["test-model"].set.called
|
| 176 |
|
| 177 |
+
def test_download_with_resume(self, downloader, temp_dir):
|
| 178 |
+
"""Test download with resume support."""
|
| 179 |
+
# Create a partial file
|
| 180 |
+
partial_file = temp_dir / "test.pth.part"
|
| 181 |
+
partial_file.write_bytes(b"partial_data")
|
| 182 |
+
|
| 183 |
+
# Mock download would check for partial file
|
| 184 |
+
assert partial_file.exists()
|
| 185 |
+
assert partial_file.stat().st_size > 0
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class TestModelLoader:
|
| 189 |
+
"""Test model loading functionality."""
|
| 190 |
+
|
| 191 |
+
@pytest.fixture
|
| 192 |
+
def loader(self, mock_registry):
|
| 193 |
+
"""Create a test loader."""
|
| 194 |
+
return ModelLoader(mock_registry, device='cpu')
|
|
|
|
|
|
|
| 195 |
|
| 196 |
+
def test_loader_initialization(self, loader):
|
| 197 |
+
"""Test loader initialization."""
|
| 198 |
+
assert loader is not None
|
| 199 |
+
assert loader.device == 'cpu'
|
| 200 |
+
assert loader.max_memory_bytes > 0
|
| 201 |
+
|
| 202 |
+
@patch('torch.load')
|
| 203 |
+
def test_load_pytorch_model(self, mock_torch_load, loader):
|
| 204 |
+
"""Test loading a PyTorch model."""
|
| 205 |
+
mock_model = MagicMock()
|
| 206 |
+
mock_torch_load.return_value = mock_model
|
| 207 |
+
|
| 208 |
+
# Mock model info
|
| 209 |
+
model_info = ModelInfo(
|
| 210 |
+
model_id="test-pytorch",
|
| 211 |
+
name="Test PyTorch Model",
|
| 212 |
+
version="1.0",
|
| 213 |
+
task=ModelTask.SEGMENTATION,
|
| 214 |
+
framework=ModelFramework.PYTORCH,
|
| 215 |
+
url="",
|
| 216 |
+
filename="model.pth",
|
| 217 |
+
local_path="/tmp/model.pth",
|
| 218 |
+
status=ModelStatus.AVAILABLE
|
| 219 |
+
)
|
| 220 |
|
| 221 |
+
loader.registry.get_model = Mock(return_value=model_info)
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
+
with patch.object(Path, 'exists', return_value=True):
|
| 224 |
+
loaded = loader.load_model("test-pytorch")
|
| 225 |
|
| 226 |
+
# Note: Full test would require more setup
|
| 227 |
+
assert mock_torch_load.called
|
| 228 |
|
| 229 |
+
def test_memory_management(self, loader):
|
| 230 |
+
"""Test memory management during model loading."""
|
| 231 |
+
# Add mock models to loaded cache
|
| 232 |
+
for i in range(5):
|
| 233 |
+
loader.loaded_models[f"model_{i}"] = Mock(
|
| 234 |
+
memory_usage=100 * 1024 * 1024 # 100MB each
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
loader.current_memory_usage = 500 * 1024 * 1024 # 500MB
|
| 238 |
+
|
| 239 |
+
# Free memory
|
| 240 |
+
loader._free_memory(200 * 1024 * 1024) # Need 200MB
|
| 241 |
+
|
| 242 |
+
# Should have freed at least 2 models
|
| 243 |
+
assert len(loader.loaded_models) < 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
def test_unload_model(self, loader):
|
| 246 |
+
"""Test unloading a model."""
|
| 247 |
+
# Add a mock model
|
| 248 |
+
loader.loaded_models["test"] = Mock(
|
| 249 |
+
model=Mock(),
|
| 250 |
+
memory_usage=100 * 1024 * 1024
|
|
|
|
|
|
|
| 251 |
)
|
| 252 |
+
loader.current_memory_usage = 100 * 1024 * 1024
|
| 253 |
|
| 254 |
+
success = loader.unload_model("test")
|
| 255 |
+
|
| 256 |
+
assert success == True
|
| 257 |
+
assert "test" not in loader.loaded_models
|
| 258 |
+
assert loader.current_memory_usage == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
+
def test_get_memory_usage(self, loader):
|
| 261 |
+
"""Test getting memory usage statistics."""
|
| 262 |
+
# Add mock models
|
| 263 |
+
loader.loaded_models["model1"] = Mock(memory_usage=100 * 1024 * 1024)
|
| 264 |
+
loader.loaded_models["model2"] = Mock(memory_usage=200 * 1024 * 1024)
|
| 265 |
+
loader.current_memory_usage = 300 * 1024 * 1024
|
| 266 |
+
|
| 267 |
+
usage = loader.get_memory_usage()
|
| 268 |
+
|
| 269 |
+
assert usage["current_usage_mb"] == 300
|
| 270 |
+
assert usage["loaded_models"] == 2
|
| 271 |
+
assert "model1" in usage["models"]
|
| 272 |
+
assert "model2" in usage["models"]
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class TestModelOptimizer:
|
| 276 |
+
"""Test model optimization functionality."""
|
| 277 |
+
|
| 278 |
+
@pytest.fixture
|
| 279 |
+
def optimizer(self, mock_registry):
|
| 280 |
+
"""Create a test optimizer."""
|
| 281 |
+
loader = ModelLoader(mock_registry, device='cpu')
|
| 282 |
+
return ModelOptimizer(loader)
|
| 283 |
+
|
| 284 |
+
@patch('torch.quantization.quantize_dynamic')
|
| 285 |
+
def test_quantize_pytorch_model(self, mock_quantize, optimizer):
|
| 286 |
+
"""Test PyTorch model quantization."""
|
| 287 |
+
# Create mock model
|
| 288 |
+
mock_model = MagicMock()
|
| 289 |
+
mock_quantize.return_value = mock_model
|
| 290 |
+
|
| 291 |
+
loaded = Mock(
|
| 292 |
+
model_id="test",
|
| 293 |
+
model=mock_model,
|
| 294 |
+
framework=ModelFramework.PYTORCH,
|
| 295 |
+
metadata={'input_size': (1, 3, 512, 512)}
|
| 296 |
)
|
| 297 |
|
| 298 |
+
with patch.object(optimizer, '_get_model_size', return_value=1000000):
|
| 299 |
+
with patch.object(optimizer, '_benchmark_model', return_value=0.1):
|
| 300 |
+
result = optimizer._quantize_pytorch(
|
| 301 |
+
loaded,
|
| 302 |
+
Path("/tmp"),
|
| 303 |
+
"dynamic"
|
| 304 |
+
)
|
| 305 |
|
| 306 |
+
assert mock_quantize.called
|
| 307 |
+
# Note: Full test would require more setup
|
| 308 |
+
|
| 309 |
+
def test_optimization_result(self, optimizer):
|
| 310 |
+
"""Test optimization result structure."""
|
| 311 |
+
from models.optimizer import OptimizationResult
|
| 312 |
+
|
| 313 |
+
result = OptimizationResult(
|
| 314 |
+
original_size_mb=100,
|
| 315 |
+
optimized_size_mb=25,
|
| 316 |
+
compression_ratio=4.0,
|
| 317 |
+
original_speed_ms=100,
|
| 318 |
+
optimized_speed_ms=50,
|
| 319 |
+
speedup=2.0,
|
| 320 |
+
accuracy_loss=0.01,
|
| 321 |
+
optimization_time=10.0,
|
| 322 |
+
output_path="/tmp/optimized.pth"
|
| 323 |
+
)
|
| 324 |
|
| 325 |
+
assert result.compression_ratio == 4.0
|
| 326 |
+
assert result.speedup == 2.0
|
| 327 |
+
assert result.accuracy_loss == 0.01
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
|
| 330 |
+
class TestModelIntegration:
|
| 331 |
+
"""Integration tests for model management."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
+
@pytest.mark.integration
|
| 334 |
@pytest.mark.slow
|
| 335 |
+
def test_model_registry_persistence(self, temp_dir):
|
| 336 |
+
"""Test registry persistence across instances."""
|
| 337 |
+
# Create registry and add model
|
| 338 |
+
registry1 = ModelRegistry(models_dir=temp_dir)
|
| 339 |
+
|
| 340 |
+
test_model = ModelInfo(
|
| 341 |
+
model_id="persistence-test",
|
| 342 |
+
name="Persistence Test",
|
| 343 |
+
version="1.0",
|
| 344 |
+
task=ModelTask.SEGMENTATION,
|
| 345 |
+
framework=ModelFramework.PYTORCH,
|
| 346 |
+
url="http://example.com/model.pth",
|
| 347 |
+
filename="persist.pth"
|
| 348 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
+
registry1.register_model(test_model)
|
|
|
|
|
|
|
| 351 |
|
| 352 |
+
# Create new registry instance
|
| 353 |
+
registry2 = ModelRegistry(models_dir=temp_dir)
|
| 354 |
|
| 355 |
+
# Check if model persisted
|
| 356 |
+
loaded_model = registry2.get_model("persistence-test")
|
| 357 |
+
assert loaded_model is not None
|
| 358 |
+
assert loaded_model.name == "Persistence Test"
|
| 359 |
+
|
| 360 |
+
@pytest.mark.integration
|
| 361 |
+
def test_model_manager_workflow(self):
|
| 362 |
+
"""Test complete model manager workflow."""
|
| 363 |
+
from models import create_model_manager
|
| 364 |
+
|
| 365 |
+
manager = create_model_manager()
|
| 366 |
+
|
| 367 |
+
# Test model discovery
|
| 368 |
+
stats = manager.get_stats()
|
| 369 |
+
assert "registry" in stats
|
| 370 |
+
assert stats["registry"]["total_models"] > 0
|
| 371 |
+
|
| 372 |
+
# Test benchmark (without actual model loading)
|
| 373 |
+
with patch.object(manager.loader, 'load_model', return_value=Mock()):
|
| 374 |
+
benchmarks = manager.benchmark()
|
| 375 |
+
# Would return empty without real models
|
| 376 |
+
assert isinstance(benchmarks, dict)
|