Spaces:
Running
on
Zero
Running
on
Zero
add NSFW checker and GPU mode
Browse files- app.py +44 -28
- data/nsfw.jpg +0 -0
- utils/pipeline.py +15 -1
app.py
CHANGED
|
@@ -61,7 +61,13 @@ class GlobalText:
|
|
| 61 |
self.pipeline = None
|
| 62 |
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 63 |
self.lora_model_state_dict = {}
|
| 64 |
-
self.device = torch.device("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def init_source_image_path(self, source_path):
|
| 67 |
self.source_paths = sorted(glob(os.path.join(source_path, '*')))
|
|
@@ -83,9 +89,9 @@ class GlobalText:
|
|
| 83 |
|
| 84 |
self.scheduler = 'LCM'
|
| 85 |
scheduler = LCMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
| 86 |
-
self.pipeline = ZePoPipeline.from_pretrained(model_path,scheduler=scheduler,torch_dtype=torch.float16,)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
time_end = datetime.now()
|
| 90 |
print(f'Load {model_path} successful in {time_end-time_start}')
|
| 91 |
return gr.Dropdown()
|
|
@@ -171,7 +177,7 @@ class GlobalText:
|
|
| 171 |
de_bug=de_bug,)
|
| 172 |
|
| 173 |
time_begin = datetime.now()
|
| 174 |
-
|
| 175 |
negative_prompt=negative_prompt_textbox,
|
| 176 |
image=source,
|
| 177 |
style=style,
|
|
@@ -183,7 +189,16 @@ class GlobalText:
|
|
| 183 |
fix_step_index=co_feat_step,
|
| 184 |
de_bug = de_bug,
|
| 185 |
callback = None
|
| 186 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
time_end = datetime.now()
|
| 188 |
print('generate one image with time {}'.format(time_end-time_begin))
|
| 189 |
|
|
@@ -191,18 +206,19 @@ class GlobalText:
|
|
| 191 |
|
| 192 |
|
| 193 |
save_file_path = os.path.join(self.savedir, save_file_name)
|
| 194 |
-
|
| 195 |
save_image(torch.tensor(generate_image).permute(0, 3, 1, 2), save_file_path, nrow=3, padding=0)
|
| 196 |
save_image(torch.tensor(generate_image[2:]).permute(0, 3, 1, 2), os.path.join(self.savedir_sample, save_file_name), nrow=3, padding=0)
|
| 197 |
self.init_results_image_path()
|
| 198 |
-
return [
|
| 199 |
-
generate_image[0],
|
| 200 |
-
generate_image[1],
|
| 201 |
-
generate_image[2],
|
| 202 |
-
self.init_results_image_path()
|
| 203 |
-
]
|
| 204 |
-
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
global_text = GlobalText()
|
| 207 |
|
| 208 |
|
|
@@ -309,23 +325,23 @@ def ui():
|
|
| 309 |
|
| 310 |
style_gallery_index.change(fn=update_style_list, inputs=[style_gallery_index], outputs=[style_image_gallery])
|
| 311 |
|
| 312 |
-
with gr.Tab("Results Gallery"):
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
|
| 320 |
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
|
| 328 |
-
|
| 329 |
|
| 330 |
|
| 331 |
|
|
|
|
| 61 |
self.pipeline = None
|
| 62 |
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 63 |
self.lora_model_state_dict = {}
|
| 64 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 65 |
+
|
| 66 |
+
self.nsfw_image = Image.open('./data/nsfw.jpg') # to float in [0,1]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
|
| 72 |
def init_source_image_path(self, source_path):
|
| 73 |
self.source_paths = sorted(glob(os.path.join(source_path, '*')))
|
|
|
|
| 89 |
|
| 90 |
self.scheduler = 'LCM'
|
| 91 |
scheduler = LCMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
| 92 |
+
self.pipeline = ZePoPipeline.from_pretrained(model_path,scheduler=scheduler,torch_dtype=torch.float16,).to('cuda')
|
| 93 |
+
if is_xformers:
|
| 94 |
+
self.pipeline.enable_xformers_memory_efficient_attention()
|
| 95 |
time_end = datetime.now()
|
| 96 |
print(f'Load {model_path} successful in {time_end-time_start}')
|
| 97 |
return gr.Dropdown()
|
|
|
|
| 177 |
de_bug=de_bug,)
|
| 178 |
|
| 179 |
time_begin = datetime.now()
|
| 180 |
+
results = model(prompt=prompts,
|
| 181 |
negative_prompt=negative_prompt_textbox,
|
| 182 |
image=source,
|
| 183 |
style=style,
|
|
|
|
| 189 |
fix_step_index=co_feat_step,
|
| 190 |
de_bug = de_bug,
|
| 191 |
callback = None
|
| 192 |
+
)
|
| 193 |
+
generate_image = results.images
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
for idx, has_nsfw_concept in enumerate(results.nsfw_content_detected):
|
| 197 |
+
if has_nsfw_concept:
|
| 198 |
+
generate_image[idx] = np.array(self.nsfw_image.resize((height_slider,width_slider))).astype(np.float32) / 255.0
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
time_end = datetime.now()
|
| 203 |
print('generate one image with time {}'.format(time_end-time_begin))
|
| 204 |
|
|
|
|
| 206 |
|
| 207 |
|
| 208 |
save_file_path = os.path.join(self.savedir, save_file_name)
|
| 209 |
+
|
| 210 |
save_image(torch.tensor(generate_image).permute(0, 3, 1, 2), save_file_path, nrow=3, padding=0)
|
| 211 |
save_image(torch.tensor(generate_image[2:]).permute(0, 3, 1, 2), os.path.join(self.savedir_sample, save_file_name), nrow=3, padding=0)
|
| 212 |
self.init_results_image_path()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
return [
|
| 215 |
+
generate_image[0],
|
| 216 |
+
generate_image[1],
|
| 217 |
+
generate_image[2],
|
| 218 |
+
self.init_results_image_path()
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
|
| 222 |
global_text = GlobalText()
|
| 223 |
|
| 224 |
|
|
|
|
| 325 |
|
| 326 |
style_gallery_index.change(fn=update_style_list, inputs=[style_gallery_index], outputs=[style_image_gallery])
|
| 327 |
|
| 328 |
+
# with gr.Tab("Results Gallery"):
|
| 329 |
+
# with gr.Row():
|
| 330 |
+
# refresh_results_list_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
| 331 |
+
# results_gallery_index = gr.Slider(label="Index", value=0, minimum=0, maximum=50, step=1)
|
| 332 |
+
# num_gallery_images = 12
|
| 333 |
+
# results_image_gallery = gr.Gallery(value=[], columns=4, label="style Image List")
|
| 334 |
+
# refresh_results_list_button.click(fn=global_text.init_results_image_path, inputs=[], outputs=[results_image_gallery])
|
| 335 |
|
| 336 |
|
| 337 |
+
# def update_results_list(index):
|
| 338 |
+
# if int(index) < 0:
|
| 339 |
+
# index = 0
|
| 340 |
+
# if int(index) > global_text.max_results_index:
|
| 341 |
+
# index = global_text.max_results_index
|
| 342 |
+
# return global_text.results_paths[int(index)*num_gallery_images:(int(index)+1)*num_gallery_images]
|
| 343 |
|
| 344 |
+
# results_gallery_index.change(fn=update_results_list, inputs=[results_gallery_index], outputs=[style_image_gallery])
|
| 345 |
|
| 346 |
|
| 347 |
|
data/nsfw.jpg
ADDED
|
utils/pipeline.py
CHANGED
|
@@ -157,6 +157,20 @@ class ZePoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMix
|
|
| 157 |
extra_step_kwargs["generator"] = generator
|
| 158 |
return extra_step_kwargs
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
| 162 |
def decode_latents(self, latents):
|
|
@@ -416,7 +430,7 @@ class ZePoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMix
|
|
| 416 |
# 9. Post-processing
|
| 417 |
if not output_type == "latent":
|
| 418 |
image = self.vae.decode(pred_x0 / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 419 |
-
has_nsfw_concept =
|
| 420 |
else:
|
| 421 |
image = pred_x0
|
| 422 |
has_nsfw_concept = None
|
|
|
|
| 157 |
extra_step_kwargs["generator"] = generator
|
| 158 |
return extra_step_kwargs
|
| 159 |
|
| 160 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
| 161 |
+
def run_safety_checker(self, image, device, dtype):
|
| 162 |
+
if self.safety_checker is None:
|
| 163 |
+
has_nsfw_concept = None
|
| 164 |
+
else:
|
| 165 |
+
if torch.is_tensor(image):
|
| 166 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 167 |
+
else:
|
| 168 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 169 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
| 170 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 171 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 172 |
+
)
|
| 173 |
+
return image, has_nsfw_concept
|
| 174 |
|
| 175 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
| 176 |
def decode_latents(self, latents):
|
|
|
|
| 430 |
# 9. Post-processing
|
| 431 |
if not output_type == "latent":
|
| 432 |
image = self.vae.decode(pred_x0 / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 433 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 434 |
else:
|
| 435 |
image = pred_x0
|
| 436 |
has_nsfw_concept = None
|