Spaces:
Runtime error
Runtime error
File size: 6,299 Bytes
c6eb9ce fb2f0a7 c6eb9ce fb2f0a7 c6eb9ce fb2f0a7 c6eb9ce fb2f0a7 c6eb9ce fb2f0a7 c6eb9ce fb2f0a7 c6eb9ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
import PIL.Image as Image
class ImageGenerator(ABC):
"""Abstract base class for text-to-image models."""
@abstractmethod
def generate(self, prompt: str, **kwargs) -> Image.Image:
"""Generate an image from a text prompt.
Args:
prompt: The text prompt to generate from
**kwargs: Additional model-specific parameters
Returns:
A PIL Image object
"""
pass
class ImageSimilarityMetric(ABC):
"""Abstract base class for image similarity metrics."""
@abstractmethod
def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float:
"""Compute similarity score between generated and target images.
Args:
generated_img: The generated image to evaluate
target_img: The target image to compare against
Returns:
Similarity score (higher means more similar)
"""
pass
class ImageEvaluator(ABC):
"""Abstract base class for evaluating image similarity."""
@abstractmethod
def generate_initial_prompt(self, target_img: Image.Image) -> str:
"""Generate initial prompt from target image using VLM.
Args:
target_img: The target image to analyze
Returns:
Initial prompt describing the target image
"""
pass
@abstractmethod
def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]:
"""Analyze differences between generated and target images using VLM.
Args:
generated_img: The generated image to analyze
target_img: The target image to compare against
Returns:
Dictionary containing analysis results (e.g. missing elements, style differences)
"""
pass
class PromptRefiner(ABC):
"""Abstract base class for prompt refinement strategies."""
@abstractmethod
def refine_prompt(self,
current_prompt: str,
analysis: Dict[str, Any],
similarity_score: float) -> str:
"""Refine the current prompt based on image analysis.
Args:
current_prompt: The current prompt PMT_i
analysis: Analysis results from ImageEvaluator
similarity_score: Current similarity score
Returns:
Refined prompt PMT_{i+1}
"""
pass
class PromptOptimizer:
"""Main class that orchestrates the prompt optimization process."""
def __init__(self,
image_generator: ImageGenerator,
evaluator: ImageEvaluator,
refiner: PromptRefiner,
similarity_metric: ImageSimilarityMetric,
max_iterations: int = 10,
similarity_threshold: float = 0.95):
"""Initialize the optimizer.
Args:
image_generator: Text-to-image generator to use
evaluator: Image evaluator for generating initial prompt and analysis
refiner: Prompt refinement strategy
similarity_metric: Image similarity metric
max_iterations: Maximum number of optimization iterations
similarity_threshold: Target similarity threshold for early stopping
"""
# Configuration
self.image_generator = image_generator
self.evaluator = evaluator
self.refiner = refiner
self.similarity_metric = similarity_metric
self.max_iterations = max_iterations
self.similarity_threshold = similarity_threshold
# Optimization state
self.target_img: Optional[Image.Image] = None
self.current_prompt: Optional[str] = None
self.iteration: int = 0
# Progress tracking
self.history: List[Dict[str, Any]] = []
def initialize(self, target_img: Image.Image) -> tuple[bool, str, Image.Image]:
"""Initialize the optimization process with a target image.
Args:
target_img: Target image to optimize towards
Returns:
Tuple of (is_completed, current_prompt, current_generated_image)
"""
self.target_img = target_img
self.current_prompt = self.evaluator.generate_initial_prompt(target_img)
self.iteration = 0
self.history = []
return self.step()
def step(self) -> tuple[bool, str, Image.Image]:
"""Perform one optimization step.
Returns:
Tuple of (is_completed, current_prompt, current_generated_image)
is_completed: True if optimization is complete (reached threshold or max iterations)
current_prompt: The current prompt
current_generated_image: The image generated from current prompt
"""
if self.target_img is None or self.current_prompt is None:
raise RuntimeError("Must call initialize() before step()")
if self.iteration >= self.max_iterations:
return True, self.current_prompt, self.image_generator.generate(self.current_prompt)
# Generate image with current prompt
generated_img = self.image_generator.generate(self.current_prompt)
# Evaluate similarity
similarity = self.similarity_metric.compute(generated_img, self.target_img)
# Analyze differences
analysis = self.evaluator.analyze_differences(generated_img, self.target_img)
# Track progress
self.history.append({
'iteration': self.iteration,
'prompt': self.current_prompt,
'similarity': similarity,
'analysis': analysis,
'image': generated_img
})
# Check if we've reached target similarity
is_completed = similarity >= self.similarity_threshold
if not is_completed:
# Refine prompt
self.current_prompt = self.refiner.refine_prompt(
self.current_prompt, analysis, similarity)
self.iteration += 1
return is_completed, self.current_prompt, generated_img |