Spaces:
Runtime error
Runtime error
File size: 7,762 Bytes
1282f37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from typing import List, Dict, Any, Tuple
from PIL import Image
from weave_prompt import PromptOptimizer, ImageEvaluator, PromptRefiner, ImageSimilarityMetric
from image_generators import MultiModelFalImageGenerator
class MultiModelPromptOptimizer:
"""Sequential multi-model prompt optimizer that finds the best model-prompt combination."""
def __init__(self,
image_generator: MultiModelFalImageGenerator,
evaluator: ImageEvaluator,
refiner: PromptRefiner,
similarity_metric: ImageSimilarityMetric,
max_iterations: int = 10,
similarity_threshold: float = 0.95):
"""Initialize the multi-model optimizer.
Args:
image_generator: Multi-model image generator
evaluator: Image evaluator for generating initial prompt and analysis
refiner: Prompt refinement strategy
similarity_metric: Image similarity metric
max_iterations: Maximum number of optimization iterations per model
similarity_threshold: Target similarity threshold for early stopping
"""
self.image_generator = image_generator
self.evaluator = evaluator
self.refiner = refiner
self.similarity_metric = similarity_metric
self.max_iterations = max_iterations
self.similarity_threshold = similarity_threshold
# Multi-model state
self.target_img = None
self.current_model_index = 0
self.model_results = {} # Results per model
self.current_optimizer = None
self.best_result = None
# Initialize individual optimizers for each model
self._create_current_optimizer()
def _create_current_optimizer(self):
"""Create optimizer for the current model."""
if self.current_model_index < len(self.image_generator.selected_models):
# Set the image generator to current model
self.image_generator.current_model_index = self.current_model_index
# Create individual optimizer for current model
self.current_optimizer = PromptOptimizer(
image_generator=self.image_generator,
evaluator=self.evaluator,
refiner=self.refiner,
similarity_metric=self.similarity_metric,
max_iterations=self.max_iterations,
similarity_threshold=self.similarity_threshold
)
def get_current_model_name(self) -> str:
"""Get the name of the currently active model."""
# Ensure the image generator index is synchronized
self.image_generator.current_model_index = self.current_model_index
return self.image_generator.get_current_model_name()
def get_progress_info(self) -> Dict[str, Any]:
"""Get current progress information."""
total_models = len(self.image_generator.selected_models)
current_model = self.current_model_index + 1
info = {
'current_model_index': self.current_model_index,
'current_model_name': self.get_current_model_name(),
'total_models': total_models,
'models_completed': self.current_model_index,
'overall_progress': self.current_model_index / total_models if total_models > 0 else 0,
'is_last_model': self.current_model_index >= total_models - 1
}
if self.current_optimizer:
info['current_iteration'] = len(self.current_optimizer.history)
info['max_iterations'] = self.max_iterations
info['model_progress'] = len(self.current_optimizer.history) / self.max_iterations
return info
def initialize(self, target_img: Image.Image) -> Tuple[bool, str, Image.Image]:
"""Initialize the multi-model optimization process.
Args:
target_img: Target image to optimize towards
Returns:
Tuple of (is_completed, current_prompt, current_generated_image)
"""
self.target_img = target_img
self.current_model_index = 0
self.model_results = {}
self.best_result = None
# Reset image generator to first model
self.image_generator.reset_to_first_model()
self._create_current_optimizer()
# Initialize first model
return self.current_optimizer.initialize(target_img)
def step(self) -> Tuple[bool, str, Image.Image]:
"""Perform one optimization step.
Returns:
Tuple of (is_completed, current_prompt, current_generated_image)
"""
if not self.current_optimizer:
raise RuntimeError("Must call initialize() before step()")
# Step the current model optimizer
is_model_completed, prompt, generated_image = self.current_optimizer.step()
if is_model_completed:
# Store results for current model - use data from history to ensure consistency
model_name = self.get_current_model_name()
if len(self.current_optimizer.history) > 0:
# Use the last step from history as the final result (ensures consistency)
last_step = self.current_optimizer.history[-1]
final_prompt = last_step['prompt']
final_image = last_step['image']
final_similarity = last_step['similarity']
else:
# Fallback to step results if no history (shouldn't happen)
final_prompt = prompt
final_image = generated_image
final_similarity = 0.0
self.model_results[model_name] = {
'final_prompt': final_prompt,
'final_image': final_image,
'final_similarity': final_similarity,
'history': self.current_optimizer.history.copy(),
'iterations': len(self.current_optimizer.history)
}
# Update best result if this is better
if self.best_result is None or final_similarity > self.best_result['similarity']:
self.best_result = {
'model_name': model_name,
'prompt': final_prompt,
'image': final_image,
'similarity': final_similarity
}
# Move to next model
self.current_model_index += 1
if self.current_model_index < len(self.image_generator.selected_models):
# Initialize next model - ensure both indices are synchronized
self.image_generator.current_model_index = self.current_model_index
self._create_current_optimizer()
return self.current_optimizer.initialize(self.target_img)
else:
# All models completed - return best result
return True, self.best_result['prompt'], self.best_result['image']
return is_model_completed, prompt, generated_image
def get_all_results(self) -> Dict[str, Dict[str, Any]]:
"""Get results from all completed models."""
return self.model_results.copy()
def get_best_result(self) -> Dict[str, Any]:
"""Get the best result across all models."""
return self.best_result.copy() if self.best_result else None
@property
def history(self):
"""Get history from current optimizer for compatibility."""
if self.current_optimizer:
return self.current_optimizer.history
return []
|