Spaces:
Running
on
Zero
Running
on
Zero
| from distutils.util import strtobool | |
| from typing import Optional | |
| import os | |
| import argparse | |
| import gc | |
| import os | |
| import random | |
| import re | |
| import time | |
| from distutils.util import strtobool | |
| import spaces | |
| import pandas as pd | |
| import gc | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from PIL import Image | |
| from src.attn_utils.attn_utils import AttentionAdapter, AttnCollector | |
| from src.attn_utils.flux_attn_processor import NewFluxAttnProcessor2_0 | |
| from src.attn_utils.seq_aligner import get_refinement_mapper | |
| from src.callback.callback_fn import CallbackAll | |
| from src.inversion.inverse import get_inversed_latent_list | |
| from src.inversion.scheduling_flow_inverse import \ | |
| FlowMatchEulerDiscreteForwardScheduler | |
| from src.pipeline.flux_pipeline import NewFluxPipeline | |
| from src.transformer_utils.transformer_utils import (FeatureCollector, | |
| FeatureReplace) | |
| from src.utils import (find_token_id_differences, find_word_token_indices, | |
| get_flux_pipeline, mask_decode, mask_interpolate) | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| pipe = get_flux_pipeline(pipeline_class=NewFluxPipeline) | |
| pipe = pipe.to("cuda") | |
| def fix_seed(random_seed): | |
| """ | |
| fix seed to control any randomness from a code | |
| (enable stability of the experiments' results.) | |
| """ | |
| torch.manual_seed(random_seed) | |
| torch.cuda.manual_seed(random_seed) | |
| torch.cuda.manual_seed_all(random_seed) # if use multi-GPU | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| np.random.seed(random_seed) | |
| random.seed(random_seed) | |
| def infer( | |
| input_image: Union[str, Image.Image], # ⬅️ Main UI (uploaded image) | |
| target_prompt: Union[str, List[str]] = '', # ⬅️ Main UI (text prompt) | |
| source_prompt: Union[str, List[str]] = '', # ⬅️ Advanced accordion | |
| seed: int = 0, # ⬅️ Advanced accordion | |
| ca_steps: int = 10, # ⬅️ Advanced accordion | |
| sa_steps: int = 7, # ⬅️ Advanced accordion | |
| feature_steps: int = 5, # ⬅️ Advanced accordion | |
| attn_topk: int = 20, # ⬅️ Advanced accordion | |
| mask_image: Optional[Image.Image] = None, # ⬅️ Advanced (optional upload) | |
| # Everything below is backend-related or defaults, not exposed in UI | |
| blend_word: str = '', | |
| results_dir: str = 'results', | |
| model: str = 'flux', | |
| ca_attn_layer_from: int = 13, | |
| ca_attn_layer_to: int = 45, | |
| sa_attn_layer_from: int = 20, | |
| sa_attn_layer_to: int = 45, | |
| feature_layer_from: int = 13, | |
| feature_layer_to: int = 20, | |
| flow_steps: int = 7, | |
| step_start: int = 0, | |
| num_inference_steps: int = 28, | |
| guidance_scale: float = 3.5, | |
| text_scale: float = 4.0, | |
| mid_step_index: int = 14, | |
| use_mask: bool = True, | |
| use_ca_mask: bool = True, | |
| mask_steps: int = 18, | |
| mask_dilation: int = 3, | |
| mask_nbins: int = 128 | |
| ): | |
| if isinstance(mask_image, Image.Image): | |
| # Ensure mask is single channel | |
| if mask_image.mode != "L": | |
| mask_image = mask_image.convert("L") | |
| fix_seed(seed) | |
| device = torch.device('cuda') | |
| attn_proc = NewFluxAttnProcessor2_0 | |
| layer_order = range(57) | |
| ca_layer_list = layer_order[ca_attn_layer_from:ca_attn_layer_to] | |
| sa_layer_list = layer_order[feature_layer_to:sa_attn_layer_to] | |
| feature_layer_list = layer_order[feature_layer_from:feature_layer_to] | |
| source_img = input_image.resize((1024, 1024)).convert("RGB") | |
| #img_base_name = os.path.splitext(img_path)[0].split('/')[-1] | |
| result_img_dir = f"{results_dir}/seed_{seed}/{target_prompt}" | |
| source_prompt = source_prompt | |
| target_prompt = target_prompt | |
| prompts = [source_prompt, target_prompt] | |
| mask_path=mask_image | |
| print(prompts) | |
| mask = None | |
| if use_mask: | |
| use_mask = True | |
| if mask_path is not None: | |
| mask = mask_path | |
| mask = torch.tensor(np.array(mask)).bool() | |
| mask = mask.to(device) | |
| # Increase the latent blending steps if the ground truth mask is used. | |
| mask_steps = int(num_inference_steps * 0.9) | |
| source_ca_index = None | |
| target_ca_index = None | |
| use_ca_mask = False | |
| elif use_ca_mask and source_prompt: | |
| mask = None | |
| if blend_word and blend_word in source_prompt: | |
| editing_source_token_index = find_word_token_indices(source_prompt, blend_word, pipe.tokenizer_2) | |
| editing_target_token_index = None | |
| else: | |
| editing_tokens_info = find_token_id_differences(*prompts, pipe.tokenizer_2) | |
| editing_source_token_index = editing_tokens_info['prompt_1']['index'] | |
| editing_target_token_index = editing_tokens_info['prompt_2']['index'] | |
| use_ca_mask = True | |
| if editing_source_token_index: | |
| source_ca_index = editing_source_token_index | |
| target_ca_index = None | |
| elif editing_target_token_index: | |
| source_ca_index = None | |
| target_ca_index = editing_target_token_index | |
| else: | |
| source_ca_index = None | |
| target_ca_index = None | |
| use_ca_mask = False | |
| else: | |
| source_ca_index = None | |
| target_ca_index = None | |
| use_ca_mask = False | |
| else: | |
| use_mask = False | |
| use_ca_mask = False | |
| source_ca_index = None | |
| target_ca_index = None | |
| if source_prompt: | |
| # Use I2T-CA injection | |
| mappers, alphas = get_refinement_mapper(prompts, pipe.tokenizer_2, max_len=512) | |
| mappers = mappers.to(device=device) | |
| alphas = alphas.to(device=device, dtype=pipe.dtype) | |
| alphas = alphas[:, None, None, :] | |
| attn_adj_from = 1 | |
| else: | |
| # Not use I2T-CA injection | |
| mappers = None | |
| alphas = None | |
| ca_steps = 0 | |
| attn_adj_from=3 | |
| feature_steps = feature_steps | |
| attn_controller = AttentionAdapter( | |
| ca_layer_list=ca_layer_list, | |
| sa_layer_list=sa_layer_list, | |
| ca_steps=ca_steps, | |
| sa_steps=sa_steps, | |
| method='replace_topk', | |
| topk=attn_topk, | |
| text_scale=text_scale, | |
| mappers=mappers, | |
| alphas=alphas, | |
| attn_adj_from=attn_adj_from, | |
| save_source_ca=source_ca_index is not None, | |
| save_target_ca=target_ca_index is not None, | |
| ) | |
| attn_collector = AttnCollector( | |
| transformer=pipe.transformer, | |
| controller=attn_controller, | |
| attn_processor_class=NewFluxAttnProcessor2_0, | |
| ) | |
| feature_controller = FeatureReplace( | |
| layer_list=feature_layer_list, | |
| feature_steps=feature_steps, | |
| ) | |
| feature_collector = FeatureCollector( | |
| transformer=pipe.transformer, | |
| controller=feature_controller, | |
| ) | |
| num_prompts=len(prompts) | |
| shape = (1, 16, 128, 128) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| latents = randn_tensor(shape, device=device, generator=generator) | |
| latents = pipe._pack_latents(latents, *latents.shape) | |
| attn_collector.restore_orig_attention() | |
| feature_collector.restore_orig_transformer() | |
| t0 = time.perf_counter() | |
| inv_latents = get_inversed_latent_list( | |
| pipe, | |
| source_img, | |
| random_noise=latents, | |
| num_inference_steps=num_inference_steps, | |
| backward_method="ode", | |
| use_prompt_for_inversion=False, | |
| guidance_scale_for_inversion=0, | |
| prompt_for_inversion='', | |
| flow_steps=flow_steps, | |
| ) | |
| source_latents = inv_latents[::-1] | |
| target_latents = inv_latents[::-1] | |
| attn_collector.register_attention_control() | |
| feature_collector.register_transformer_control() | |
| callback_fn = CallbackAll( | |
| latents=source_latents, | |
| attn_collector=attn_collector, | |
| feature_collector=feature_collector, | |
| feature_inject_steps=feature_steps, | |
| mid_step_index=mid_step_index, | |
| step_start=step_start, | |
| use_mask=use_mask, | |
| use_ca_mask=use_ca_mask, | |
| source_ca_index=source_ca_index, | |
| target_ca_index=target_ca_index, | |
| mask_kwargs={'dilation': mask_dilation}, | |
| mask_steps=mask_steps, | |
| mask=mask, | |
| ) | |
| init_latent = target_latents[step_start] | |
| init_latent = init_latent.repeat(num_prompts, 1, 1) | |
| init_latent[0] = source_latents[mid_step_index] | |
| os.makedirs(result_img_dir, exist_ok=True) | |
| pipe.scheduler = FlowMatchEulerDiscreteForwardScheduler.from_config( | |
| pipe.scheduler.config, | |
| step_start=step_start, | |
| margin_index_from_image=0 | |
| ) | |
| attn_controller.reset() | |
| feature_controller.reset() | |
| attn_controller.text_scale = text_scale | |
| attn_controller.cur_step = step_start | |
| feature_controller.cur_step = step_start | |
| with torch.no_grad(): | |
| images = pipe( | |
| prompts, | |
| latents=init_latent, | |
| num_images_per_prompt=1, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| callback_on_step_end=callback_fn, | |
| mid_step_index=mid_step_index, | |
| step_start=step_start, | |
| callback_on_step_end_tensor_inputs=['latents'], | |
| ).images | |
| t1 = time.perf_counter() | |
| print(f"Done in {t1 - t0:.1f}s.") | |
| source_img_path = os.path.join(result_img_dir, f"source.png") | |
| source_img.save(source_img_path) | |
| final_image=input_image | |
| for i, img in enumerate(images[1:]): | |
| target_img_path = os.path.join(result_img_dir, f"target_{i}.png") | |
| img.save(target_img_path) | |
| final_image=img | |
| target_text_path = os.path.join(result_img_dir, f"target_prompts.txt") | |
| with open(target_text_path, 'w') as file: | |
| file.write(target_prompt + '\n') | |
| source_text_path = os.path.join(result_img_dir, f"source_prompt.txt") | |
| with open(source_text_path, 'w') as file: | |
| file.write(source_prompt + '\n') | |
| images = [source_img] + images | |
| fs=3 | |
| n = len(images) | |
| fig, ax = plt.subplots(1, n, figsize=(n*fs, 1*fs)) | |
| for i, img in enumerate(images): | |
| ax[i].imshow(img) | |
| ax[0].set_title('source') | |
| ax[1].set_title(source_prompt, fontsize=7) | |
| ax[2].set_title(target_prompt, fontsize=7) | |
| overall_img_path = os.path.join(result_img_dir, f"overall.png") | |
| plt.savefig(overall_img_path, bbox_inches='tight') | |
| plt.close() | |
| mask_save_dir = os.path.join(result_img_dir, f"mask") | |
| os.makedirs(mask_save_dir, exist_ok=True) | |
| if use_ca_mask: | |
| ca_mask_path = os.path.join(mask_save_dir, f"mask_ca.png") | |
| mask_img = Image.fromarray((callback_fn.mask.cpu().float().numpy() * 255).astype(np.uint8)).convert('L') | |
| mask_img.save(ca_mask_path) | |
| del inv_latents | |
| del init_latent | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| import shutil | |
| shutil.rmtree(result_img_dir) | |
| shutil.rmtree(results_dir) | |
| return final_image, seed, gr.Button(visible=True) | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| MAX_SEED = np.iinfo(np.int32).max | |
| def infer_example(input_image, target_prompt, source_prompt, seed, ca_steps, sa_steps, feature_steps, attn_topk, mask_image=None): | |
| img, seed, _ = infer( | |
| input_image=input_image, | |
| target_prompt=target_prompt, | |
| source_prompt=source_prompt, | |
| seed=seed, | |
| ca_steps=ca_steps, | |
| sa_steps=sa_steps, | |
| feature_steps=feature_steps, | |
| attn_topk=attn_topk, | |
| mask_image=mask_image | |
| ) | |
| return img, seed | |
| with gr.Blocks() as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("""# ReFlex | |
| Text-Guided Editing of Real Images in Rectified Flow via Mid-Step Feature Extraction and Attention Adaptation | |
| [[blog]](https://wlaud1001.github.io/ReFlex/) | [[Github]](https://github.com/wlaud1001/ReFlex) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Upload the image for editing", type="pil") | |
| mask_image = gr.Image(label="Upload optional mask", type="pil") | |
| with gr.Row(): | |
| target_prompt = gr.Text( | |
| label="Target Prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Describe the Edited Image", | |
| container=False, | |
| ) | |
| with gr.Column(): | |
| source_prompt = gr.Text( | |
| label="Source Prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter source prompt (optional) : Describe the Input Image", | |
| container=False, | |
| ) | |
| run_button = gr.Button("Run", scale=10) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| ca_steps = gr.Slider( | |
| label="Cross-Attn (CA) Steps", | |
| minimum=0, | |
| maximum=20, | |
| step=1, | |
| value=10 | |
| ) | |
| sa_steps = gr.Slider( | |
| label="Self-Attn (SA) Steps", | |
| minimum=0, | |
| maximum=20, | |
| step=1, | |
| value=7 | |
| ) | |
| feature_steps = gr.Slider( | |
| label="Feature Injection Steps", | |
| minimum=0, | |
| maximum=20, | |
| step=1, | |
| value=5 | |
| ) | |
| attn_topk = gr.Slider( | |
| label="Attention Top-K", | |
| minimum=1, | |
| maximum=64, | |
| step=1, | |
| value=20 | |
| ) | |
| with gr.Column(): | |
| result = gr.Image(label="Result", show_label=False, interactive=False) | |
| reuse_button = gr.Button("Reuse this image", visible=False) | |
| examples = gr.Examples( | |
| examples=[ | |
| # 2. Without mask | |
| [ | |
| "data/images/bear.jpeg", | |
| "an image of Paddington the bear", | |
| "", | |
| 0, 0, 12, 7, 20, | |
| None | |
| ], | |
| # 3. Without mask | |
| [ | |
| "data/images/bird_painting.jpg", | |
| "a photo of an eagle in the sky", | |
| "", | |
| 0, 0, 12, 7, 20, | |
| None | |
| ], | |
| [ | |
| "data/images/dancing.jpeg", | |
| "a couple of silver robots dancing in the garden", | |
| "", | |
| 0, 0, 12, 7, 20, | |
| None | |
| ], | |
| [ | |
| "data/images/real_karate.jpeg", | |
| "a silver robot in the snow", | |
| "", | |
| 0, 0, 12, 7, 20, | |
| None | |
| ], | |
| [ | |
| "data/images/woman_book.jpg", | |
| "a woman sitting in the grass with a laptop", | |
| "a woman sitting in the grass with a book", | |
| 0, 10, 7, 5, 20, | |
| None | |
| ], | |
| [ | |
| "data/images/statue.jpg", | |
| "photo of a statue in side view", | |
| "photo of a statue in front view", | |
| 0, 10, 7, 5, 60, | |
| None | |
| ], | |
| [ | |
| "data/images/tennis.jpg", | |
| "a iron woman robot in a black tank top and pink shorts is about to hit a tennis ball", | |
| "a woman in a black tank top and pink shorts is about to hit a tennis ball", | |
| 0, 10, 7, 5, 20, | |
| None | |
| ], | |
| [ | |
| "data/images/owl_heart.jpg", | |
| "a cartoon painting of a cute owl with a circle on its body", | |
| "a cartoon painting of a cute owl with a heart on its body", | |
| 0, 10, 7, 5, 20, | |
| None | |
| ], | |
| [ | |
| "data/images/girl_mountain.jpg", | |
| "a woman with her arms outstretched in front of the NewYork", | |
| "a woman with her arms outstretched on top of a mountain", | |
| 0, 10, 7, 5, 20, | |
| "data/masks/girl_mountain.jpg" | |
| ], | |
| [ | |
| "data/images/santa.jpg", | |
| "the christmas illustration of a santa's angry face", | |
| "the christmas illustration of a santa's laughing face", | |
| 0, 10, 7, 5, 20, | |
| "data/masks/santa.jpg" | |
| ], | |
| [ | |
| "data/images/cat_mirror.jpg", | |
| "a tiger sitting next to a mirror", | |
| "a cat sitting next to a mirror", | |
| 0, 10, 7, 5, 20, | |
| "data/masks/cat_mirror.jpg" | |
| ], | |
| ], | |
| inputs=[ | |
| input_image, | |
| target_prompt, | |
| source_prompt, | |
| seed, | |
| ca_steps, | |
| sa_steps, | |
| feature_steps, | |
| attn_topk, | |
| mask_image | |
| ], | |
| outputs=[result, seed], | |
| fn=infer_example, | |
| cache_examples="lazy" | |
| ) | |
| gr.on( | |
| triggers=[run_button.click, target_prompt.submit], | |
| fn=infer, | |
| inputs=[ | |
| input_image, | |
| target_prompt, | |
| source_prompt, | |
| seed, | |
| ca_steps, | |
| sa_steps, | |
| feature_steps, | |
| attn_topk, | |
| mask_image | |
| ], | |
| outputs=[result, seed, reuse_button] | |
| ) | |
| reuse_button.click( | |
| fn=lambda image: image, | |
| inputs=[result], | |
| outputs=[input_image] | |
| ) | |
| demo.launch(share=True, debug=True) |