Spaces:
Runtime error
Runtime error
| # Copyright 2023 Adobe Research. All rights reserved. | |
| # To view a copy of the license, visit LICENSE.md. | |
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" | |
| from PIL import Image | |
| import torch | |
| import gradio as gr | |
| from lavis.models import load_model_and_preprocess | |
| from diffusers import DDIMScheduler | |
| from src.utils.ddim_inv import DDIMInversion | |
| from src.utils.edit_directions import construct_direction | |
| from src.utils.scheduler import DDIMInverseScheduler | |
| from src.utils.edit_pipeline import EditingPipeline | |
| def main(): | |
| NUM_DDIM_STEPS = 50 | |
| TORCH_DTYPE = torch.float16 | |
| XA_GUIDANCE = 0.1 | |
| DIR_SCALE = 1.0 | |
| MODEL_NAME = 'CompVis/stable-diffusion-v1-4' | |
| NEGATIVE_GUIDANCE_SCALE = 5.0 | |
| DEVICE = "cuda" | |
| # if torch.cuda.is_available(): | |
| # DEVICE = "cuda" | |
| # else: | |
| # DEVICE = "cpu" | |
| # print(f"Using {DEVICE}") | |
| model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=DEVICE) | |
| pipe = EditingPipeline.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to(DEVICE) | |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| inv_pipe = DDIMInversion.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to("cuda") | |
| inv_pipe.scheduler = DDIMInverseScheduler.from_config(inv_pipe.scheduler.config) | |
| TASKS = ["dog2cat","cat2dog","horse2zebra","zebra2horse","horse2llama","dog2capy"] | |
| TASK_OPTIONS = ["Dog to Cat", "Cat to Dog", "Horse to Zebra", "Zebra to Horse", "Horse to Llama", "Dog to Capy"] | |
| def edit_real_image( | |
| og_img, | |
| task, | |
| seed, | |
| xa_guidance, | |
| num_ddim_steps, | |
| dir_scale | |
| ): | |
| torch.cuda.manual_seed(seed) | |
| # do inversion first, get inversion and generated prompt | |
| curr_img = og_img.resize((512,512), Image.Resampling.LANCZOS) | |
| _image = vis_processors["eval"](curr_img).unsqueeze(0).to(DEVICE) | |
| prompt_str = model_blip.generate({"image": _image})[0] | |
| x_inv, _, _ = inv_pipe( | |
| prompt_str, | |
| guidance_scale=1, | |
| num_inversion_steps=NUM_DDIM_STEPS, | |
| img=curr_img, | |
| torch_dtype=TORCH_DTYPE | |
| ) | |
| task_str = TASKS[task] | |
| rec_pil, edit_pil = pipe( | |
| prompt_str, | |
| num_inference_steps=num_ddim_steps, | |
| x_in=x_inv[0].unsqueeze(0), | |
| edit_dir=construct_direction(task_str)*dir_scale, | |
| guidance_amount=xa_guidance, | |
| guidance_scale=NEGATIVE_GUIDANCE_SCALE, | |
| negative_prompt=prompt_str # use the unedited prompt for the negative prompt | |
| ) | |
| return prompt_str, edit_pil[0] | |
| def edit_real_image_example(): | |
| test_img = Image.open("./assets/test_images/cats/cat_4.png") | |
| seed = 42 | |
| task = 1 | |
| prompt_str, edited_img = edit_real_image(test_img, task, seed, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE) | |
| return test_img, seed, "Cat to Dog", prompt_str, edited_img, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE | |
| def edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps): | |
| torch.cuda.manual_seed(seed) | |
| x = torch.randn((1,4,64,64), device="cuda") | |
| task_str = TASKS[task] | |
| rec_pil, edit_pil = pipe( | |
| prompt_str, | |
| num_inference_steps=num_ddim_steps, | |
| x_in=x, | |
| edit_dir=construct_direction(task_str), | |
| guidance_amount=xa_guidance, | |
| guidance_scale=NEGATIVE_GUIDANCE_SCALE, | |
| negative_prompt="" # use the empty string for the negative prompt | |
| ) | |
| return rec_pil[0], edit_pil[0] | |
| def edit_synth_image_example(): | |
| seed = 42 | |
| task = 1 | |
| xa_guidance = XA_GUIDANCE | |
| num_ddim_steps = NUM_DDIM_STEPS | |
| prompt_str = "A cute white cat sitting on top of the fridge" | |
| recon_img, edited_img = edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps) | |
| return seed, "Cat to Dog", xa_guidance, num_ddim_steps, prompt_str, recon_img, edited_img | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| ### Zero-shot Image-to-Image Translation (https://github.com/pix2pixzero/pix2pix-zero) | |
| Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, Jun-Yan Zhu <br/> | |
| - For real images: | |
| - Upload an image of a dog, cat or horse, | |
| - Choose one of the task options to turn it into another animal! | |
| - Changing Parameters: | |
| - Increase direction scale is it is not cat (or another animal) enough. | |
| - If the quality is not high enough, increase num ddim steps. | |
| - Increase cross attention guidance to preserve original image structures. <br/> | |
| - For synthetic images: | |
| - Enter a prompt about dogs/cats/horses | |
| - Choose a task option | |
| """) | |
| with gr.Tab("Real Image"): | |
| with gr.Row(): | |
| seed = gr.Number(value=42, precision=1, label="Seed", interactive=True) | |
| real_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True) | |
| real_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True) | |
| real_edit_dir_scale = gr.Number(value=DIR_SCALE, label="Edit Direction Scale", interactive=True) | |
| real_generate_button = gr.Button("Generate") | |
| real_load_sample_button = gr.Button("Load Example") | |
| with gr.Row(): | |
| task_name = gr.Radio( | |
| label='Task Name', | |
| choices=TASK_OPTIONS, | |
| value=TASK_OPTIONS[0], | |
| type="index", | |
| show_label=True, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| recon_text = gr.Textbox(lines=1, label="Reconstructed Text", interactive=False) | |
| with gr.Row(): | |
| input_image = gr.Image(label="Input Image", type="pil", interactive=True) | |
| output_image = gr.Image(label="Output Image", type="pil", interactive=False) | |
| with gr.Tab("Synthetic Images"): | |
| with gr.Row(): | |
| synth_seed = gr.Number(value=42, precision=1, label="Seed", interactive=True) | |
| synth_prompt = gr.Textbox(lines=1, label="Prompt", interactive=True) | |
| synth_generate_button = gr.Button("Generate") | |
| synth_load_sample_button = gr.Button("Load Example") | |
| with gr.Row(): | |
| synth_task_name = gr.Radio( | |
| label='Task Name', | |
| choices=TASK_OPTIONS, | |
| value=TASK_OPTIONS[0], | |
| type="index", | |
| show_label=True, | |
| interactive=True, | |
| ) | |
| synth_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True) | |
| synth_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True) | |
| with gr.Row(): | |
| synth_input_image = gr.Image(label="Input Image", type="pil", interactive=False) | |
| synth_output_image = gr.Image(label="Output Image", type="pil", interactive=False) | |
| real_generate_button.click( | |
| fn=edit_real_image, | |
| inputs=[ | |
| input_image, task_name, seed, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale | |
| ], | |
| outputs=[recon_text, output_image] | |
| ) | |
| real_load_sample_button.click( | |
| fn=edit_real_image_example, | |
| inputs=[], | |
| outputs=[input_image, seed, task_name, recon_text, output_image, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale] | |
| ) | |
| synth_generate_button.click( | |
| fn=edit_synthetic_image, | |
| inputs=[synth_seed, synth_task_name, synth_prompt, synth_xa_guidance, synth_num_ddim_steps], | |
| outputs=[synth_input_image, synth_output_image] | |
| ) | |
| synth_load_sample_button.click( | |
| fn=edit_synth_image_example, | |
| inputs=[], | |
| outputs=[seed, synth_task_name, synth_xa_guidance, synth_num_ddim_steps, synth_prompt, synth_input_image, synth_output_image] | |
| ) | |
| demo.queue(concurrency_count=1) | |
| demo.launch(share=False, server_name="0.0.0.0") | |
| if __name__ == "__main__": | |
| main() | |