Spaces:
Runtime error
Runtime error
| import logging | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| class ContextualWeightOverrideAgent: | |
| def __init__(self): | |
| self.context_overrides = { | |
| # Example: when image is outdoor, model_X is penalized, model_Y is boosted | |
| "outdoor": { | |
| "model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes | |
| "model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes | |
| }, | |
| "low_light": { | |
| "model_2": 0.7, | |
| "model_7": 1.3, | |
| }, | |
| "sunny": { | |
| "model_3": 0.9, | |
| "model_4": 1.1, | |
| } | |
| # Add more contexts and their specific model weight adjustments here | |
| } | |
| def get_overrides(self, context_tags: list[str]) -> dict: | |
| """Returns combined weight overrides for given context tags.""" | |
| combined_overrides = {} | |
| for tag in context_tags: | |
| if tag in self.context_overrides: | |
| for model_id, multiplier in self.context_overrides[tag].items(): | |
| # If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max) | |
| # For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect. | |
| combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier | |
| return combined_overrides | |
| class ModelWeightManager: | |
| def __init__(self): | |
| self.base_weights = { | |
| "model_1": 0.15, # SwinV2 Based | |
| "model_2": 0.15, # ViT Based | |
| "model_3": 0.15, # SDXL Dataset | |
| "model_4": 0.15, # SDXL + FLUX | |
| "model_5": 0.15, # ViT Based | |
| "model_5b": 0.10, # ViT Based, Newer Dataset | |
| "model_6": 0.10, # Swin, Midj + SDXL | |
| "model_7": 0.05 # ViT | |
| } | |
| self.situation_weights = { | |
| "high_confidence": 1.2, # Boost weights for high confidence predictions | |
| "low_confidence": 0.8, # Reduce weights for low confidence | |
| "conflict": 0.5, # Reduce weights when models disagree | |
| "consensus": 1.5 # Boost weights when models agree | |
| } | |
| self.context_override_agent = ContextualWeightOverrideAgent() | |
| def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None): | |
| """Dynamically adjust weights based on prediction patterns and optional context.""" | |
| adjusted_weights = self.base_weights.copy() | |
| # 1. Apply contextual overrides first | |
| if context_tags: | |
| overrides = self.context_override_agent.get_overrides(context_tags) | |
| for model_id, multiplier in overrides.items(): | |
| adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier | |
| # 2. Apply situation-based adjustments (consensus, conflict, confidence) | |
| # Check for consensus | |
| if self._has_consensus(predictions): | |
| for model in adjusted_weights: | |
| adjusted_weights[model] *= self.situation_weights["consensus"] | |
| # Check for conflicts | |
| if self._has_conflicts(predictions): | |
| for model in adjusted_weights: | |
| adjusted_weights[model] *= self.situation_weights["conflict"] | |
| # Adjust based on confidence | |
| for model, confidence in confidence_scores.items(): | |
| if confidence > 0.8: | |
| adjusted_weights[model] *= self.situation_weights["high_confidence"] | |
| elif confidence < 0.5: | |
| adjusted_weights[model] *= self.situation_weights["low_confidence"] | |
| return self._normalize_weights(adjusted_weights) | |
| def _has_consensus(self, predictions): | |
| """Check if models agree on prediction""" | |
| # Ensure all predictions are not None before checking for consensus | |
| non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"] | |
| return len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1 | |
| def _has_conflicts(self, predictions): | |
| """Check if models have conflicting predictions""" | |
| # Ensure all predictions are not None before checking for conflicts | |
| non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"] | |
| return len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1 | |
| def _normalize_weights(self, weights): | |
| """Normalize weights to sum to 1""" | |
| total = sum(weights.values()) | |
| if total == 0: | |
| # Handle case where all weights became zero due to aggressive multipliers | |
| # This could assign equal weights or revert to base weights | |
| logger.warning("All weights became zero after adjustments. Reverting to base weights.") | |
| return {k: 1.0/len(self.base_weights) for k in self.base_weights} | |
| return {k: v/total for k, v in weights.items()} |