Spaces:
Runtime error
Runtime error
| from typing import Optional, List | |
| import numpy as np | |
| import torch | |
| from cv2 import dilate | |
| from diffusers import DDIMScheduler, StableDiffusionPipeline | |
| from tqdm import tqdm | |
| from src.attention_based_segmentation import Segmentor | |
| from src.attention_utils import show_cross_attention | |
| from src.prompt_to_prompt_controllers import DummyController, AttentionStore | |
| def get_stable_diffusion_model(args): | |
| device = torch.device(f'cuda:{args.gpu_id}') if torch.cuda.is_available() else torch.device('cpu') | |
| if args.real_image_path != "": | |
| scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) | |
| ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token, scheduler=scheduler).to(device) | |
| else: | |
| ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token).to(device) | |
| return ldm_stable | |
| def get_stable_diffusion_config(args): | |
| return { | |
| "low_resource": args.low_resource, | |
| "num_diffusion_steps": args.num_diffusion_steps, | |
| "guidance_scale": args.guidance_scale, | |
| "max_num_words": args.max_num_words | |
| } | |
| def generate_original_image(args, ldm_stable, ldm_stable_config, prompts, latent, uncond_embeddings): | |
| g_cpu = torch.Generator(device=ldm_stable.device).manual_seed(args.seed) | |
| controller = AttentionStore(ldm_stable_config["low_resource"]) | |
| diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, generator=g_cpu) | |
| image, x_t, orig_all_latents, _ = diffusion_model_wrapper.forward(prompts, | |
| latent=latent, | |
| uncond_embeddings=uncond_embeddings) | |
| orig_mask = Segmentor(controller, prompts, args.num_segments, args.background_segment_threshold, background_nouns=args.background_nouns)\ | |
| .get_background_mask(args.prompt.split(' ').index("{word}") + 1) | |
| average_attention = controller.get_average_attention() | |
| return image, x_t, orig_all_latents, orig_mask, average_attention | |
| class DiffusionModelWrapper: | |
| def __init__(self, args, model, model_config, controller=None, prompt_mixing=None, generator=None): | |
| self.args = args | |
| self.model = model | |
| self.model_config = model_config | |
| self.controller = controller | |
| if self.controller is None: | |
| self.controller = DummyController() | |
| self.prompt_mixing = prompt_mixing | |
| self.device = model.device | |
| self.generator = generator | |
| self.height = 512 | |
| self.width = 512 | |
| self.diff_step = 0 | |
| self.register_attention_control() | |
| def diffusion_step(self, latents, context, t, other_context=None): | |
| if self.model_config["low_resource"]: | |
| self.uncond_pred = True | |
| noise_pred_uncond = self.model.unet(latents, t, encoder_hidden_states=(context[0], None))["sample"] | |
| self.uncond_pred = False | |
| noise_prediction_text = self.model.unet(latents, t, encoder_hidden_states=(context[1], other_context))["sample"] | |
| else: | |
| latents_input = torch.cat([latents] * 2) | |
| noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=(context, other_context))["sample"] | |
| noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_prediction_text - noise_pred_uncond) | |
| latents = self.model.scheduler.step(noise_pred, t, latents)["prev_sample"] | |
| latents = self.controller.step_callback(latents) | |
| return latents | |
| def latent2image(self, latents): | |
| latents = 1 / 0.18215 * latents | |
| image = self.model.vae.decode(latents)['sample'] | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy() | |
| image = (image * 255).astype(np.uint8) | |
| return image | |
| def init_latent(self, latent, batch_size): | |
| if latent is None: | |
| latent = torch.randn( | |
| (1, self.model.unet.in_channels, self.height // 8, self.width // 8), | |
| generator=self.generator, device=self.model.device | |
| ) | |
| latents = latent.expand(batch_size, self.model.unet.in_channels, self.height // 8, self.width // 8).to(self.device) | |
| return latent, latents | |
| def register_attention_control(self): | |
| def ca_forward(model_self, place_in_unet): | |
| to_out = model_self.to_out | |
| if type(to_out) is torch.nn.modules.container.ModuleList: | |
| to_out = model_self.to_out[0] | |
| else: | |
| to_out = model_self.to_out | |
| def forward(x, context=None, mask=None): | |
| batch_size, sequence_length, dim = x.shape | |
| h = model_self.heads | |
| q = model_self.to_q(x) | |
| is_cross = context is not None | |
| context = context if is_cross else (x, None) | |
| k = model_self.to_k(context[0]) | |
| if is_cross and self.prompt_mixing is not None: | |
| v_context = self.prompt_mixing.get_context_for_v(self.diff_step, context[0], context[1]) | |
| v = model_self.to_v(v_context) | |
| else: | |
| v = model_self.to_v(context[0]) | |
| q = model_self.reshape_heads_to_batch_dim(q) | |
| k = model_self.reshape_heads_to_batch_dim(k) | |
| v = model_self.reshape_heads_to_batch_dim(v) | |
| sim = torch.einsum("b i d, b j d -> b i j", q, k) * model_self.scale | |
| if mask is not None: | |
| mask = mask.reshape(batch_size, -1) | |
| max_neg_value = -torch.finfo(sim.dtype).max | |
| mask = mask[:, None, :].repeat(h, 1, 1) | |
| sim.masked_fill_(~mask, max_neg_value) | |
| # attention, what we cannot get enough of | |
| attn = sim.softmax(dim=-1) | |
| if self.enbale_attn_controller_changes: | |
| attn = self.controller(attn, is_cross, place_in_unet) | |
| if is_cross and self.prompt_mixing is not None and context[1] is not None: | |
| attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size) | |
| if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None: | |
| attn = self.prompt_mixing.get_self_attn(self, self.diff_step, attn, place_in_unet, batch_size) | |
| out = torch.einsum("b i j, b j d -> b i d", attn, v) | |
| out = model_self.reshape_batch_dim_to_heads(out) | |
| return to_out(out) | |
| return forward | |
| def register_recr(net_, count, place_in_unet): | |
| if net_.__class__.__name__ == 'CrossAttention': | |
| net_.forward = ca_forward(net_, place_in_unet) | |
| return count + 1 | |
| elif hasattr(net_, 'children'): | |
| for net__ in net_.children(): | |
| count = register_recr(net__, count, place_in_unet) | |
| return count | |
| cross_att_count = 0 | |
| sub_nets = self.model.unet.named_children() | |
| for net in sub_nets: | |
| if "down" in net[0]: | |
| cross_att_count += register_recr(net[1], 0, "down") | |
| elif "up" in net[0]: | |
| cross_att_count += register_recr(net[1], 0, "up") | |
| elif "mid" in net[0]: | |
| cross_att_count += register_recr(net[1], 0, "mid") | |
| self.controller.num_att_layers = cross_att_count | |
| def get_text_embedding(self, prompt: List[str], max_length=None, truncation=True): | |
| text_input = self.model.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.model.tokenizer.model_max_length if max_length is None else max_length, | |
| truncation=truncation, | |
| return_tensors="pt", | |
| ) | |
| text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.device))[0] | |
| max_length = text_input.input_ids.shape[-1] | |
| return text_embeddings, max_length | |
| def forward(self, prompt: List[str], latent: Optional[torch.FloatTensor] = None, | |
| other_prompt: List[str] = None, post_background = False, orig_all_latents = None, orig_mask = None, | |
| uncond_embeddings=None, start_time=51, return_type='image'): | |
| self.enbale_attn_controller_changes = True | |
| batch_size = len(prompt) | |
| text_embeddings, max_length = self.get_text_embedding(prompt) | |
| if uncond_embeddings is None: | |
| uncond_embeddings_, _ = self.get_text_embedding([""] * batch_size, max_length=max_length, truncation=False) | |
| else: | |
| uncond_embeddings_ = None | |
| other_context = None | |
| if other_prompt is not None: | |
| other_text_embeddings, _ = self.get_text_embedding(other_prompt) | |
| other_context = other_text_embeddings | |
| latent, latents = self.init_latent(latent, batch_size) | |
| # set timesteps | |
| self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"]) | |
| all_latents = [] | |
| object_mask = None | |
| self.diff_step = 0 | |
| for i, t in enumerate(tqdm(self.model.scheduler.timesteps[-start_time:])): | |
| if uncond_embeddings_ is None: | |
| context = [uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings] | |
| else: | |
| context = [uncond_embeddings_, text_embeddings] | |
| if not self.model_config["low_resource"]: | |
| context = torch.cat(context) | |
| self.down_cross_index = 0 | |
| self.mid_cross_index = 0 | |
| self.up_cross_index = 0 | |
| latents = self.diffusion_step(latents, context, t, other_context) | |
| if post_background and self.diff_step == self.args.background_blend_timestep: | |
| object_mask = Segmentor(self.controller, | |
| prompt, | |
| self.args.num_segments, | |
| self.args.background_segment_threshold, | |
| background_nouns=self.args.background_nouns)\ | |
| .get_background_mask(self.args.prompt.split(' ').index("{word}") + 1) | |
| self.enbale_attn_controller_changes = False | |
| mask = object_mask.astype(np.bool8) + orig_mask.astype(np.bool8) | |
| mask = torch.from_numpy(mask).float().cuda() | |
| shape = (1, 1, mask.shape[0], mask.shape[1]) | |
| mask = torch.nn.Upsample(size=(64, 64), mode='nearest')(mask.view(shape)) | |
| mask_eroded = dilate(mask.cpu().numpy()[0, 0], np.ones((3, 3), np.uint8), iterations=1) | |
| mask = torch.from_numpy(mask_eroded).float().cuda().view(1, 1, 64, 64) | |
| latents = mask * latents + (1 - mask) * orig_all_latents[self.diff_step] | |
| all_latents.append(latents) | |
| self.diff_step += 1 | |
| if return_type == 'image': | |
| image = self.latent2image(latents) | |
| else: | |
| image = latents | |
| return image, latent, all_latents, object_mask | |
| def show_last_cross_attention(self, res: int, from_where: List[str], prompts, select: int = 0): | |
| show_cross_attention(self.controller, res, from_where, prompts, tokenizer=self.model.tokenizer, select=select) |