revi13 commited on
Commit
bde45d4
·
verified ·
1 Parent(s): 960dad1

Update ip_adapter/attention_processor_faceid.py

Browse files
ip_adapter/attention_processor_faceid.py CHANGED
@@ -392,7 +392,7 @@ class LoRAIPAttnProcessor2_0(nn.Module):
392
  head_dim = inner_dim // attn.heads
393
 
394
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
395
- query = query.to(dtype=key.dtype) # ← これを追加
396
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
397
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
398
 
@@ -413,6 +413,8 @@ class LoRAIPAttnProcessor2_0(nn.Module):
413
  ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
414
  ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
415
 
 
 
416
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
417
  # TODO: add support for attn.scale when we move to Torch 2.1
418
  ip_hidden_states = F.scaled_dot_product_attention(
 
392
  head_dim = inner_dim // attn.heads
393
 
394
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
395
+ query = query.to(dtype=ip_key.dtype) # ← これを追加
396
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
397
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
398
 
 
413
  ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
414
  ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
415
 
416
+
417
+
418
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
419
  # TODO: add support for attn.scale when we move to Torch 2.1
420
  ip_hidden_states = F.scaled_dot_product_attention(