Spaces:
Runtime error
Runtime error
fix: sinkformer gradient
Browse files
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -215,8 +215,25 @@ def dot_product_attention_weights(
|
|
| 215 |
# normalize the attention weights
|
| 216 |
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
| 217 |
for i in range(sinkhorn_iters - 1):
|
|
|
|
| 218 |
axis = -2 if i % 2 == 0 else -1
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
# apply attention dropout
|
| 222 |
if not deterministic and dropout_rate > 0.0:
|
|
@@ -396,6 +413,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 396 |
query_states,
|
| 397 |
key_states,
|
| 398 |
bias=attention_bias,
|
|
|
|
| 399 |
dropout_rng=dropout_rng,
|
| 400 |
dropout_rate=self.dropout,
|
| 401 |
broadcast_dropout=True,
|
|
|
|
| 215 |
# normalize the attention weights
|
| 216 |
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
| 217 |
for i in range(sinkhorn_iters - 1):
|
| 218 |
+
# TODO: this is unstable, requires lse space
|
| 219 |
axis = -2 if i % 2 == 0 else -1
|
| 220 |
+
if mask is not None:
|
| 221 |
+
attn_weights = jnp.where(
|
| 222 |
+
mask > 0,
|
| 223 |
+
attn_weights
|
| 224 |
+
/ (
|
| 225 |
+
1e-5
|
| 226 |
+
+ jax.lax.stop_gradient(
|
| 227 |
+
jnp.sum(attn_weights, axis=axis, where=mask, keepdims=True)
|
| 228 |
+
)
|
| 229 |
+
),
|
| 230 |
+
0.0,
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
attn_weights = attn_weights / (
|
| 234 |
+
1e-5
|
| 235 |
+
+ jax.lax.stop_gradient(jnp.sum(attn_weights, axis=axis, keepdims=True))
|
| 236 |
+
)
|
| 237 |
|
| 238 |
# apply attention dropout
|
| 239 |
if not deterministic and dropout_rate > 0.0:
|
|
|
|
| 413 |
query_states,
|
| 414 |
key_states,
|
| 415 |
bias=attention_bias,
|
| 416 |
+
mask=attention_mask,
|
| 417 |
dropout_rng=dropout_rng,
|
| 418 |
dropout_rate=self.dropout,
|
| 419 |
broadcast_dropout=True,
|