Spaces:
Running
on
Zero
Running
on
Zero
Speed up by storing cross-attention scores only at the last timestep
Browse files- pipeline_objectclear.py +56 -37
pipeline_objectclear.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
|
| 15 |
import inspect
|
| 16 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import PIL.Image
|
|
@@ -58,8 +59,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusio
|
|
| 58 |
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 59 |
from dataclasses import dataclass
|
| 60 |
|
| 61 |
-
from
|
| 62 |
-
from utils import attention_guided_fusion
|
| 63 |
import gc
|
| 64 |
import torch.nn.functional as F
|
| 65 |
|
|
@@ -334,6 +335,7 @@ def retrieve_timesteps(
|
|
| 334 |
class ObjectClearPipelineOutput(StableDiffusionXLPipelineOutput):
|
| 335 |
attns: Optional[List[PIL.Image.Image]] = None
|
| 336 |
|
|
|
|
| 337 |
class ObjectClearPipeline(
|
| 338 |
DiffusionPipeline,
|
| 339 |
StableDiffusionMixin,
|
|
@@ -428,7 +430,7 @@ class ObjectClearPipeline(
|
|
| 428 |
requires_aesthetics_score: bool = False,
|
| 429 |
force_zeros_for_empty_prompt: bool = True,
|
| 430 |
add_watermarker: Optional[bool] = None,
|
| 431 |
-
apply_attention_guided_fusion: bool =
|
| 432 |
):
|
| 433 |
super().__init__()
|
| 434 |
|
|
@@ -463,9 +465,7 @@ class ObjectClearPipeline(
|
|
| 463 |
|
| 464 |
if self.config.apply_attention_guided_fusion:
|
| 465 |
self.cross_attention_scores = {}
|
| 466 |
-
self.
|
| 467 |
-
self.unet, self.cross_attention_scores
|
| 468 |
-
)
|
| 469 |
|
| 470 |
|
| 471 |
@classmethod
|
|
@@ -486,14 +486,17 @@ class ObjectClearPipeline(
|
|
| 486 |
)
|
| 487 |
|
| 488 |
postfuse_module = PostfuseModule(embed_dim=2048, embed_dim_img=768)
|
|
|
|
| 489 |
filename = "model.safetensors"
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
|
|
|
|
|
|
| 497 |
state_dict_postfuse = load_file(safetensor_path)
|
| 498 |
postfuse_module.load_state_dict(state_dict_postfuse)
|
| 499 |
|
|
@@ -537,7 +540,7 @@ class ObjectClearPipeline(
|
|
| 537 |
|
| 538 |
return image_embeds, uncond_image_embeds
|
| 539 |
|
| 540 |
-
def unet_store_cross_attention_scores(self, unet, attention_scores):
|
| 541 |
from diffusers.models.attention_processor import (
|
| 542 |
Attention,
|
| 543 |
AttnProcessor,
|
|
@@ -545,34 +548,25 @@ class ObjectClearPipeline(
|
|
| 545 |
)
|
| 546 |
import types
|
| 547 |
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
"down_blocks.1",
|
| 551 |
-
"down_blocks.2",
|
| 552 |
-
"mid_block",
|
| 553 |
-
"up_blocks.1",
|
| 554 |
-
"up_blocks.2",
|
| 555 |
-
"up_blocks.3",
|
| 556 |
-
]
|
| 557 |
-
|
| 558 |
-
start_layer = 0
|
| 559 |
-
end_layer = 2
|
| 560 |
-
applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer]
|
| 561 |
|
| 562 |
def make_new_get_attention_scores_fn(name):
|
| 563 |
def new_get_attention_scores(module, query, key, attention_mask=None):
|
| 564 |
attention_probs = module.old_get_attention_scores(
|
| 565 |
query, key, attention_mask
|
| 566 |
)
|
| 567 |
-
|
|
|
|
| 568 |
return attention_probs
|
| 569 |
-
|
| 570 |
return new_get_attention_scores
|
| 571 |
|
| 572 |
for name, module in unet.named_modules():
|
| 573 |
-
if isinstance(module, Attention) and "attn2" in name:
|
| 574 |
-
|
| 575 |
-
|
|
|
|
|
|
|
| 576 |
if isinstance(module.processor, AttnProcessor2_0):
|
| 577 |
module.set_processor(AttnProcessor())
|
| 578 |
module.old_get_attention_scores = module.get_attention_scores
|
|
@@ -581,6 +575,19 @@ class ObjectClearPipeline(
|
|
| 581 |
)
|
| 582 |
module.get_attention_scores = module.new_get_attention_scores
|
| 583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
return unet
|
| 585 |
|
| 586 |
def resize_attn_map_divide2(self, attn_map, mask, fuse_index):
|
|
@@ -1426,7 +1433,7 @@ class ObjectClearPipeline(
|
|
| 1426 |
on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
|
| 1427 |
resizing to the original image size for inpainting. This is useful when the masked area is small while
|
| 1428 |
the image is large and contain information irrelevant for inpainting, such as background.
|
| 1429 |
-
strength (`float`, *optional*, defaults to
|
| 1430 |
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
| 1431 |
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
| 1432 |
`strength`. The number of denoising steps depends on the amount of noise initially added. When
|
|
@@ -1871,6 +1878,12 @@ class ObjectClearPipeline(
|
|
| 1871 |
for i, t in enumerate(timesteps):
|
| 1872 |
if self.interrupt:
|
| 1873 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1874 |
# expand the latents if we are doing classifier free guidance
|
| 1875 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1876 |
|
|
@@ -1924,8 +1937,8 @@ class ObjectClearPipeline(
|
|
| 1924 |
)
|
| 1925 |
|
| 1926 |
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
| 1927 |
-
|
| 1928 |
-
if i == len(timesteps) - 1:
|
| 1929 |
attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
|
| 1930 |
attn_map = self.resize_attn_map_divide2(attn_map, mask, fuse_index)
|
| 1931 |
init_latents_proper = image_latents
|
|
@@ -1934,7 +1947,13 @@ class ObjectClearPipeline(
|
|
| 1934 |
else:
|
| 1935 |
init_mask = attn_map
|
| 1936 |
attn_map = init_mask
|
| 1937 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1938 |
|
| 1939 |
if num_channels_unet == 4:
|
| 1940 |
init_latents_proper = image_latents
|
|
@@ -2057,4 +2076,4 @@ class ObjectClearPipeline(
|
|
| 2057 |
else:
|
| 2058 |
if not return_dict:
|
| 2059 |
return (image,)
|
| 2060 |
-
return ObjectClearPipelineOutput(images=image)
|
|
|
|
| 14 |
|
| 15 |
import inspect
|
| 16 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
import os
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
import PIL.Image
|
|
|
|
| 59 |
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 60 |
from dataclasses import dataclass
|
| 61 |
|
| 62 |
+
from ..models import CLIPImageEncoder, PostfuseModule
|
| 63 |
+
from ..utils import attention_guided_fusion
|
| 64 |
import gc
|
| 65 |
import torch.nn.functional as F
|
| 66 |
|
|
|
|
| 335 |
class ObjectClearPipelineOutput(StableDiffusionXLPipelineOutput):
|
| 336 |
attns: Optional[List[PIL.Image.Image]] = None
|
| 337 |
|
| 338 |
+
|
| 339 |
class ObjectClearPipeline(
|
| 340 |
DiffusionPipeline,
|
| 341 |
StableDiffusionMixin,
|
|
|
|
| 430 |
requires_aesthetics_score: bool = False,
|
| 431 |
force_zeros_for_empty_prompt: bool = True,
|
| 432 |
add_watermarker: Optional[bool] = None,
|
| 433 |
+
apply_attention_guided_fusion: bool = True,
|
| 434 |
):
|
| 435 |
super().__init__()
|
| 436 |
|
|
|
|
| 465 |
|
| 466 |
if self.config.apply_attention_guided_fusion:
|
| 467 |
self.cross_attention_scores = {}
|
| 468 |
+
self.original_state = None
|
|
|
|
|
|
|
| 469 |
|
| 470 |
|
| 471 |
@classmethod
|
|
|
|
| 486 |
)
|
| 487 |
|
| 488 |
postfuse_module = PostfuseModule(embed_dim=2048, embed_dim_img=768)
|
| 489 |
+
sub_folder = "postfuse_module"
|
| 490 |
filename = "model.safetensors"
|
| 491 |
+
if pretrained_model_name_or_path == "jixin0101/ObjectClear":
|
| 492 |
+
safetensor_path = hf_hub_download(
|
| 493 |
+
repo_id="jixin0101/ObjectClear",
|
| 494 |
+
filename=filename,
|
| 495 |
+
subfolder="postfuse_module",
|
| 496 |
+
cache_dir=cache_dir
|
| 497 |
+
)
|
| 498 |
+
else:
|
| 499 |
+
safetensor_path = os.path.join(pretrained_model_name_or_path, sub_folder, filename)
|
| 500 |
state_dict_postfuse = load_file(safetensor_path)
|
| 501 |
postfuse_module.load_state_dict(state_dict_postfuse)
|
| 502 |
|
|
|
|
| 540 |
|
| 541 |
return image_embeds, uncond_image_embeds
|
| 542 |
|
| 543 |
+
def unet_store_cross_attention_scores(self, unet, attention_scores, applicable_layers=None):
|
| 544 |
from diffusers.models.attention_processor import (
|
| 545 |
Attention,
|
| 546 |
AttnProcessor,
|
|
|
|
| 548 |
)
|
| 549 |
import types
|
| 550 |
|
| 551 |
+
TARGET_LAYER = "down_blocks.1.attentions.0.transformer_blocks.0.attn2"
|
| 552 |
+
original_state = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
def make_new_get_attention_scores_fn(name):
|
| 555 |
def new_get_attention_scores(module, query, key, attention_mask=None):
|
| 556 |
attention_probs = module.old_get_attention_scores(
|
| 557 |
query, key, attention_mask
|
| 558 |
)
|
| 559 |
+
if name == TARGET_LAYER:
|
| 560 |
+
attention_scores[name] = attention_probs
|
| 561 |
return attention_probs
|
|
|
|
| 562 |
return new_get_attention_scores
|
| 563 |
|
| 564 |
for name, module in unet.named_modules():
|
| 565 |
+
if isinstance(module, Attention) and name == TARGET_LAYER and "attn2" in name:
|
| 566 |
+
original_state[name] = {
|
| 567 |
+
"processor": module.processor,
|
| 568 |
+
"get_attention_scores": module.get_attention_scores
|
| 569 |
+
}
|
| 570 |
if isinstance(module.processor, AttnProcessor2_0):
|
| 571 |
module.set_processor(AttnProcessor())
|
| 572 |
module.old_get_attention_scores = module.get_attention_scores
|
|
|
|
| 575 |
)
|
| 576 |
module.get_attention_scores = module.new_get_attention_scores
|
| 577 |
|
| 578 |
+
return unet, original_state
|
| 579 |
+
|
| 580 |
+
def unet_restore_attention_processor(self, unet, original_state):
|
| 581 |
+
from diffusers.models.attention_processor import Attention
|
| 582 |
+
|
| 583 |
+
for name, module in unet.named_modules():
|
| 584 |
+
if isinstance(module, Attention) and "attn2" in name and name in original_state:
|
| 585 |
+
module.get_attention_scores = original_state[name]["get_attention_scores"]
|
| 586 |
+
module.set_processor(original_state[name]["processor"])
|
| 587 |
+
if hasattr(module, "old_get_attention_scores"):
|
| 588 |
+
delattr(module, "old_get_attention_scores")
|
| 589 |
+
if hasattr(module, "new_get_attention_scores"):
|
| 590 |
+
delattr(module, "new_get_attention_scores")
|
| 591 |
return unet
|
| 592 |
|
| 593 |
def resize_attn_map_divide2(self, attn_map, mask, fuse_index):
|
|
|
|
| 1433 |
on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
|
| 1434 |
resizing to the original image size for inpainting. This is useful when the masked area is small while
|
| 1435 |
the image is large and contain information irrelevant for inpainting, such as background.
|
| 1436 |
+
strength (`float`, *optional*, defaults to 0.9999):
|
| 1437 |
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
| 1438 |
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
| 1439 |
`strength`. The number of denoising steps depends on the amount of noise initially added. When
|
|
|
|
| 1878 |
for i, t in enumerate(timesteps):
|
| 1879 |
if self.interrupt:
|
| 1880 |
continue
|
| 1881 |
+
# Inject cross-attention storage logic at the last timestep
|
| 1882 |
+
if i == len(timesteps) - 1 and self.config.apply_attention_guided_fusion:
|
| 1883 |
+
self.unet, self.original_state = self.unet_store_cross_attention_scores(
|
| 1884 |
+
self.unet,
|
| 1885 |
+
self.cross_attention_scores
|
| 1886 |
+
)
|
| 1887 |
# expand the latents if we are doing classifier free guidance
|
| 1888 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1889 |
|
|
|
|
| 1937 |
)
|
| 1938 |
|
| 1939 |
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
| 1940 |
+
|
| 1941 |
+
if i == len(timesteps) - 1 and self.config.apply_attention_guided_fusion:
|
| 1942 |
attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
|
| 1943 |
attn_map = self.resize_attn_map_divide2(attn_map, mask, fuse_index)
|
| 1944 |
init_latents_proper = image_latents
|
|
|
|
| 1947 |
else:
|
| 1948 |
init_mask = attn_map
|
| 1949 |
attn_map = init_mask
|
| 1950 |
+
|
| 1951 |
+
self.unet = self.unet_restore_attention_processor(
|
| 1952 |
+
self.unet,
|
| 1953 |
+
self.original_state
|
| 1954 |
+
)
|
| 1955 |
+
|
| 1956 |
+
self.clear_cross_attention_scores(self.cross_attention_scores)
|
| 1957 |
|
| 1958 |
if num_channels_unet == 4:
|
| 1959 |
init_latents_proper = image_latents
|
|
|
|
| 2076 |
else:
|
| 2077 |
if not return_dict:
|
| 2078 |
return (image,)
|
| 2079 |
+
return ObjectClearPipelineOutput(images=image)
|