lirannoc commited on
Commit
0948651
·
verified ·
1 Parent(s): 868b7c0

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +4 -4
modeling_super_linear.py CHANGED
@@ -561,6 +561,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
561
 
562
  def forward(self,
563
  inputs_embeds: torch.Tensor = None,
 
564
  attention_mask: Optional[torch.Tensor] = None,
565
  past_key_values: Optional[Tuple] = None,
566
  use_cache: bool = True,
@@ -572,18 +573,17 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
572
  raise ValueError("Pass the time‑series as `inputs_embeds`")
573
 
574
  # backbone expects (B, C, L)
575
- preds = self.backbone(inputs_embeds)
576
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
577
 
578
 
579
- def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs):
580
  if past_key_values is not None:
581
  # only feed the last new step
582
  inputs_embeds = inputs_embeds[:, -1:, :]
583
- return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values}
584
 
585
  def _reorder_cache(self, past, beam_idx, **kwargs):
586
  return past # backbone keeps no KV cache
587
 
588
 
589
-
 
561
 
562
  def forward(self,
563
  inputs_embeds: torch.Tensor = None,
564
+ prediction_len: int = None,
565
  attention_mask: Optional[torch.Tensor] = None,
566
  past_key_values: Optional[Tuple] = None,
567
  use_cache: bool = True,
 
573
  raise ValueError("Pass the time‑series as `inputs_embeds`")
574
 
575
  # backbone expects (B, C, L)
576
+ preds = self.backbone(inputs_embeds, inf_pred_len=prediction_len)
577
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
578
 
579
 
580
+ def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, prediction_len=None, **kwargs):
581
  if past_key_values is not None:
582
  # only feed the last new step
583
  inputs_embeds = inputs_embeds[:, -1:, :]
584
+ return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values, "prediction_len": prediction_len}
585
 
586
  def _reorder_cache(self, past, beam_idx, **kwargs):
587
  return past # backbone keeps no KV cache
588
 
589