import streamlit as st
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers import UniPCMultistepScheduler
import torch
from PIL import Image
import numpy as np
import cv2
import time
# App title and config
st.set_page_config(
page_title="AI Image Generator with ControlNet",
page_icon="🎨",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for styling
st.markdown("""
""", unsafe_allow_html=True)
# Header
st.markdown("
", unsafe_allow_html=True)
st.markdown("Generate stunning images guided by Stable Diffusion and ControlNet. Upload a reference image or use edge detection to control the output.")
# Sidebar for controls
with st.sidebar:
st.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=200)
st.markdown("### Configuration")
# Model selection
model_choice = st.selectbox(
"Select ControlNet Type",
("Canny Edge", "Depth Map", "OpenPose (Human Pose)"),
index=0
)
# Parameters
prompt = st.text_area("Prompt", "a beautiful landscape with mountains and lake, highly detailed, digital art")
negative_prompt = st.text_area("Negative Prompt", "blurry, low quality, distorted")
num_images = st.slider("Number of images to generate", 1, 4, 1)
steps = st.slider("Number of inference steps", 20, 100, 50)
guidance_scale = st.slider("Guidance scale", 1.0, 20.0, 7.5)
seed = st.number_input("Seed", value=42, min_value=0, max_value=1000000)
# Upload control image
uploaded_file = st.file_uploader("Upload control image", type=["jpg", "png", "jpeg"])
# Advanced options
with st.expander("Advanced Options"):
strength = st.slider("Control strength", 0.1, 2.0, 1.0)
low_threshold = st.slider("Canny low threshold", 1, 255, 100)
high_threshold = st.slider("Canny high threshold", 1, 255, 200)
# Initialize models (cached)
@st.cache_resource
def load_models(model_type):
if model_type == "Canny Edge":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
torch_dtype=torch.float16
)
elif model_type == "Depth Map":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-depth",
torch_dtype=torch.float16
)
else: # OpenPose
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
torch_dtype=torch.float16,
safety_checker=None
).to("cuda")
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
return pipe
# Process control image based on model type
def process_control_image(image, model_type):
image = np.array(image)
if model_type == "Canny Edge":
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
elif model_type == "Depth Map":
# Using MiDaS for depth estimation - would need additional imports
# This is simplified for demo purposes
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = np.stack([image]*3, axis=-1)
else: # OpenPose
# Would need OpenPose processing - simplified for demo
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return Image.fromarray(image)
# Main content
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("### Control Image")
if uploaded_file is not None:
control_image = Image.open(uploaded_file)
processed_image = process_control_image(control_image, model_choice)
st.image(processed_image, caption="Processed Control Image", use_column_width=True)
else:
st.info("Please upload an image to use as control")
with col2:
st.markdown("### Generated Images")
if st.button("Generate Images"):
if uploaded_file is None:
st.warning("Please upload a control image first")
else:
with st.spinner("Generating images... Please wait"):
start_time = time.time()
# Load models
pipe = load_models(model_choice)
# Generator for reproducibility
generator = torch.Generator(device="cuda").manual_seed(seed)
# Generate images
images = pipe(
[prompt] * num_images,
negative_prompt=[negative_prompt] * num_images,
image=processed_image,
num_inference_steps=steps,
generator=generator,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=strength
).images
# Display results
st.markdown(f"", unsafe_allow_html=True)
for i, img in enumerate(images):
st.image(img, caption=f"Image {i+1}", use_column_width=True)
st.markdown("
", unsafe_allow_html=True)
# Show performance info
end_time = time.time()
st.success(f"Generated {num_images} images in {end_time - start_time:.2f} seconds")
# Footer
st.markdown("""
""", unsafe_allow_html=True)