Update modeling_super_linear.py
Browse files- modeling_super_linear.py +4 -4
modeling_super_linear.py
CHANGED
|
@@ -561,6 +561,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 561 |
|
| 562 |
def forward(self,
|
| 563 |
inputs_embeds: torch.Tensor = None,
|
|
|
|
| 564 |
attention_mask: Optional[torch.Tensor] = None,
|
| 565 |
past_key_values: Optional[Tuple] = None,
|
| 566 |
use_cache: bool = True,
|
|
@@ -572,18 +573,17 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 572 |
raise ValueError("Pass the time‑series as `inputs_embeds`")
|
| 573 |
|
| 574 |
# backbone expects (B, C, L)
|
| 575 |
-
preds = self.backbone(inputs_embeds)
|
| 576 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|
| 577 |
|
| 578 |
|
| 579 |
-
def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs):
|
| 580 |
if past_key_values is not None:
|
| 581 |
# only feed the last new step
|
| 582 |
inputs_embeds = inputs_embeds[:, -1:, :]
|
| 583 |
-
return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values}
|
| 584 |
|
| 585 |
def _reorder_cache(self, past, beam_idx, **kwargs):
|
| 586 |
return past # backbone keeps no KV cache
|
| 587 |
|
| 588 |
|
| 589 |
-
|
|
|
|
| 561 |
|
| 562 |
def forward(self,
|
| 563 |
inputs_embeds: torch.Tensor = None,
|
| 564 |
+
prediction_len: int = None,
|
| 565 |
attention_mask: Optional[torch.Tensor] = None,
|
| 566 |
past_key_values: Optional[Tuple] = None,
|
| 567 |
use_cache: bool = True,
|
|
|
|
| 573 |
raise ValueError("Pass the time‑series as `inputs_embeds`")
|
| 574 |
|
| 575 |
# backbone expects (B, C, L)
|
| 576 |
+
preds = self.backbone(inputs_embeds, inf_pred_len=prediction_len)
|
| 577 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|
| 578 |
|
| 579 |
|
| 580 |
+
def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, prediction_len=None, **kwargs):
|
| 581 |
if past_key_values is not None:
|
| 582 |
# only feed the last new step
|
| 583 |
inputs_embeds = inputs_embeds[:, -1:, :]
|
| 584 |
+
return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values, "prediction_len": prediction_len}
|
| 585 |
|
| 586 |
def _reorder_cache(self, past, beam_idx, **kwargs):
|
| 587 |
return past # backbone keeps no KV cache
|
| 588 |
|
| 589 |
|
|
|