|
|
""" |
|
|
Configuration Management Module |
|
|
============================== |
|
|
|
|
|
Centralized configuration management for BackgroundFX Pro. |
|
|
Handles settings, model paths, quality parameters, and environment variables. |
|
|
|
|
|
Features: |
|
|
- YAML and JSON configuration files |
|
|
- Environment variable integration |
|
|
- Model path management (works with checkpoints/ folder) |
|
|
- Quality thresholds and processing parameters |
|
|
- Development vs Production configurations |
|
|
- Runtime configuration updates |
|
|
|
|
|
Author: BackgroundFX Pro Team |
|
|
License: MIT |
|
|
""" |
|
|
|
|
|
import os |
|
|
import yaml |
|
|
import json |
|
|
from typing import Dict, Any, Optional, Union |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass, field |
|
|
import logging |
|
|
from copy import deepcopy |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
"""Configuration for AI models""" |
|
|
name: str |
|
|
path: Optional[str] = None |
|
|
device: str = "auto" |
|
|
enabled: bool = True |
|
|
fallback: bool = False |
|
|
parameters: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
@dataclass |
|
|
class QualityConfig: |
|
|
"""Quality assessment configuration""" |
|
|
min_detection_confidence: float = 0.5 |
|
|
min_edge_quality: float = 0.3 |
|
|
min_mask_coverage: float = 0.05 |
|
|
max_asymmetry_score: float = 0.8 |
|
|
temporal_consistency_threshold: float = 0.05 |
|
|
matanyone_quality_threshold: float = 0.3 |
|
|
|
|
|
@dataclass |
|
|
class ProcessingConfig: |
|
|
"""Processing pipeline configuration""" |
|
|
batch_size: int = 1 |
|
|
max_resolution: tuple = (1920, 1080) |
|
|
temporal_smoothing: bool = True |
|
|
edge_refinement: bool = True |
|
|
fallback_enabled: bool = True |
|
|
cache_enabled: bool = True |
|
|
|
|
|
@dataclass |
|
|
class VideoConfig: |
|
|
"""Video processing configuration""" |
|
|
output_format: str = "mp4" |
|
|
output_quality: str = "high" |
|
|
preserve_audio: bool = True |
|
|
fps_limit: Optional[int] = None |
|
|
codec: str = "h264" |
|
|
|
|
|
class ConfigManager: |
|
|
"""Main configuration manager""" |
|
|
|
|
|
def __init__(self, config_dir: str = ".", checkpoints_dir: str = "checkpoints"): |
|
|
self.config_dir = Path(config_dir) |
|
|
self.checkpoints_dir = Path(checkpoints_dir) |
|
|
|
|
|
|
|
|
self.models: Dict[str, ModelConfig] = {} |
|
|
self.quality = QualityConfig() |
|
|
self.processing = ProcessingConfig() |
|
|
self.video = VideoConfig() |
|
|
|
|
|
|
|
|
self.debug_mode = False |
|
|
self.environment = "development" |
|
|
|
|
|
|
|
|
self._initialize_default_configs() |
|
|
|
|
|
def _initialize_default_configs(self): |
|
|
"""Initialize with default model configurations""" |
|
|
|
|
|
|
|
|
self.models['sam2'] = ModelConfig( |
|
|
name='sam2', |
|
|
path=self._find_model_path('sam2', ['sam2_hiera_large.pt', 'sam2_hiera_base.pt']), |
|
|
device='auto', |
|
|
enabled=True, |
|
|
fallback=False, |
|
|
parameters={ |
|
|
'model_type': 'vit_l', |
|
|
'checkpoint': None, |
|
|
'multimask_output': False, |
|
|
'use_checkpoint': True |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
self.models['matanyone'] = ModelConfig( |
|
|
name='matanyone', |
|
|
path=None, |
|
|
device='auto', |
|
|
enabled=True, |
|
|
fallback=False, |
|
|
parameters={ |
|
|
'use_hf_api': True, |
|
|
'hf_model': 'PeiqingYang/MatAnyone', |
|
|
'api_timeout': 60, |
|
|
'quality_threshold': 0.3, |
|
|
'fallback_enabled': True |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
self.models['traditional_cv'] = ModelConfig( |
|
|
name='traditional_cv', |
|
|
path=None, |
|
|
device='cpu', |
|
|
enabled=True, |
|
|
fallback=True, |
|
|
parameters={ |
|
|
'methods': ['canny', 'color_detection', 'texture_analysis'], |
|
|
'edge_threshold': [50, 150], |
|
|
'color_ranges': { |
|
|
'dark_hair': [[0, 0, 0], [180, 255, 80]], |
|
|
'brown_hair': [[8, 50, 20], [25, 255, 200]] |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
def _find_model_path(self, model_name: str, possible_files: list) -> Optional[str]: |
|
|
"""Find model file in checkpoints directory""" |
|
|
try: |
|
|
|
|
|
for filename in possible_files: |
|
|
full_path = self.checkpoints_dir / filename |
|
|
if full_path.exists(): |
|
|
logger.info(f"✅ Found {model_name} at: {full_path}") |
|
|
return str(full_path) |
|
|
|
|
|
|
|
|
model_subdir = self.checkpoints_dir / model_name / filename |
|
|
if model_subdir.exists(): |
|
|
logger.info(f"✅ Found {model_name} at: {model_subdir}") |
|
|
return str(model_subdir) |
|
|
|
|
|
logger.warning(f"⚠️ {model_name} model not found in {self.checkpoints_dir}") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error finding {model_name}: {e}") |
|
|
return None |
|
|
|
|
|
def load_from_file(self, config_path: str) -> bool: |
|
|
"""Load configuration from YAML or JSON file""" |
|
|
try: |
|
|
config_path = Path(config_path) |
|
|
|
|
|
if not config_path.exists(): |
|
|
logger.warning(f"⚠️ Config file not found: {config_path}") |
|
|
return False |
|
|
|
|
|
|
|
|
if config_path.suffix.lower() in ['.yaml', '.yml']: |
|
|
with open(config_path, 'r') as f: |
|
|
config_data = yaml.safe_load(f) |
|
|
elif config_path.suffix.lower() == '.json': |
|
|
with open(config_path, 'r') as f: |
|
|
config_data = json.load(f) |
|
|
else: |
|
|
logger.error(f"❌ Unsupported config format: {config_path.suffix}") |
|
|
return False |
|
|
|
|
|
|
|
|
self._apply_config_data(config_data) |
|
|
logger.info(f"✅ Configuration loaded from: {config_path}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to load config from {config_path}: {e}") |
|
|
return False |
|
|
|
|
|
def _apply_config_data(self, config_data: Dict[str, Any]): |
|
|
"""Apply configuration data to current settings""" |
|
|
try: |
|
|
|
|
|
if 'models' in config_data: |
|
|
for model_name, model_config in config_data['models'].items(): |
|
|
if model_name in self.models: |
|
|
|
|
|
for key, value in model_config.items(): |
|
|
if hasattr(self.models[model_name], key): |
|
|
setattr(self.models[model_name], key, value) |
|
|
elif key == 'parameters': |
|
|
self.models[model_name].parameters.update(value) |
|
|
|
|
|
|
|
|
if 'quality' in config_data: |
|
|
for key, value in config_data['quality'].items(): |
|
|
if hasattr(self.quality, key): |
|
|
setattr(self.quality, key, value) |
|
|
|
|
|
|
|
|
if 'processing' in config_data: |
|
|
for key, value in config_data['processing'].items(): |
|
|
if hasattr(self.processing, key): |
|
|
setattr(self.processing, key, value) |
|
|
|
|
|
|
|
|
if 'video' in config_data: |
|
|
for key, value in config_data['video'].items(): |
|
|
if hasattr(self.video, key): |
|
|
setattr(self.video, key, value) |
|
|
|
|
|
|
|
|
if 'environment' in config_data: |
|
|
self.environment = config_data['environment'] |
|
|
|
|
|
if 'debug_mode' in config_data: |
|
|
self.debug_mode = config_data['debug_mode'] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error applying config data: {e}") |
|
|
raise |
|
|
|
|
|
def load_from_environment(self): |
|
|
"""Load configuration from environment variables""" |
|
|
try: |
|
|
|
|
|
sam2_path = os.getenv('SAM2_MODEL_PATH') |
|
|
if sam2_path and Path(sam2_path).exists(): |
|
|
self.models['sam2'].path = sam2_path |
|
|
|
|
|
|
|
|
hf_token = os.getenv('HUGGINGFACE_TOKEN') |
|
|
if hf_token: |
|
|
self.models['matanyone'].parameters['hf_token'] = hf_token |
|
|
|
|
|
|
|
|
device = os.getenv('TORCH_DEVICE', os.getenv('DEVICE')) |
|
|
if device: |
|
|
for model in self.models.values(): |
|
|
if model.device == 'auto': |
|
|
model.device = device |
|
|
|
|
|
|
|
|
batch_size = os.getenv('BATCH_SIZE') |
|
|
if batch_size: |
|
|
self.processing.batch_size = int(batch_size) |
|
|
|
|
|
|
|
|
min_confidence = os.getenv('MIN_DETECTION_CONFIDENCE') |
|
|
if min_confidence: |
|
|
self.quality.min_detection_confidence = float(min_confidence) |
|
|
|
|
|
|
|
|
env_mode = os.getenv('ENVIRONMENT', os.getenv('ENV')) |
|
|
if env_mode: |
|
|
self.environment = env_mode |
|
|
|
|
|
|
|
|
debug = os.getenv('DEBUG', os.getenv('DEBUG_MODE')) |
|
|
if debug: |
|
|
self.debug_mode = debug.lower() in ['true', '1', 'yes'] |
|
|
|
|
|
logger.info("✅ Environment variables loaded") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error loading environment variables: {e}") |
|
|
|
|
|
def save_to_file(self, config_path: str, format: str = 'yaml') -> bool: |
|
|
"""Save current configuration to file""" |
|
|
try: |
|
|
config_path = Path(config_path) |
|
|
config_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
config_data = self.to_dict() |
|
|
|
|
|
|
|
|
if format.lower() in ['yaml', 'yml']: |
|
|
with open(config_path, 'w') as f: |
|
|
yaml.dump(config_data, f, default_flow_style=False, indent=2) |
|
|
elif format.lower() == 'json': |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config_data, f, indent=2) |
|
|
else: |
|
|
logger.error(f"❌ Unsupported save format: {format}") |
|
|
return False |
|
|
|
|
|
logger.info(f"✅ Configuration saved to: {config_path}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to save config to {config_path}: {e}") |
|
|
return False |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert configuration to dictionary""" |
|
|
return { |
|
|
'models': { |
|
|
name: { |
|
|
'name': config.name, |
|
|
'path': config.path, |
|
|
'device': config.device, |
|
|
'enabled': config.enabled, |
|
|
'fallback': config.fallback, |
|
|
'parameters': config.parameters |
|
|
} for name, config in self.models.items() |
|
|
}, |
|
|
'quality': { |
|
|
'min_detection_confidence': self.quality.min_detection_confidence, |
|
|
'min_edge_quality': self.quality.min_edge_quality, |
|
|
'min_mask_coverage': self.quality.min_mask_coverage, |
|
|
'max_asymmetry_score': self.quality.max_asymmetry_score, |
|
|
'temporal_consistency_threshold': self.quality.temporal_consistency_threshold, |
|
|
'matanyone_quality_threshold': self.quality.matanyone_quality_threshold |
|
|
}, |
|
|
'processing': { |
|
|
'batch_size': self.processing.batch_size, |
|
|
'max_resolution': self.processing.max_resolution, |
|
|
'temporal_smoothing': self.processing.temporal_smoothing, |
|
|
'edge_refinement': self.processing.edge_refinement, |
|
|
'fallback_enabled': self.processing.fallback_enabled, |
|
|
'cache_enabled': self.processing.cache_enabled |
|
|
}, |
|
|
'video': { |
|
|
'output_format': self.video.output_format, |
|
|
'output_quality': self.video.output_quality, |
|
|
'preserve_audio': self.video.preserve_audio, |
|
|
'fps_limit': self.video.fps_limit, |
|
|
'codec': self.video.codec |
|
|
}, |
|
|
'environment': self.environment, |
|
|
'debug_mode': self.debug_mode |
|
|
} |
|
|
|
|
|
def get_model_config(self, model_name: str) -> Optional[ModelConfig]: |
|
|
"""Get configuration for specific model""" |
|
|
return self.models.get(model_name) |
|
|
|
|
|
def is_model_enabled(self, model_name: str) -> bool: |
|
|
"""Check if model is enabled""" |
|
|
model = self.models.get(model_name) |
|
|
return model.enabled if model else False |
|
|
|
|
|
def get_enabled_models(self) -> Dict[str, ModelConfig]: |
|
|
"""Get all enabled models""" |
|
|
return {name: config for name, config in self.models.items() if config.enabled} |
|
|
|
|
|
def get_fallback_models(self) -> Dict[str, ModelConfig]: |
|
|
"""Get all fallback models""" |
|
|
return {name: config for name, config in self.models.items() |
|
|
if config.enabled and config.fallback} |
|
|
|
|
|
def update_model_path(self, model_name: str, path: str) -> bool: |
|
|
"""Update model path""" |
|
|
if model_name in self.models: |
|
|
if Path(path).exists(): |
|
|
self.models[model_name].path = path |
|
|
logger.info(f"✅ Updated {model_name} path: {path}") |
|
|
return True |
|
|
else: |
|
|
logger.error(f"❌ Model path does not exist: {path}") |
|
|
return False |
|
|
else: |
|
|
logger.error(f"❌ Unknown model: {model_name}") |
|
|
return False |
|
|
|
|
|
def validate_configuration(self) -> Dict[str, Any]: |
|
|
"""Validate current configuration and return status""" |
|
|
validation_results = { |
|
|
'valid': True, |
|
|
'errors': [], |
|
|
'warnings': [], |
|
|
'model_status': {} |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
for name, config in self.models.items(): |
|
|
model_status = {'enabled': config.enabled, 'path_exists': True, 'issues': []} |
|
|
|
|
|
if config.enabled and config.path: |
|
|
if not Path(config.path).exists(): |
|
|
model_status['path_exists'] = False |
|
|
model_status['issues'].append(f"Model file not found: {config.path}") |
|
|
validation_results['errors'].append(f"{name}: Model file not found") |
|
|
validation_results['valid'] = False |
|
|
|
|
|
validation_results['model_status'][name] = model_status |
|
|
|
|
|
|
|
|
if not 0 <= self.quality.min_detection_confidence <= 1: |
|
|
validation_results['errors'].append("min_detection_confidence must be between 0 and 1") |
|
|
validation_results['valid'] = False |
|
|
|
|
|
|
|
|
if self.processing.batch_size < 1: |
|
|
validation_results['errors'].append("batch_size must be >= 1") |
|
|
validation_results['valid'] = False |
|
|
|
|
|
|
|
|
enabled_models = self.get_enabled_models() |
|
|
if not enabled_models: |
|
|
validation_results['warnings'].append("No models are enabled") |
|
|
|
|
|
|
|
|
fallback_models = self.get_fallback_models() |
|
|
if not fallback_models: |
|
|
validation_results['warnings'].append("No fallback models configured") |
|
|
|
|
|
logger.info(f"✅ Configuration validation completed: {'Valid' if validation_results['valid'] else 'Invalid'}") |
|
|
|
|
|
except Exception as e: |
|
|
validation_results['valid'] = False |
|
|
validation_results['errors'].append(f"Validation error: {str(e)}") |
|
|
logger.error(f"❌ Configuration validation failed: {e}") |
|
|
|
|
|
return validation_results |
|
|
|
|
|
def create_runtime_config(self) -> Dict[str, Any]: |
|
|
"""Create runtime configuration for processing pipeline""" |
|
|
return { |
|
|
'models': self.get_enabled_models(), |
|
|
'quality_thresholds': { |
|
|
'min_confidence': self.quality.min_detection_confidence, |
|
|
'min_edge_quality': self.quality.min_edge_quality, |
|
|
'temporal_threshold': self.quality.temporal_consistency_threshold, |
|
|
'matanyone_threshold': self.quality.matanyone_quality_threshold |
|
|
}, |
|
|
'processing_options': { |
|
|
'batch_size': self.processing.batch_size, |
|
|
'temporal_smoothing': self.processing.temporal_smoothing, |
|
|
'edge_refinement': self.processing.edge_refinement, |
|
|
'fallback_enabled': self.processing.fallback_enabled, |
|
|
'cache_enabled': self.processing.cache_enabled |
|
|
}, |
|
|
'video_settings': { |
|
|
'format': self.video.output_format, |
|
|
'quality': self.video.output_quality, |
|
|
'preserve_audio': self.video.preserve_audio, |
|
|
'codec': self.video.codec |
|
|
}, |
|
|
'debug_mode': self.debug_mode |
|
|
} |
|
|
|
|
|
|
|
|
_config_manager: Optional[ConfigManager] = None |
|
|
|
|
|
def get_config(config_dir: str = ".", checkpoints_dir: str = "checkpoints") -> ConfigManager: |
|
|
"""Get global configuration manager""" |
|
|
global _config_manager |
|
|
if _config_manager is None: |
|
|
_config_manager = ConfigManager(config_dir, checkpoints_dir) |
|
|
|
|
|
_config_manager.load_from_environment() |
|
|
|
|
|
|
|
|
config_files = ['config.yaml', 'config.yml', 'config.json'] |
|
|
for config_file in config_files: |
|
|
if Path(config_file).exists(): |
|
|
_config_manager.load_from_file(config_file) |
|
|
break |
|
|
|
|
|
return _config_manager |
|
|
|
|
|
def load_config(config_path: str) -> ConfigManager: |
|
|
"""Load configuration from specific file""" |
|
|
config = get_config() |
|
|
config.load_from_file(config_path) |
|
|
return config |
|
|
|
|
|
def get_model_config(model_name: str) -> Optional[ModelConfig]: |
|
|
"""Get model configuration""" |
|
|
return get_config().get_model_config(model_name) |
|
|
|
|
|
def is_model_enabled(model_name: str) -> bool: |
|
|
"""Check if model is enabled""" |
|
|
return get_config().is_model_enabled(model_name) |
|
|
|
|
|
def get_quality_thresholds() -> QualityConfig: |
|
|
"""Get quality configuration""" |
|
|
return get_config().quality |
|
|
|
|
|
def get_processing_config() -> ProcessingConfig: |
|
|
"""Get processing configuration""" |
|
|
return get_config().processing |