|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from inversion_utils import get_ddpm_inversion_scheduler, create_xts | 
					
						
						|  | from config import get_config, get_num_steps_actual | 
					
						
						|  | from functools import partial | 
					
						
						|  | from compel import Compel, ReturnedEmbeddingsType | 
					
						
						|  |  | 
					
						
						|  | from model_handler import MODELS | 
					
						
						|  |  | 
					
						
						|  | class Object(object): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | args = Object() | 
					
						
						|  | args.images_paths = None | 
					
						
						|  | args.images_folder = None | 
					
						
						|  | args.force_use_cpu = False | 
					
						
						|  | args.folder_name = 'test_measure_time' | 
					
						
						|  | args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml' | 
					
						
						|  | args.save_intermediate_results = False | 
					
						
						|  | args.batch_size = None | 
					
						
						|  | args.skip_p_to_p = True | 
					
						
						|  | args.only_p_to_p = False | 
					
						
						|  | args.fp16 = False | 
					
						
						|  | args.prompts_file = 'dataset_measure_time/dataset.json' | 
					
						
						|  | args.images_in_prompts_file = None | 
					
						
						|  | args.seed = 986 | 
					
						
						|  | args.time_measure_n = 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | assert ( | 
					
						
						|  | args.batch_size is None or args.save_intermediate_results is False | 
					
						
						|  | ), "save_intermediate_results is not implemented for batch_size > 1" | 
					
						
						|  |  | 
					
						
						|  | generator = None | 
					
						
						|  | device = "cuda" | 
					
						
						|  |  | 
					
						
						|  | pipeline = MODELS.base_pipe | 
					
						
						|  |  | 
					
						
						|  | config = get_config(args) | 
					
						
						|  |  | 
					
						
						|  | compel_proc = Compel( | 
					
						
						|  | tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , | 
					
						
						|  | text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], | 
					
						
						|  | returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | 
					
						
						|  | requires_pooled=[False, True] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def run( | 
					
						
						|  | input_image:Image, | 
					
						
						|  | src_prompt:str, | 
					
						
						|  | tgt_prompt:str, | 
					
						
						|  | seed:int, | 
					
						
						|  | w1:float, | 
					
						
						|  | w2:float, | 
					
						
						|  | num_steps:int, | 
					
						
						|  | start_step:int, | 
					
						
						|  | guidance_scale:float, | 
					
						
						|  | ): | 
					
						
						|  | generator = torch.Generator().manual_seed(seed) | 
					
						
						|  |  | 
					
						
						|  | config.num_steps_inversion = num_steps | 
					
						
						|  | config.step_start = start_step | 
					
						
						|  | num_steps_actual = get_num_steps_actual(config) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | num_steps_inversion = config.num_steps_inversion | 
					
						
						|  | denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion | 
					
						
						|  |  | 
					
						
						|  | timesteps, num_inference_steps = retrieve_timesteps( | 
					
						
						|  | pipeline.scheduler, num_steps_inversion, device, None | 
					
						
						|  | ) | 
					
						
						|  | timesteps, num_inference_steps = pipeline.get_timesteps( | 
					
						
						|  | num_inference_steps=num_inference_steps, | 
					
						
						|  | denoising_start=denoising_start, | 
					
						
						|  | strength=0, | 
					
						
						|  | device=device, | 
					
						
						|  | ) | 
					
						
						|  | timesteps = timesteps.type(torch.int64) | 
					
						
						|  |  | 
					
						
						|  | timesteps = [torch.tensor(t) for t in timesteps.tolist()] | 
					
						
						|  | timesteps_len = len(timesteps) | 
					
						
						|  | config.step_start = start_step + num_steps_actual - timesteps_len | 
					
						
						|  | num_steps_actual = timesteps_len | 
					
						
						|  | config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5] | 
					
						
						|  |  | 
					
						
						|  | pipeline.__call__ = partial( | 
					
						
						|  | pipeline.__call__, | 
					
						
						|  | num_inference_steps=num_steps_inversion, | 
					
						
						|  | guidance_scale=guidance_scale, | 
					
						
						|  | generator=generator, | 
					
						
						|  | denoising_start=denoising_start, | 
					
						
						|  | strength=0, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | x_0_image = input_image | 
					
						
						|  | x_0 = encode_image(x_0_image, pipeline) | 
					
						
						|  | x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False) | 
					
						
						|  | x_ts = [xt.to(dtype=torch.float16) for xt in x_ts] | 
					
						
						|  | latents = [x_ts[0]] | 
					
						
						|  | x_ts_c_hat = [None] | 
					
						
						|  | config.ws1 = [w1] * num_steps_actual | 
					
						
						|  | config.ws2 = [w2] * num_steps_actual | 
					
						
						|  | pipeline.scheduler = get_ddpm_inversion_scheduler( | 
					
						
						|  | pipeline.scheduler, | 
					
						
						|  | config.step_function, | 
					
						
						|  | config, | 
					
						
						|  | timesteps, | 
					
						
						|  | config.save_timesteps, | 
					
						
						|  | latents, | 
					
						
						|  | x_ts, | 
					
						
						|  | x_ts_c_hat, | 
					
						
						|  | args.save_intermediate_results, | 
					
						
						|  | pipeline, | 
					
						
						|  | x_0, | 
					
						
						|  | v1s_images := [], | 
					
						
						|  | v2s_images := [], | 
					
						
						|  | deltas_images := [], | 
					
						
						|  | v1_x0s := [], | 
					
						
						|  | v2_x0s := [], | 
					
						
						|  | deltas_x0s := [], | 
					
						
						|  | "res12", | 
					
						
						|  | image_name="im_name", | 
					
						
						|  | time_measure_n=args.time_measure_n, | 
					
						
						|  | ) | 
					
						
						|  | latent = latents[0].expand(3, -1, -1, -1) | 
					
						
						|  | prompt = [src_prompt, src_prompt, tgt_prompt] | 
					
						
						|  | conditioning, pooled = compel_proc(prompt) | 
					
						
						|  | image = pipeline.__call__( | 
					
						
						|  | image=latent, | 
					
						
						|  | prompt_embeds=conditioning, | 
					
						
						|  | pooled_prompt_embeds=pooled, | 
					
						
						|  | eta=1, | 
					
						
						|  | ).images | 
					
						
						|  | return image[2] | 
					
						
						|  |  | 
					
						
						|  | def encode_image(image, pipe): | 
					
						
						|  | image = pipe.image_processor.preprocess(image) | 
					
						
						|  | originDtype = pipe.dtype | 
					
						
						|  | image = image.to(device=device, dtype=originDtype) | 
					
						
						|  |  | 
					
						
						|  | if pipe.vae.config.force_upcast: | 
					
						
						|  | image = image.float() | 
					
						
						|  | pipe.vae.to(dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(generator, list): | 
					
						
						|  | init_latents = [ | 
					
						
						|  | retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i]) | 
					
						
						|  | for i in range(1) | 
					
						
						|  | ] | 
					
						
						|  | init_latents = torch.cat(init_latents, dim=0) | 
					
						
						|  | else: | 
					
						
						|  | init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator) | 
					
						
						|  |  | 
					
						
						|  | if pipe.vae.config.force_upcast: | 
					
						
						|  | pipe.vae.to(originDtype) | 
					
						
						|  |  | 
					
						
						|  | init_latents = init_latents.to(originDtype) | 
					
						
						|  | init_latents = pipe.vae.config.scaling_factor * init_latents | 
					
						
						|  |  | 
					
						
						|  | return init_latents.to(dtype=torch.float16) | 
					
						
						|  |  | 
					
						
						|  | def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None): | 
					
						
						|  |  | 
					
						
						|  | if denoising_start is None: | 
					
						
						|  | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | 
					
						
						|  | t_start = max(num_inference_steps - init_timestep, 0) | 
					
						
						|  | else: | 
					
						
						|  | t_start = 0 | 
					
						
						|  |  | 
					
						
						|  | timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if denoising_start is not None: | 
					
						
						|  | discrete_timestep_cutoff = int( | 
					
						
						|  | round( | 
					
						
						|  | pipe.scheduler.config.num_train_timesteps | 
					
						
						|  | - (denoising_start * pipe.scheduler.config.num_train_timesteps) | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() | 
					
						
						|  | if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | num_inference_steps = num_inference_steps + 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | timesteps = timesteps[-num_inference_steps:] | 
					
						
						|  | return timesteps, num_inference_steps | 
					
						
						|  |  | 
					
						
						|  | return timesteps, num_inference_steps - t_start | 
					
						
						|  |  |