fix model correction_bias
Browse files- modeling_ernie4_5_moe.py +8 -2
modeling_ernie4_5_moe.py
CHANGED
|
@@ -483,8 +483,14 @@ class Ernie4_5_MoeMLP(nn.Module):
|
|
| 483 |
S, H = x.shape
|
| 484 |
E = gate_logits.shape[1]
|
| 485 |
device = x.device
|
| 486 |
-
|
| 487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
expert_id = topk_idx
|
| 489 |
y = x.new_zeros((E, capacity, H))
|
| 490 |
scatter_index = x.new_full((k, S), -1, dtype=torch.int32)
|
|
|
|
| 483 |
S, H = x.shape
|
| 484 |
E = gate_logits.shape[1]
|
| 485 |
device = x.device
|
| 486 |
+
|
| 487 |
+
if self.use_correction_bias:
|
| 488 |
+
_, topk_idx = torch.topk(gate_logits + self.moe_statics.e_score_correction_bias, k, dim=-1)
|
| 489 |
+
topk_prob = torch.gather(gate_logits, dim=-1, index=topk_idx)
|
| 490 |
+
else:
|
| 491 |
+
topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1)
|
| 492 |
+
|
| 493 |
+
combine_weights = topk_prob
|
| 494 |
expert_id = topk_idx
|
| 495 |
y = x.new_zeros((E, capacity, H))
|
| 496 |
scatter_index = x.new_full((k, S), -1, dtype=torch.int32)
|