Spaces:
Runtime error
Runtime error
| import fal_client | |
| from PIL import Image | |
| from typing import Dict, Any | |
| import requests | |
| from io import BytesIO | |
| from weave_prompt import ImageGenerator | |
| from typing import List, Tuple | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Available fal.ai models for text-to-image generation | |
| AVAILABLE_MODELS = { | |
| # FLUX Models | |
| "FLUX.1 [pro]": "fal-ai/flux-pro", | |
| "FLUX.1 [dev]": "fal-ai/flux/dev", | |
| "FLUX.1 [schnell]": "fal-ai/flux/schnell", | |
| "FLUX.1 with LoRAs": "fal-ai/flux-lora", | |
| # Google Models | |
| "Imagen 4": "fal-ai/imagen4/preview", | |
| "Imagen 4 Ultra": "fal-ai/imagen4/preview/ultra", | |
| "Gemini 2.5 Flash Image": "fal-ai/gemini-25-flash-image", | |
| # Other Models | |
| "Stable Diffusion 3.5 Large": "fal-ai/stable-diffusion-v35-large", | |
| "Qwen Image": "fal-ai/qwen-image" | |
| } | |
| class FalImageGenerator(ImageGenerator): | |
| """Handles image generation using fal_client.""" | |
| def __init__(self, model_name: str = "fal-ai/flux-pro"): | |
| self.model_name = model_name | |
| def _on_queue_update(self, update): | |
| """Handle queue updates during image generation.""" | |
| if isinstance(update, fal_client.InProgress): | |
| for log in update.logs: | |
| print(log["message"]) | |
| def generate(self, prompt: str, **kwargs) -> Image.Image: | |
| """Generate an image from a text prompt using fal_client.""" | |
| result = fal_client.subscribe( | |
| self.model_name, | |
| arguments={ | |
| "prompt": prompt, | |
| **kwargs | |
| }, | |
| with_logs=True, | |
| on_queue_update=self._on_queue_update, | |
| ) | |
| print(result) | |
| return self._extract_image_from_result(result) | |
| def _extract_image_from_result(self, result: Dict[str, Any]) -> Image.Image: | |
| """Extract and download image from fal_client result.""" | |
| if result and 'images' in result and len(result['images']) > 0: | |
| image_url = result['images'][0]['url'] | |
| response = requests.get(image_url) | |
| response.raise_for_status() # Raise an exception for bad status codes | |
| image = Image.open(BytesIO(response.content)) | |
| return image | |
| else: | |
| raise ValueError("No image found in the result") | |
| class MultiModelFalImageGenerator(ImageGenerator): | |
| """Handles image generation using multiple fal.ai models.""" | |
| def __init__(self, selected_models: List[str] = None): | |
| """Initialize with selected model names. | |
| Args: | |
| selected_models: List of model display names from AVAILABLE_MODELS keys | |
| """ | |
| if selected_models is None: | |
| selected_models = ["FLUX.1 [pro]"] # Default to single model | |
| self.selected_models = selected_models | |
| self.current_model_index = 0 | |
| self.generators = {} | |
| # Create individual generators for each selected model | |
| for model_name in selected_models: | |
| if model_name in AVAILABLE_MODELS: | |
| model_id = AVAILABLE_MODELS[model_name] | |
| self.generators[model_name] = FalImageGenerator(model_id) | |
| def get_current_model_name(self) -> str: | |
| """Get the name of the currently active model.""" | |
| if self.current_model_index < len(self.selected_models): | |
| return self.selected_models[self.current_model_index] | |
| return self.selected_models[0] if self.selected_models else "Unknown" | |
| def switch_to_next_model(self) -> bool: | |
| """Switch to the next model in the sequence. | |
| Returns: | |
| True if switched to next model, False if no more models | |
| """ | |
| self.current_model_index += 1 | |
| return self.current_model_index < len(self.selected_models) | |
| def reset_to_first_model(self): | |
| """Reset to the first model in the sequence.""" | |
| self.current_model_index = 0 | |
| def generate(self, prompt: str, **kwargs) -> Image.Image: | |
| """Generate an image using the current model.""" | |
| current_model = self.get_current_model_name() | |
| if current_model in self.generators: | |
| return self.generators[current_model].generate(prompt, **kwargs) | |
| else: | |
| raise ValueError(f"Model {current_model} not available") | |
| def generate_with_model(self, model_name: str, prompt: str, **kwargs) -> Image.Image: | |
| """Generate an image using a specific model.""" | |
| if model_name in self.generators: | |
| return self.generators[model_name].generate(prompt, **kwargs) | |
| else: | |
| raise ValueError(f"Model {model_name} not available") |