hlfby06 commited on
Commit
25a6467
·
verified ·
1 Parent(s): da07396

fix model correction_bias

Browse files
Files changed (1) hide show
  1. modeling_ernie4_5_moe.py +8 -2
modeling_ernie4_5_moe.py CHANGED
@@ -483,8 +483,14 @@ class Ernie4_5_MoeMLP(nn.Module):
483
  S, H = x.shape
484
  E = gate_logits.shape[1]
485
  device = x.device
486
- topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1)
487
- combine_weights = topk_prob
 
 
 
 
 
 
488
  expert_id = topk_idx
489
  y = x.new_zeros((E, capacity, H))
490
  scatter_index = x.new_full((k, S), -1, dtype=torch.int32)
 
483
  S, H = x.shape
484
  E = gate_logits.shape[1]
485
  device = x.device
486
+
487
+ if self.use_correction_bias:
488
+ _, topk_idx = torch.topk(gate_logits + self.moe_statics.e_score_correction_bias, k, dim=-1)
489
+ topk_prob = torch.gather(gate_logits, dim=-1, index=topk_idx)
490
+ else:
491
+ topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1)
492
+
493
+ combine_weights = topk_prob
494
  expert_id = topk_idx
495
  y = x.new_zeros((E, capacity, H))
496
  scatter_index = x.new_full((k, S), -1, dtype=torch.int32)