from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union import PIL.Image as Image class ImageGenerator(ABC): """Abstract base class for text-to-image models.""" @abstractmethod def generate(self, prompt: str, **kwargs) -> Image.Image: """Generate an image from a text prompt. Args: prompt: The text prompt to generate from **kwargs: Additional model-specific parameters Returns: A PIL Image object """ pass class ImageSimilarityMetric(ABC): """Abstract base class for image similarity metrics.""" @abstractmethod def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float: """Compute similarity score between generated and target images. Args: generated_img: The generated image to evaluate target_img: The target image to compare against Returns: Similarity score (higher means more similar) """ pass class ImageEvaluator(ABC): """Abstract base class for evaluating image similarity.""" @abstractmethod def generate_initial_prompt(self, target_img: Image.Image) -> str: """Generate initial prompt from target image using VLM. Args: target_img: The target image to analyze Returns: Initial prompt describing the target image """ pass @abstractmethod def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]: """Analyze differences between generated and target images using VLM. Args: generated_img: The generated image to analyze target_img: The target image to compare against Returns: Dictionary containing analysis results (e.g. missing elements, style differences) """ pass class PromptRefiner(ABC): """Abstract base class for prompt refinement strategies.""" @abstractmethod def refine_prompt(self, current_prompt: str, analysis: Dict[str, Any], similarity_score: float) -> str: """Refine the current prompt based on image analysis. Args: current_prompt: The current prompt PMT_i analysis: Analysis results from ImageEvaluator similarity_score: Current similarity score Returns: Refined prompt PMT_{i+1} """ pass class PromptOptimizer: """Main class that orchestrates the prompt optimization process.""" def __init__(self, image_generator: ImageGenerator, evaluator: ImageEvaluator, refiner: PromptRefiner, similarity_metric: ImageSimilarityMetric, max_iterations: int = 10, similarity_threshold: float = 0.95): """Initialize the optimizer. Args: image_generator: Text-to-image generator to use evaluator: Image evaluator for generating initial prompt and analysis refiner: Prompt refinement strategy similarity_metric: Image similarity metric max_iterations: Maximum number of optimization iterations similarity_threshold: Target similarity threshold for early stopping """ # Configuration self.image_generator = image_generator self.evaluator = evaluator self.refiner = refiner self.similarity_metric = similarity_metric self.max_iterations = max_iterations self.similarity_threshold = similarity_threshold # Optimization state self.target_img: Optional[Image.Image] = None self.current_prompt: Optional[str] = None self.iteration: int = 0 # Progress tracking self.history: List[Dict[str, Any]] = [] def initialize(self, target_img: Image.Image) -> tuple[bool, str, Image.Image]: """Initialize the optimization process with a target image. Args: target_img: Target image to optimize towards Returns: Tuple of (is_completed, current_prompt, current_generated_image) """ self.target_img = target_img self.current_prompt = self.evaluator.generate_initial_prompt(target_img) self.iteration = 0 self.history = [] return self.step() def step(self) -> tuple[bool, str, Image.Image]: """Perform one optimization step. Returns: Tuple of (is_completed, current_prompt, current_generated_image) is_completed: True if optimization is complete (reached threshold or max iterations) current_prompt: The current prompt current_generated_image: The image generated from current prompt """ if self.target_img is None or self.current_prompt is None: raise RuntimeError("Must call initialize() before step()") if self.iteration >= self.max_iterations: return True, self.current_prompt, self.image_generator.generate(self.current_prompt) # Generate image with current prompt generated_img = self.image_generator.generate(self.current_prompt) # Evaluate similarity similarity = self.similarity_metric.compute(generated_img, self.target_img) # Analyze differences analysis = self.evaluator.analyze_differences(generated_img, self.target_img) # Track progress self.history.append({ 'iteration': self.iteration, 'prompt': self.current_prompt, 'similarity': similarity, 'analysis': analysis, 'image': generated_img }) # Check if we've reached target similarity is_completed = similarity >= self.similarity_threshold if not is_completed: # Refine prompt self.current_prompt = self.refiner.refine_prompt( self.current_prompt, analysis, similarity) self.iteration += 1 return is_completed, self.current_prompt, generated_img