Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import spaces | |
| from PIL import Image | |
| from diffusers import DDPMScheduler | |
| from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler | |
| from module.ip_adapter.utils import load_adapter_to_pipe | |
| from pipelines.sdxl_instantir import InstantIRPipeline | |
| def resize_img(input_image, max_side=1280, min_side=1024, size=None, | |
| pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): | |
| w, h = input_image.size | |
| if size is not None: | |
| w_resize_new, h_resize_new = size | |
| else: | |
| # ratio = min_side / min(h, w) | |
| # w, h = round(ratio*w), round(ratio*h) | |
| ratio = max_side / max(h, w) | |
| input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) | |
| w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number | |
| h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number | |
| input_image = input_image.resize([w_resize_new, h_resize_new], mode) | |
| if pad_to_max_side: | |
| res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 | |
| offset_x = (max_side - w_resize_new) // 2 | |
| offset_y = (max_side - h_resize_new) // 2 | |
| res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) | |
| input_image = Image.fromarray(res) | |
| return input_image | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download(repo_id="InstantX/InstantIR", filename="models/adapter.pt", local_dir=".") | |
| hf_hub_download(repo_id="InstantX/InstantIR", filename="models/aggregator.pt", local_dir=".") | |
| hf_hub_download(repo_id="InstantX/InstantIR", filename="models/previewer_lora_weights.bin", local_dir=".") | |
| instantir_path = f'./models' | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| sdxl_repo_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
| dinov2_repo_id = "facebook/dinov2-large" | |
| lcm_repo_id = "latent-consistency/lcm-lora-sdxl" | |
| if torch.cuda.is_available(): | |
| torch_dtype = torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| # Load pretrained models. | |
| print("Initializing pipeline...") | |
| pipe = InstantIRPipeline.from_pretrained( | |
| sdxl_repo_id, | |
| torch_dtype=torch_dtype, | |
| ) | |
| # Image prompt projector. | |
| print("Loading LQ-Adapter...") | |
| load_adapter_to_pipe( | |
| pipe, | |
| f"{instantir_path}/adapter.pt", | |
| dinov2_repo_id, | |
| ) | |
| # Prepare previewer | |
| lora_alpha = pipe.prepare_previewers(instantir_path) | |
| print(f"use lora alpha {lora_alpha}") | |
| lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True) | |
| print(f"use lora alpha {lora_alpha}") | |
| pipe.to(device=device, dtype=torch_dtype) | |
| pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler") | |
| lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) | |
| # Load weights. | |
| print("Loading checkpoint...") | |
| aggregator_state_dict = torch.load( | |
| f"{instantir_path}/aggregator.pt", | |
| map_location="cpu" | |
| ) | |
| pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True) | |
| pipe.aggregator.to(device=device, dtype=torch_dtype) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \ | |
| ultra HD, extreme meticulous detailing, skin pore detailing, \ | |
| hyper sharpness, perfect without deformations, \ | |
| taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. " | |
| NEG_PROMPT = "blurry, out of focus, unclear, depth of field, over-smooth, \ | |
| sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \ | |
| dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \ | |
| watermark, signature, jpeg artifacts, deformed, lowres" | |
| def unpack_pipe_out(preview_row, index): | |
| return preview_row[index][0] | |
| def dynamic_preview_slider(sampling_steps): | |
| print(sampling_steps) | |
| return gr.Slider(label="Restoration Previews", value=sampling_steps-1, minimum=0, maximum=sampling_steps-1, step=1) | |
| def dynamic_guidance_slider(sampling_steps): | |
| return gr.Slider(label="Start Free Rendering", value=sampling_steps, minimum=0, maximum=sampling_steps, step=1) | |
| def show_final_preview(preview_row): | |
| return preview_row[-1][0] | |
| #[uncomment to use ZeroGPU] | |
| def instantir_restore( | |
| lq, # A low-quality PIL image to be restored | |
| prompt="", # Optional: A text prompt guiding creative restoration | |
| steps=30, # Number of denoising steps (controls generation detail and time) | |
| cfg_scale=7.0, # Classifier-Free Guidance scale; higher = more prompt adherence | |
| guidance_end=1.0, # When to stop guidance and allow free generation (0.0 - 1.0 or 0 - steps) | |
| creative_restoration=False, # Toggle creative mode (uses LCM adapter) | |
| seed=3407, # Seed for reproducibility | |
| height=1024, # Target height for output image | |
| width=1024, # Target width for output image | |
| preview_start=0.0, # When to start showing previews (fraction or step index) | |
| progress=gr.Progress(track_tqdm=True) # Progress tracker for Gradio | |
| ): | |
| """ | |
| Restore or creatively re-generate a low-quality image using the InstantIR pipeline. | |
| This function takes a degraded image and applies a guided diffusion model to restore it. | |
| Optionally, a text prompt can be provided to guide a creative re-interpretation of the image. | |
| Args: | |
| lq (PIL.Image): The input low-quality image to restore. | |
| prompt (str, optional): Text description to guide restoration or creative re-generation. | |
| steps (int): Number of inference steps; more steps generally yield better results. | |
| cfg_scale (float): Guidance scale for prompt adherence; higher means stronger influence. | |
| guidance_end (float or int): Defines when to stop using prompt guidance during diffusion. | |
| creative_restoration (bool): Whether to enable imaginative regeneration via LCM adapter. | |
| seed (int): Random seed for reproducible results. | |
| height (int): Output image height; used if input is square. | |
| width (int): Output image width; used if input is square. | |
| preview_start (float or int): Step or ratio when previewing starts. | |
| progress (gr.Progress): Progress tracker for UI feedback. | |
| Returns: | |
| Tuple[PIL.Image, List[List[Union[PIL.Image, str]]]]: | |
| - The final restored image. | |
| - A list of preview images from intermediate steps with labels. | |
| """ | |
| if creative_restoration: | |
| if "lcm" not in pipe.unet.active_adapters(): | |
| pipe.unet.set_adapter('lcm') | |
| else: | |
| if "previewer" not in pipe.unet.active_adapters(): | |
| pipe.unet.set_adapter('previewer') | |
| if isinstance(guidance_end, int): | |
| guidance_end = guidance_end / steps | |
| elif guidance_end > 1.0: | |
| guidance_end = guidance_end / steps | |
| if isinstance(preview_start, int): | |
| preview_start = preview_start / steps | |
| elif preview_start > 1.0: | |
| preview_start = preview_start / steps | |
| w, h = lq.size | |
| if w == h : | |
| lq = [resize_img(lq.convert("RGB"), size=(width, height))] | |
| else: | |
| lq = [resize_img(lq.convert("RGB"), size=None)] | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| timesteps = [ | |
| i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps) | |
| ] | |
| timesteps = timesteps[::-1] | |
| prompt = PROMPT if len(prompt)==0 else prompt | |
| neg_prompt = NEG_PROMPT | |
| out = pipe( | |
| prompt=[prompt]*len(lq), | |
| image=lq, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| timesteps=timesteps, | |
| negative_prompt=[neg_prompt]*len(lq), | |
| guidance_scale=cfg_scale, | |
| control_guidance_end=guidance_end, | |
| preview_start=preview_start, | |
| previewer_scheduler=lcm_scheduler, | |
| return_dict=False, | |
| save_preview_row=True, | |
| ) | |
| for i, preview_img in enumerate(out[1]): | |
| preview_img.append(f"preview_{i}") | |
| return out[0][0], out[1] | |
| examples = [ | |
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
| "An astronaut riding a green horse", | |
| "A delicious ceviche cheesecake slice", | |
| ] | |
| css=""" | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 640px; | |
| } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # InstantIR: Blind Image Restoration with Instant Generative Reference. | |
| ### **Official 🤗 Gradio demo of [InstantIR](https://arxiv.org/abs/2410.06551).** | |
| ### **InstantIR can not only help you restore your broken image, but also capable of imaginative re-creation following your text prompts. See advance usage for more details!** | |
| ## Basic usage: revitalize your image | |
| 1. Upload an image you want to restore; | |
| 2. Optionally, tune the `Steps` `CFG Scale` parameters. Typically higher steps lead to better results, but less than 50 is recommended for efficiency; | |
| 3. Click `InstantIR magic!`. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| lq_img = gr.Image(label="Low-quality image", type="pil") | |
| with gr.Row(): | |
| steps = gr.Number(label="Steps", value=30, step=1) | |
| cfg_scale = gr.Number(label="CFG Scale", value=7.0, step=0.1) | |
| with gr.Row(): | |
| height = gr.Number(label="Height", value=1024, step=1, visible=False) | |
| width = gr.Number(label="Width", value=1024, step=1, visible=False) | |
| seed = gr.Number(label="Seed", value=42, step=1) | |
| # guidance_start = gr.Slider(label="Guidance Start", value=1.0, minimum=0.0, maximum=1.0, step=0.05) | |
| guidance_end = gr.Slider(label="Start Free Rendering", value=30, minimum=0, maximum=30, step=1) | |
| preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1) | |
| prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="") | |
| mode = gr.Checkbox(label="Creative Restoration", value=False) | |
| with gr.Row(): | |
| restore_btn = gr.Button("InstantIR magic!") | |
| clear_btn = gr.ClearButton() | |
| gr.Examples( | |
| examples = ["assets/lady.png", "assets/man.png", "assets/dog.png", "assets/panda.png", "assets/sculpture.png", "assets/cottage.png", "assets/Naruto.png", "assets/Konan.png"], | |
| inputs = [lq_img] | |
| ) | |
| with gr.Column(): | |
| output = gr.Image(label="InstantIR restored", type="pil") | |
| index = gr.Slider(label="Restoration Previews", value=29, minimum=0, maximum=29, step=1) | |
| preview = gr.Image(label="Preview", type="pil") | |
| pipe_out = gr.Gallery(visible=False) | |
| clear_btn.add([lq_img, output, preview]) | |
| restore_btn.click( | |
| instantir_restore, inputs=[ | |
| lq_img, prompt, steps, cfg_scale, guidance_end, | |
| mode, seed, height, width, preview_start, | |
| ], | |
| outputs=[output, pipe_out], api_name="InstantIR" | |
| ) | |
| steps.change(dynamic_guidance_slider, inputs=steps, outputs=guidance_end, show_api=False) | |
| output.change(dynamic_preview_slider, inputs=steps, outputs=index, show_api=False) | |
| index.release(unpack_pipe_out, inputs=[pipe_out, index], outputs=preview, show_api=False) | |
| output.change(show_final_preview, inputs=pipe_out, outputs=preview, show_api=False) | |
| gr.Markdown( | |
| """ | |
| ## Advance usage: | |
| ### Browse restoration variants: | |
| 1. After InstantIR processing, drag the `Restoration Previews` slider to explore other in-progress versions; | |
| 2. If you like one of them, set the `Start Free Rendering` slider to the same value to get a more refined result. | |
| ### Creative restoration: | |
| 1. Check the `Creative Restoration` checkbox; | |
| 2. Input your text prompts in the `Restoration prompts` textbox; | |
| 3. Set `Start Free Rendering` slider to a medium value (around half of the `steps`) to provide adequate room for InstantIR creation. | |
| """) | |
| gr.Markdown( | |
| """ | |
| ## Citation | |
| If InstantIR is helpful to your work, please cite our paper via: | |
| ``` | |
| @article{huang2024instantir, | |
| title={InstantIR: Blind Image Restoration with Instant Generative Reference}, | |
| author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse}, | |
| journal={arXiv preprint arXiv:2410.06551}, | |
| year={2024} | |
| } | |
| ``` | |
| """) | |
| demo.queue().launch(ssr_mode=False, mcp_server=True) |