File size: 11,498 Bytes
a3a2e41
 
 
 
 
 
6ffc50e
a3a2e41
3a03985
a3a2e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2ed9cf
a3a2e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49ba373
b2ed9cf
a3a2e41
 
 
 
 
 
b2ed9cf
a3a2e41
 
 
 
 
 
b2ed9cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3a2e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch

try:
    import flash_attn_interface
    FLASH_ATTN_3_AVAILABLE = True
    print(f'FLASH_ATTN_3_AVAILABLE:{FLASH_ATTN_3_AVAILABLE}')
except ModuleNotFoundError:
    print(f'faield FLASH_ATTN_3_AVAILABLE:{FLASH_ATTN_3_AVAILABLE}')
    FLASH_ATTN_3_AVAILABLE = False

try:
    import flash_attn
    FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_2_AVAILABLE = False

import warnings

__all__ = [
    'flash_attention',
    'attention',
    'attention_with_weights',
]


def flash_attention(

    q,

    k,

    v,

    q_lens=None,

    k_lens=None,

    dropout_p=0.,

    softmax_scale=None,

    q_scale=None,

    causal=False,

    window_size=(-1, -1),

    deterministic=False,

    dtype=torch.bfloat16,

    version=None

):
    """

    q:              [B, Lq, Nq, C1].

    k:              [B, Lk, Nk, C1].

    v:              [B, Lk, Nk, C2]. Nq must be divisible by Nk.

    q_lens:         [B].

    k_lens:         [B].

    dropout_p:      float. Dropout probability.

    softmax_scale:  float. The scaling of QK^T before applying softmax.

    causal:         bool. Whether to apply causal attention mask.

    window_size:    (left right). If not (-1, -1), apply sliding window local attention.

    deterministic:  bool. If True, slightly slower and uses more memory.

    dtype:          torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.

    """
    half_dtypes = (torch.float16, torch.bfloat16)
    assert dtype in half_dtypes
    assert q.device.type == 'cuda' and q.size(-1) <= 256

    # params
    b, lq, nheads, lk, out_dtype = q.size(0), q.size(1), q.size(2), k.size(1), q.dtype

    def half(x):
        return x if x.dtype in half_dtypes else x.to(dtype)

    # preprocess query
    if q_lens is None:
        q = half(q.flatten(0, 1))
        q_lens = torch.tensor(
            [lq] * b, dtype=torch.int32).to(
                device=q.device, non_blocking=True)
    else:
        q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))

    # preprocess key, value
    if k_lens is None:
        k = half(k.flatten(0, 1))
        v = half(v.flatten(0, 1))
        k_lens = torch.tensor(
            [lk] * b, dtype=torch.int32).to(
                device=k.device, non_blocking=True)
    else:
        k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
        v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))

    q = q.to(v.dtype)
    k = k.to(v.dtype)

    if q_scale is not None:
        q = q * q_scale

    if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
        warnings.warn(
            'Flash attention 3 is not available, use flash attention 2 instead.'
        )

    # apply attention
    if FLASH_ATTN_3_AVAILABLE:
        ret = flash_attn_interface.flash_attn_varlen_func(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
                0, dtype=torch.int32).to(q.device, non_blocking=True),
            cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
                0, dtype=torch.int32).to(k.device, non_blocking=True),
            seqused_q=None,
            seqused_k=None,
            max_seqlen_q=lq,
            max_seqlen_k=lk,
            softmax_scale=softmax_scale,
            causal=causal,
            deterministic=deterministic
        )

        # Some FA3 wheels return (out, softmax_lse); some return just out.
        out0 = ret[0] if isinstance(ret, (tuple, list)) else ret

        # Normalize FA3 output layout to (total_q, nheads, headdim)
        total_q = b * lq
        if out0.dim() == 3:
            if out0.shape[0] == total_q:
                pass  # (total_q, nheads, headdim) -> good
            elif out0.shape[0] == nheads and out0.shape[1] == total_q:
                # heads-first -> transpose to (total_q, nheads, headdim)
                out0 = out0.transpose(0, 1).contiguous()
            else:
                raise RuntimeError(
                    f"Unexpected FA3 output shape {tuple(out0.shape)}; "
                    f"expected (total_q, nheads, headdim) or (nheads, total_q, headdim)"
                )
        else:
            raise RuntimeError(
                f"Unexpected FA3 output rank {out0.dim()} with shape {tuple(out0.shape)}; "
                f"expected a 3D tensor."
            )

        x = out0.unflatten(0, (b, lq))
        
    else:
        assert FLASH_ATTN_2_AVAILABLE
        x = flash_attn.flash_attn_varlen_func(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
                0, dtype=torch.int32).to(q.device, non_blocking=True),
            cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
                0, dtype=torch.int32).to(q.device, non_blocking=True),
            max_seqlen_q=lq,
            max_seqlen_k=lk,
            dropout_p=dropout_p,
            softmax_scale=softmax_scale,
            causal=causal,
            window_size=window_size,
            deterministic=deterministic).unflatten(0, (b, lq))

    # output
    return x.type(out_dtype)


