Update modeling_atomformer.py
Browse files- modeling_atomformer.py +2 -2
modeling_atomformer.py
CHANGED
|
@@ -2516,7 +2516,7 @@ class AtomformerEncoder(nn.Module):
|
|
| 2516 |
for blk in self.blocks:
|
| 2517 |
input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask)
|
| 2518 |
|
| 2519 |
-
return input_embeds
|
| 2520 |
|
| 2521 |
|
| 2522 |
class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore
|
|
@@ -2550,7 +2550,7 @@ class AtomformerModel(AtomformerPreTrainedModel):
|
|
| 2550 |
) -> torch.Tensor:
|
| 2551 |
"""Forward function call for the transformer model."""
|
| 2552 |
output: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
|
| 2553 |
-
return output
|
| 2554 |
|
| 2555 |
|
| 2556 |
class AtomformerForMaskedAM(AtomformerPreTrainedModel):
|
|
|
|
| 2516 |
for blk in self.blocks:
|
| 2517 |
input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask)
|
| 2518 |
|
| 2519 |
+
return input_embeds, pos_embeds
|
| 2520 |
|
| 2521 |
|
| 2522 |
class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore
|
|
|
|
| 2550 |
) -> torch.Tensor:
|
| 2551 |
"""Forward function call for the transformer model."""
|
| 2552 |
output: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
|
| 2553 |
+
return output[0][:, :-1]
|
| 2554 |
|
| 2555 |
|
| 2556 |
class AtomformerForMaskedAM(AtomformerPreTrainedModel):
|