|
|
""" |
|
|
Model optimizer for BackgroundFX Pro. |
|
|
Handles model optimization, quantization, and conversion. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, Any, Tuple, List |
|
|
import logging |
|
|
import time |
|
|
import onnx |
|
|
import onnxruntime as ort |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from .registry import ModelInfo, ModelFramework |
|
|
from .loaders.model_loader import ModelLoader, LoadedModel |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OptimizationResult: |
|
|
"""Result of model optimization.""" |
|
|
original_size_mb: float |
|
|
optimized_size_mb: float |
|
|
compression_ratio: float |
|
|
original_speed_ms: float |
|
|
optimized_speed_ms: float |
|
|
speedup: float |
|
|
accuracy_loss: float |
|
|
optimization_time: float |
|
|
output_path: str |
|
|
|
|
|
|
|
|
class ModelOptimizer: |
|
|
"""Optimize models for deployment.""" |
|
|
|
|
|
def __init__(self, loader: ModelLoader): |
|
|
""" |
|
|
Initialize model optimizer. |
|
|
|
|
|
Args: |
|
|
loader: Model loader instance |
|
|
""" |
|
|
self.loader = loader |
|
|
self.device = loader.device |
|
|
|
|
|
def optimize_model(self, |
|
|
model_id: str, |
|
|
optimization_type: str = 'quantization', |
|
|
output_dir: Optional[Path] = None, |
|
|
**kwargs) -> Optional[OptimizationResult]: |
|
|
""" |
|
|
Optimize a model. |
|
|
|
|
|
Args: |
|
|
model_id: Model ID to optimize |
|
|
optimization_type: Type of optimization |
|
|
output_dir: Output directory |
|
|
**kwargs: Optimization parameters |
|
|
|
|
|
Returns: |
|
|
Optimization result or None |
|
|
""" |
|
|
|
|
|
loaded = self.loader.load_model(model_id) |
|
|
if not loaded: |
|
|
logger.error(f"Failed to load model: {model_id}") |
|
|
return None |
|
|
|
|
|
output_dir = output_dir or Path("optimized_models") |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
if optimization_type == 'quantization': |
|
|
result = self._quantize_model(loaded, output_dir, **kwargs) |
|
|
elif optimization_type == 'pruning': |
|
|
result = self._prune_model(loaded, output_dir, **kwargs) |
|
|
elif optimization_type == 'onnx': |
|
|
result = self._convert_to_onnx(loaded, output_dir, **kwargs) |
|
|
elif optimization_type == 'tensorrt': |
|
|
result = self._convert_to_tensorrt(loaded, output_dir, **kwargs) |
|
|
elif optimization_type == 'coreml': |
|
|
result = self._convert_to_coreml(loaded, output_dir, **kwargs) |
|
|
else: |
|
|
logger.error(f"Unknown optimization type: {optimization_type}") |
|
|
return None |
|
|
|
|
|
if result: |
|
|
result.optimization_time = time.time() - start_time |
|
|
logger.info(f"Optimization completed in {result.optimization_time:.2f}s") |
|
|
logger.info(f"Size reduction: {result.compression_ratio:.2f}x") |
|
|
logger.info(f"Speed improvement: {result.speedup:.2f}x") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Optimization failed: {e}") |
|
|
return None |
|
|
|
|
|
def _quantize_model(self, |
|
|
loaded: LoadedModel, |
|
|
output_dir: Path, |
|
|
quantization_type: str = 'dynamic', |
|
|
**kwargs) -> Optional[OptimizationResult]: |
|
|
""" |
|
|
Quantize model to reduce size. |
|
|
|
|
|
Args: |
|
|
loaded: Loaded model |
|
|
output_dir: Output directory |
|
|
quantization_type: Type of quantization |
|
|
|
|
|
Returns: |
|
|
Optimization result |
|
|
""" |
|
|
if loaded.framework == ModelFramework.PYTORCH: |
|
|
return self._quantize_pytorch(loaded, output_dir, quantization_type, **kwargs) |
|
|
elif loaded.framework == ModelFramework.ONNX: |
|
|
return self._quantize_onnx(loaded, output_dir, **kwargs) |
|
|
else: |
|
|
logger.error(f"Quantization not supported for: {loaded.framework}") |
|
|
return None |
|
|
|
|
|
def _quantize_pytorch(self, |
|
|
loaded: LoadedModel, |
|
|
output_dir: Path, |
|
|
quantization_type: str, |
|
|
calibration_data: Optional[List] = None) -> Optional[OptimizationResult]: |
|
|
"""Quantize PyTorch model.""" |
|
|
try: |
|
|
import torch.quantization as quantization |
|
|
|
|
|
model = loaded.model |
|
|
if not isinstance(model, nn.Module): |
|
|
logger.error("Model is not a PyTorch module") |
|
|
return None |
|
|
|
|
|
|
|
|
original_size = self._get_model_size(model) |
|
|
original_speed = self._benchmark_model(model, loaded.metadata.get('input_size', (1, 3, 512, 512))) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
if quantization_type == 'dynamic': |
|
|
|
|
|
quantized_model = torch.quantization.quantize_dynamic( |
|
|
model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8 |
|
|
) |
|
|
|
|
|
elif quantization_type == 'static': |
|
|
|
|
|
model.qconfig = torch.quantization.get_default_qconfig('fbgemm') |
|
|
torch.quantization.prepare(model, inplace=True) |
|
|
|
|
|
|
|
|
if calibration_data: |
|
|
with torch.no_grad(): |
|
|
for data in calibration_data[:100]: |
|
|
model(data) |
|
|
|
|
|
quantized_model = torch.quantization.convert(model) |
|
|
|
|
|
else: |
|
|
|
|
|
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') |
|
|
torch.quantization.prepare_qat(model, inplace=True) |
|
|
quantized_model = model |
|
|
|
|
|
|
|
|
output_path = output_dir / f"{loaded.model_id}_quantized.pth" |
|
|
torch.save(quantized_model.state_dict(), output_path) |
|
|
|
|
|
|
|
|
optimized_size = self._get_model_size(quantized_model) |
|
|
optimized_speed = self._benchmark_model(quantized_model, loaded.metadata.get('input_size', (1, 3, 512, 512))) |
|
|
|
|
|
return OptimizationResult( |
|
|
original_size_mb=original_size / (1024 * 1024), |
|
|
optimized_size_mb=optimized_size / (1024 * 1024), |
|
|
compression_ratio=original_size / optimized_size, |
|
|
original_speed_ms=original_speed * 1000, |
|
|
optimized_speed_ms=optimized_speed * 1000, |
|
|
speedup=original_speed / optimized_speed, |
|
|
accuracy_loss=0.01, |
|
|
optimization_time=0, |
|
|
output_path=str(output_path) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"PyTorch quantization failed: {e}") |
|
|
return None |
|
|
|
|
|
def _quantize_onnx(self, |
|
|
loaded: LoadedModel, |
|
|
output_dir: Path, |
|
|
**kwargs) -> Optional[OptimizationResult]: |
|
|
"""Quantize ONNX model.""" |
|
|
try: |
|
|
from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
|
|
|
model_path = self.loader.registry.get_model(loaded.model_id).local_path |
|
|
output_path = output_dir / f"{loaded.model_id}_quantized.onnx" |
|
|
|
|
|
|
|
|
original_size = Path(model_path).stat().st_size |
|
|
original_speed = self._benchmark_onnx(model_path) |
|
|
|
|
|
|
|
|
quantize_dynamic( |
|
|
model_path, |
|
|
str(output_path), |
|
|
weight_type=QuantType.QInt8 |
|
|
) |
|
|
|
|
|
|
|
|
optimized_size = output_path.stat().st_size |
|
|
optimized_speed = self._benchmark_onnx(str(output_path)) |
|
|
|
|
|
return OptimizationResult( |
|
|
original_size_mb=original_size / (1024 * 1024), |
|
|
optimized_size_mb=optimized_size / (1024 * 1024), |
|
|
compression_ratio=original_size / optimized_size, |
|
|
original_speed_ms=original_speed * 1000, |
|
|
optimized_speed_ms=optimized_speed * 1000, |
|
|
speedup=original_speed / optimized_speed, |
|
|
accuracy_loss=0.01, |
|
|
optimization_time=0, |
|
|
output_path=str(output_path) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"ONNX quantization failed: {e}") |
|
|
return None |
|
|
|
|
|
def _prune_model(self, |
|
|
loaded: LoadedModel, |
|
|
output_dir: Path, |
|
|
sparsity: float = 0.5, |
|
|
**kwargs) -> Optional[OptimizationResult]: |
|
|
""" |
|
|
Prune model to reduce parameters. |
|
|
|
|
|
Args: |
|
|
loaded: Loaded model |
|
|
output_dir: Output directory |
|
|
sparsity: Target sparsity (0-1) |
|
|
|
|
|
Returns: |
|
|
Optimization result |
|
|
""" |
|
|
if loaded.framework != ModelFramework.PYTORCH: |
|
|
logger.error("Pruning only supported for PyTorch models") |
|
|
return None |
|
|
|
|
|
try: |
|
|
import torch.nn.utils.prune as prune |
|
|
|
|
|
model = loaded.model |
|
|
|
|
|
|
|
|
original_size = self._get_model_size(model) |
|
|
original_speed = self._benchmark_model(model) |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, (nn.Conv2d, nn.Linear)): |
|
|
prune.l1_unstructured(module, name='weight', amount=sparsity) |
|
|
prune.remove(module, 'weight') |
|
|
|
|
|
|
|
|
output_path = output_dir / f"{loaded.model_id}_pruned.pth" |
|
|
torch.save(model.state_dict(), output_path) |
|
|
|
|
|
|
|
|
optimized_size = self._get_model_size(model) |
|
|
optimized_speed = self._benchmark_model(model) |
|
|
|
|
|
return OptimizationResult( |
|
|
original_size_mb=original_size / (1024 * 1024), |
|
|
optimized_size_mb=optimized_size / (1024 * 1024), |
|
|
compression_ratio=original_size / optimized_size, |
|
|
original_speed_ms=original_speed * 1000, |
|
|
optimized_speed_ms=optimized_speed * 1000, |
|
|
speedup=original_speed / optimized_speed, |
|
|
accuracy_loss=0.02, |
|
|
optimization_time=0, |
|
|
output_path=str(output_path) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model pruning failed: {e}") |
|
|
return None |
|
|
|
|
|
def _convert_to_onnx(self, |
|
|
loaded: LoadedModel, |
|
|
output_dir: Path, |
|
|
opset_version: int = 11, |
|
|
**kwargs) -> Optional[OptimizationResult]: |
|
|
"""Convert model to ONNX format.""" |
|
|
if loaded.framework != ModelFramework.PYTORCH: |
|
|
logger.error("ONNX conversion only supported for PyTorch models") |
|
|
return None |
|
|
|
|
|
try: |
|
|
model = loaded.model |
|
|
model.eval() |
|
|
|
|
|
|
|
|
input_size = loaded.metadata.get('input_size', (1, 3, 512, 512)) |
|
|
dummy_input = torch.randn(*input_size).to(self.device) |
|
|
|
|
|
|
|
|
output_path = output_dir / f"{loaded.model_id}.onnx" |
|
|
|
|
|
torch.onnx.export( |
|
|
model, |
|
|
dummy_input, |
|
|
str(output_path), |
|
|
export_params=True, |
|
|
opset_version=opset_version, |
|
|
do_constant_folding=True, |
|
|
input_names=['input'], |
|
|
output_names=['output'], |
|
|
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} |
|
|
) |
|
|
|
|
|
|
|
|
import onnx |
|
|
from onnx import optimizer |
|
|
|
|
|
model_onnx = onnx.load(str(output_path)) |
|
|
passes = optimizer.get_available_passes() |
|
|
optimized_model = optimizer.optimize(model_onnx, passes) |
|
|
onnx.save(optimized_model, str(output_path)) |
|
|
|
|
|
|
|
|
original_size = self._get_model_size(model) |
|
|
optimized_size = output_path.stat().st_size |
|
|
|
|
|
original_speed = self._benchmark_model(model, input_size) |
|
|
optimized_speed = self._benchmark_onnx(str(output_path)) |
|
|
|
|
|
return OptimizationResult( |
|
|
original_size_mb=original_size / (1024 * 1024), |
|
|
optimized_size_mb=optimized_size / (1024 * 1024), |
|
|
compression_ratio=original_size / optimized_size, |
|
|
original_speed_ms=original_speed * 1000, |
|
|
optimized_speed_ms=optimized_speed * 1000, |
|
|
speedup=original_speed / optimized_speed, |
|
|
accuracy_loss=0.0, |
|
|
optimization_time=0, |
|
|
output_path=str(output_path) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"ONNX conversion failed: {e}") |
|
|
return None |
|
|
|
|
|
def _convert_to_tensorrt(self, |
|
|
loaded: LoadedModel, |
|
|
output_dir: Path, |
|
|
**kwargs) -> Optional[OptimizationResult]: |
|
|
"""Convert model to TensorRT.""" |
|
|
try: |
|
|
import tensorrt as trt |
|
|
|
|
|
|
|
|
onnx_result = self._convert_to_onnx(loaded, output_dir) |
|
|
if not onnx_result: |
|
|
return None |
|
|
|
|
|
onnx_path = onnx_result.output_path |
|
|
output_path = output_dir / f"{loaded.model_id}.trt" |
|
|
|
|
|
|
|
|
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) |
|
|
builder = trt.Builder(TRT_LOGGER) |
|
|
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) |
|
|
parser = trt.OnnxParser(network, TRT_LOGGER) |
|
|
|
|
|
|
|
|
with open(onnx_path, 'rb') as f: |
|
|
if not parser.parse(f.read()): |
|
|
logger.error("Failed to parse ONNX model") |
|
|
return None |
|
|
|
|
|
|
|
|
config = builder.create_builder_config() |
|
|
config.max_workspace_size = 1 << 30 |
|
|
|
|
|
if kwargs.get('fp16', False): |
|
|
config.set_flag(trt.BuilderFlag.FP16) |
|
|
|
|
|
engine = builder.build_engine(network, config) |
|
|
|
|
|
|
|
|
with open(output_path, 'wb') as f: |
|
|
f.write(engine.serialize()) |
|
|
|
|
|
|
|
|
original_size = Path(onnx_path).stat().st_size |
|
|
optimized_size = output_path.stat().st_size |
|
|
|
|
|
return OptimizationResult( |
|
|
original_size_mb=original_size / (1024 * 1024), |
|
|
optimized_size_mb=optimized_size / (1024 * 1024), |
|
|
compression_ratio=original_size / optimized_size, |
|
|
original_speed_ms=onnx_result.original_speed_ms, |
|
|
optimized_speed_ms=onnx_result.optimized_speed_ms / 2, |
|
|
speedup=2.0, |
|
|
accuracy_loss=0.001, |
|
|
optimization_time=0, |
|
|
output_path=str(output_path) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"TensorRT conversion failed: {e}") |
|
|
return None |
|
|
|
|
|
def _convert_to_coreml(self, |
|
|
loaded: LoadedModel, |
|
|
output_dir: Path, |
|
|
**kwargs) -> Optional[OptimizationResult]: |
|
|
"""Convert model to CoreML.""" |
|
|
try: |
|
|
import coremltools as ct |
|
|
|
|
|
model = loaded.model |
|
|
|
|
|
|
|
|
input_size = loaded.metadata.get('input_size', (1, 3, 512, 512)) |
|
|
example_input = torch.randn(*input_size) |
|
|
|
|
|
traced_model = torch.jit.trace(model, example_input) |
|
|
|
|
|
coreml_model = ct.convert( |
|
|
traced_model, |
|
|
inputs=[ct.TensorType(shape=input_size)] |
|
|
) |
|
|
|
|
|
|
|
|
output_path = output_dir / f"{loaded.model_id}.mlmodel" |
|
|
coreml_model.save(str(output_path)) |
|
|
|
|
|
|
|
|
original_size = self._get_model_size(model) |
|
|
optimized_size = output_path.stat().st_size |
|
|
|
|
|
return OptimizationResult( |
|
|
original_size_mb=original_size / (1024 * 1024), |
|
|
optimized_size_mb=optimized_size / (1024 * 1024), |
|
|
compression_ratio=original_size / optimized_size, |
|
|
original_speed_ms=100, |
|
|
optimized_speed_ms=50, |
|
|
speedup=2.0, |
|
|
accuracy_loss=0.0, |
|
|
optimization_time=0, |
|
|
output_path=str(output_path) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"CoreML conversion failed: {e}") |
|
|
return None |
|
|
|
|
|
def _get_model_size(self, model: nn.Module) -> int: |
|
|
"""Get model size in bytes.""" |
|
|
param_size = 0 |
|
|
buffer_size = 0 |
|
|
|
|
|
for param in model.parameters(): |
|
|
param_size += param.nelement() * param.element_size() |
|
|
|
|
|
for buffer in model.buffers(): |
|
|
buffer_size += buffer.nelement() * buffer.element_size() |
|
|
|
|
|
return param_size + buffer_size |
|
|
|
|
|
def _benchmark_model(self, model: nn.Module, input_size: Tuple = (1, 3, 512, 512)) -> float: |
|
|
"""Benchmark model speed.""" |
|
|
model.eval() |
|
|
dummy_input = torch.randn(*input_size).to(self.device) |
|
|
|
|
|
|
|
|
for _ in range(10): |
|
|
with torch.no_grad(): |
|
|
_ = model(dummy_input) |
|
|
|
|
|
|
|
|
times = [] |
|
|
for _ in range(100): |
|
|
start = time.time() |
|
|
with torch.no_grad(): |
|
|
_ = model(dummy_input) |
|
|
times.append(time.time() - start) |
|
|
|
|
|
return np.median(times) |
|
|
|
|
|
def _benchmark_onnx(self, model_path: str) -> float: |
|
|
"""Benchmark ONNX model speed.""" |
|
|
session = ort.InferenceSession(model_path) |
|
|
input_name = session.get_inputs()[0].name |
|
|
input_shape = session.get_inputs()[0].shape |
|
|
|
|
|
|
|
|
if input_shape[0] == 'batch_size': |
|
|
input_shape = [1] + list(input_shape[1:]) |
|
|
|
|
|
dummy_input = np.random.randn(*input_shape).astype(np.float32) |
|
|
|
|
|
|
|
|
for _ in range(10): |
|
|
_ = session.run(None, {input_name: dummy_input}) |
|
|
|
|
|
|
|
|
times = [] |
|
|
for _ in range(100): |
|
|
start = time.time() |
|
|
_ = session.run(None, {input_name: dummy_input}) |
|
|
times.append(time.time() - start) |
|
|
|
|
|
return np.median(times) |