Nick Sorros
commited on
Commit
·
e842824
1
Parent(s):
d89f434
Update WellcomeBertMesh with transformers based trained model
Browse files- config.json +0 -0
- model.py +19 -23
- pytorch_model.bin +2 -2
config.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from transformers import AutoModel,
|
| 2 |
import torch
|
| 3 |
|
| 4 |
|
|
@@ -16,34 +16,33 @@ class MultiLabelAttention(torch.nn.Module):
|
|
| 16 |
|
| 17 |
|
| 18 |
class BertMesh(PreTrainedModel):
|
|
|
|
|
|
|
| 19 |
def __init__(
|
| 20 |
self,
|
| 21 |
config,
|
| 22 |
-
pretrained_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
|
| 23 |
-
num_labels=28761,
|
| 24 |
-
hidden_size=1024,
|
| 25 |
-
dropout=0,
|
| 26 |
-
multilabel_attention=True,
|
| 27 |
):
|
| 28 |
super().__init__(config=config)
|
| 29 |
-
self.config.auto_map = {"AutoModel": "
|
| 30 |
-
self.pretrained_model = pretrained_model
|
| 31 |
-
self.num_labels = num_labels
|
| 32 |
-
self.hidden_size = hidden_size
|
| 33 |
-
self.dropout = dropout
|
| 34 |
-
self.multilabel_attention = multilabel_attention
|
| 35 |
-
|
| 36 |
-
self.bert = AutoModel.from_pretrained(pretrained_model) # 768
|
| 37 |
self.multilabel_attention_layer = MultiLabelAttention(
|
| 38 |
-
768, num_labels
|
| 39 |
) # num_labels, 768
|
| 40 |
-
self.linear_1 = torch.nn.Linear(768, hidden_size) # num_labels, 512
|
| 41 |
-
self.linear_2 = torch.nn.Linear(hidden_size, 1) # num_labels, 1
|
| 42 |
-
self.linear_out = torch.nn.Linear(hidden_size, num_labels)
|
| 43 |
self.dropout_layer = torch.nn.Dropout(self.dropout)
|
| 44 |
|
| 45 |
-
def forward(self, input_ids,
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
if self.multilabel_attention:
|
| 48 |
hidden_states = self.bert(input_ids=input_ids)[0]
|
| 49 |
attention_outs = self.multilabel_attention_layer(hidden_states)
|
|
@@ -57,6 +56,3 @@ class BertMesh(PreTrainedModel):
|
|
| 57 |
outs = self.dropout_layer(outs)
|
| 58 |
outs = torch.sigmoid(self.linear_out(outs))
|
| 59 |
return outs
|
| 60 |
-
|
| 61 |
-
def _init_weights(self, module):
|
| 62 |
-
pass
|
|
|
|
| 1 |
+
from transformers import AutoModel, PreTrainedModel, BertConfig
|
| 2 |
import torch
|
| 3 |
|
| 4 |
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class BertMesh(PreTrainedModel):
|
| 19 |
+
config_class = BertConfig
|
| 20 |
+
|
| 21 |
def __init__(
|
| 22 |
self,
|
| 23 |
config,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
):
|
| 25 |
super().__init__(config=config)
|
| 26 |
+
self.config.auto_map = {"AutoModel": "model.BertMesh"}
|
| 27 |
+
self.pretrained_model = self.config.pretrained_model
|
| 28 |
+
self.num_labels = self.config.num_labels
|
| 29 |
+
self.hidden_size = getattr(self.config, "hidden_size", 512)
|
| 30 |
+
self.dropout = getattr(self.config, "dropout", 0.1)
|
| 31 |
+
self.multilabel_attention = getattr(self.config, "multilabel_attention", False)
|
| 32 |
+
|
| 33 |
+
self.bert = AutoModel.from_pretrained(self.pretrained_model) # 768
|
| 34 |
self.multilabel_attention_layer = MultiLabelAttention(
|
| 35 |
+
768, self.num_labels
|
| 36 |
) # num_labels, 768
|
| 37 |
+
self.linear_1 = torch.nn.Linear(768, self.hidden_size) # num_labels, 512
|
| 38 |
+
self.linear_2 = torch.nn.Linear(self.hidden_size, 1) # num_labels, 1
|
| 39 |
+
self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels)
|
| 40 |
self.dropout_layer = torch.nn.Dropout(self.dropout)
|
| 41 |
|
| 42 |
+
def forward(self, input_ids, **kwargs):
|
| 43 |
+
if type(input_ids) is list:
|
| 44 |
+
# coming from tokenizer
|
| 45 |
+
input_ids = torch.tensor(input_ids)
|
| 46 |
if self.multilabel_attention:
|
| 47 |
hidden_states = self.bert(input_ids=input_ids)[0]
|
| 48 |
attention_outs = self.multilabel_attention_layer(hidden_states)
|
|
|
|
| 56 |
outs = self.dropout_layer(outs)
|
| 57 |
outs = torch.sigmoid(self.linear_out(outs))
|
| 58 |
return outs
|
|
|
|
|
|
|
|
|
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c80db3a392fe08b3faa111d46e48fef56eb2c0efe862f0a80cc7fe4da55baea
|
| 3 |
+
size 647442531
|