MogensR commited on
Commit
ca7243a
·
1 Parent(s): 31e1565

Create api/pipeline.py

Browse files
Files changed (1) hide show
  1. api/pipeline.py +763 -0
api/pipeline.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main processing pipeline for BackgroundFX Pro.
3
+ Orchestrates the complete background removal and replacement workflow.
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from typing import Dict, List, Optional, Tuple, Union, Callable, Any
10
+ from dataclasses import dataclass, field
11
+ from enum import Enum
12
+ from pathlib import Path
13
+ import time
14
+ import threading
15
+ from queue import Queue
16
+ import json
17
+ import hashlib
18
+ from concurrent.futures import ThreadPoolExecutor, Future
19
+
20
+ from ..utils.logger import setup_logger
21
+ from ..utils.device import DeviceManager
22
+ from ..utils.config import ConfigManager
23
+ from ..utils import TimeEstimator, MemoryMonitor
24
+
25
+ from ..core.models import ModelFactory, ModelType
26
+ from ..core.temporal import TemporalCoherence
27
+ from ..core.quality import QualityAnalyzer
28
+ from ..core.edge import EdgeRefinement
29
+ from ..core.hair_segmentation import HairSegmentation
30
+
31
+ from ..processing.matting import AlphaMatting, MattingConfig, CompositingEngine
32
+ from ..processing.fallback import FallbackStrategy, FallbackLevel
33
+ from ..processing.effects import BackgroundEffects, CompositeEffects, EffectType
34
+
35
+ logger = setup_logger(__name__)
36
+
37
+
38
+ class ProcessingMode(Enum):
39
+ """Processing mode types."""
40
+ PHOTO = "photo"
41
+ VIDEO = "video"
42
+ REALTIME = "realtime"
43
+ BATCH = "batch"
44
+
45
+
46
+ class PipelineStage(Enum):
47
+ """Pipeline processing stages."""
48
+ INITIALIZATION = "initialization"
49
+ PREPROCESSING = "preprocessing"
50
+ SEGMENTATION = "segmentation"
51
+ MATTING = "matting"
52
+ REFINEMENT = "refinement"
53
+ EFFECTS = "effects"
54
+ COMPOSITING = "compositing"
55
+ POSTPROCESSING = "postprocessing"
56
+ COMPLETE = "complete"
57
+
58
+
59
+ @dataclass
60
+ class PipelineConfig:
61
+ """Configuration for the processing pipeline."""
62
+ # Model settings
63
+ model_type: ModelType = ModelType.RMBG_1_4
64
+ use_gpu: bool = True
65
+ device: Optional[str] = None
66
+
67
+ # Processing settings
68
+ mode: ProcessingMode = ProcessingMode.PHOTO
69
+ enable_temporal: bool = True
70
+ enable_hair_refinement: bool = True
71
+ enable_edge_refinement: bool = True
72
+ enable_fallback: bool = True
73
+
74
+ # Quality settings
75
+ quality_preset: str = "high" # low, medium, high, ultra
76
+ target_resolution: Optional[Tuple[int, int]] = None
77
+ maintain_aspect_ratio: bool = True
78
+
79
+ # Matting settings
80
+ matting_method: str = "auto" # auto, trimap, deep, guided
81
+ matting_config: MattingConfig = field(default_factory=MattingConfig)
82
+
83
+ # Effects settings
84
+ background_blur: bool = False
85
+ blur_strength: float = 15.0
86
+ apply_effects: List[EffectType] = field(default_factory=list)
87
+
88
+ # Performance settings
89
+ batch_size: int = 1
90
+ num_workers: int = 4
91
+ enable_caching: bool = True
92
+ cache_size_mb: int = 500
93
+
94
+ # Output settings
95
+ output_format: str = "png" # png, jpg, webp
96
+ output_quality: int = 95
97
+ preserve_metadata: bool = True
98
+
99
+ # Callbacks
100
+ progress_callback: Optional[Callable[[float, str], None]] = None
101
+ stage_callback: Optional[Callable[[PipelineStage, Dict], None]] = None
102
+
103
+
104
+ @dataclass
105
+ class PipelineResult:
106
+ """Result from pipeline processing."""
107
+ success: bool
108
+ output_image: Optional[np.ndarray] = None
109
+ alpha_matte: Optional[np.ndarray] = None
110
+ foreground: Optional[np.ndarray] = None
111
+ background: Optional[np.ndarray] = None
112
+ metadata: Dict[str, Any] = field(default_factory=dict)
113
+ processing_time: float = 0.0
114
+ stages_completed: List[PipelineStage] = field(default_factory=list)
115
+ errors: List[str] = field(default_factory=list)
116
+ quality_score: float = 0.0
117
+
118
+
119
+ class ProcessingPipeline:
120
+ """Main processing pipeline orchestrator."""
121
+
122
+ def __init__(self, config: Optional[PipelineConfig] = None):
123
+ """
124
+ Initialize the processing pipeline.
125
+
126
+ Args:
127
+ config: Pipeline configuration
128
+ """
129
+ self.config = config or PipelineConfig()
130
+ self.logger = setup_logger(f"{__name__}.ProcessingPipeline")
131
+
132
+ # Initialize components
133
+ self._initialize_components()
134
+
135
+ # State management
136
+ self.current_stage = PipelineStage.INITIALIZATION
137
+ self.processing_stats = {}
138
+ self.cache = {}
139
+ self.is_processing = False
140
+
141
+ # Thread pool for parallel processing
142
+ self.executor = ThreadPoolExecutor(max_workers=self.config.num_workers)
143
+
144
+ self.logger.info("Pipeline initialized successfully")
145
+
146
+ def _initialize_components(self):
147
+ """Initialize all pipeline components."""
148
+ try:
149
+ # Device management
150
+ self.device_manager = DeviceManager()
151
+ if self.config.device:
152
+ self.device_manager.set_device(self.config.device)
153
+ elif not self.config.use_gpu:
154
+ self.device_manager.set_device('cpu')
155
+
156
+ # Core components
157
+ self.model_factory = ModelFactory()
158
+ self.quality_analyzer = QualityAnalyzer()
159
+ self.edge_refinement = EdgeRefinement()
160
+ self.temporal_coherence = TemporalCoherence() if self.config.enable_temporal else None
161
+ self.hair_segmentation = HairSegmentation() if self.config.enable_hair_refinement else None
162
+
163
+ # Processing components
164
+ self.alpha_matting = AlphaMatting(self.config.matting_config)
165
+ self.compositing_engine = CompositingEngine()
166
+ self.background_effects = BackgroundEffects()
167
+ self.composite_effects = CompositeEffects()
168
+
169
+ # Fallback strategy
170
+ self.fallback_strategy = FallbackStrategy() if self.config.enable_fallback else None
171
+
172
+ # Memory monitoring
173
+ self.memory_monitor = MemoryMonitor()
174
+ self.time_estimator = TimeEstimator()
175
+
176
+ # Load model
177
+ self._load_model()
178
+
179
+ except Exception as e:
180
+ self.logger.error(f"Component initialization failed: {e}")
181
+ raise
182
+
183
+ def _load_model(self):
184
+ """Load the segmentation model."""
185
+ try:
186
+ self.logger.info(f"Loading model: {self.config.model_type.value}")
187
+
188
+ self.model = self.model_factory.load_model(
189
+ self.config.model_type,
190
+ device=self.device_manager.get_device(),
191
+ optimize=True
192
+ )
193
+
194
+ self.logger.info("Model loaded successfully")
195
+
196
+ except Exception as e:
197
+ self.logger.error(f"Model loading failed: {e}")
198
+ if self.config.enable_fallback:
199
+ self.logger.info("Attempting fallback model loading")
200
+ self.config.model_type = ModelType.U2NET_LITE
201
+ self.model = self.model_factory.load_model(
202
+ self.config.model_type,
203
+ device='cpu'
204
+ )
205
+
206
+ def process_image(self,
207
+ image: Union[np.ndarray, str, Path],
208
+ background: Optional[Union[np.ndarray, str, Path]] = None,
209
+ **kwargs) -> PipelineResult:
210
+ """
211
+ Process a single image through the pipeline.
212
+
213
+ Args:
214
+ image: Input image (array or path)
215
+ background: Optional background image/path
216
+ **kwargs: Additional processing parameters
217
+
218
+ Returns:
219
+ PipelineResult with processed image and metadata
220
+ """
221
+ start_time = time.time()
222
+ self.is_processing = True
223
+ result = PipelineResult(success=False)
224
+
225
+ try:
226
+ # Stage 1: Initialization
227
+ self._update_stage(PipelineStage.INITIALIZATION)
228
+ image_array = self._load_image(image)
229
+ bg_array = self._load_image(background) if background is not None else None
230
+
231
+ # Generate cache key
232
+ cache_key = self._generate_cache_key(image_array, kwargs)
233
+
234
+ # Check cache
235
+ if self.config.enable_caching and cache_key in self.cache:
236
+ self.logger.info("Using cached result")
237
+ cached_result = self.cache[cache_key]
238
+ cached_result.processing_time = time.time() - start_time
239
+ return cached_result
240
+
241
+ # Stage 2: Preprocessing
242
+ self._update_stage(PipelineStage.PREPROCESSING)
243
+ preprocessed = self._preprocess_image(image_array)
244
+ result.metadata['original_size'] = image_array.shape[:2]
245
+ result.metadata['preprocessed_size'] = preprocessed.shape[:2]
246
+
247
+ # Quality analysis
248
+ quality_metrics = self.quality_analyzer.analyze_frame(preprocessed)
249
+ result.metadata['quality_metrics'] = quality_metrics
250
+
251
+ # Stage 3: Segmentation
252
+ self._update_stage(PipelineStage.SEGMENTATION)
253
+ segmentation_mask = self._segment_image(preprocessed)
254
+
255
+ # Hair refinement if enabled
256
+ if self.config.enable_hair_refinement:
257
+ self.logger.info("Applying hair refinement")
258
+ hair_mask = self.hair_segmentation.segment_hair(preprocessed)
259
+ segmentation_mask = self._combine_masks(segmentation_mask, hair_mask)
260
+
261
+ # Stage 4: Matting
262
+ self._update_stage(PipelineStage.MATTING)
263
+ matting_result = self.alpha_matting.process(
264
+ preprocessed,
265
+ segmentation_mask,
266
+ method=self.config.matting_method
267
+ )
268
+ alpha_matte = matting_result['alpha']
269
+ result.metadata['matting_confidence'] = matting_result['confidence']
270
+
271
+ # Stage 5: Refinement
272
+ self._update_stage(PipelineStage.REFINEMENT)
273
+ if self.config.enable_edge_refinement:
274
+ alpha_matte = self.edge_refinement.refine_edges(
275
+ preprocessed,
276
+ (alpha_matte * 255).astype(np.uint8)
277
+ ) / 255.0
278
+
279
+ # Resize alpha to original size if needed
280
+ if preprocessed.shape[:2] != image_array.shape[:2]:
281
+ alpha_matte = cv2.resize(
282
+ alpha_matte,
283
+ (image_array.shape[1], image_array.shape[0]),
284
+ interpolation=cv2.INTER_LINEAR
285
+ )
286
+
287
+ # Extract foreground
288
+ foreground = self._extract_foreground(image_array, alpha_matte)
289
+
290
+ # Stage 6: Background & Effects
291
+ self._update_stage(PipelineStage.EFFECTS)
292
+
293
+ if bg_array is not None:
294
+ # Resize background to match image
295
+ bg_array = self._resize_background(bg_array, image_array.shape[:2])
296
+
297
+ # Apply background effects
298
+ if self.config.background_blur:
299
+ bg_array = self.background_effects.apply_blur(
300
+ bg_array,
301
+ strength=self.config.blur_strength,
302
+ mask=1 - alpha_matte
303
+ )
304
+
305
+ # Apply configured effects
306
+ if self.config.apply_effects:
307
+ bg_array = self._apply_effects(bg_array, alpha_matte)
308
+ else:
309
+ # Create transparent background
310
+ bg_array = np.zeros_like(image_array)
311
+
312
+ # Stage 7: Compositing
313
+ self._update_stage(PipelineStage.COMPOSITING)
314
+
315
+ if self.config.apply_effects and EffectType.LIGHT_WRAP in self.config.apply_effects:
316
+ foreground = self.background_effects.apply_light_wrap(
317
+ foreground, bg_array, alpha_matte
318
+ )
319
+
320
+ composited = self.compositing_engine.composite(
321
+ foreground, bg_array, alpha_matte
322
+ )
323
+
324
+ # Apply post-composite effects
325
+ if self.config.apply_effects:
326
+ composited = self._apply_post_effects(composited, alpha_matte)
327
+
328
+ # Stage 8: Postprocessing
329
+ self._update_stage(PipelineStage.POSTPROCESSING)
330
+ final_output = self._postprocess_image(composited, alpha_matte)
331
+
332
+ # Calculate quality score
333
+ result.quality_score = self._calculate_quality_score(
334
+ final_output, alpha_matte, quality_metrics
335
+ )
336
+
337
+ # Build result
338
+ result.success = True
339
+ result.output_image = final_output
340
+ result.alpha_matte = alpha_matte
341
+ result.foreground = foreground
342
+ result.background = bg_array
343
+ result.stages_completed = list(PipelineStage)
344
+ result.processing_time = time.time() - start_time
345
+
346
+ # Cache result
347
+ if self.config.enable_caching:
348
+ self._cache_result(cache_key, result)
349
+
350
+ # Complete
351
+ self._update_stage(PipelineStage.COMPLETE)
352
+ self.logger.info(f"Processing completed in {result.processing_time:.2f}s")
353
+
354
+ # Update statistics
355
+ self._update_statistics(result)
356
+
357
+ except Exception as e:
358
+ self.logger.error(f"Pipeline processing failed: {e}")
359
+ result.errors.append(str(e))
360
+
361
+ if self.config.enable_fallback and self.fallback_strategy:
362
+ self.logger.info("Attempting fallback processing")
363
+ result = self._fallback_processing(image_array, bg_array)
364
+
365
+ finally:
366
+ self.is_processing = False
367
+
368
+ return result
369
+
370
+ def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
371
+ """Preprocess image for optimal processing."""
372
+ processed = image.copy()
373
+
374
+ # Resize if needed
375
+ if self.config.target_resolution:
376
+ target_h, target_w = self.config.target_resolution
377
+ h, w = image.shape[:2]
378
+
379
+ if self.config.maintain_aspect_ratio:
380
+ scale = min(target_w / w, target_h / h)
381
+ new_w = int(w * scale)
382
+ new_h = int(h * scale)
383
+ else:
384
+ new_w, new_h = target_w, target_h
385
+
386
+ if (new_w, new_h) != (w, h):
387
+ processed = cv2.resize(processed, (new_w, new_h),
388
+ interpolation=cv2.INTER_AREA)
389
+
390
+ # Apply quality-based preprocessing
391
+ if self.config.quality_preset == "low":
392
+ # Reduce noise for faster processing
393
+ processed = cv2.fastNlMeansDenoising(processed, None, 10, 7, 21)
394
+ elif self.config.quality_preset in ["high", "ultra"]:
395
+ # Enhance details
396
+ processed = cv2.detailEnhance(processed, sigma_s=10, sigma_r=0.15)
397
+
398
+ return processed
399
+
400
+ def _segment_image(self, image: np.ndarray) -> np.ndarray:
401
+ """Perform image segmentation."""
402
+ try:
403
+ # Use the loaded model for segmentation
404
+ with torch.no_grad():
405
+ # Prepare input
406
+ input_tensor = self._prepare_input_tensor(image)
407
+
408
+ # Run inference
409
+ output = self.model(input_tensor)
410
+
411
+ # Process output
412
+ if isinstance(output, tuple):
413
+ output = output[0]
414
+
415
+ # Convert to numpy mask
416
+ mask = output.squeeze().cpu().numpy()
417
+
418
+ # Threshold and convert to uint8
419
+ mask = (mask > 0.5).astype(np.uint8) * 255
420
+
421
+ # Resize to original size if needed
422
+ if mask.shape[:2] != image.shape[:2]:
423
+ mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
424
+
425
+ return mask
426
+
427
+ except Exception as e:
428
+ self.logger.error(f"Segmentation failed: {e}")
429
+ if self.config.enable_fallback:
430
+ # Use basic segmentation as fallback
431
+ from ..processing.fallback import ProcessingFallback
432
+ fallback = ProcessingFallback()
433
+ return fallback.basic_segmentation(image)
434
+ raise
435
+
436
+ def _prepare_input_tensor(self, image: np.ndarray) -> torch.Tensor:
437
+ """Prepare image tensor for model input."""
438
+ # Resize to model input size (typically 512x512 or 1024x1024)
439
+ model_size = 512 # Default, should be from model config
440
+ resized = cv2.resize(image, (model_size, model_size))
441
+
442
+ # Convert to tensor
443
+ tensor = torch.from_numpy(resized.transpose(2, 0, 1)).float()
444
+ tensor = tensor.unsqueeze(0) / 255.0
445
+
446
+ # Move to device
447
+ tensor = tensor.to(self.device_manager.get_device())
448
+
449
+ return tensor
450
+
451
+ def _combine_masks(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray:
452
+ """Combine two masks intelligently."""
453
+ # Convert to float for blending
454
+ m1 = mask1.astype(np.float32) / 255.0
455
+ m2 = mask2.astype(np.float32) / 255.0
456
+
457
+ # Combine using maximum (union)
458
+ combined = np.maximum(m1, m2)
459
+
460
+ # Convert back to uint8
461
+ return (combined * 255).astype(np.uint8)
462
+
463
+ def _extract_foreground(self, image: np.ndarray,
464
+ alpha: np.ndarray) -> np.ndarray:
465
+ """Extract foreground using alpha matte."""
466
+ if len(alpha.shape) == 2:
467
+ alpha = np.expand_dims(alpha, axis=2)
468
+
469
+ if alpha.shape[2] == 1:
470
+ alpha = np.repeat(alpha, 3, axis=2)
471
+
472
+ # Premultiply alpha
473
+ foreground = image.astype(np.float32) * alpha
474
+
475
+ return foreground.astype(np.uint8)
476
+
477
+ def _resize_background(self, background: np.ndarray,
478
+ target_shape: Tuple[int, int]) -> np.ndarray:
479
+ """Resize background to match target shape."""
480
+ h, w = target_shape
481
+ bg_h, bg_w = background.shape[:2]
482
+
483
+ if (bg_h, bg_w) == (h, w):
484
+ return background
485
+
486
+ # Calculate scale to cover entire image
487
+ scale = max(h / bg_h, w / bg_w)
488
+ new_h = int(bg_h * scale)
489
+ new_w = int(bg_w * scale)
490
+
491
+ # Resize
492
+ resized = cv2.resize(background, (new_w, new_h),
493
+ interpolation=cv2.INTER_LINEAR)
494
+
495
+ # Center crop
496
+ start_y = (new_h - h) // 2
497
+ start_x = (new_w - w) // 2
498
+ cropped = resized[start_y:start_y + h, start_x:start_x + w]
499
+
500
+ return cropped
501
+
502
+ def _apply_effects(self, image: np.ndarray,
503
+ mask: np.ndarray) -> np.ndarray:
504
+ """Apply configured effects to image."""
505
+ result = image.copy()
506
+
507
+ for effect in self.config.apply_effects:
508
+ if effect == EffectType.BOKEH:
509
+ result = self.background_effects.apply_bokeh(result)
510
+ elif effect == EffectType.VIGNETTE:
511
+ result = self.background_effects.add_vignette(result)
512
+ elif effect == EffectType.FILM_GRAIN:
513
+ result = self.background_effects.add_film_grain(result)
514
+
515
+ return result
516
+
517
+ def _apply_post_effects(self, image: np.ndarray,
518
+ mask: np.ndarray) -> np.ndarray:
519
+ """Apply post-composite effects."""
520
+ result = image.copy()
521
+
522
+ for effect in self.config.apply_effects:
523
+ if effect == EffectType.SHADOW:
524
+ result = self.background_effects.add_shadow(result, mask)
525
+ elif effect == EffectType.REFLECTION:
526
+ result = self.background_effects.add_reflection(result, mask)
527
+ elif effect == EffectType.GLOW:
528
+ result = self.background_effects.add_glow(result, mask)
529
+ elif effect == EffectType.CHROMATIC_ABERRATION:
530
+ result = self.background_effects.chromatic_aberration(result)
531
+
532
+ return result
533
+
534
+ def _postprocess_image(self, image: np.ndarray,
535
+ alpha: np.ndarray) -> np.ndarray:
536
+ """Apply final postprocessing."""
537
+ result = image.copy()
538
+
539
+ # Color correction based on quality preset
540
+ if self.config.quality_preset in ["high", "ultra"]:
541
+ # Auto color balance
542
+ lab = cv2.cvtColor(result, cv2.COLOR_BGR2LAB)
543
+ l, a, b = cv2.split(lab)
544
+ l = cv2.equalizeHist(l)
545
+ result = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
546
+
547
+ # Sharpen if ultra quality
548
+ if self.config.quality_preset == "ultra":
549
+ kernel = np.array([[-1,-1,-1],
550
+ [-1, 9,-1],
551
+ [-1,-1,-1]])
552
+ result = cv2.filter2D(result, -1, kernel)
553
+
554
+ return result
555
+
556
+ def _calculate_quality_score(self, image: np.ndarray,
557
+ alpha: np.ndarray,
558
+ metrics: Dict) -> float:
559
+ """Calculate overall quality score."""
560
+ scores = []
561
+
562
+ # Edge quality
563
+ edge_score = metrics.get('edge_clarity', 0.5)
564
+ scores.append(edge_score)
565
+
566
+ # Alpha matte quality (contrast)
567
+ alpha_std = np.std(alpha)
568
+ alpha_score = min(alpha_std * 2, 1.0) # Higher std = better separation
569
+ scores.append(alpha_score)
570
+
571
+ # Overall image quality
572
+ quality_score = metrics.get('overall_quality', 0.5)
573
+ scores.append(quality_score)
574
+
575
+ return np.mean(scores)
576
+
577
+ def _load_image(self, source: Union[np.ndarray, str, Path]) -> np.ndarray:
578
+ """Load image from various sources."""
579
+ if isinstance(source, np.ndarray):
580
+ return source
581
+
582
+ path = Path(source) if not isinstance(source, Path) else source
583
+ if not path.exists():
584
+ raise FileNotFoundError(f"Image not found: {path}")
585
+
586
+ image = cv2.imread(str(path))
587
+ if image is None:
588
+ raise ValueError(f"Failed to load image: {path}")
589
+
590
+ return image
591
+
592
+ def _generate_cache_key(self, image: np.ndarray,
593
+ params: Dict) -> str:
594
+ """Generate cache key for result."""
595
+ # Create hash from image and parameters
596
+ hasher = hashlib.md5()
597
+ hasher.update(image.tobytes())
598
+ hasher.update(json.dumps(params, sort_keys=True).encode())
599
+ return hasher.hexdigest()
600
+
601
+ def _cache_result(self, key: str, result: PipelineResult):
602
+ """Cache processing result."""
603
+ self.cache[key] = result
604
+
605
+ # Limit cache size
606
+ cache_memory = sum(
607
+ r.output_image.nbytes if r.output_image is not None else 0
608
+ for r in self.cache.values()
609
+ )
610
+
611
+ max_bytes = self.config.cache_size_mb * 1024 * 1024
612
+
613
+ if cache_memory > max_bytes:
614
+ # Remove oldest entries
615
+ for old_key in list(self.cache.keys())[:len(self.cache)//4]:
616
+ del self.cache[old_key]
617
+
618
+ def _update_stage(self, stage: PipelineStage):
619
+ """Update current processing stage."""
620
+ self.current_stage = stage
621
+
622
+ if self.config.stage_callback:
623
+ self.config.stage_callback(stage, {
624
+ 'timestamp': time.time(),
625
+ 'memory_usage': self.memory_monitor.get_usage()
626
+ })
627
+
628
+ if self.config.progress_callback:
629
+ progress = list(PipelineStage).index(stage) / len(PipelineStage)
630
+ self.config.progress_callback(progress, stage.value)
631
+
632
+ def _update_statistics(self, result: PipelineResult):
633
+ """Update processing statistics."""
634
+ if 'total_processed' not in self.processing_stats:
635
+ self.processing_stats['total_processed'] = 0
636
+ self.processing_stats['total_time'] = 0
637
+ self.processing_stats['avg_quality'] = 0
638
+
639
+ self.processing_stats['total_processed'] += 1
640
+ self.processing_stats['total_time'] += result.processing_time
641
+ self.processing_stats['avg_time'] = (
642
+ self.processing_stats['total_time'] /
643
+ self.processing_stats['total_processed']
644
+ )
645
+
646
+ # Update average quality
647
+ n = self.processing_stats['total_processed']
648
+ old_avg = self.processing_stats['avg_quality']
649
+ self.processing_stats['avg_quality'] = (
650
+ (old_avg * (n - 1) + result.quality_score) / n
651
+ )
652
+
653
+ def _fallback_processing(self, image: np.ndarray,
654
+ background: Optional[np.ndarray]) -> PipelineResult:
655
+ """Fallback processing when main pipeline fails."""
656
+ from ..processing.fallback import ProcessingFallback
657
+
658
+ result = PipelineResult(success=False)
659
+ fallback = ProcessingFallback()
660
+
661
+ try:
662
+ # Basic segmentation
663
+ mask = fallback.basic_segmentation(image)
664
+
665
+ # Basic matting
666
+ alpha = fallback.basic_matting(image, mask)
667
+
668
+ # Simple composite if background provided
669
+ if background is not None:
670
+ background = self._resize_background(background, image.shape[:2])
671
+ output = self.compositing_engine.composite(
672
+ image, background, alpha
673
+ )
674
+ else:
675
+ output = image
676
+
677
+ result.success = True
678
+ result.output_image = output
679
+ result.alpha_matte = alpha
680
+ result.metadata['fallback_used'] = True
681
+
682
+ except Exception as e:
683
+ self.logger.error(f"Fallback processing also failed: {e}")
684
+ result.errors.append(str(e))
685
+
686
+ return result
687
+
688
+ def process_batch(self, images: List[Union[np.ndarray, str, Path]],
689
+ background: Optional[Union[np.ndarray, str, Path]] = None,
690
+ **kwargs) -> List[PipelineResult]:
691
+ """
692
+ Process multiple images in batch.
693
+
694
+ Args:
695
+ images: List of input images
696
+ background: Optional background for all images
697
+ **kwargs: Additional processing parameters
698
+
699
+ Returns:
700
+ List of PipelineResults
701
+ """
702
+ results = []
703
+ total = len(images)
704
+
705
+ self.logger.info(f"Processing batch of {total} images")
706
+
707
+ # Process in parallel using thread pool
708
+ futures = []
709
+ for i, image in enumerate(images):
710
+ future = self.executor.submit(
711
+ self.process_image, image, background, **kwargs
712
+ )
713
+ futures.append(future)
714
+
715
+ # Collect results
716
+ for i, future in enumerate(futures):
717
+ try:
718
+ result = future.result(timeout=30)
719
+ results.append(result)
720
+
721
+ if self.config.progress_callback:
722
+ progress = (i + 1) / total
723
+ self.config.progress_callback(
724
+ progress,
725
+ f"Processed {i + 1}/{total}"
726
+ )
727
+
728
+ except Exception as e:
729
+ self.logger.error(f"Batch item {i} failed: {e}")
730
+ results.append(PipelineResult(
731
+ success=False,
732
+ errors=[str(e)]
733
+ ))
734
+
735
+ return results
736
+
737
+ def get_statistics(self) -> Dict[str, Any]:
738
+ """Get processing statistics."""
739
+ return {
740
+ **self.processing_stats,
741
+ 'cache_size': len(self.cache),
742
+ 'current_stage': self.current_stage.value,
743
+ 'is_processing': self.is_processing,
744
+ 'device': str(self.device_manager.get_device()),
745
+ 'model_type': self.config.model_type.value
746
+ }
747
+
748
+ def clear_cache(self):
749
+ """Clear the result cache."""
750
+ self.cache.clear()
751
+ self.logger.info("Cache cleared")
752
+
753
+ def shutdown(self):
754
+ """Shutdown the pipeline and cleanup resources."""
755
+ self.executor.shutdown(wait=True)
756
+ self.clear_cache()
757
+
758
+ # Cleanup models
759
+ if hasattr(self, 'model'):
760
+ del self.model
761
+ torch.cuda.empty_cache()
762
+
763
+ self.logger.info("Pipeline shutdown complete")