Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from transformers import PreTrainedModel | |
| from transformers import PretrainedConfig | |
| class CustomClassificationConfig(PretrainedConfig): | |
| model_type = "custom_classifier" | |
| def __init__(self, input_dim=32, hidden_dim=64, num_classes=2, **kwargs): | |
| super().__init__(**kwargs) | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.num_classes = num_classes | |
| class CustomClassifier(PreTrainedModel): | |
| config_class = CustomClassificationConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.encoder = nn.Sequential( | |
| nn.Linear(config.input_dim, config.hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(config.hidden_dim, config.hidden_dim), | |
| nn.ReLU(), | |
| ) | |
| self.classifier = nn.Linear(config.hidden_dim, config.num_classes) | |
| def forward(self, input_ids=None, labels=None, **kwargs): | |
| # input_ids: shape (batch_size, input_dim) | |
| hidden = self.encoder(input_ids) | |
| logits = self.classifier(hidden) | |
| loss = None | |
| if labels is not None: | |
| loss_fn = nn.CrossEntropyLoss() | |
| loss = loss_fn(logits, labels) | |
| return {"loss": loss, "logits": logits} | |