import streamlit as st from diffusers import StableDiffusionControlNetPipeline, ControlNetModel from diffusers import UniPCMultistepScheduler import torch from PIL import Image import numpy as np import cv2 import time # App title and config st.set_page_config( page_title="AI Image Generator with ControlNet", page_icon="🎨", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for styling st.markdown(""" """, unsafe_allow_html=True) # Header st.markdown("

🎨 AI Image Generator with ControlNet

", unsafe_allow_html=True) st.markdown("Generate stunning images guided by Stable Diffusion and ControlNet. Upload a reference image or use edge detection to control the output.") # Sidebar for controls with st.sidebar: st.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=200) st.markdown("### Configuration") # Model selection model_choice = st.selectbox( "Select ControlNet Type", ("Canny Edge", "Depth Map", "OpenPose (Human Pose)"), index=0 ) # Parameters prompt = st.text_area("Prompt", "a beautiful landscape with mountains and lake, highly detailed, digital art") negative_prompt = st.text_area("Negative Prompt", "blurry, low quality, distorted") num_images = st.slider("Number of images to generate", 1, 4, 1) steps = st.slider("Number of inference steps", 20, 100, 50) guidance_scale = st.slider("Guidance scale", 1.0, 20.0, 7.5) seed = st.number_input("Seed", value=42, min_value=0, max_value=1000000) # Upload control image uploaded_file = st.file_uploader("Upload control image", type=["jpg", "png", "jpeg"]) # Advanced options with st.expander("Advanced Options"): strength = st.slider("Control strength", 0.1, 2.0, 1.0) low_threshold = st.slider("Canny low threshold", 1, 255, 100) high_threshold = st.slider("Canny high threshold", 1, 255, 200) # Initialize models (cached) @st.cache_resource def load_models(model_type): if model_type == "Canny Edge": controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16 ) elif model_type == "Depth Map": controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16 ) else: # OpenPose controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16 ) pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None ).to("cuda") pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() return pipe # Process control image based on model type def process_control_image(image, model_type): image = np.array(image) if model_type == "Canny Edge": image = cv2.Canny(image, low_threshold, high_threshold) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) elif model_type == "Depth Map": # Using MiDaS for depth estimation - would need additional imports # This is simplified for demo purposes image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = np.stack([image]*3, axis=-1) else: # OpenPose # Would need OpenPose processing - simplified for demo image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return Image.fromarray(image) # Main content col1, col2 = st.columns([1, 1]) with col1: st.markdown("### Control Image") if uploaded_file is not None: control_image = Image.open(uploaded_file) processed_image = process_control_image(control_image, model_choice) st.image(processed_image, caption="Processed Control Image", use_column_width=True) else: st.info("Please upload an image to use as control") with col2: st.markdown("### Generated Images") if st.button("Generate Images"): if uploaded_file is None: st.warning("Please upload a control image first") else: with st.spinner("Generating images... Please wait"): start_time = time.time() # Load models pipe = load_models(model_choice) # Generator for reproducibility generator = torch.Generator(device="cuda").manual_seed(seed) # Generate images images = pipe( [prompt] * num_images, negative_prompt=[negative_prompt] * num_images, image=processed_image, num_inference_steps=steps, generator=generator, guidance_scale=guidance_scale, controlnet_conditioning_scale=strength ).images # Display results st.markdown(f"
", unsafe_allow_html=True) for i, img in enumerate(images): st.image(img, caption=f"Image {i+1}", use_column_width=True) st.markdown("
", unsafe_allow_html=True) # Show performance info end_time = time.time() st.success(f"Generated {num_images} images in {end_time - start_time:.2f} seconds") # Footer st.markdown(""" """, unsafe_allow_html=True)