jixin0101 commited on
Commit
80a6063
·
verified ·
1 Parent(s): b0fe135

Speed up by storing cross-attention scores only at the last timestep

Browse files
Files changed (1) hide show
  1. pipeline_objectclear.py +56 -37
pipeline_objectclear.py CHANGED
@@ -14,6 +14,7 @@
14
 
15
  import inspect
16
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
17
 
18
  import numpy as np
19
  import PIL.Image
@@ -58,8 +59,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusio
58
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59
  from dataclasses import dataclass
60
 
61
- from model import CLIPImageEncoder, PostfuseModule
62
- from utils import attention_guided_fusion
63
  import gc
64
  import torch.nn.functional as F
65
 
@@ -334,6 +335,7 @@ def retrieve_timesteps(
334
  class ObjectClearPipelineOutput(StableDiffusionXLPipelineOutput):
335
  attns: Optional[List[PIL.Image.Image]] = None
336
 
 
337
  class ObjectClearPipeline(
338
  DiffusionPipeline,
339
  StableDiffusionMixin,
@@ -428,7 +430,7 @@ class ObjectClearPipeline(
428
  requires_aesthetics_score: bool = False,
429
  force_zeros_for_empty_prompt: bool = True,
430
  add_watermarker: Optional[bool] = None,
431
- apply_attention_guided_fusion: bool = False,
432
  ):
433
  super().__init__()
434
 
@@ -463,9 +465,7 @@ class ObjectClearPipeline(
463
 
464
  if self.config.apply_attention_guided_fusion:
465
  self.cross_attention_scores = {}
466
- self.unet = self.unet_store_cross_attention_scores(
467
- self.unet, self.cross_attention_scores
468
- )
469
 
470
 
471
  @classmethod
@@ -486,14 +486,17 @@ class ObjectClearPipeline(
486
  )
487
 
488
  postfuse_module = PostfuseModule(embed_dim=2048, embed_dim_img=768)
 
489
  filename = "model.safetensors"
490
-
491
- safetensor_path = hf_hub_download(
492
- repo_id="jixin0101/ObjectClear",
493
- filename=filename,
494
- subfolder="postfuse_module",
495
- cache_dir=cache_dir
496
- )
 
 
497
  state_dict_postfuse = load_file(safetensor_path)
498
  postfuse_module.load_state_dict(state_dict_postfuse)
499
 
@@ -537,7 +540,7 @@ class ObjectClearPipeline(
537
 
538
  return image_embeds, uncond_image_embeds
539
 
540
- def unet_store_cross_attention_scores(self, unet, attention_scores):
541
  from diffusers.models.attention_processor import (
542
  Attention,
543
  AttnProcessor,
@@ -545,34 +548,25 @@ class ObjectClearPipeline(
545
  )
546
  import types
547
 
548
- UNET_LAYER_NAMES = [
549
- "down_blocks.0",
550
- "down_blocks.1",
551
- "down_blocks.2",
552
- "mid_block",
553
- "up_blocks.1",
554
- "up_blocks.2",
555
- "up_blocks.3",
556
- ]
557
-
558
- start_layer = 0
559
- end_layer = 2
560
- applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer]
561
 
562
  def make_new_get_attention_scores_fn(name):
563
  def new_get_attention_scores(module, query, key, attention_mask=None):
564
  attention_probs = module.old_get_attention_scores(
565
  query, key, attention_mask
566
  )
567
- attention_scores[name] = attention_probs
 
568
  return attention_probs
569
-
570
  return new_get_attention_scores
571
 
572
  for name, module in unet.named_modules():
573
- if isinstance(module, Attention) and "attn2" in name:
574
- if not any(layer in name for layer in applicable_layers):
575
- continue
 
 
576
  if isinstance(module.processor, AttnProcessor2_0):
577
  module.set_processor(AttnProcessor())
578
  module.old_get_attention_scores = module.get_attention_scores
@@ -581,6 +575,19 @@ class ObjectClearPipeline(
581
  )
582
  module.get_attention_scores = module.new_get_attention_scores
583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  return unet
585
 
586
  def resize_attn_map_divide2(self, attn_map, mask, fuse_index):
@@ -1426,7 +1433,7 @@ class ObjectClearPipeline(
1426
  on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1427
  resizing to the original image size for inpainting. This is useful when the masked area is small while
1428
  the image is large and contain information irrelevant for inpainting, such as background.
1429
- strength (`float`, *optional*, defaults to 1.0):
1430
  Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1431
  between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1432
  `strength`. The number of denoising steps depends on the amount of noise initially added. When
@@ -1871,6 +1878,12 @@ class ObjectClearPipeline(
1871
  for i, t in enumerate(timesteps):
1872
  if self.interrupt:
1873
  continue
 
 
 
 
 
 
1874
  # expand the latents if we are doing classifier free guidance
1875
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1876
 
@@ -1924,8 +1937,8 @@ class ObjectClearPipeline(
1924
  )
1925
 
1926
  latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1927
-
1928
- if i == len(timesteps) - 1:
1929
  attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
1930
  attn_map = self.resize_attn_map_divide2(attn_map, mask, fuse_index)
1931
  init_latents_proper = image_latents
@@ -1934,7 +1947,13 @@ class ObjectClearPipeline(
1934
  else:
1935
  init_mask = attn_map
1936
  attn_map = init_mask
1937
- self.clear_cross_attention_scores(self.cross_attention_scores)
 
 
 
 
 
 
1938
 
1939
  if num_channels_unet == 4:
1940
  init_latents_proper = image_latents
@@ -2057,4 +2076,4 @@ class ObjectClearPipeline(
2057
  else:
2058
  if not return_dict:
2059
  return (image,)
2060
- return ObjectClearPipelineOutput(images=image)
 
14
 
15
  import inspect
16
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+ import os
18
 
19
  import numpy as np
20
  import PIL.Image
 
59
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
60
  from dataclasses import dataclass
61
 
62
+ from ..models import CLIPImageEncoder, PostfuseModule
63
+ from ..utils import attention_guided_fusion
64
  import gc
65
  import torch.nn.functional as F
66
 
 
335
  class ObjectClearPipelineOutput(StableDiffusionXLPipelineOutput):
336
  attns: Optional[List[PIL.Image.Image]] = None
337
 
338
+
339
  class ObjectClearPipeline(
340
  DiffusionPipeline,
341
  StableDiffusionMixin,
 
430
  requires_aesthetics_score: bool = False,
431
  force_zeros_for_empty_prompt: bool = True,
432
  add_watermarker: Optional[bool] = None,
433
+ apply_attention_guided_fusion: bool = True,
434
  ):
435
  super().__init__()
436
 
 
465
 
466
  if self.config.apply_attention_guided_fusion:
467
  self.cross_attention_scores = {}
468
+ self.original_state = None
 
 
469
 
470
 
471
  @classmethod
 
486
  )
487
 
488
  postfuse_module = PostfuseModule(embed_dim=2048, embed_dim_img=768)
489
+ sub_folder = "postfuse_module"
490
  filename = "model.safetensors"
491
+ if pretrained_model_name_or_path == "jixin0101/ObjectClear":
492
+ safetensor_path = hf_hub_download(
493
+ repo_id="jixin0101/ObjectClear",
494
+ filename=filename,
495
+ subfolder="postfuse_module",
496
+ cache_dir=cache_dir
497
+ )
498
+ else:
499
+ safetensor_path = os.path.join(pretrained_model_name_or_path, sub_folder, filename)
500
  state_dict_postfuse = load_file(safetensor_path)
501
  postfuse_module.load_state_dict(state_dict_postfuse)
502
 
 
540
 
541
  return image_embeds, uncond_image_embeds
542
 
543
+ def unet_store_cross_attention_scores(self, unet, attention_scores, applicable_layers=None):
544
  from diffusers.models.attention_processor import (
545
  Attention,
546
  AttnProcessor,
 
548
  )
549
  import types
550
 
551
+ TARGET_LAYER = "down_blocks.1.attentions.0.transformer_blocks.0.attn2"
552
+ original_state = {}
 
 
 
 
 
 
 
 
 
 
 
553
 
554
  def make_new_get_attention_scores_fn(name):
555
  def new_get_attention_scores(module, query, key, attention_mask=None):
556
  attention_probs = module.old_get_attention_scores(
557
  query, key, attention_mask
558
  )
559
+ if name == TARGET_LAYER:
560
+ attention_scores[name] = attention_probs
561
  return attention_probs
 
562
  return new_get_attention_scores
563
 
564
  for name, module in unet.named_modules():
565
+ if isinstance(module, Attention) and name == TARGET_LAYER and "attn2" in name:
566
+ original_state[name] = {
567
+ "processor": module.processor,
568
+ "get_attention_scores": module.get_attention_scores
569
+ }
570
  if isinstance(module.processor, AttnProcessor2_0):
571
  module.set_processor(AttnProcessor())
572
  module.old_get_attention_scores = module.get_attention_scores
 
575
  )
576
  module.get_attention_scores = module.new_get_attention_scores
577
 
578
+ return unet, original_state
579
+
580
+ def unet_restore_attention_processor(self, unet, original_state):
581
+ from diffusers.models.attention_processor import Attention
582
+
583
+ for name, module in unet.named_modules():
584
+ if isinstance(module, Attention) and "attn2" in name and name in original_state:
585
+ module.get_attention_scores = original_state[name]["get_attention_scores"]
586
+ module.set_processor(original_state[name]["processor"])
587
+ if hasattr(module, "old_get_attention_scores"):
588
+ delattr(module, "old_get_attention_scores")
589
+ if hasattr(module, "new_get_attention_scores"):
590
+ delattr(module, "new_get_attention_scores")
591
  return unet
592
 
593
  def resize_attn_map_divide2(self, attn_map, mask, fuse_index):
 
1433
  on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1434
  resizing to the original image size for inpainting. This is useful when the masked area is small while
1435
  the image is large and contain information irrelevant for inpainting, such as background.
1436
+ strength (`float`, *optional*, defaults to 0.9999):
1437
  Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1438
  between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1439
  `strength`. The number of denoising steps depends on the amount of noise initially added. When
 
1878
  for i, t in enumerate(timesteps):
1879
  if self.interrupt:
1880
  continue
1881
+ # Inject cross-attention storage logic at the last timestep
1882
+ if i == len(timesteps) - 1 and self.config.apply_attention_guided_fusion:
1883
+ self.unet, self.original_state = self.unet_store_cross_attention_scores(
1884
+ self.unet,
1885
+ self.cross_attention_scores
1886
+ )
1887
  # expand the latents if we are doing classifier free guidance
1888
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1889
 
 
1937
  )
1938
 
1939
  latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1940
+
1941
+ if i == len(timesteps) - 1 and self.config.apply_attention_guided_fusion:
1942
  attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
1943
  attn_map = self.resize_attn_map_divide2(attn_map, mask, fuse_index)
1944
  init_latents_proper = image_latents
 
1947
  else:
1948
  init_mask = attn_map
1949
  attn_map = init_mask
1950
+
1951
+ self.unet = self.unet_restore_attention_processor(
1952
+ self.unet,
1953
+ self.original_state
1954
+ )
1955
+
1956
+ self.clear_cross_attention_scores(self.cross_attention_scores)
1957
 
1958
  if num_channels_unet == 4:
1959
  init_latents_proper = image_latents
 
2076
  else:
2077
  if not return_dict:
2078
  return (image,)
2079
+ return ObjectClearPipelineOutput(images=image)