phobert-tox-api / modeling_phobert_attn.py
DungSon's picture
Upload modeling_phobert_attn.py
b1edabb verified
raw
history blame contribute delete
886 Bytes
# modeling_phobert_attn.py
import torch
import torch.nn as nn
from transformers import AutoModel
class PhoBERT_Attention(nn.Module):
def __init__(self, num_classes=2, dropout=0.3):
super().__init__()
self.xlm_roberta = AutoModel.from_pretrained("vinai/phobert-base")
hidden = self.xlm_roberta.config.hidden_size
self.attention = nn.Linear(hidden, 1)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden, num_classes)
def forward(self, input_ids, attention_mask):
out = self.xlm_roberta(input_ids=input_ids, attention_mask=attention_mask)
H = out.last_hidden_state # [B, T, H]
attn = torch.softmax(self.attention(H), dim=1) # [B, T, 1]
ctx = (attn * H).sum(dim=1) # [B, H]
logits = self.fc(self.dropout(ctx)) # [B, C]
return logits