| from transformers import PretrainedConfig | |
| class TotalClassifierConfig(PretrainedConfig): | |
| model_type = "total_classifier" | |
| def __init__( | |
| self, | |
| backbone: str = "tf_efficientnetv2_b0", | |
| feature_dim: int = 192, | |
| cnn_dropout: float = 0.1, | |
| in_chans: int = 1, | |
| rnn_type: str = "GRU", | |
| rnn_num_layers: int = 1, | |
| rnn_dropout: float = 0.0, | |
| num_classes: int = 117, | |
| seq_len: int = 512, | |
| linear_dropout: float = 0.1, | |
| image_size: tuple[int, int] = (256, 256), | |
| **kwargs, | |
| ): | |
| self.backbone = backbone | |
| self.feature_dim = feature_dim | |
| self.cnn_dropout = cnn_dropout | |
| self.in_chans = in_chans | |
| self.rnn_type = rnn_type | |
| self.rnn_num_layers = rnn_num_layers | |
| self.rnn_dropout = rnn_dropout | |
| self.num_classes = num_classes | |
| self.seq_len = seq_len | |
| self.linear_dropout = linear_dropout | |
| self.image_size = image_size | |
| super().__init__(**kwargs) | |