Spaces:
Sleeping
Sleeping
Update ip_adapter/attention_processor_faceid.py
Browse files
ip_adapter/attention_processor_faceid.py
CHANGED
|
@@ -392,7 +392,7 @@ class LoRAIPAttnProcessor2_0(nn.Module):
|
|
| 392 |
head_dim = inner_dim // attn.heads
|
| 393 |
|
| 394 |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 395 |
-
query = query.to(dtype=
|
| 396 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 397 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 398 |
|
|
@@ -413,6 +413,8 @@ class LoRAIPAttnProcessor2_0(nn.Module):
|
|
| 413 |
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 414 |
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 415 |
|
|
|
|
|
|
|
| 416 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 417 |
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 418 |
ip_hidden_states = F.scaled_dot_product_attention(
|
|
|
|
| 392 |
head_dim = inner_dim // attn.heads
|
| 393 |
|
| 394 |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 395 |
+
query = query.to(dtype=ip_key.dtype) # ← これを追加
|
| 396 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 397 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 398 |
|
|
|
|
| 413 |
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 414 |
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 415 |
|
| 416 |
+
|
| 417 |
+
|
| 418 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 419 |
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 420 |
ip_hidden_states = F.scaled_dot_product_attention(
|