| 
							 | 
						from ml_collections import config_dict | 
					
					
						
						| 
							 | 
						import yaml | 
					
					
						
						| 
							 | 
						from diffusers.schedulers import ( | 
					
					
						
						| 
							 | 
						    DDIMScheduler, | 
					
					
						
						| 
							 | 
						    EulerAncestralDiscreteScheduler, | 
					
					
						
						| 
							 | 
						    EulerDiscreteScheduler, | 
					
					
						
						| 
							 | 
						    DDPMScheduler, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						from inversion_utils import ( | 
					
					
						
						| 
							 | 
						    deterministic_ddim_step, | 
					
					
						
						| 
							 | 
						    deterministic_ddpm_step, | 
					
					
						
						| 
							 | 
						    deterministic_euler_step, | 
					
					
						
						| 
							 | 
						    deterministic_non_ancestral_euler_step, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						BREAKDOWNS = ["x_t_c_hat", "x_t_hat_c", "no_breakdown", "x_t_hat_c_with_zeros"] | 
					
					
						
						| 
							 | 
						SCHEDULERS = ["ddpm", "ddim", "euler", "euler_non_ancestral"] | 
					
					
						
						| 
							 | 
						MODELS = [ | 
					
					
						
						| 
							 | 
						    "stabilityai/sdxl-turbo", | 
					
					
						
						| 
							 | 
						    "stabilityai/stable-diffusion-xl-base-1.0", | 
					
					
						
						| 
							 | 
						    "CompVis/stable-diffusion-v1-4", | 
					
					
						
						| 
							 | 
						] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_num_steps_actual(cfg): | 
					
					
						
						| 
							 | 
						    return ( | 
					
					
						
						| 
							 | 
						        cfg.num_steps_inversion | 
					
					
						
						| 
							 | 
						        - cfg.step_start | 
					
					
						
						| 
							 | 
						        + (1 if cfg.clean_step_timestep > 0 else 0) | 
					
					
						
						| 
							 | 
						        if cfg.timesteps is None | 
					
					
						
						| 
							 | 
						        else len(cfg.timesteps) + (1 if cfg.clean_step_timestep > 0 else 0) | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_config(args): | 
					
					
						
						| 
							 | 
						    if args.config_from_file and args.config_from_file != "": | 
					
					
						
						| 
							 | 
						        with open(args.config_from_file, "r") as f: | 
					
					
						
						| 
							 | 
						            cfg = config_dict.ConfigDict(yaml.safe_load(f)) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        num_steps_actual = get_num_steps_actual(cfg) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        cfg = config_dict.ConfigDict() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        cfg.seed = 2 | 
					
					
						
						| 
							 | 
						        cfg.self_r = 0.5 | 
					
					
						
						| 
							 | 
						        cfg.cross_r = 0.9 | 
					
					
						
						| 
							 | 
						        cfg.eta = 1 | 
					
					
						
						| 
							 | 
						        cfg.scheduler_type = SCHEDULERS[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        cfg.num_steps_inversion = 50   | 
					
					
						
						| 
							 | 
						        cfg.step_start = 20 | 
					
					
						
						| 
							 | 
						        cfg.timesteps = None | 
					
					
						
						| 
							 | 
						        cfg.noise_timesteps = None | 
					
					
						
						| 
							 | 
						        num_steps_actual = get_num_steps_actual(cfg) | 
					
					
						
						| 
							 | 
						        cfg.ws1 = [2] * num_steps_actual | 
					
					
						
						| 
							 | 
						        cfg.ws2 = [1] * num_steps_actual | 
					
					
						
						| 
							 | 
						        cfg.real_cfg_scale = 0 | 
					
					
						
						| 
							 | 
						        cfg.real_cfg_scale_save = 0 | 
					
					
						
						| 
							 | 
						        cfg.breakdown = BREAKDOWNS[1] | 
					
					
						
						| 
							 | 
						        cfg.noise_shift_delta = 1 | 
					
					
						
						| 
							 | 
						        cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        cfg.clean_step_timestep = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        cfg.model = MODELS[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if cfg.scheduler_type == "ddim": | 
					
					
						
						| 
							 | 
						        cfg.scheduler_class = DDIMScheduler | 
					
					
						
						| 
							 | 
						        cfg.step_function = deterministic_ddim_step | 
					
					
						
						| 
							 | 
						    elif cfg.scheduler_type == "ddpm": | 
					
					
						
						| 
							 | 
						        cfg.scheduler_class = DDPMScheduler | 
					
					
						
						| 
							 | 
						        cfg.step_function = deterministic_ddpm_step | 
					
					
						
						| 
							 | 
						    elif cfg.scheduler_type == "euler": | 
					
					
						
						| 
							 | 
						        cfg.scheduler_class = EulerAncestralDiscreteScheduler | 
					
					
						
						| 
							 | 
						        cfg.step_function = deterministic_euler_step | 
					
					
						
						| 
							 | 
						    elif cfg.scheduler_type == "euler_non_ancestral": | 
					
					
						
						| 
							 | 
						        cfg.scheduler_class = EulerDiscreteScheduler | 
					
					
						
						| 
							 | 
						        cfg.step_function = deterministic_non_ancestral_euler_step | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise ValueError(f"Unknown scheduler type: {cfg.scheduler_type}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with cfg.ignore_type(): | 
					
					
						
						| 
							 | 
						        if isinstance(cfg.max_norm_zs, (int, float)): | 
					
					
						
						| 
							 | 
						            cfg.max_norm_zs = [cfg.max_norm_zs] * num_steps_actual | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if isinstance(cfg.ws1, (int, float)): | 
					
					
						
						| 
							 | 
						            cfg.ws1 = [cfg.ws1] * num_steps_actual | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if isinstance(cfg.ws2, (int, float)): | 
					
					
						
						| 
							 | 
						            cfg.ws2 = [cfg.ws2] * num_steps_actual | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if not hasattr(cfg, "update_eta"): | 
					
					
						
						| 
							 | 
						        cfg.update_eta = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if not hasattr(cfg, "save_timesteps"): | 
					
					
						
						| 
							 | 
						        cfg.save_timesteps = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if not hasattr(cfg, "scheduler_timesteps"): | 
					
					
						
						| 
							 | 
						        cfg.scheduler_timesteps = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert ( | 
					
					
						
						| 
							 | 
						        cfg.scheduler_type == "ddpm" or cfg.timesteps is None | 
					
					
						
						| 
							 | 
						    ), "timesteps must be None for ddim/euler" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5] | 
					
					
						
						| 
							 | 
						    assert ( | 
					
					
						
						| 
							 | 
						        len(cfg.max_norm_zs) == num_steps_actual | 
					
					
						
						| 
							 | 
						    ), f"len(cfg.max_norm_zs) ({len(cfg.max_norm_zs)}) != num_steps_actual ({num_steps_actual})" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert ( | 
					
					
						
						| 
							 | 
						        len(cfg.ws1) == num_steps_actual | 
					
					
						
						| 
							 | 
						    ), f"len(cfg.ws1) ({len(cfg.ws1)}) != num_steps_actual ({num_steps_actual})" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert ( | 
					
					
						
						| 
							 | 
						        len(cfg.ws2) == num_steps_actual | 
					
					
						
						| 
							 | 
						    ), f"len(cfg.ws2) ({len(cfg.ws2)}) != num_steps_actual ({num_steps_actual})" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert cfg.noise_timesteps is None or len(cfg.noise_timesteps) == ( | 
					
					
						
						| 
							 | 
						        num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0) | 
					
					
						
						| 
							 | 
						    ), f"len(cfg.noise_timesteps) ({len(cfg.noise_timesteps)}) != num_steps_actual ({num_steps_actual})" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert cfg.save_timesteps is None or len(cfg.save_timesteps) == ( | 
					
					
						
						| 
							 | 
						        num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0) | 
					
					
						
						| 
							 | 
						    ), f"len(cfg.save_timesteps) ({len(cfg.save_timesteps)}) != num_steps_actual ({num_steps_actual})" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return cfg | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_config_name(config, args): | 
					
					
						
						| 
							 | 
						    if args.folder_name is not None and args.folder_name != "": | 
					
					
						
						| 
							 | 
						        return args.folder_name | 
					
					
						
						| 
							 | 
						    timesteps_str = ( | 
					
					
						
						| 
							 | 
						        f"step_start {config.step_start}" | 
					
					
						
						| 
							 | 
						        if config.timesteps is None | 
					
					
						
						| 
							 | 
						        else f"timesteps {config.timesteps}" | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    return f"""\ | 
					
					
						
						| 
							 | 
						ws1 {config.ws1[0]} ws2 {config.ws2[0]} real_cfg_scale {config.real_cfg_scale} {timesteps_str} \ | 
					
					
						
						| 
							 | 
						real_cfg_scale_save {config.real_cfg_scale_save} seed {config.seed} max_norm_zs {config.max_norm_zs[-1]} noise_shift_delta {config.noise_shift_delta} \ | 
					
					
						
						| 
							 | 
						scheduler_type {config.scheduler_type} fp16 {args.fp16}\ | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						
 |