File size: 7,762 Bytes
1282f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from typing import List, Dict, Any, Tuple
from PIL import Image
from weave_prompt import PromptOptimizer, ImageEvaluator, PromptRefiner, ImageSimilarityMetric
from image_generators import MultiModelFalImageGenerator


class MultiModelPromptOptimizer:
    """Sequential multi-model prompt optimizer that finds the best model-prompt combination."""
    
    def __init__(self,
                 image_generator: MultiModelFalImageGenerator,
                 evaluator: ImageEvaluator,
                 refiner: PromptRefiner,
                 similarity_metric: ImageSimilarityMetric,
                 max_iterations: int = 10,
                 similarity_threshold: float = 0.95):
        """Initialize the multi-model optimizer.
        
        Args:
            image_generator: Multi-model image generator
            evaluator: Image evaluator for generating initial prompt and analysis
            refiner: Prompt refinement strategy
            similarity_metric: Image similarity metric
            max_iterations: Maximum number of optimization iterations per model
            similarity_threshold: Target similarity threshold for early stopping
        """
        self.image_generator = image_generator
        self.evaluator = evaluator
        self.refiner = refiner
        self.similarity_metric = similarity_metric
        self.max_iterations = max_iterations
        self.similarity_threshold = similarity_threshold
        
        # Multi-model state
        self.target_img = None
        self.current_model_index = 0
        self.model_results = {}  # Results per model
        self.current_optimizer = None
        self.best_result = None
        
        # Initialize individual optimizers for each model
        self._create_current_optimizer()
    
    def _create_current_optimizer(self):
        """Create optimizer for the current model."""
        if self.current_model_index < len(self.image_generator.selected_models):
            # Set the image generator to current model
            self.image_generator.current_model_index = self.current_model_index
            
            # Create individual optimizer for current model
            self.current_optimizer = PromptOptimizer(
                image_generator=self.image_generator,
                evaluator=self.evaluator,
                refiner=self.refiner,
                similarity_metric=self.similarity_metric,
                max_iterations=self.max_iterations,
                similarity_threshold=self.similarity_threshold
            )
    
    def get_current_model_name(self) -> str:
        """Get the name of the currently active model."""
        # Ensure the image generator index is synchronized
        self.image_generator.current_model_index = self.current_model_index
        return self.image_generator.get_current_model_name()
    
    def get_progress_info(self) -> Dict[str, Any]:
        """Get current progress information."""
        total_models = len(self.image_generator.selected_models)
        current_model = self.current_model_index + 1
        
        info = {
            'current_model_index': self.current_model_index,
            'current_model_name': self.get_current_model_name(),
            'total_models': total_models,
            'models_completed': self.current_model_index,
            'overall_progress': self.current_model_index / total_models if total_models > 0 else 0,
            'is_last_model': self.current_model_index >= total_models - 1
        }
        
        if self.current_optimizer:
            info['current_iteration'] = len(self.current_optimizer.history)
            info['max_iterations'] = self.max_iterations
            info['model_progress'] = len(self.current_optimizer.history) / self.max_iterations
        
        return info
    
    def initialize(self, target_img: Image.Image) -> Tuple[bool, str, Image.Image]:
        """Initialize the multi-model optimization process.
        
        Args:
            target_img: Target image to optimize towards
        Returns:
            Tuple of (is_completed, current_prompt, current_generated_image)
        """
        self.target_img = target_img
        self.current_model_index = 0
        self.model_results = {}
        self.best_result = None
        
        # Reset image generator to first model
        self.image_generator.reset_to_first_model()
        self._create_current_optimizer()
        
        # Initialize first model
        return self.current_optimizer.initialize(target_img)
    
    def step(self) -> Tuple[bool, str, Image.Image]:
        """Perform one optimization step.
        
        Returns:
            Tuple of (is_completed, current_prompt, current_generated_image)
        """
        if not self.current_optimizer:
            raise RuntimeError("Must call initialize() before step()")
        
        # Step the current model optimizer
        is_model_completed, prompt, generated_image = self.current_optimizer.step()
        
        if is_model_completed:
            # Store results for current model - use data from history to ensure consistency
            model_name = self.get_current_model_name()
            
            if len(self.current_optimizer.history) > 0:
                # Use the last step from history as the final result (ensures consistency)
                last_step = self.current_optimizer.history[-1]
                final_prompt = last_step['prompt']
                final_image = last_step['image']
                final_similarity = last_step['similarity']
            else:
                # Fallback to step results if no history (shouldn't happen)
                final_prompt = prompt
                final_image = generated_image
                final_similarity = 0.0
            
            self.model_results[model_name] = {
                'final_prompt': final_prompt,
                'final_image': final_image,
                'final_similarity': final_similarity,
                'history': self.current_optimizer.history.copy(),
                'iterations': len(self.current_optimizer.history)
            }
            
            # Update best result if this is better
            if self.best_result is None or final_similarity > self.best_result['similarity']:
                self.best_result = {
                    'model_name': model_name,
                    'prompt': final_prompt,
                    'image': final_image,
                    'similarity': final_similarity
                }
            
            # Move to next model
            self.current_model_index += 1
            
            if self.current_model_index < len(self.image_generator.selected_models):
                # Initialize next model - ensure both indices are synchronized
                self.image_generator.current_model_index = self.current_model_index
                self._create_current_optimizer()
                return self.current_optimizer.initialize(self.target_img)
            else:
                # All models completed - return best result
                return True, self.best_result['prompt'], self.best_result['image']
        
        return is_model_completed, prompt, generated_image
    
    def get_all_results(self) -> Dict[str, Dict[str, Any]]:
        """Get results from all completed models."""
        return self.model_results.copy()
    
    def get_best_result(self) -> Dict[str, Any]:
        """Get the best result across all models."""
        return self.best_result.copy() if self.best_result else None
    
    @property
    def history(self):
        """Get history from current optimizer for compatibility."""
        if self.current_optimizer:
            return self.current_optimizer.history
        return []