DharavathSri commited on
Commit
235ff3e
·
verified ·
1 Parent(s): 18018f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
3
+ from diffusers import UniPCMultistepScheduler
4
+ import torch
5
+ from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
+ import time
9
+
10
+ # App title and config
11
+ st.set_page_config(
12
+ page_title="AI Image Generator with ControlNet",
13
+ page_icon="🎨",
14
+ layout="wide",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ # Custom CSS for styling
19
+ st.markdown("""
20
+ <style>
21
+ .main {
22
+ background-color: #f5f5f5;
23
+ }
24
+ .stButton>button {
25
+ background-color: #4CAF50;
26
+ color: white;
27
+ border-radius: 8px;
28
+ padding: 10px 24px;
29
+ font-weight: bold;
30
+ }
31
+ .stButton>button:hover {
32
+ background-color: #45a049;
33
+ }
34
+ .stSelectbox, .stSlider, .stTextInput {
35
+ margin-bottom: 20px;
36
+ }
37
+ .header {
38
+ color: #4CAF50;
39
+ text-align: center;
40
+ }
41
+ .footer {
42
+ text-align: center;
43
+ margin-top: 30px;
44
+ color: #777;
45
+ font-size: 0.9em;
46
+ }
47
+ .image-container {
48
+ display: flex;
49
+ justify-content: space-around;
50
+ flex-wrap: wrap;
51
+ gap: 20px;
52
+ margin-top: 20px;
53
+ }
54
+ .image-card {
55
+ border-radius: 10px;
56
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
57
+ padding: 15px;
58
+ background: white;
59
+ }
60
+ </style>
61
+ """, unsafe_allow_html=True)
62
+
63
+ # Header
64
+ st.markdown("<h1 class='header'>🎨 AI Image Generator with ControlNet</h1>", unsafe_allow_html=True)
65
+ st.markdown("Generate stunning images guided by Stable Diffusion and ControlNet. Upload a reference image or use edge detection to control the output.")
66
+
67
+ # Sidebar for controls
68
+ with st.sidebar:
69
+ st.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=200)
70
+ st.markdown("### Configuration")
71
+
72
+ # Model selection
73
+ model_choice = st.selectbox(
74
+ "Select ControlNet Type",
75
+ ("Canny Edge", "Depth Map", "OpenPose (Human Pose)"),
76
+ index=0
77
+ )
78
+
79
+ # Parameters
80
+ prompt = st.text_area("Prompt", "a beautiful landscape with mountains and lake, highly detailed, digital art")
81
+ negative_prompt = st.text_area("Negative Prompt", "blurry, low quality, distorted")
82
+ num_images = st.slider("Number of images to generate", 1, 4, 1)
83
+ steps = st.slider("Number of inference steps", 20, 100, 50)
84
+ guidance_scale = st.slider("Guidance scale", 1.0, 20.0, 7.5)
85
+ seed = st.number_input("Seed", value=42, min_value=0, max_value=1000000)
86
+
87
+ # Upload control image
88
+ uploaded_file = st.file_uploader("Upload control image", type=["jpg", "png", "jpeg"])
89
+
90
+ # Advanced options
91
+ with st.expander("Advanced Options"):
92
+ strength = st.slider("Control strength", 0.1, 2.0, 1.0)
93
+ low_threshold = st.slider("Canny low threshold", 1, 255, 100)
94
+ high_threshold = st.slider("Canny high threshold", 1, 255, 200)
95
+
96
+ # Initialize models (cached)
97
+ @st.cache_resource
98
+ def load_models(model_type):
99
+ if model_type == "Canny Edge":
100
+ controlnet = ControlNetModel.from_pretrained(
101
+ "lllyasviel/sd-controlnet-canny",
102
+ torch_dtype=torch.float16
103
+ )
104
+ elif model_type == "Depth Map":
105
+ controlnet = ControlNetModel.from_pretrained(
106
+ "lllyasviel/sd-controlnet-depth",
107
+ torch_dtype=torch.float16
108
+ )
109
+ else: # OpenPose
110
+ controlnet = ControlNetModel.from_pretrained(
111
+ "lllyasviel/sd-controlnet-openpose",
112
+ torch_dtype=torch.float16
113
+ )
114
+
115
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
116
+ "runwayml/stable-diffusion-v1-5",
117
+ controlnet=controlnet,
118
+ torch_dtype=torch.float16,
119
+ safety_checker=None
120
+ ).to("cuda")
121
+
122
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
123
+ pipe.enable_model_cpu_offload()
124
+ return pipe
125
+
126
+ # Process control image based on model type
127
+ def process_control_image(image, model_type):
128
+ image = np.array(image)
129
+
130
+ if model_type == "Canny Edge":
131
+ image = cv2.Canny(image, low_threshold, high_threshold)
132
+ image = image[:, :, None]
133
+ image = np.concatenate([image, image, image], axis=2)
134
+ elif model_type == "Depth Map":
135
+ # Using MiDaS for depth estimation - would need additional imports
136
+ # This is simplified for demo purposes
137
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
138
+ image = np.stack([image]*3, axis=-1)
139
+ else: # OpenPose
140
+ # Would need OpenPose processing - simplified for demo
141
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
142
+
143
+ return Image.fromarray(image)
144
+
145
+ # Main content
146
+ col1, col2 = st.columns([1, 1])
147
+
148
+ with col1:
149
+ st.markdown("### Control Image")
150
+ if uploaded_file is not None:
151
+ control_image = Image.open(uploaded_file)
152
+ processed_image = process_control_image(control_image, model_choice)
153
+ st.image(processed_image, caption="Processed Control Image", use_column_width=True)
154
+ else:
155
+ st.info("Please upload an image to use as control")
156
+
157
+ with col2:
158
+ st.markdown("### Generated Images")
159
+ if st.button("Generate Images"):
160
+ if uploaded_file is None:
161
+ st.warning("Please upload a control image first")
162
+ else:
163
+ with st.spinner("Generating images... Please wait"):
164
+ start_time = time.time()
165
+
166
+ # Load models
167
+ pipe = load_models(model_choice)
168
+
169
+ # Generator for reproducibility
170
+ generator = torch.Generator(device="cuda").manual_seed(seed)
171
+
172
+ # Generate images
173
+ images = pipe(
174
+ [prompt] * num_images,
175
+ negative_prompt=[negative_prompt] * num_images,
176
+ image=processed_image,
177
+ num_inference_steps=steps,
178
+ generator=generator,
179
+ guidance_scale=guidance_scale,
180
+ controlnet_conditioning_scale=strength
181
+ ).images
182
+
183
+ # Display results
184
+ st.markdown(f"<div class='image-container'>", unsafe_allow_html=True)
185
+ for i, img in enumerate(images):
186
+ st.image(img, caption=f"Image {i+1}", use_column_width=True)
187
+ st.markdown("</div>", unsafe_allow_html=True)
188
+
189
+ # Show performance info
190
+ end_time = time.time()
191
+ st.success(f"Generated {num_images} images in {end_time - start_time:.2f} seconds")
192
+
193
+ # Footer
194
+ st.markdown("""
195
+ <div class='footer'>
196
+ <p>Powered by Stable Diffusion and ControlNet | Deployed on Hugging Face Spaces</p>
197
+ </div>
198
+ """, unsafe_allow_html=True)