File size: 6,299 Bytes
c6eb9ce
 
 
 
 
fb2f0a7
c6eb9ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb2f0a7
c6eb9ce
 
 
 
 
 
 
 
fb2f0a7
c6eb9ce
 
 
 
 
 
 
fb2f0a7
c6eb9ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb2f0a7
c6eb9ce
fb2f0a7
c6eb9ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
import PIL.Image as Image

class ImageGenerator(ABC):
    """Abstract base class for text-to-image models."""
    
    @abstractmethod
    def generate(self, prompt: str, **kwargs) -> Image.Image:
        """Generate an image from a text prompt.
        
        Args:
            prompt: The text prompt to generate from
            **kwargs: Additional model-specific parameters
            
        Returns:
            A PIL Image object
        """
        pass

class ImageSimilarityMetric(ABC):
    """Abstract base class for image similarity metrics."""
    @abstractmethod
    def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float:
        """Compute similarity score between generated and target images.
        Args:
            generated_img: The generated image to evaluate
            target_img: The target image to compare against
        Returns:
            Similarity score (higher means more similar)
        """
        pass

class ImageEvaluator(ABC):
    """Abstract base class for evaluating image similarity."""
    
    @abstractmethod
    def generate_initial_prompt(self, target_img: Image.Image) -> str:
        """Generate initial prompt from target image using VLM.
        
        Args:
            target_img: The target image to analyze
            
        Returns:
            Initial prompt describing the target image
        """
        pass
    
    
    @abstractmethod
    def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]:
        """Analyze differences between generated and target images using VLM.
        
        Args:
            generated_img: The generated image to analyze
            target_img: The target image to compare against
            
        Returns:
            Dictionary containing analysis results (e.g. missing elements, style differences)
        """
        pass

class PromptRefiner(ABC):
    """Abstract base class for prompt refinement strategies."""
    
    @abstractmethod
    def refine_prompt(self, 
                     current_prompt: str,
                     analysis: Dict[str, Any],
                     similarity_score: float) -> str:
        """Refine the current prompt based on image analysis.
        
        Args:
            current_prompt: The current prompt PMT_i
            analysis: Analysis results from ImageEvaluator
            similarity_score: Current similarity score
            
        Returns:
            Refined prompt PMT_{i+1}
        """
        pass

class PromptOptimizer:
    """Main class that orchestrates the prompt optimization process."""
    
    def __init__(self,
                 image_generator: ImageGenerator,
                 evaluator: ImageEvaluator,
                 refiner: PromptRefiner,
                 similarity_metric: ImageSimilarityMetric,
                 max_iterations: int = 10,
                 similarity_threshold: float = 0.95):
        """Initialize the optimizer.
        
        Args:
            image_generator: Text-to-image generator to use
            evaluator: Image evaluator for generating initial prompt and analysis
            refiner: Prompt refinement strategy
            similarity_metric: Image similarity metric
            max_iterations: Maximum number of optimization iterations
            similarity_threshold: Target similarity threshold for early stopping
        """
        # Configuration
        self.image_generator = image_generator
        self.evaluator = evaluator
        self.refiner = refiner
        self.similarity_metric = similarity_metric
        self.max_iterations = max_iterations
        self.similarity_threshold = similarity_threshold
        # Optimization state
        self.target_img: Optional[Image.Image] = None
        self.current_prompt: Optional[str] = None
        self.iteration: int = 0
        # Progress tracking
        self.history: List[Dict[str, Any]] = []

    def initialize(self, target_img: Image.Image) -> tuple[bool, str, Image.Image]:
        """Initialize the optimization process with a target image.
        
        Args:
            target_img: Target image to optimize towards
        Returns:
            Tuple of (is_completed, current_prompt, current_generated_image)
        """
        self.target_img = target_img
        self.current_prompt = self.evaluator.generate_initial_prompt(target_img)
        self.iteration = 0
        self.history = []
        return self.step()
    
    def step(self) -> tuple[bool, str, Image.Image]:
        """Perform one optimization step.
        
        Returns:
            Tuple of (is_completed, current_prompt, current_generated_image)
            is_completed: True if optimization is complete (reached threshold or max iterations)
            current_prompt: The current prompt
            current_generated_image: The image generated from current prompt
        """
        if self.target_img is None or self.current_prompt is None:
            raise RuntimeError("Must call initialize() before step()")
        if self.iteration >= self.max_iterations:
            return True, self.current_prompt, self.image_generator.generate(self.current_prompt)
        # Generate image with current prompt
        generated_img = self.image_generator.generate(self.current_prompt)
        # Evaluate similarity
        similarity = self.similarity_metric.compute(generated_img, self.target_img)
        # Analyze differences
        analysis = self.evaluator.analyze_differences(generated_img, self.target_img)
        # Track progress
        self.history.append({
            'iteration': self.iteration,
            'prompt': self.current_prompt,
            'similarity': similarity,
            'analysis': analysis,
            'image': generated_img
        })
        # Check if we've reached target similarity
        is_completed = similarity >= self.similarity_threshold
        if not is_completed:
            # Refine prompt
            self.current_prompt = self.refiner.refine_prompt(
                self.current_prompt, analysis, similarity)
            self.iteration += 1
        return is_completed, self.current_prompt, generated_img