Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| from datasets import load_dataset | |
| from PIL import Image, ImageOps, ImageFilter | |
| from tqdm import tqdm | |
| import random | |
| import requests | |
| import io | |
| import time | |
| def download_image(url, timeout=10, retries=2): | |
| """Download image from URL with retry mechanism""" | |
| for attempt in range(retries): | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| } | |
| response = requests.get(url, timeout=timeout, headers=headers) | |
| if response.status_code == 200: | |
| image = Image.open(io.BytesIO(response.content)) | |
| return image | |
| else: | |
| return None | |
| except Exception as e: | |
| if attempt == retries - 1: # Last attempt | |
| print(f"Failed to download {url}: {e}") | |
| return None | |
| time.sleep(0.5) # Brief pause before retry | |
| return None | |
| def preprocess_image(image, target_size=512, quality_threshold=0.7): | |
| """Preprocess image with various enhancements""" | |
| if image is None: | |
| return None | |
| try: | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Filter out low quality images | |
| width, height = image.size | |
| if min(width, height) < target_size * quality_threshold: | |
| return None | |
| # Center crop to square if not already | |
| if width != height: | |
| size = min(width, height) | |
| left = (width - size) // 2 | |
| top = (height - size) // 2 | |
| image = image.crop((left, top, left + size, top + size)) | |
| # Resize to target size | |
| image = image.resize((target_size, target_size), Image.Resampling.LANCZOS) | |
| # Enhance image quality | |
| # Slightly sharpen | |
| image = image.filter(ImageFilter.UnsharpMask(radius=0.5, percent=120, threshold=3)) | |
| # Auto-adjust levels | |
| image = ImageOps.autocontrast(image, cutoff=1) | |
| return image | |
| except Exception as e: | |
| print(f"Error preprocessing image: {e}") | |
| return None | |
| def clean_prompt(prompt): | |
| """Clean and normalize prompts""" | |
| if not prompt: | |
| return None | |
| # Remove excessive whitespace | |
| prompt = ' '.join(prompt.split()) | |
| # Remove common artifacts | |
| prompt = prompt.replace(' ', ' ') | |
| prompt = prompt.strip(' .,;:') | |
| # Filter out very short or very long prompts | |
| words = prompt.split() | |
| if len(words) < 3 or len(words) > 50: | |
| return None | |
| return prompt | |
| def prepare_dreambooth_data(): | |
| # Load dataset | |
| print("Loading LAION dataset...") | |
| dataset = load_dataset("laion/laion2B-en-aesthetic", split="train", streaming=True) | |
| # Create directory structure | |
| data_dir = "./laion_dataset" | |
| os.makedirs(data_dir, exist_ok=True) | |
| valid_samples = 0 | |
| processed_count = 0 | |
| max_samples = 1000 # Limit total samples to process | |
| print(f"Starting to process up to {max_samples} samples...") | |
| # Process images with preprocessing | |
| for idx, sample in enumerate(tqdm(dataset, desc="Processing LAION samples")): | |
| if processed_count >= max_samples: | |
| break | |
| processed_count += 1 | |
| try: | |
| # Get URL and text from LAION format | |
| image_url = sample.get('URL', '') | |
| text_prompt = sample.get('TEXT', '') | |
| if not image_url or not text_prompt: | |
| continue | |
| # Clean prompt first | |
| prompt = clean_prompt(text_prompt) | |
| if prompt is None: | |
| continue | |
| # Download image from URL | |
| print(f"Downloading image {valid_samples + 1}: {image_url[:50]}...") | |
| image = download_image(image_url) | |
| if image is None: | |
| continue | |
| # Preprocess downloaded image | |
| processed_image = preprocess_image(image) | |
| if processed_image is None: | |
| continue | |
| # Save processed image | |
| image_path = os.path.join(data_dir, f"image_{valid_samples:04d}.jpg") | |
| processed_image.save(image_path, "JPEG", quality=95, optimize=True) | |
| # Save cleaned caption | |
| caption_path = os.path.join(data_dir, f"image_{valid_samples:04d}.txt") | |
| with open(caption_path, 'w', encoding='utf-8') as f: | |
| f.write(prompt) | |
| valid_samples += 1 | |
| # Optional: Add metadata file | |
| metadata_path = os.path.join(data_dir, f"image_{valid_samples-1:04d}_meta.txt") | |
| with open(metadata_path, 'w', encoding='utf-8') as f: | |
| f.write(f"URL: {image_url}\n") | |
| f.write(f"Aesthetic: {sample.get('aesthetic', 'N/A')}\n") | |
| f.write(f"Width: {sample.get('WIDTH', 'N/A')}\n") | |
| f.write(f"Height: {sample.get('HEIGHT', 'N/A')}\n") | |
| # Stop if we have enough samples | |
| if valid_samples >= 100: # Adjust this number as needed | |
| break | |
| except Exception as e: | |
| print(f"Error processing sample {idx}: {e}") | |
| continue | |
| print(f"Processed {processed_count} samples, saved {valid_samples} valid images to {data_dir}") | |
| return data_dir | |
| def create_demo_dataset(): | |
| """Create demo dataset as last resort""" | |
| print("Creating demo dataset...") | |
| data_dir = "./demo_dataset" | |
| os.makedirs(data_dir, exist_ok=True) | |
| demo_prompts = [ | |
| "a beautiful landscape with mountains", | |
| "portrait of a person with detailed features", | |
| "abstract colorful digital artwork", | |
| "modern architecture building design", | |
| "natural forest scene with trees", | |
| "urban cityscape at sunset", | |
| "artistic oil painting style", | |
| "vintage photography aesthetic", | |
| "minimalist geometric composition", | |
| "vibrant surreal art piece" | |
| ] | |
| for idx, prompt in enumerate(demo_prompts): | |
| # Create gradient background | |
| color1 = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200)) | |
| color2 = (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255)) | |
| image = Image.new('RGB', (512, 512), color1) | |
| # Save files | |
| image_path = os.path.join(data_dir, f"image_{idx:04d}.jpg") | |
| image.save(image_path, "JPEG", quality=95) | |
| caption_path = os.path.join(data_dir, f"image_{idx:04d}.txt") | |
| with open(caption_path, 'w', encoding='utf-8') as f: | |
| f.write(prompt) | |
| print(f"Created {len(demo_prompts)} demo samples") | |
| return data_dir | |
| # Main execution with fallback | |
| def main(): | |
| data_dir = prepare_dreambooth_data() | |
| # Generate training command | |
| training_command = f""" | |
| accelerate launch \\ | |
| --deepspeed_config_file ds_config.json \\ | |
| diffusers/examples/dreambooth/train_dreambooth.py \\ | |
| --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \\ | |
| --instance_data_dir="{data_dir}" \\ | |
| --instance_prompt="a high quality image" \\ | |
| --output_dir="./laion-model" \\ | |
| --resolution=512 \\ | |
| --train_batch_size=1 \\ | |
| --gradient_accumulation_steps=1 \\ | |
| --gradient_checkpointing \\ | |
| --learning_rate=5e-6 \\ | |
| --lr_scheduler="constant" \\ | |
| --lr_warmup_steps=0 \\ | |
| --max_train_steps=400 \\ | |
| --mixed_precision="fp16" \\ | |
| --checkpointing_steps=100 \\ | |
| --checkpoints_total_limit=1 \\ | |
| --report_to="tensorboard" \\ | |
| --logging_dir="./laion-model/logs" | |
| """ | |
| print(f"\n✅ Dataset prepared in: {data_dir}") | |
| print("🚀 Run this command to train:") | |
| print(training_command) | |
| if __name__ == "__main__": | |
| main() |