Nicheformer / configuration_nicheformer.py
aletlvl's picture
Upload Nicheformer model
a24db0c verified
from transformers import PretrainedConfig
class NicheformerConfig(PretrainedConfig):
model_type = "nicheformer"
def __init__(
self,
dim_model=512,
nheads=16,
dim_feedforward=1024,
nlayers=12,
dropout=0.0,
batch_first=True,
masking_p=0.15,
n_tokens=20340,
context_length=1500,
cls_classes=164,
supervised_task=None,
learnable_pe=True,
specie=True,
assay=True,
modality=True,
**kwargs
):
"""Initialize NicheformerConfig.
Args:
dim_model: Dimensionality of the model
nheads: Number of attention heads
dim_feedforward: Dimensionality of MLPs in attention blocks
nlayers: Number of transformer layers
dropout: Dropout probability
batch_first: Whether batch dimension is first
masking_p: Probability of masking tokens
n_tokens: Total number of tokens (excluding auxiliary)
context_length: Length of the context window
cls_classes: Number of classification classes
supervised_task: Type of supervised task
learnable_pe: Whether to use learnable positional embeddings
specie: Whether to add specie token
assay: Whether to add assay token
modality: Whether to add modality token
"""
super().__init__(**kwargs)
self.dim_model = dim_model
self.nheads = nheads
self.dim_feedforward = dim_feedforward
self.nlayers = nlayers
self.dropout = dropout
self.batch_first = batch_first
self.masking_p = masking_p
self.n_tokens = n_tokens
self.context_length = context_length
self.cls_classes = cls_classes
self.supervised_task = supervised_task
self.learnable_pe = learnable_pe
self.specie = specie
self.assay = assay
self.modality = modality