Spaces:
Runtime error
Runtime error
| from typing import List, Dict, Any, Tuple | |
| from PIL import Image | |
| from weave_prompt import PromptOptimizer, ImageEvaluator, PromptRefiner, ImageSimilarityMetric | |
| from image_generators import MultiModelFalImageGenerator | |
| class MultiModelPromptOptimizer: | |
| """Sequential multi-model prompt optimizer that finds the best model-prompt combination.""" | |
| def __init__(self, | |
| image_generator: MultiModelFalImageGenerator, | |
| evaluator: ImageEvaluator, | |
| refiner: PromptRefiner, | |
| similarity_metric: ImageSimilarityMetric, | |
| max_iterations: int = 10, | |
| similarity_threshold: float = 0.95): | |
| """Initialize the multi-model optimizer. | |
| Args: | |
| image_generator: Multi-model image generator | |
| evaluator: Image evaluator for generating initial prompt and analysis | |
| refiner: Prompt refinement strategy | |
| similarity_metric: Image similarity metric | |
| max_iterations: Maximum number of optimization iterations per model | |
| similarity_threshold: Target similarity threshold for early stopping | |
| """ | |
| self.image_generator = image_generator | |
| self.evaluator = evaluator | |
| self.refiner = refiner | |
| self.similarity_metric = similarity_metric | |
| self.max_iterations = max_iterations | |
| self.similarity_threshold = similarity_threshold | |
| # Multi-model state | |
| self.target_img = None | |
| self.current_model_index = 0 | |
| self.model_results = {} # Results per model | |
| self.current_optimizer = None | |
| self.best_result = None | |
| # Initialize individual optimizers for each model | |
| self._create_current_optimizer() | |
| def _create_current_optimizer(self): | |
| """Create optimizer for the current model.""" | |
| if self.current_model_index < len(self.image_generator.selected_models): | |
| # Set the image generator to current model | |
| self.image_generator.current_model_index = self.current_model_index | |
| # Create individual optimizer for current model | |
| self.current_optimizer = PromptOptimizer( | |
| image_generator=self.image_generator, | |
| evaluator=self.evaluator, | |
| refiner=self.refiner, | |
| similarity_metric=self.similarity_metric, | |
| max_iterations=self.max_iterations, | |
| similarity_threshold=self.similarity_threshold | |
| ) | |
| def get_current_model_name(self) -> str: | |
| """Get the name of the currently active model.""" | |
| # Ensure the image generator index is synchronized | |
| self.image_generator.current_model_index = self.current_model_index | |
| return self.image_generator.get_current_model_name() | |
| def get_progress_info(self) -> Dict[str, Any]: | |
| """Get current progress information.""" | |
| total_models = len(self.image_generator.selected_models) | |
| current_model = self.current_model_index + 1 | |
| info = { | |
| 'current_model_index': self.current_model_index, | |
| 'current_model_name': self.get_current_model_name(), | |
| 'total_models': total_models, | |
| 'models_completed': self.current_model_index, | |
| 'overall_progress': self.current_model_index / total_models if total_models > 0 else 0, | |
| 'is_last_model': self.current_model_index >= total_models - 1 | |
| } | |
| if self.current_optimizer: | |
| info['current_iteration'] = len(self.current_optimizer.history) | |
| info['max_iterations'] = self.max_iterations | |
| info['model_progress'] = len(self.current_optimizer.history) / self.max_iterations | |
| return info | |
| def initialize(self, target_img: Image.Image) -> Tuple[bool, str, Image.Image]: | |
| """Initialize the multi-model optimization process. | |
| Args: | |
| target_img: Target image to optimize towards | |
| Returns: | |
| Tuple of (is_completed, current_prompt, current_generated_image) | |
| """ | |
| self.target_img = target_img | |
| self.current_model_index = 0 | |
| self.model_results = {} | |
| self.best_result = None | |
| # Reset image generator to first model | |
| self.image_generator.reset_to_first_model() | |
| self._create_current_optimizer() | |
| # Initialize first model | |
| return self.current_optimizer.initialize(target_img) | |
| def step(self) -> Tuple[bool, str, Image.Image]: | |
| """Perform one optimization step. | |
| Returns: | |
| Tuple of (is_completed, current_prompt, current_generated_image) | |
| """ | |
| if not self.current_optimizer: | |
| raise RuntimeError("Must call initialize() before step()") | |
| # Step the current model optimizer | |
| is_model_completed, prompt, generated_image = self.current_optimizer.step() | |
| if is_model_completed: | |
| # Store results for current model - use data from history to ensure consistency | |
| model_name = self.get_current_model_name() | |
| if len(self.current_optimizer.history) > 0: | |
| # Use the last step from history as the final result (ensures consistency) | |
| last_step = self.current_optimizer.history[-1] | |
| final_prompt = last_step['prompt'] | |
| final_image = last_step['image'] | |
| final_similarity = last_step['similarity'] | |
| else: | |
| # Fallback to step results if no history (shouldn't happen) | |
| final_prompt = prompt | |
| final_image = generated_image | |
| final_similarity = 0.0 | |
| self.model_results[model_name] = { | |
| 'final_prompt': final_prompt, | |
| 'final_image': final_image, | |
| 'final_similarity': final_similarity, | |
| 'history': self.current_optimizer.history.copy(), | |
| 'iterations': len(self.current_optimizer.history) | |
| } | |
| # Update best result if this is better | |
| if self.best_result is None or final_similarity > self.best_result['similarity']: | |
| self.best_result = { | |
| 'model_name': model_name, | |
| 'prompt': final_prompt, | |
| 'image': final_image, | |
| 'similarity': final_similarity | |
| } | |
| # Move to next model | |
| self.current_model_index += 1 | |
| if self.current_model_index < len(self.image_generator.selected_models): | |
| # Initialize next model - ensure both indices are synchronized | |
| self.image_generator.current_model_index = self.current_model_index | |
| self._create_current_optimizer() | |
| return self.current_optimizer.initialize(self.target_img) | |
| else: | |
| # All models completed - return best result | |
| return True, self.best_result['prompt'], self.best_result['image'] | |
| return is_model_completed, prompt, generated_image | |
| def get_all_results(self) -> Dict[str, Dict[str, Any]]: | |
| """Get results from all completed models.""" | |
| return self.model_results.copy() | |
| def get_best_result(self) -> Dict[str, Any]: | |
| """Get the best result across all models.""" | |
| return self.best_result.copy() if self.best_result else None | |
| def history(self): | |
| """Get history from current optimizer for compatibility.""" | |
| if self.current_optimizer: | |
| return self.current_optimizer.history | |
| return [] | |