Spaces:
Runtime error
Runtime error
| import tempfile | |
| import unittest | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor | |
| from diffusers.utils.testing_utils import torch_device | |
| class AttnAddedKVProcessorTests(unittest.TestCase): | |
| def get_constructor_arguments(self, only_cross_attention: bool = False): | |
| query_dim = 10 | |
| if only_cross_attention: | |
| cross_attention_dim = 12 | |
| else: | |
| # when only cross attention is not set, the cross attention dim must be the same as the query dim | |
| cross_attention_dim = query_dim | |
| return { | |
| "query_dim": query_dim, | |
| "cross_attention_dim": cross_attention_dim, | |
| "heads": 2, | |
| "dim_head": 4, | |
| "added_kv_proj_dim": 6, | |
| "norm_num_groups": 1, | |
| "only_cross_attention": only_cross_attention, | |
| "processor": AttnAddedKVProcessor(), | |
| } | |
| def get_forward_arguments(self, query_dim, added_kv_proj_dim): | |
| batch_size = 2 | |
| hidden_states = torch.rand(batch_size, query_dim, 3, 2) | |
| encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim) | |
| attention_mask = None | |
| return { | |
| "hidden_states": hidden_states, | |
| "encoder_hidden_states": encoder_hidden_states, | |
| "attention_mask": attention_mask, | |
| } | |
| def test_only_cross_attention(self): | |
| # self and cross attention | |
| torch.manual_seed(0) | |
| constructor_args = self.get_constructor_arguments(only_cross_attention=False) | |
| attn = Attention(**constructor_args) | |
| self.assertTrue(attn.to_k is not None) | |
| self.assertTrue(attn.to_v is not None) | |
| forward_args = self.get_forward_arguments( | |
| query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] | |
| ) | |
| self_and_cross_attn_out = attn(**forward_args) | |
| # only self attention | |
| torch.manual_seed(0) | |
| constructor_args = self.get_constructor_arguments(only_cross_attention=True) | |
| attn = Attention(**constructor_args) | |
| self.assertTrue(attn.to_k is None) | |
| self.assertTrue(attn.to_v is None) | |
| forward_args = self.get_forward_arguments( | |
| query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] | |
| ) | |
| only_cross_attn_out = attn(**forward_args) | |
| self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) | |
| class DeprecatedAttentionBlockTests(unittest.TestCase): | |
| def is_dist_enabled(pytestconfig): | |
| return pytestconfig.getoption("dist") == "loadfile" | |
| def test_conversion_when_using_device_map(self): | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None | |
| ) | |
| pre_conversion = pipe( | |
| "foo", | |
| num_inference_steps=2, | |
| generator=torch.Generator("cpu").manual_seed(0), | |
| output_type="np", | |
| ).images | |
| # the initial conversion succeeds | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "hf-internal-testing/tiny-stable-diffusion-torch", device_map="balanced", safety_checker=None | |
| ) | |
| conversion = pipe( | |
| "foo", | |
| num_inference_steps=2, | |
| generator=torch.Generator("cpu").manual_seed(0), | |
| output_type="np", | |
| ).images | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| # save the converted model | |
| pipe.save_pretrained(tmpdir) | |
| # can also load the converted weights | |
| pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="balanced", safety_checker=None) | |
| after_conversion = pipe( | |
| "foo", | |
| num_inference_steps=2, | |
| generator=torch.Generator("cpu").manual_seed(0), | |
| output_type="np", | |
| ).images | |
| self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3)) | |
| self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3)) | |