| from torch import nn | |
| class ClassificationModel(nn.Module): | |
| def __init__(self, base_model): | |
| super(ClassificationModel, self).__init__() | |
| self.base_model = base_model | |
| self.classifier = nn.Sequential( | |
| nn.Linear(768, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 8), | |
| nn.LogSoftmax(dim=1) | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| hidden_states = self.base_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
| cls_output = hidden_states[:, 0, :] | |
| probs = self.classifier(cls_output) | |
| return probs | |