WeavePrompt / image_generators.py
kevin1kevin1k's picture
Upload folder using huggingface_hub
1282f37 verified
import fal_client
from PIL import Image
from typing import Dict, Any
import requests
from io import BytesIO
from weave_prompt import ImageGenerator
from typing import List, Tuple
from dotenv import load_dotenv
load_dotenv()
# Available fal.ai models for text-to-image generation
AVAILABLE_MODELS = {
# FLUX Models
"FLUX.1 [pro]": "fal-ai/flux-pro",
"FLUX.1 [dev]": "fal-ai/flux/dev",
"FLUX.1 [schnell]": "fal-ai/flux/schnell",
"FLUX.1 with LoRAs": "fal-ai/flux-lora",
# Google Models
"Imagen 4": "fal-ai/imagen4/preview",
"Imagen 4 Ultra": "fal-ai/imagen4/preview/ultra",
"Gemini 2.5 Flash Image": "fal-ai/gemini-25-flash-image",
# Other Models
"Stable Diffusion 3.5 Large": "fal-ai/stable-diffusion-v35-large",
"Qwen Image": "fal-ai/qwen-image"
}
class FalImageGenerator(ImageGenerator):
"""Handles image generation using fal_client."""
def __init__(self, model_name: str = "fal-ai/flux-pro"):
self.model_name = model_name
def _on_queue_update(self, update):
"""Handle queue updates during image generation."""
if isinstance(update, fal_client.InProgress):
for log in update.logs:
print(log["message"])
def generate(self, prompt: str, **kwargs) -> Image.Image:
"""Generate an image from a text prompt using fal_client."""
result = fal_client.subscribe(
self.model_name,
arguments={
"prompt": prompt,
**kwargs
},
with_logs=True,
on_queue_update=self._on_queue_update,
)
print(result)
return self._extract_image_from_result(result)
def _extract_image_from_result(self, result: Dict[str, Any]) -> Image.Image:
"""Extract and download image from fal_client result."""
if result and 'images' in result and len(result['images']) > 0:
image_url = result['images'][0]['url']
response = requests.get(image_url)
response.raise_for_status() # Raise an exception for bad status codes
image = Image.open(BytesIO(response.content))
return image
else:
raise ValueError("No image found in the result")
class MultiModelFalImageGenerator(ImageGenerator):
"""Handles image generation using multiple fal.ai models."""
def __init__(self, selected_models: List[str] = None):
"""Initialize with selected model names.
Args:
selected_models: List of model display names from AVAILABLE_MODELS keys
"""
if selected_models is None:
selected_models = ["FLUX.1 [pro]"] # Default to single model
self.selected_models = selected_models
self.current_model_index = 0
self.generators = {}
# Create individual generators for each selected model
for model_name in selected_models:
if model_name in AVAILABLE_MODELS:
model_id = AVAILABLE_MODELS[model_name]
self.generators[model_name] = FalImageGenerator(model_id)
def get_current_model_name(self) -> str:
"""Get the name of the currently active model."""
if self.current_model_index < len(self.selected_models):
return self.selected_models[self.current_model_index]
return self.selected_models[0] if self.selected_models else "Unknown"
def switch_to_next_model(self) -> bool:
"""Switch to the next model in the sequence.
Returns:
True if switched to next model, False if no more models
"""
self.current_model_index += 1
return self.current_model_index < len(self.selected_models)
def reset_to_first_model(self):
"""Reset to the first model in the sequence."""
self.current_model_index = 0
def generate(self, prompt: str, **kwargs) -> Image.Image:
"""Generate an image using the current model."""
current_model = self.get_current_model_name()
if current_model in self.generators:
return self.generators[current_model].generate(prompt, **kwargs)
else:
raise ValueError(f"Model {current_model} not available")
def generate_with_model(self, model_name: str, prompt: str, **kwargs) -> Image.Image:
"""Generate an image using a specific model."""
if model_name in self.generators:
return self.generators[model_name].generate(prompt, **kwargs)
else:
raise ValueError(f"Model {model_name} not available")