WeavePrompt / weave_prompt.py
kevin1kevin1k's picture
Upload folder using huggingface_hub
fb2f0a7 verified
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
import PIL.Image as Image
class ImageGenerator(ABC):
"""Abstract base class for text-to-image models."""
@abstractmethod
def generate(self, prompt: str, **kwargs) -> Image.Image:
"""Generate an image from a text prompt.
Args:
prompt: The text prompt to generate from
**kwargs: Additional model-specific parameters
Returns:
A PIL Image object
"""
pass
class ImageSimilarityMetric(ABC):
"""Abstract base class for image similarity metrics."""
@abstractmethod
def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float:
"""Compute similarity score between generated and target images.
Args:
generated_img: The generated image to evaluate
target_img: The target image to compare against
Returns:
Similarity score (higher means more similar)
"""
pass
class ImageEvaluator(ABC):
"""Abstract base class for evaluating image similarity."""
@abstractmethod
def generate_initial_prompt(self, target_img: Image.Image) -> str:
"""Generate initial prompt from target image using VLM.
Args:
target_img: The target image to analyze
Returns:
Initial prompt describing the target image
"""
pass
@abstractmethod
def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]:
"""Analyze differences between generated and target images using VLM.
Args:
generated_img: The generated image to analyze
target_img: The target image to compare against
Returns:
Dictionary containing analysis results (e.g. missing elements, style differences)
"""
pass
class PromptRefiner(ABC):
"""Abstract base class for prompt refinement strategies."""
@abstractmethod
def refine_prompt(self,
current_prompt: str,
analysis: Dict[str, Any],
similarity_score: float) -> str:
"""Refine the current prompt based on image analysis.
Args:
current_prompt: The current prompt PMT_i
analysis: Analysis results from ImageEvaluator
similarity_score: Current similarity score
Returns:
Refined prompt PMT_{i+1}
"""
pass
class PromptOptimizer:
"""Main class that orchestrates the prompt optimization process."""
def __init__(self,
image_generator: ImageGenerator,
evaluator: ImageEvaluator,
refiner: PromptRefiner,
similarity_metric: ImageSimilarityMetric,
max_iterations: int = 10,
similarity_threshold: float = 0.95):
"""Initialize the optimizer.
Args:
image_generator: Text-to-image generator to use
evaluator: Image evaluator for generating initial prompt and analysis
refiner: Prompt refinement strategy
similarity_metric: Image similarity metric
max_iterations: Maximum number of optimization iterations
similarity_threshold: Target similarity threshold for early stopping
"""
# Configuration
self.image_generator = image_generator
self.evaluator = evaluator
self.refiner = refiner
self.similarity_metric = similarity_metric
self.max_iterations = max_iterations
self.similarity_threshold = similarity_threshold
# Optimization state
self.target_img: Optional[Image.Image] = None
self.current_prompt: Optional[str] = None
self.iteration: int = 0
# Progress tracking
self.history: List[Dict[str, Any]] = []
def initialize(self, target_img: Image.Image) -> tuple[bool, str, Image.Image]:
"""Initialize the optimization process with a target image.
Args:
target_img: Target image to optimize towards
Returns:
Tuple of (is_completed, current_prompt, current_generated_image)
"""
self.target_img = target_img
self.current_prompt = self.evaluator.generate_initial_prompt(target_img)
self.iteration = 0
self.history = []
return self.step()
def step(self) -> tuple[bool, str, Image.Image]:
"""Perform one optimization step.
Returns:
Tuple of (is_completed, current_prompt, current_generated_image)
is_completed: True if optimization is complete (reached threshold or max iterations)
current_prompt: The current prompt
current_generated_image: The image generated from current prompt
"""
if self.target_img is None or self.current_prompt is None:
raise RuntimeError("Must call initialize() before step()")
if self.iteration >= self.max_iterations:
return True, self.current_prompt, self.image_generator.generate(self.current_prompt)
# Generate image with current prompt
generated_img = self.image_generator.generate(self.current_prompt)
# Evaluate similarity
similarity = self.similarity_metric.compute(generated_img, self.target_img)
# Analyze differences
analysis = self.evaluator.analyze_differences(generated_img, self.target_img)
# Track progress
self.history.append({
'iteration': self.iteration,
'prompt': self.current_prompt,
'similarity': similarity,
'analysis': analysis,
'image': generated_img
})
# Check if we've reached target similarity
is_completed = similarity >= self.similarity_threshold
if not is_completed:
# Refine prompt
self.current_prompt = self.refiner.refine_prompt(
self.current_prompt, analysis, similarity)
self.iteration += 1
return is_completed, self.current_prompt, generated_img