Update chatNT.py
#3
by
Yanisadel
- opened
chatNT.py
CHANGED
|
@@ -28,7 +28,6 @@ class RotaryEmbeddingConfig:
|
|
| 28 |
class PerceiverResamplerConfig:
|
| 29 |
"""
|
| 30 |
Parameters to initialize an PerceiverResampler model.
|
| 31 |
-
|
| 32 |
Args:
|
| 33 |
emb_layer_norm_before: Whether to use layer norm before the first attention
|
| 34 |
layer.
|
|
@@ -93,9 +92,7 @@ class PerceiverResamplerConfig:
|
|
| 93 |
class GptConfig:
|
| 94 |
"""
|
| 95 |
Parameters to initialize a Gpt model.
|
| 96 |
-
|
| 97 |
NOTE: the pad token is not defined
|
| 98 |
-
|
| 99 |
Args:
|
| 100 |
vocab_size: Token vocabulary.
|
| 101 |
eos_token_id: used to stop sentence generation
|
|
@@ -191,7 +188,6 @@ class GptConfig:
|
|
| 191 |
class NucleotideTransformerConfig:
|
| 192 |
"""
|
| 193 |
Parameters to initialize an NT model.
|
| 194 |
-
|
| 195 |
Args:
|
| 196 |
alphabet_size: Token vocabulary.
|
| 197 |
pad_token_id: ID of pad token.
|
|
@@ -364,21 +360,20 @@ class ChatNTConfig(PretrainedConfig):
|
|
| 364 |
return output
|
| 365 |
|
| 366 |
|
| 367 |
-
class
|
| 368 |
def __init__(
|
| 369 |
self,
|
| 370 |
gpt_config: GptConfig,
|
| 371 |
seq_token_id: int,
|
| 372 |
):
|
| 373 |
"""
|
| 374 |
-
Initializes the
|
| 375 |
bio embeddings.
|
| 376 |
-
|
| 377 |
Args:
|
| 378 |
gpt_config: Configuration for the GPT model
|
| 379 |
seq_token_id: Index of the SEQ token
|
| 380 |
"""
|
| 381 |
-
super(
|
| 382 |
self.gpt_config = gpt_config
|
| 383 |
self.seq_token_id = seq_token_id
|
| 384 |
|
|
@@ -390,13 +385,11 @@ class ChatNTDecoder(nn.Module):
|
|
| 390 |
) -> torch.Tensor:
|
| 391 |
"""
|
| 392 |
Forward pass through the model.
|
| 393 |
-
|
| 394 |
Args:
|
| 395 |
english_token_ids: Tensor of English token IDs with shape
|
| 396 |
(batch_size, num_english_tokens).
|
| 397 |
projected_bio_embeddings: Optional tensor of bio embeddings with shape
|
| 398 |
(batch_size, num_bio_sequences, ?, embed_dim).
|
| 399 |
-
|
| 400 |
Returns:
|
| 401 |
torch.Tensor: The logits from the GPT model,
|
| 402 |
shaped (batch_size, num_english_tokens, vocab_size).
|
|
@@ -452,13 +445,11 @@ class ChatNTDecoder(nn.Module):
|
|
| 452 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 453 |
"""
|
| 454 |
Inserts resampled embeddings in input_embeddings, starting at the SEQ token
|
| 455 |
-
|
| 456 |
Args:
|
| 457 |
tokens (torch.Tensor): Shape (batch_size, num_tokens)
|
| 458 |
input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
|
| 459 |
resampled_embeddings (torch.Tensor):
|
| 460 |
Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
|
| 461 |
-
|
| 462 |
Returns:
|
| 463 |
Tuple[torch.Tensor, torch.Tensor]:
|
| 464 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
|
@@ -521,11 +512,9 @@ class ChatNTDecoder(nn.Module):
|
|
| 521 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 522 |
"""
|
| 523 |
Removes the logits corresponding to the unused embeddings.
|
| 524 |
-
|
| 525 |
Args:
|
| 526 |
tokens: Input english tokens.
|
| 527 |
logits: Input logits.
|
| 528 |
-
|
| 529 |
Returns:
|
| 530 |
Cleaned logits, last values will be equal to 0.
|
| 531 |
"""
|
|
@@ -582,7 +571,7 @@ class ChatNTDecoder(nn.Module):
|
|
| 582 |
return logits_acc, tokens_acc
|
| 583 |
|
| 584 |
|
| 585 |
-
class
|
| 586 |
config_class = ChatNTConfig
|
| 587 |
|
| 588 |
def __init__(self, config: ChatNTConfig) -> None:
|
|
@@ -625,11 +614,11 @@ class ChatNT(PreTrainedModel):
|
|
| 625 |
# Correct seq_token_id
|
| 626 |
self.seq_token_id -= 1
|
| 627 |
|
| 628 |
-
self.
|
| 629 |
-
self.
|
| 630 |
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
|
| 631 |
)
|
| 632 |
-
self.projection_model =
|
| 633 |
perceiver_resampler_config=self.perceiver_resampler_config,
|
| 634 |
input_embed_dim=self.nt_config.embed_dim,
|
| 635 |
embed_dim=self.gpt_config.embed_dim,
|
|
@@ -645,27 +634,21 @@ class ChatNT(PreTrainedModel):
|
|
| 645 |
projected_bio_embeddings: torch.Tensor = None,
|
| 646 |
) -> dict[str, torch.Tensor]:
|
| 647 |
"""
|
| 648 |
-
|
| 649 |
Args:
|
| 650 |
multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
|
| 651 |
english_tokens_ids: Represents the prompt tokens (english tokens)
|
| 652 |
Shape (batch_size, num_english_tokens)
|
| 653 |
-
|
| 654 |
bio_tokens_ids: Represents the bio sequences tokens
|
| 655 |
Shape (batch_size, num_bio_sequences, num_bio_tokens)
|
| 656 |
-
|
| 657 |
projection_english_tokens_ids (torch.Tensor):
|
| 658 |
Shape (batch_size, num_english_tokens)
|
| 659 |
-
|
| 660 |
projected_bio_embeddings (projected_bio_embeddings, optional):
|
| 661 |
Shape (batch_size, num_bio_sequencse, ?, embed_dim).
|
| 662 |
Defaults to None.
|
| 663 |
-
|
| 664 |
Returns:
|
| 665 |
dict[str, torch.Tensor] containing:
|
| 666 |
- logits:
|
| 667 |
Shape (batch_size, num_tokens, vocab_size)
|
| 668 |
-
|
| 669 |
- projected_bio_embeddings:
|
| 670 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
| 671 |
"""
|
|
@@ -702,7 +685,7 @@ class ChatNT(PreTrainedModel):
|
|
| 702 |
if projected_bio_embeddings is None:
|
| 703 |
# Compute bio sequences embeddings
|
| 704 |
bio_embeddings_list = [
|
| 705 |
-
self.
|
| 706 |
for bio_seq_num in range(num_bio_sequences)
|
| 707 |
]
|
| 708 |
|
|
@@ -718,7 +701,7 @@ class ChatNT(PreTrainedModel):
|
|
| 718 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 719 |
|
| 720 |
# decode
|
| 721 |
-
logits = self.
|
| 722 |
english_token_ids=english_token_ids,
|
| 723 |
projected_bio_embeddings=projected_bio_embeddings,
|
| 724 |
)
|
|
@@ -741,7 +724,6 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 741 |
def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
|
| 742 |
"""
|
| 743 |
Create the sines and cosines for the RoPE.
|
| 744 |
-
|
| 745 |
Returns:
|
| 746 |
Sinusoidal positions of shape (self.max_seq_len, self.dim).
|
| 747 |
"""
|
|
@@ -774,11 +756,9 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 774 |
def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
|
| 775 |
"""
|
| 776 |
Prepare a tensor to apply the RoPE mechanism.
|
| 777 |
-
|
| 778 |
Args:
|
| 779 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 780 |
typically this is the key or query tensor.
|
| 781 |
-
|
| 782 |
Returns:
|
| 783 |
The even indices in the last dimension have their sign flipped.
|
| 784 |
Tensor of shape (batch_size, seq_len, num_heads, head_dim).
|
|
@@ -795,12 +775,10 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 795 |
) -> torch.Tensor:
|
| 796 |
"""
|
| 797 |
Applies rotary embeddings to x.
|
| 798 |
-
|
| 799 |
Args:
|
| 800 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 801 |
typically this is the key or query tensor.
|
| 802 |
sincos: Tuple of sine and cosine tensors for position encoding.
|
| 803 |
-
|
| 804 |
Returns:
|
| 805 |
RoPE embeddings tensor.
|
| 806 |
"""
|
|
@@ -818,12 +796,10 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 818 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 819 |
"""
|
| 820 |
Applies rotary embeddings to k and q.
|
| 821 |
-
|
| 822 |
Args:
|
| 823 |
k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 824 |
q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 825 |
positions: optional positions offset useful when caching,
|
| 826 |
-
|
| 827 |
Returns:
|
| 828 |
RoPE embeddings for the keys and values.
|
| 829 |
"""
|
|
@@ -1141,11 +1117,9 @@ def build_causal_attention_mask(
|
|
| 1141 |
"""
|
| 1142 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
| 1143 |
to an attention layer.
|
| 1144 |
-
|
| 1145 |
Args:
|
| 1146 |
batch_size: Batch size.
|
| 1147 |
seq_len: Length of the sequences.
|
| 1148 |
-
|
| 1149 |
Returns:
|
| 1150 |
Batch of causal masks.
|
| 1151 |
"""
|
|
@@ -1498,12 +1472,12 @@ class RobertaLMHead(nn.Module):
|
|
| 1498 |
return {"embeddings": embeddings, "logits": logits}
|
| 1499 |
|
| 1500 |
|
| 1501 |
-
class
|
| 1502 |
def __init__(
|
| 1503 |
self,
|
| 1504 |
nt_config: NucleotideTransformerConfig,
|
| 1505 |
):
|
| 1506 |
-
super(
|
| 1507 |
self.nt_config = nt_config
|
| 1508 |
|
| 1509 |
# Other cases are not implemented
|
|
@@ -1551,13 +1525,11 @@ class NucleotideTransformer(nn.Module):
|
|
| 1551 |
) -> torch.Tensor:
|
| 1552 |
"""
|
| 1553 |
Computes the embeddings based on the input tokens.
|
| 1554 |
-
|
| 1555 |
Args:
|
| 1556 |
tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
|
| 1557 |
attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
|
| 1558 |
If no mask is provided, a mask by default which equals 1 over all non
|
| 1559 |
pad tokens and 0 over pad tokens is computed.
|
| 1560 |
-
|
| 1561 |
Returns:
|
| 1562 |
Dictionary containing the final embeddings and logits.
|
| 1563 |
"""
|
|
@@ -1585,11 +1557,9 @@ def build_padding_attention_mask(
|
|
| 1585 |
) -> torch.Tensor:
|
| 1586 |
"""
|
| 1587 |
Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
|
| 1588 |
-
|
| 1589 |
Args:
|
| 1590 |
tokens: Batch of sequences of shape (batch_size, seq_len).
|
| 1591 |
pad_token_id: Int corresponding to the <pad> token to mask.
|
| 1592 |
-
|
| 1593 |
Returns:
|
| 1594 |
Batch of attention masks, masking out <pad> tokens.
|
| 1595 |
"""
|
|
@@ -1599,14 +1569,14 @@ def build_padding_attention_mask(
|
|
| 1599 |
return padding_mask
|
| 1600 |
|
| 1601 |
|
| 1602 |
-
class
|
| 1603 |
def __init__(
|
| 1604 |
self,
|
| 1605 |
nt_config: NucleotideTransformerConfig,
|
| 1606 |
):
|
| 1607 |
-
super(
|
| 1608 |
self.nt_config = nt_config
|
| 1609 |
-
self.nt_model =
|
| 1610 |
|
| 1611 |
def forward(
|
| 1612 |
self,
|
|
@@ -1616,7 +1586,6 @@ class ChatNTEncoder(nn.Module):
|
|
| 1616 |
Args:
|
| 1617 |
bio_token_ids (torch.Tensor):
|
| 1618 |
Shape (batch_size, num_bio_tokens)
|
| 1619 |
-
|
| 1620 |
Returns:
|
| 1621 |
torch.Tensor:
|
| 1622 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
|
@@ -1626,7 +1595,7 @@ class ChatNTEncoder(nn.Module):
|
|
| 1626 |
return bio_embeddings
|
| 1627 |
|
| 1628 |
|
| 1629 |
-
class
|
| 1630 |
def __init__(
|
| 1631 |
self,
|
| 1632 |
num_heads: int,
|
|
@@ -1714,7 +1683,7 @@ class MultiModalPerceiverResamplerBlock(nn.Module):
|
|
| 1714 |
return {"embeddings": x}
|
| 1715 |
|
| 1716 |
|
| 1717 |
-
class
|
| 1718 |
"""
|
| 1719 |
Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
|
| 1720 |
"""
|
|
@@ -1726,7 +1695,6 @@ class MultiModalPerceiverResampler(nn.Module):
|
|
| 1726 |
):
|
| 1727 |
"""
|
| 1728 |
Initialize a Perceiver Resampler model.
|
| 1729 |
-
|
| 1730 |
Args:
|
| 1731 |
config: Dataclass containing model hyperparameters.
|
| 1732 |
name: Name for module (custom will break weight loading).
|
|
@@ -1736,7 +1704,7 @@ class MultiModalPerceiverResampler(nn.Module):
|
|
| 1736 |
self.name = name
|
| 1737 |
self.layers = nn.ModuleList(
|
| 1738 |
[
|
| 1739 |
-
|
| 1740 |
num_heads=self.config.attention_heads,
|
| 1741 |
embed_dim=self.config.embed_dim,
|
| 1742 |
key_size=self.config.key_size,
|
|
@@ -1823,7 +1791,7 @@ class MultiModalPerceiverResampler(nn.Module):
|
|
| 1823 |
return outs
|
| 1824 |
|
| 1825 |
|
| 1826 |
-
class
|
| 1827 |
def __init__(
|
| 1828 |
self,
|
| 1829 |
perceiver_resampler_config: PerceiverResamplerConfig,
|
|
@@ -1843,7 +1811,7 @@ class MultiModalPerceiverResamplerProjection(nn.Module):
|
|
| 1843 |
|
| 1844 |
self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
|
| 1845 |
self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
|
| 1846 |
-
self.perceiver_resampler =
|
| 1847 |
|
| 1848 |
def forward(
|
| 1849 |
self,
|
|
@@ -1855,10 +1823,8 @@ class MultiModalPerceiverResamplerProjection(nn.Module):
|
|
| 1855 |
Args:
|
| 1856 |
bio_token_ids (torch.Tensor):
|
| 1857 |
Shape (batch_size, num_bio_tokens)
|
| 1858 |
-
|
| 1859 |
bio_embeddings (torch.Tensor):
|
| 1860 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
| 1861 |
-
|
| 1862 |
english_token_ids (torch.Tensor):
|
| 1863 |
Shape (batch_size, num_english_tokens)
|
| 1864 |
"""
|
|
@@ -1901,3 +1867,4 @@ def build_perceiver_padding_attention_mask(
|
|
| 1901 |
padding_mask = padding_mask[:, None, None, :]
|
| 1902 |
padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
|
| 1903 |
return padding_mask
|
|
|
|
|
|
| 28 |
class PerceiverResamplerConfig:
|
| 29 |
"""
|
| 30 |
Parameters to initialize an PerceiverResampler model.
|
|
|
|
| 31 |
Args:
|
| 32 |
emb_layer_norm_before: Whether to use layer norm before the first attention
|
| 33 |
layer.
|
|
|
|
| 92 |
class GptConfig:
|
| 93 |
"""
|
| 94 |
Parameters to initialize a Gpt model.
|
|
|
|
| 95 |
NOTE: the pad token is not defined
|
|
|
|
| 96 |
Args:
|
| 97 |
vocab_size: Token vocabulary.
|
| 98 |
eos_token_id: used to stop sentence generation
|
|
|
|
| 188 |
class NucleotideTransformerConfig:
|
| 189 |
"""
|
| 190 |
Parameters to initialize an NT model.
|
|
|
|
| 191 |
Args:
|
| 192 |
alphabet_size: Token vocabulary.
|
| 193 |
pad_token_id: ID of pad token.
|
|
|
|
| 360 |
return output
|
| 361 |
|
| 362 |
|
| 363 |
+
class TorchBioBrainDecoder(nn.Module):
|
| 364 |
def __init__(
|
| 365 |
self,
|
| 366 |
gpt_config: GptConfig,
|
| 367 |
seq_token_id: int,
|
| 368 |
):
|
| 369 |
"""
|
| 370 |
+
Initializes the BioBrain decoder, using a GPT model for text generation with
|
| 371 |
bio embeddings.
|
|
|
|
| 372 |
Args:
|
| 373 |
gpt_config: Configuration for the GPT model
|
| 374 |
seq_token_id: Index of the SEQ token
|
| 375 |
"""
|
| 376 |
+
super(TorchBioBrainDecoder, self).__init__()
|
| 377 |
self.gpt_config = gpt_config
|
| 378 |
self.seq_token_id = seq_token_id
|
| 379 |
|
|
|
|
| 385 |
) -> torch.Tensor:
|
| 386 |
"""
|
| 387 |
Forward pass through the model.
|
|
|
|
| 388 |
Args:
|
| 389 |
english_token_ids: Tensor of English token IDs with shape
|
| 390 |
(batch_size, num_english_tokens).
|
| 391 |
projected_bio_embeddings: Optional tensor of bio embeddings with shape
|
| 392 |
(batch_size, num_bio_sequences, ?, embed_dim).
|
|
|
|
| 393 |
Returns:
|
| 394 |
torch.Tensor: The logits from the GPT model,
|
| 395 |
shaped (batch_size, num_english_tokens, vocab_size).
|
|
|
|
| 445 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 446 |
"""
|
| 447 |
Inserts resampled embeddings in input_embeddings, starting at the SEQ token
|
|
|
|
| 448 |
Args:
|
| 449 |
tokens (torch.Tensor): Shape (batch_size, num_tokens)
|
| 450 |
input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
|
| 451 |
resampled_embeddings (torch.Tensor):
|
| 452 |
Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
|
|
|
|
| 453 |
Returns:
|
| 454 |
Tuple[torch.Tensor, torch.Tensor]:
|
| 455 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
|
|
|
| 512 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 513 |
"""
|
| 514 |
Removes the logits corresponding to the unused embeddings.
|
|
|
|
| 515 |
Args:
|
| 516 |
tokens: Input english tokens.
|
| 517 |
logits: Input logits.
|
|
|
|
| 518 |
Returns:
|
| 519 |
Cleaned logits, last values will be equal to 0.
|
| 520 |
"""
|
|
|
|
| 571 |
return logits_acc, tokens_acc
|
| 572 |
|
| 573 |
|
| 574 |
+
class TorchMultiOmicsModel(PreTrainedModel):
|
| 575 |
config_class = ChatNTConfig
|
| 576 |
|
| 577 |
def __init__(self, config: ChatNTConfig) -> None:
|
|
|
|
| 614 |
# Correct seq_token_id
|
| 615 |
self.seq_token_id -= 1
|
| 616 |
|
| 617 |
+
self.biobrain_encoder = TorchBioBrainEncoder(nt_config=self.nt_config)
|
| 618 |
+
self.biobrain_decoder = TorchBioBrainDecoder(
|
| 619 |
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
|
| 620 |
)
|
| 621 |
+
self.projection_model = TorchMultiModalPerceiverResamplerProjection(
|
| 622 |
perceiver_resampler_config=self.perceiver_resampler_config,
|
| 623 |
input_embed_dim=self.nt_config.embed_dim,
|
| 624 |
embed_dim=self.gpt_config.embed_dim,
|
|
|
|
| 634 |
projected_bio_embeddings: torch.Tensor = None,
|
| 635 |
) -> dict[str, torch.Tensor]:
|
| 636 |
"""
|
|
|
|
| 637 |
Args:
|
| 638 |
multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
|
| 639 |
english_tokens_ids: Represents the prompt tokens (english tokens)
|
| 640 |
Shape (batch_size, num_english_tokens)
|
|
|
|
| 641 |
bio_tokens_ids: Represents the bio sequences tokens
|
| 642 |
Shape (batch_size, num_bio_sequences, num_bio_tokens)
|
|
|
|
| 643 |
projection_english_tokens_ids (torch.Tensor):
|
| 644 |
Shape (batch_size, num_english_tokens)
|
|
|
|
| 645 |
projected_bio_embeddings (projected_bio_embeddings, optional):
|
| 646 |
Shape (batch_size, num_bio_sequencse, ?, embed_dim).
|
| 647 |
Defaults to None.
|
|
|
|
| 648 |
Returns:
|
| 649 |
dict[str, torch.Tensor] containing:
|
| 650 |
- logits:
|
| 651 |
Shape (batch_size, num_tokens, vocab_size)
|
|
|
|
| 652 |
- projected_bio_embeddings:
|
| 653 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
| 654 |
"""
|
|
|
|
| 685 |
if projected_bio_embeddings is None:
|
| 686 |
# Compute bio sequences embeddings
|
| 687 |
bio_embeddings_list = [
|
| 688 |
+
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
| 689 |
for bio_seq_num in range(num_bio_sequences)
|
| 690 |
]
|
| 691 |
|
|
|
|
| 701 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 702 |
|
| 703 |
# decode
|
| 704 |
+
logits = self.biobrain_decoder(
|
| 705 |
english_token_ids=english_token_ids,
|
| 706 |
projected_bio_embeddings=projected_bio_embeddings,
|
| 707 |
)
|
|
|
|
| 724 |
def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
|
| 725 |
"""
|
| 726 |
Create the sines and cosines for the RoPE.
|
|
|
|
| 727 |
Returns:
|
| 728 |
Sinusoidal positions of shape (self.max_seq_len, self.dim).
|
| 729 |
"""
|
|
|
|
| 756 |
def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
|
| 757 |
"""
|
| 758 |
Prepare a tensor to apply the RoPE mechanism.
|
|
|
|
| 759 |
Args:
|
| 760 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 761 |
typically this is the key or query tensor.
|
|
|
|
| 762 |
Returns:
|
| 763 |
The even indices in the last dimension have their sign flipped.
|
| 764 |
Tensor of shape (batch_size, seq_len, num_heads, head_dim).
|
|
|
|
| 775 |
) -> torch.Tensor:
|
| 776 |
"""
|
| 777 |
Applies rotary embeddings to x.
|
|
|
|
| 778 |
Args:
|
| 779 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 780 |
typically this is the key or query tensor.
|
| 781 |
sincos: Tuple of sine and cosine tensors for position encoding.
|
|
|
|
| 782 |
Returns:
|
| 783 |
RoPE embeddings tensor.
|
| 784 |
"""
|
|
|
|
| 796 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 797 |
"""
|
| 798 |
Applies rotary embeddings to k and q.
|
|
|
|
| 799 |
Args:
|
| 800 |
k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 801 |
q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 802 |
positions: optional positions offset useful when caching,
|
|
|
|
| 803 |
Returns:
|
| 804 |
RoPE embeddings for the keys and values.
|
| 805 |
"""
|
|
|
|
| 1117 |
"""
|
| 1118 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
| 1119 |
to an attention layer.
|
|
|
|
| 1120 |
Args:
|
| 1121 |
batch_size: Batch size.
|
| 1122 |
seq_len: Length of the sequences.
|
|
|
|
| 1123 |
Returns:
|
| 1124 |
Batch of causal masks.
|
| 1125 |
"""
|
|
|
|
| 1472 |
return {"embeddings": embeddings, "logits": logits}
|
| 1473 |
|
| 1474 |
|
| 1475 |
+
class TorchNucleotideTransformer(nn.Module):
|
| 1476 |
def __init__(
|
| 1477 |
self,
|
| 1478 |
nt_config: NucleotideTransformerConfig,
|
| 1479 |
):
|
| 1480 |
+
super(TorchNucleotideTransformer, self).__init__()
|
| 1481 |
self.nt_config = nt_config
|
| 1482 |
|
| 1483 |
# Other cases are not implemented
|
|
|
|
| 1525 |
) -> torch.Tensor:
|
| 1526 |
"""
|
| 1527 |
Computes the embeddings based on the input tokens.
|
|
|
|
| 1528 |
Args:
|
| 1529 |
tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
|
| 1530 |
attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
|
| 1531 |
If no mask is provided, a mask by default which equals 1 over all non
|
| 1532 |
pad tokens and 0 over pad tokens is computed.
|
|
|
|
| 1533 |
Returns:
|
| 1534 |
Dictionary containing the final embeddings and logits.
|
| 1535 |
"""
|
|
|
|
| 1557 |
) -> torch.Tensor:
|
| 1558 |
"""
|
| 1559 |
Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
|
|
|
|
| 1560 |
Args:
|
| 1561 |
tokens: Batch of sequences of shape (batch_size, seq_len).
|
| 1562 |
pad_token_id: Int corresponding to the <pad> token to mask.
|
|
|
|
| 1563 |
Returns:
|
| 1564 |
Batch of attention masks, masking out <pad> tokens.
|
| 1565 |
"""
|
|
|
|
| 1569 |
return padding_mask
|
| 1570 |
|
| 1571 |
|
| 1572 |
+
class TorchBioBrainEncoder(nn.Module):
|
| 1573 |
def __init__(
|
| 1574 |
self,
|
| 1575 |
nt_config: NucleotideTransformerConfig,
|
| 1576 |
):
|
| 1577 |
+
super(TorchBioBrainEncoder, self).__init__()
|
| 1578 |
self.nt_config = nt_config
|
| 1579 |
+
self.nt_model = TorchNucleotideTransformer(self.nt_config)
|
| 1580 |
|
| 1581 |
def forward(
|
| 1582 |
self,
|
|
|
|
| 1586 |
Args:
|
| 1587 |
bio_token_ids (torch.Tensor):
|
| 1588 |
Shape (batch_size, num_bio_tokens)
|
|
|
|
| 1589 |
Returns:
|
| 1590 |
torch.Tensor:
|
| 1591 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
|
|
|
| 1595 |
return bio_embeddings
|
| 1596 |
|
| 1597 |
|
| 1598 |
+
class TorchMultiModalPerceiverResamplerBlock(nn.Module):
|
| 1599 |
def __init__(
|
| 1600 |
self,
|
| 1601 |
num_heads: int,
|
|
|
|
| 1683 |
return {"embeddings": x}
|
| 1684 |
|
| 1685 |
|
| 1686 |
+
class TorchMultiModalPerceiverResampler(nn.Module):
|
| 1687 |
"""
|
| 1688 |
Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
|
| 1689 |
"""
|
|
|
|
| 1695 |
):
|
| 1696 |
"""
|
| 1697 |
Initialize a Perceiver Resampler model.
|
|
|
|
| 1698 |
Args:
|
| 1699 |
config: Dataclass containing model hyperparameters.
|
| 1700 |
name: Name for module (custom will break weight loading).
|
|
|
|
| 1704 |
self.name = name
|
| 1705 |
self.layers = nn.ModuleList(
|
| 1706 |
[
|
| 1707 |
+
TorchMultiModalPerceiverResamplerBlock(
|
| 1708 |
num_heads=self.config.attention_heads,
|
| 1709 |
embed_dim=self.config.embed_dim,
|
| 1710 |
key_size=self.config.key_size,
|
|
|
|
| 1791 |
return outs
|
| 1792 |
|
| 1793 |
|
| 1794 |
+
class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
| 1795 |
def __init__(
|
| 1796 |
self,
|
| 1797 |
perceiver_resampler_config: PerceiverResamplerConfig,
|
|
|
|
| 1811 |
|
| 1812 |
self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
|
| 1813 |
self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
|
| 1814 |
+
self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
|
| 1815 |
|
| 1816 |
def forward(
|
| 1817 |
self,
|
|
|
|
| 1823 |
Args:
|
| 1824 |
bio_token_ids (torch.Tensor):
|
| 1825 |
Shape (batch_size, num_bio_tokens)
|
|
|
|
| 1826 |
bio_embeddings (torch.Tensor):
|
| 1827 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
|
|
|
| 1828 |
english_token_ids (torch.Tensor):
|
| 1829 |
Shape (batch_size, num_english_tokens)
|
| 1830 |
"""
|
|
|
|
| 1867 |
padding_mask = padding_mask[:, None, None, :]
|
| 1868 |
padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
|
| 1869 |
return padding_mask
|
| 1870 |
+
|