def attention_with_weights(

    q,

    k,

    v,

    q_lens=None,

    k_lens=None,

    softmax_scale=None,

    q_scale=None,

    causal=False,

    average_for_q=False,

    total_video_latent_frames = 21

):
    """

    Compute attention with explicit attention weights for visualization.

    Returns both output and attention weights.

    """
    out_dtype = q.dtype
    
    # Handle sequence lengths
    b, lq, lk = q.size(0), q.size(1), k.size(1)
    
    if q_lens is None:
        q_lens = torch.tensor([lq] * b, dtype=torch.int32, device=q.device)
    else:
        # Ensure q_lens is on the same device as q
        q_lens = q_lens.to(q.device)
        
    if k_lens is None:
        k_lens = torch.tensor([lk] * b, dtype=torch.int32, device=k.device)
    else:
        # Ensure k_lens is on the same device as k
        k_lens = k_lens.to(k.device)
    
    # Apply q_scale if provided
    if q_scale is not None:
        q = q * q_scale
    
    # Compute attention weights manually
    # q: [B, Lq, Nq, C], k: [B, Lk, Nk, C]
    scale = softmax_scale if softmax_scale is not None else (q.size(-1) ** -0.5)
    
    # Compute scores: [B, Nq, Lq, Lk]
    scores = torch.einsum('blhd,bshd->bhls', q, k) * scale
    
    # Apply causal mask if needed
    if causal:
        mask = torch.triu(torch.ones(lq, lk, device=q.device, dtype=torch.bool), diagonal=1)
        scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
    
    # Mask for k_lens (columns)
    k_mask = torch.arange(lk, device=k.device).unsqueeze(0) >= k_lens.unsqueeze(1)  # [B, Lk]
    scores.masked_fill_(k_mask.unsqueeze(1).unsqueeze(2), float('-inf'))  # [B, 1, 1, Lk]
    
    # Mask for q_lens (rows) 
    q_mask = torch.arange(lq, device=q.device).unsqueeze(0) >= q_lens.unsqueeze(1)  # [B, Lq]
    scores.masked_fill_(q_mask.unsqueeze(1).unsqueeze(3), float('-inf'))  # [B, 1, Lq, 1]
    
    # Compute attention weights
    attn_weights = torch.softmax(scores, dim=-1)  # [B, Nq, Lq, Lk]
    assert attn_weights.shape[0] == 1, "Batch size > 1 not supported for attention visualization."
    
    # Average attention weights to reduce memory usage before returning
    # Average across batch dimension (should be 1) and query heads and query sequence length
    # This gives us attention weight per video token: [Lk]
    if average_for_q:
        #avg_attn_weights = torch.mean(attn_weights, dim=(0, 1, 3))  # [Lq]
        avg_attn_weights = torch.max(attn_weights, dim=3)[0].mean(dim=(0, 1))  # [Lq]
    else:
        if 0:
            avg_attn_weights = torch.mean(attn_weights, dim=(0, 1, 2))  # [Lk]
        elif 1:
            B, H, Lq, Lk = attn_weights.shape  # [1, H, Lq, Lk]
            per_frame_seq_len = Lk // total_video_latent_frames
            per_frame_aud_len = Lq // total_video_latent_frames

            avg_attn_weights = torch.zeros((Lk,), device=attn_weights.device, dtype=attn_weights.dtype)

            eps = 1e-8  # numerical stability
            for i in range(total_video_latent_frames):
                start_idx_v = i * per_frame_seq_len
                end_idx_v   = (i + 1) * per_frame_seq_len

                start_idx_a = i * per_frame_aud_len
                end_idx_a   = (i + 1) * per_frame_aud_len

                # attn_chunk: [H, La, Lv]
                attn_chunk = attn_weights[0, :, start_idx_a:end_idx_a, start_idx_v:end_idx_v]

                # ---- Head informativeness via (low) entropy over Lv ----
                # Normalize within the Lv slice per (head, query) to make a proper distribution
                p = attn_chunk / (attn_chunk.sum(dim=-1, keepdim=True) + eps)          # [H, La, Lv]
                entropy = -(p * (p + eps).log()).sum(dim=-1).mean(dim=1)               # [H]

                # Convert to positive head weights (lower entropy -> larger weight)
                saliency = 1.0 / (entropy + 1e-6)                                      # [H]
                head_w = saliency / (saliency.sum() + eps)                             # [H], sum=1

                # Reduce across audio queries first (pick strong responses), then weight heads
                per_head = torch.amax(attn_chunk, dim=1)                               # [H, Lv]
                weighted = (per_head * head_w[:, None]).sum(dim=0)                     # [Lv]

                avg_attn_weights[start_idx_v:end_idx_v] = weighted
        else:
            avg_attn_weights = torch.mean(attn_weights, dim=(0, 2)).max(dim=(0))[0]  # [Lk]
    
    # Compute output: [B, Lq, Nq, C]
    out = torch.einsum('bhls,bshd->blhd', attn_weights, v)
    
    return out.to(out_dtype), avg_attn_weights.to(out_dtype)


def attention(

    q,

    k,

    v,

    q_lens=None,

    k_lens=None,

    dropout_p=0.,

    softmax_scale=None,

    q_scale=None,

    causal=False,

    window_size=(-1, -1),

    deterministic=False,

    dtype=torch.bfloat16,

    fa_version=None,

):
    if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
        return flash_attention(
            q=q,
            k=k,
            v=v,
            q_lens=q_lens,
            k_lens=k_lens,
            dropout_p=dropout_p,
            softmax_scale=softmax_scale,
            q_scale=q_scale,
            causal=causal,
            window_size=window_size,
            deterministic=deterministic,
            dtype=dtype,
            version=fa_version,
        )
    else:
        if q_lens is not None or k_lens is not None:
            warnings.warn(
                'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
            )
        attn_mask = None

        q = q.transpose(1, 2).to(dtype)
        k = k.transpose(1, 2).to(dtype)
        v = v.transpose(1, 2).to(dtype)

        out = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)

        out = out.transpose(1, 2).contiguous()
        return out