WeavePrompt / multi_model_optimizer.py
kevin1kevin1k's picture
Upload folder using huggingface_hub
1282f37 verified
from typing import List, Dict, Any, Tuple
from PIL import Image
from weave_prompt import PromptOptimizer, ImageEvaluator, PromptRefiner, ImageSimilarityMetric
from image_generators import MultiModelFalImageGenerator
class MultiModelPromptOptimizer:
"""Sequential multi-model prompt optimizer that finds the best model-prompt combination."""
def __init__(self,
image_generator: MultiModelFalImageGenerator,
evaluator: ImageEvaluator,
refiner: PromptRefiner,
similarity_metric: ImageSimilarityMetric,
max_iterations: int = 10,
similarity_threshold: float = 0.95):
"""Initialize the multi-model optimizer.
Args:
image_generator: Multi-model image generator
evaluator: Image evaluator for generating initial prompt and analysis
refiner: Prompt refinement strategy
similarity_metric: Image similarity metric
max_iterations: Maximum number of optimization iterations per model
similarity_threshold: Target similarity threshold for early stopping
"""
self.image_generator = image_generator
self.evaluator = evaluator
self.refiner = refiner
self.similarity_metric = similarity_metric
self.max_iterations = max_iterations
self.similarity_threshold = similarity_threshold
# Multi-model state
self.target_img = None
self.current_model_index = 0
self.model_results = {} # Results per model
self.current_optimizer = None
self.best_result = None
# Initialize individual optimizers for each model
self._create_current_optimizer()
def _create_current_optimizer(self):
"""Create optimizer for the current model."""
if self.current_model_index < len(self.image_generator.selected_models):
# Set the image generator to current model
self.image_generator.current_model_index = self.current_model_index
# Create individual optimizer for current model
self.current_optimizer = PromptOptimizer(
image_generator=self.image_generator,
evaluator=self.evaluator,
refiner=self.refiner,
similarity_metric=self.similarity_metric,
max_iterations=self.max_iterations,
similarity_threshold=self.similarity_threshold
)
def get_current_model_name(self) -> str:
"""Get the name of the currently active model."""
# Ensure the image generator index is synchronized
self.image_generator.current_model_index = self.current_model_index
return self.image_generator.get_current_model_name()
def get_progress_info(self) -> Dict[str, Any]:
"""Get current progress information."""
total_models = len(self.image_generator.selected_models)
current_model = self.current_model_index + 1
info = {
'current_model_index': self.current_model_index,
'current_model_name': self.get_current_model_name(),
'total_models': total_models,
'models_completed': self.current_model_index,
'overall_progress': self.current_model_index / total_models if total_models > 0 else 0,
'is_last_model': self.current_model_index >= total_models - 1
}
if self.current_optimizer:
info['current_iteration'] = len(self.current_optimizer.history)
info['max_iterations'] = self.max_iterations
info['model_progress'] = len(self.current_optimizer.history) / self.max_iterations
return info
def initialize(self, target_img: Image.Image) -> Tuple[bool, str, Image.Image]:
"""Initialize the multi-model optimization process.
Args:
target_img: Target image to optimize towards
Returns:
Tuple of (is_completed, current_prompt, current_generated_image)
"""
self.target_img = target_img
self.current_model_index = 0
self.model_results = {}
self.best_result = None
# Reset image generator to first model
self.image_generator.reset_to_first_model()
self._create_current_optimizer()
# Initialize first model
return self.current_optimizer.initialize(target_img)
def step(self) -> Tuple[bool, str, Image.Image]:
"""Perform one optimization step.
Returns:
Tuple of (is_completed, current_prompt, current_generated_image)
"""
if not self.current_optimizer:
raise RuntimeError("Must call initialize() before step()")
# Step the current model optimizer
is_model_completed, prompt, generated_image = self.current_optimizer.step()
if is_model_completed:
# Store results for current model - use data from history to ensure consistency
model_name = self.get_current_model_name()
if len(self.current_optimizer.history) > 0:
# Use the last step from history as the final result (ensures consistency)
last_step = self.current_optimizer.history[-1]
final_prompt = last_step['prompt']
final_image = last_step['image']
final_similarity = last_step['similarity']
else:
# Fallback to step results if no history (shouldn't happen)
final_prompt = prompt
final_image = generated_image
final_similarity = 0.0
self.model_results[model_name] = {
'final_prompt': final_prompt,
'final_image': final_image,
'final_similarity': final_similarity,
'history': self.current_optimizer.history.copy(),
'iterations': len(self.current_optimizer.history)
}
# Update best result if this is better
if self.best_result is None or final_similarity > self.best_result['similarity']:
self.best_result = {
'model_name': model_name,
'prompt': final_prompt,
'image': final_image,
'similarity': final_similarity
}
# Move to next model
self.current_model_index += 1
if self.current_model_index < len(self.image_generator.selected_models):
# Initialize next model - ensure both indices are synchronized
self.image_generator.current_model_index = self.current_model_index
self._create_current_optimizer()
return self.current_optimizer.initialize(self.target_img)
else:
# All models completed - return best result
return True, self.best_result['prompt'], self.best_result['image']
return is_model_completed, prompt, generated_image
def get_all_results(self) -> Dict[str, Dict[str, Any]]:
"""Get results from all completed models."""
return self.model_results.copy()
def get_best_result(self) -> Dict[str, Any]:
"""Get the best result across all models."""
return self.best_result.copy() if self.best_result else None
@property
def history(self):
"""Get history from current optimizer for compatibility."""
if self.current_optimizer:
return self.current_optimizer.history
return []