Spaces:
Build error
Build error
update trainer
Browse files
weakly_supervised_parser/model/trainer.py
CHANGED
|
@@ -10,7 +10,7 @@ from pytorch_lightning import Trainer, seed_everything
|
|
| 10 |
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
| 11 |
from transformers import AutoTokenizer, logging
|
| 12 |
|
| 13 |
-
from onnxruntime import InferenceSession
|
| 14 |
from scipy.special import softmax
|
| 15 |
|
| 16 |
from weakly_supervised_parser.model.data_module_loader import DataModule
|
|
@@ -98,7 +98,10 @@ class InsideOutsideStringClassifier:
|
|
| 98 |
)
|
| 99 |
|
| 100 |
def load_model(self, pre_trained_model_path):
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
| 102 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
| 103 |
|
| 104 |
def preprocess_function(self, data):
|
|
|
|
| 10 |
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
| 11 |
from transformers import AutoTokenizer, logging
|
| 12 |
|
| 13 |
+
from onnxruntime import InferenceSession, SessionOptions
|
| 14 |
from scipy.special import softmax
|
| 15 |
|
| 16 |
from weakly_supervised_parser.model.data_module_loader import DataModule
|
|
|
|
| 98 |
)
|
| 99 |
|
| 100 |
def load_model(self, pre_trained_model_path):
|
| 101 |
+
options = SessionOptions()
|
| 102 |
+
options.intra_op_num_threads = 1
|
| 103 |
+
options.inter_op_num_threads = 1
|
| 104 |
+
self.model = InferenceSession(pre_trained_model_path, options, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
| 105 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
| 106 |
|
| 107 |
def preprocess_function(self, data):
|