| import numpy | |
| import torch | |
| import torch.nn as nn | |
| class LCF_Pooler(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.activation = nn.Tanh() | |
| def forward(self, hidden_states, lcf_vec): | |
| device = hidden_states.device | |
| lcf_vec = lcf_vec.detach().cpu().numpy() | |
| pooled_output = numpy.zeros( | |
| (hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32 | |
| ) | |
| hidden_states = hidden_states.detach().cpu().numpy() | |
| for i, vec in enumerate(lcf_vec): | |
| lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.0) == 0] | |
| pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]] | |
| pooled_output = torch.Tensor(pooled_output).to(device) | |
| pooled_output = self.dense(pooled_output) | |
| pooled_output = self.activation(pooled_output) | |
| return pooled_output | |