Spaces:
Running
Running
Add application file
Browse files
README.md
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
---
|
| 2 |
title: IllusionDiffusion
|
| 3 |
emoji: 🔥
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license:
|
| 11 |
hf_oauth: true
|
| 12 |
disable_embedding: true
|
| 13 |
short_description: Generate stunning high quality illusion artwork
|
|
|
|
| 1 |
---
|
| 2 |
title: IllusionDiffusion
|
| 3 |
emoji: 🔥
|
| 4 |
+
colorFrom: green
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
hf_oauth: true
|
| 12 |
disable_embedding: true
|
| 13 |
short_description: Generate stunning high quality illusion artwork
|
app.py
CHANGED
|
@@ -28,15 +28,15 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
|
| 28 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
| 29 |
|
| 30 |
# Initialize both pipelines
|
| 31 |
-
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.
|
| 32 |
-
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.
|
| 33 |
|
| 34 |
# Initialize the safety checker conditionally
|
| 35 |
SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
| 36 |
safety_checker = None
|
| 37 |
feature_extractor = None
|
| 38 |
if SAFETY_CHECKER_ENABLED:
|
| 39 |
-
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("
|
| 40 |
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 41 |
|
| 42 |
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
@@ -45,33 +45,11 @@ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
| 45 |
vae=vae,
|
| 46 |
safety_checker=safety_checker,
|
| 47 |
feature_extractor=feature_extractor,
|
| 48 |
-
torch_dtype=torch.
|
| 49 |
-
).to("
|
| 50 |
|
| 51 |
-
# Function to check NSFW images
|
| 52 |
-
#def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
|
| 53 |
-
# if SAFETY_CHECKER_ENABLED:
|
| 54 |
-
# safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
| 55 |
-
# has_nsfw_concepts = safety_checker(
|
| 56 |
-
# images=[images],
|
| 57 |
-
# clip_input=safety_checker_input.pixel_values.to("cuda")
|
| 58 |
-
# )
|
| 59 |
-
# return images, has_nsfw_concepts
|
| 60 |
-
# else:
|
| 61 |
-
# return images, [False] * len(images)
|
| 62 |
-
|
| 63 |
-
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 64 |
-
#main_pipe.unet.to(memory_format=torch.channels_last)
|
| 65 |
-
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 66 |
-
#model_id = "stabilityai/sd-x2-latent-upscaler"
|
| 67 |
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
|
| 68 |
|
| 69 |
-
|
| 70 |
-
#image_pipe.unet = torch.compile(image_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 71 |
-
#upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
| 72 |
-
#upscaler.to("cuda")
|
| 73 |
-
|
| 74 |
-
|
| 75 |
# Sampler map
|
| 76 |
SAMPLER_MAP = {
|
| 77 |
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
|
|
@@ -113,7 +91,6 @@ def common_upscale(samples, width, height, upscale_method, crop=False):
|
|
| 113 |
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
| 114 |
|
| 115 |
def upscale(samples, upscale_method, scale_by):
|
| 116 |
-
#s = samples.copy()
|
| 117 |
width = round(samples["images"].shape[3] * scale_by)
|
| 118 |
height = round(samples["images"].shape[2] * scale_by)
|
| 119 |
s = common_upscale(samples["images"], width, height, upscale_method, "disabled")
|
|
@@ -135,7 +112,6 @@ def convert_to_base64(pil_image):
|
|
| 135 |
return temp_file.name
|
| 136 |
|
| 137 |
# Inference function
|
| 138 |
-
@spaces.GPU
|
| 139 |
def inference(
|
| 140 |
control_image: Image.Image,
|
| 141 |
prompt: str,
|
|
@@ -155,16 +131,12 @@ def inference(
|
|
| 155 |
start_time_formatted = time.strftime("%H:%M:%S", start_time_struct)
|
| 156 |
print(f"Inference started at {start_time_formatted}")
|
| 157 |
|
| 158 |
-
# Generate the initial image
|
| 159 |
-
#init_image = init_pipe(prompt).images[0]
|
| 160 |
-
|
| 161 |
-
# Rest of your existing code
|
| 162 |
control_image_small = center_crop_resize(control_image)
|
| 163 |
control_image_large = center_crop_resize(control_image, (1024, 1024))
|
| 164 |
|
| 165 |
main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
|
| 166 |
my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
|
| 167 |
-
generator = torch.Generator(device="
|
| 168 |
|
| 169 |
out = main_pipe(
|
| 170 |
prompt=prompt,
|
|
@@ -230,7 +202,6 @@ with gr.Blocks() as app:
|
|
| 230 |
'''
|
| 231 |
)
|
| 232 |
|
| 233 |
-
|
| 234 |
state_img_input = gr.State()
|
| 235 |
state_img_output = gr.State()
|
| 236 |
with gr.Row():
|
|
@@ -282,7 +253,7 @@ with gr.Blocks(css=css) as app_with_history:
|
|
| 282 |
with gr.Tab("Past generations"):
|
| 283 |
user_history.render()
|
| 284 |
|
| 285 |
-
app_with_history.queue(max_size=20,api_open=False
|
| 286 |
|
| 287 |
if __name__ == "__main__":
|
| 288 |
app_with_history.launch(max_threads=400,share=True)
|
|
|
|
| 28 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
| 29 |
|
| 30 |
# Initialize both pipelines
|
| 31 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32)
|
| 32 |
+
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float32)
|
| 33 |
|
| 34 |
# Initialize the safety checker conditionally
|
| 35 |
SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
| 36 |
safety_checker = None
|
| 37 |
feature_extractor = None
|
| 38 |
if SAFETY_CHECKER_ENABLED:
|
| 39 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cpu")
|
| 40 |
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 41 |
|
| 42 |
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
|
|
| 45 |
vae=vae,
|
| 46 |
safety_checker=safety_checker,
|
| 47 |
feature_extractor=feature_extractor,
|
| 48 |
+
torch_dtype=torch.float32,
|
| 49 |
+
).to("cpu")
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# Sampler map
|
| 54 |
SAMPLER_MAP = {
|
| 55 |
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
|
|
|
|
| 91 |
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
| 92 |
|
| 93 |
def upscale(samples, upscale_method, scale_by):
|
|
|
|
| 94 |
width = round(samples["images"].shape[3] * scale_by)
|
| 95 |
height = round(samples["images"].shape[2] * scale_by)
|
| 96 |
s = common_upscale(samples["images"], width, height, upscale_method, "disabled")
|
|
|
|
| 112 |
return temp_file.name
|
| 113 |
|
| 114 |
# Inference function
|
|
|
|
| 115 |
def inference(
|
| 116 |
control_image: Image.Image,
|
| 117 |
prompt: str,
|
|
|
|
| 131 |
start_time_formatted = time.strftime("%H:%M:%S", start_time_struct)
|
| 132 |
print(f"Inference started at {start_time_formatted}")
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
control_image_small = center_crop_resize(control_image)
|
| 135 |
control_image_large = center_crop_resize(control_image, (1024, 1024))
|
| 136 |
|
| 137 |
main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
|
| 138 |
my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
|
| 139 |
+
generator = torch.Generator(device="cpu").manual_seed(my_seed)
|
| 140 |
|
| 141 |
out = main_pipe(
|
| 142 |
prompt=prompt,
|
|
|
|
| 202 |
'''
|
| 203 |
)
|
| 204 |
|
|
|
|
| 205 |
state_img_input = gr.State()
|
| 206 |
state_img_output = gr.State()
|
| 207 |
with gr.Row():
|
|
|
|
| 253 |
with gr.Tab("Past generations"):
|
| 254 |
user_history.render()
|
| 255 |
|
| 256 |
+
app_with_history.queue(max_size=20,api_open=False)
|
| 257 |
|
| 258 |
if __name__ == "__main__":
|
| 259 |
app_with_history.launch(max_threads=400,share=True)
|