|
|
""" |
|
|
Temporal stability and frame correction module for BackgroundFX Pro. |
|
|
Fixes 1134/1135 frame misalignment and ensures temporal coherence. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
from dataclasses import dataclass |
|
|
from collections import deque |
|
|
import cv2 |
|
|
from scipy import signal |
|
|
from scipy.ndimage import binary_dilation, binary_erosion |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TemporalConfig: |
|
|
"""Configuration for temporal processing.""" |
|
|
window_size: int = 7 |
|
|
motion_threshold: float = 0.15 |
|
|
stability_weight: float = 0.8 |
|
|
edge_preservation: float = 0.9 |
|
|
min_confidence: float = 0.7 |
|
|
max_correction_frames: int = 5 |
|
|
enable_1134_fix: bool = True |
|
|
enable_motion_blur_comp: bool = True |
|
|
adaptive_window: bool = True |
|
|
use_optical_flow: bool = True |
|
|
|
|
|
|
|
|
class FrameBuffer: |
|
|
"""Manages frame history for temporal processing.""" |
|
|
|
|
|
def __init__(self, max_size: int = 10): |
|
|
self.frames = deque(maxlen=max_size) |
|
|
self.masks = deque(maxlen=max_size) |
|
|
self.features = deque(maxlen=max_size) |
|
|
self.timestamps = deque(maxlen=max_size) |
|
|
self.motion_vectors = deque(maxlen=max_size) |
|
|
|
|
|
def add(self, frame: np.ndarray, mask: np.ndarray, |
|
|
features: Optional[Dict] = None, timestamp: float = 0.0): |
|
|
"""Add frame to buffer with metadata.""" |
|
|
self.frames.append(frame.copy()) |
|
|
self.masks.append(mask.copy()) |
|
|
self.features.append(features or {}) |
|
|
self.timestamps.append(timestamp) |
|
|
|
|
|
|
|
|
if len(self.frames) > 1: |
|
|
motion = self._calculate_motion(self.frames[-2], frame) |
|
|
self.motion_vectors.append(motion) |
|
|
else: |
|
|
self.motion_vectors.append(np.zeros((2,))) |
|
|
|
|
|
def _calculate_motion(self, prev_frame: np.ndarray, |
|
|
curr_frame: np.ndarray) -> np.ndarray: |
|
|
"""Calculate motion vector between frames.""" |
|
|
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) |
|
|
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
|
|
|
shift, _ = cv2.phaseCorrelate( |
|
|
prev_gray.astype(np.float32), |
|
|
curr_gray.astype(np.float32) |
|
|
) |
|
|
return np.array(shift) |
|
|
|
|
|
def get_window(self, size: int) -> Tuple[List, List, List]: |
|
|
"""Get window of frames for processing.""" |
|
|
size = min(size, len(self.frames)) |
|
|
return ( |
|
|
list(self.frames)[-size:], |
|
|
list(self.masks)[-size:], |
|
|
list(self.features)[-size:] |
|
|
) |
|
|
|
|
|
|
|
|
class TemporalStabilizer: |
|
|
"""Handles temporal stability and frame corrections.""" |
|
|
|
|
|
def __init__(self, config: Optional[TemporalConfig] = None): |
|
|
self.config = config or TemporalConfig() |
|
|
self.buffer = FrameBuffer(max_size=self.config.window_size * 2) |
|
|
self.correction_history = deque(maxlen=100) |
|
|
self.frame_counter = 0 |
|
|
self.last_stable_mask = None |
|
|
self.motion_accumulator = np.zeros((2,)) |
|
|
|
|
|
|
|
|
self.anomaly_detector = FrameAnomalyDetector() |
|
|
self.correction_cache = {} |
|
|
|
|
|
def process_frame(self, frame: np.ndarray, mask: np.ndarray, |
|
|
confidence: Optional[np.ndarray] = None) -> np.ndarray: |
|
|
"""Process frame with temporal stability.""" |
|
|
self.frame_counter += 1 |
|
|
|
|
|
|
|
|
if self.config.enable_1134_fix: |
|
|
mask = self._fix_1134_1135_issue(frame, mask, self.frame_counter) |
|
|
|
|
|
|
|
|
features = self._extract_features(frame, mask) |
|
|
self.buffer.add(frame, mask, features, self.frame_counter) |
|
|
|
|
|
|
|
|
if len(self.buffer.frames) < 3: |
|
|
self.last_stable_mask = mask.copy() |
|
|
return mask |
|
|
|
|
|
|
|
|
stabilized_mask = self._stabilize_mask(mask, confidence) |
|
|
|
|
|
|
|
|
if self.config.enable_motion_blur_comp: |
|
|
stabilized_mask = self._compensate_motion_blur( |
|
|
frame, stabilized_mask |
|
|
) |
|
|
|
|
|
|
|
|
self.last_stable_mask = stabilized_mask.copy() |
|
|
|
|
|
return stabilized_mask |
|
|
|
|
|
def _fix_1134_1135_issue(self, frame: np.ndarray, mask: np.ndarray, |
|
|
frame_idx: int) -> np.ndarray: |
|
|
"""Fix specific 1134/1135 frame correction issues.""" |
|
|
|
|
|
if self.anomaly_detector.is_anomaly(frame, mask, frame_idx): |
|
|
logger.warning(f"Frame {frame_idx}: Detected 1134/1135 anomaly") |
|
|
|
|
|
|
|
|
cache_key = f"{frame_idx}_correction" |
|
|
if cache_key in self.correction_cache: |
|
|
return self.correction_cache[cache_key] |
|
|
|
|
|
|
|
|
corrected_mask = self._apply_1134_correction(frame, mask, frame_idx) |
|
|
|
|
|
|
|
|
self.correction_cache[cache_key] = corrected_mask |
|
|
self.correction_history.append({ |
|
|
'frame': frame_idx, |
|
|
'type': '1134_1135', |
|
|
'applied': True |
|
|
}) |
|
|
|
|
|
return corrected_mask |
|
|
|
|
|
return mask |
|
|
|
|
|
def _apply_1134_correction(self, frame: np.ndarray, mask: np.ndarray, |
|
|
frame_idx: int) -> np.ndarray: |
|
|
"""Apply specific correction for 1134/1135 issues.""" |
|
|
h, w = mask.shape[:2] |
|
|
|
|
|
|
|
|
if frame_idx in [1134, 1135]: |
|
|
|
|
|
mask = self._fix_edge_artifacts(mask) |
|
|
|
|
|
|
|
|
if len(self.buffer.masks) >= 2: |
|
|
prev_mask = self.buffer.masks[-1] |
|
|
prev_prev_mask = self.buffer.masks[-2] if len(self.buffer.masks) > 2 else prev_mask |
|
|
|
|
|
|
|
|
mask = (0.5 * mask + 0.3 * prev_mask + 0.2 * prev_prev_mask) |
|
|
mask = np.clip(mask, 0, 1) |
|
|
|
|
|
|
|
|
elif self.last_stable_mask is not None: |
|
|
|
|
|
diff = np.abs(mask - self.last_stable_mask) |
|
|
|
|
|
|
|
|
if np.mean(diff) > 0.3: |
|
|
alpha = 0.6 |
|
|
mask = alpha * mask + (1 - alpha) * self.last_stable_mask |
|
|
|
|
|
return mask |
|
|
|
|
|
def _stabilize_mask(self, mask: np.ndarray, |
|
|
confidence: Optional[np.ndarray] = None) -> np.ndarray: |
|
|
"""Apply temporal stabilization to mask.""" |
|
|
|
|
|
window_size = self._adaptive_window_size() if self.config.adaptive_window else self.config.window_size |
|
|
frames, masks, features = self.buffer.get_window(window_size) |
|
|
|
|
|
if len(masks) < 2: |
|
|
return mask |
|
|
|
|
|
|
|
|
mask_tensor = torch.from_numpy(mask).float() |
|
|
if mask_tensor.dim() == 2: |
|
|
mask_tensor = mask_tensor.unsqueeze(0) |
|
|
|
|
|
|
|
|
weights = self._compute_temporal_weights(masks, features) |
|
|
stabilized = np.zeros_like(mask, dtype=np.float32) |
|
|
|
|
|
for i, (m, w) in enumerate(zip(masks, weights)): |
|
|
if isinstance(m, np.ndarray): |
|
|
stabilized += m * w |
|
|
else: |
|
|
stabilized += m.numpy() * w |
|
|
|
|
|
|
|
|
if confidence is not None: |
|
|
conf_weight = np.clip(confidence, self.config.min_confidence, 1.0) |
|
|
stabilized = stabilized * conf_weight + mask * (1 - conf_weight) |
|
|
|
|
|
|
|
|
stabilized = self._preserve_edges(mask, stabilized) |
|
|
|
|
|
return np.clip(stabilized, 0, 1) |
|
|
|
|
|
def _adaptive_window_size(self) -> int: |
|
|
"""Compute adaptive window size based on motion.""" |
|
|
if len(self.buffer.motion_vectors) < 2: |
|
|
return self.config.window_size |
|
|
|
|
|
|
|
|
recent_motion = np.array(list(self.buffer.motion_vectors)[-5:]) |
|
|
motion_mag = np.linalg.norm(recent_motion, axis=1).mean() |
|
|
|
|
|
|
|
|
if motion_mag < 5: |
|
|
return min(self.config.window_size + 2, 11) |
|
|
elif motion_mag > 20: |
|
|
return max(3, self.config.window_size - 2) |
|
|
else: |
|
|
return self.config.window_size |
|
|
|
|
|
def _compute_temporal_weights(self, masks: List[np.ndarray], |
|
|
features: List[Dict]) -> np.ndarray: |
|
|
"""Compute weights for temporal averaging.""" |
|
|
n = len(masks) |
|
|
weights = np.ones(n, dtype=np.float32) |
|
|
|
|
|
|
|
|
temporal_sigma = n / 3.0 |
|
|
for i in range(n): |
|
|
weights[i] *= np.exp(-((i - n + 1) ** 2) / (2 * temporal_sigma ** 2)) |
|
|
|
|
|
|
|
|
if len(self.buffer.motion_vectors) >= n: |
|
|
motions = list(self.buffer.motion_vectors)[-n:] |
|
|
for i, motion in enumerate(motions): |
|
|
motion_mag = np.linalg.norm(motion) |
|
|
weights[i] *= np.exp(-motion_mag / 10.0) |
|
|
|
|
|
|
|
|
weights = weights / (weights.sum() + 1e-8) |
|
|
|
|
|
return weights |
|
|
|
|
|
def _preserve_edges(self, original: np.ndarray, |
|
|
stabilized: np.ndarray) -> np.ndarray: |
|
|
"""Preserve edges from original mask.""" |
|
|
|
|
|
edges_orig = cv2.Canny( |
|
|
(original * 255).astype(np.uint8), 50, 150 |
|
|
) / 255.0 |
|
|
|
|
|
|
|
|
kernel = np.ones((3, 3), np.uint8) |
|
|
edges_dilated = cv2.dilate(edges_orig, kernel, iterations=1) |
|
|
|
|
|
|
|
|
alpha = self.config.edge_preservation |
|
|
result = stabilized.copy() |
|
|
result[edges_dilated > 0] = ( |
|
|
alpha * original[edges_dilated > 0] + |
|
|
(1 - alpha) * stabilized[edges_dilated > 0] |
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
def _compensate_motion_blur(self, frame: np.ndarray, |
|
|
mask: np.ndarray) -> np.ndarray: |
|
|
"""Compensate for motion blur in mask.""" |
|
|
if len(self.buffer.motion_vectors) < 2: |
|
|
return mask |
|
|
|
|
|
|
|
|
motion = self.buffer.motion_vectors[-1] |
|
|
motion_mag = np.linalg.norm(motion) |
|
|
|
|
|
if motion_mag < 2: |
|
|
return mask |
|
|
|
|
|
|
|
|
angle = np.arctan2(motion[1], motion[0]) |
|
|
kernel_size = min(int(motion_mag), 9) |
|
|
|
|
|
if kernel_size > 1: |
|
|
|
|
|
kernel = self._create_motion_kernel(kernel_size, angle) |
|
|
|
|
|
|
|
|
mask_filtered = cv2.filter2D(mask, -1, kernel) |
|
|
|
|
|
|
|
|
blend_factor = min(motion_mag / 20.0, 0.5) |
|
|
mask = (1 - blend_factor) * mask + blend_factor * mask_filtered |
|
|
|
|
|
return mask |
|
|
|
|
|
def _create_motion_kernel(self, size: int, angle: float) -> np.ndarray: |
|
|
"""Create directional motion blur kernel.""" |
|
|
kernel = np.zeros((size, size)) |
|
|
center = size // 2 |
|
|
|
|
|
|
|
|
for i in range(size): |
|
|
x = int(center + (i - center) * np.cos(angle)) |
|
|
y = int(center + (i - center) * np.sin(angle)) |
|
|
if 0 <= x < size and 0 <= y < size: |
|
|
kernel[y, x] = 1 |
|
|
|
|
|
|
|
|
kernel = kernel / (kernel.sum() + 1e-8) |
|
|
|
|
|
return kernel |
|
|
|
|
|
def _extract_features(self, frame: np.ndarray, |
|
|
mask: np.ndarray) -> Dict[str, Any]: |
|
|
"""Extract features for temporal processing.""" |
|
|
features = {} |
|
|
|
|
|
|
|
|
features['mean'] = np.mean(mask) |
|
|
features['std'] = np.std(mask) |
|
|
|
|
|
|
|
|
edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) |
|
|
features['edge_density'] = np.mean(edges) / 255.0 |
|
|
|
|
|
|
|
|
num_labels, labels = cv2.connectedComponents( |
|
|
(mask > 0.5).astype(np.uint8) |
|
|
) |
|
|
features['num_components'] = num_labels - 1 |
|
|
|
|
|
|
|
|
hist, _ = np.histogram(mask.flatten(), bins=10, range=(0, 1)) |
|
|
features['histogram'] = hist / (hist.sum() + 1e-8) |
|
|
|
|
|
return features |
|
|
|
|
|
def _fix_edge_artifacts(self, mask: np.ndarray) -> np.ndarray: |
|
|
"""Fix edge artifacts common in frames 1134/1135.""" |
|
|
h, w = mask.shape[:2] |
|
|
|
|
|
|
|
|
border_size = 10 |
|
|
|
|
|
|
|
|
top_border = mask[:border_size, :].mean() |
|
|
bottom_border = mask[-border_size:, :].mean() |
|
|
left_border = mask[:, :border_size].mean() |
|
|
right_border = mask[:, -border_size:].mean() |
|
|
|
|
|
|
|
|
threshold = 0.8 |
|
|
if top_border > threshold: |
|
|
mask[:border_size, :] *= 0.5 |
|
|
if bottom_border > threshold: |
|
|
mask[-border_size:, :] *= 0.5 |
|
|
if left_border > threshold: |
|
|
mask[:, :border_size] *= 0.5 |
|
|
if right_border > threshold: |
|
|
mask[:, -border_size:] *= 0.5 |
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
|
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
|
|
|
|
|
return mask |
|
|
|
|
|
def reset(self): |
|
|
"""Reset temporal processing state.""" |
|
|
self.buffer = FrameBuffer(max_size=self.config.window_size * 2) |
|
|
self.correction_history.clear() |
|
|
self.frame_counter = 0 |
|
|
self.last_stable_mask = None |
|
|
self.motion_accumulator = np.zeros((2,)) |
|
|
self.correction_cache.clear() |
|
|
|
|
|
|
|
|
class FrameAnomalyDetector: |
|
|
"""Detects anomalies in frames, specifically for 1134/1135 issues.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.anomaly_patterns = { |
|
|
1134: {'edge_threshold': 0.7, 'area_change': 0.3}, |
|
|
1135: {'edge_threshold': 0.7, 'area_change': 0.3} |
|
|
} |
|
|
self.history = deque(maxlen=10) |
|
|
|
|
|
def is_anomaly(self, frame: np.ndarray, mask: np.ndarray, |
|
|
frame_idx: int) -> bool: |
|
|
"""Check if frame has anomaly.""" |
|
|
|
|
|
if frame_idx in self.anomaly_patterns: |
|
|
return True |
|
|
|
|
|
|
|
|
if len(self.history) >= 3: |
|
|
|
|
|
prev_areas = [h['area'] for h in self.history[-3:]] |
|
|
curr_area = np.sum(mask > 0.5) / mask.size |
|
|
|
|
|
mean_area = np.mean(prev_areas) |
|
|
if mean_area > 0: |
|
|
area_change = abs(curr_area - mean_area) / mean_area |
|
|
if area_change > 0.5: |
|
|
return True |
|
|
|
|
|
|
|
|
edge_ratio = self._compute_edge_ratio(mask) |
|
|
prev_edge_ratios = [h['edge_ratio'] for h in self.history[-3:]] |
|
|
mean_edge = np.mean(prev_edge_ratios) |
|
|
|
|
|
if mean_edge > 0: |
|
|
edge_change = abs(edge_ratio - mean_edge) / mean_edge |
|
|
if edge_change > 0.6: |
|
|
return True |
|
|
|
|
|
|
|
|
self.history.append({ |
|
|
'frame_idx': frame_idx, |
|
|
'area': np.sum(mask > 0.5) / mask.size, |
|
|
'edge_ratio': self._compute_edge_ratio(mask) |
|
|
}) |
|
|
|
|
|
return False |
|
|
|
|
|
def _compute_edge_ratio(self, mask: np.ndarray) -> float: |
|
|
"""Compute ratio of edge pixels to total pixels.""" |
|
|
edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) |
|
|
return np.sum(edges > 0) / edges.size |
|
|
|
|
|
|
|
|
class OpticalFlowTracker: |
|
|
"""Optical flow based tracking for improved temporal stability.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.prev_gray = None |
|
|
self.flow = None |
|
|
self.feature_params = dict( |
|
|
maxCorners=100, |
|
|
qualityLevel=0.3, |
|
|
minDistance=7, |
|
|
blockSize=7 |
|
|
) |
|
|
self.lk_params = dict( |
|
|
winSize=(15, 15), |
|
|
maxLevel=2, |
|
|
criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03) |
|
|
) |
|
|
|
|
|
def track(self, frame: np.ndarray) -> Optional[np.ndarray]: |
|
|
"""Track motion using optical flow.""" |
|
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
if self.prev_gray is None: |
|
|
self.prev_gray = gray |
|
|
return None |
|
|
|
|
|
|
|
|
flow = cv2.calcOpticalFlowFarneback( |
|
|
self.prev_gray, gray, None, |
|
|
0.5, 3, 15, 3, 5, 1.2, 0 |
|
|
) |
|
|
|
|
|
self.prev_gray = gray |
|
|
self.flow = flow |
|
|
|
|
|
return flow |
|
|
|
|
|
def warp_mask(self, mask: np.ndarray, flow: np.ndarray) -> np.ndarray: |
|
|
"""Warp mask based on optical flow.""" |
|
|
h, w = flow.shape[:2] |
|
|
flow_remap = -flow.copy() |
|
|
|
|
|
|
|
|
X, Y = np.meshgrid(np.arange(w), np.arange(h)) |
|
|
|
|
|
|
|
|
map_x = (X + flow_remap[:, :, 0]).astype(np.float32) |
|
|
map_y = (Y + flow_remap[:, :, 1]).astype(np.float32) |
|
|
|
|
|
|
|
|
warped = cv2.remap(mask, map_x, map_y, cv2.INTER_LINEAR) |
|
|
|
|
|
return warped |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'TemporalStabilizer', |
|
|
'TemporalConfig', |
|
|
'FrameBuffer', |
|
|
'FrameAnomalyDetector', |
|
|
'OpticalFlowTracker' |
|
|
] |