Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from transformers import RobertaModel | |
| class CommentClassifier(nn.Module): | |
| def __init__(self, dropout=0.3): | |
| super(CommentClassifier, self).__init__() | |
| self.roberta = RobertaModel.from_pretrained("roberta-base") | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Linear(self.roberta.config.hidden_size + 3, 1) # +3 for metadata features | |
| def forward(self, input_ids, attention_mask, metadata_features): | |
| outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.pooler_output | |
| combined = torch.cat((pooled_output, metadata_features), dim=1) | |
| x = self.dropout(combined) | |
| return self.classifier(x) | |