Spaces:
Build error
Build error
Upload weakly_supervised_parser/utils/populate_chart.py
Browse files
weakly_supervised_parser/utils/populate_chart.py
CHANGED
|
@@ -26,9 +26,9 @@ ptb_top_100_common = ['this', 'myself', 'shouldn', 'not', 'analysts', 'same', 'm
|
|
| 26 |
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
|
| 27 |
ptb_most_common_first_token = "the"
|
| 28 |
|
| 29 |
-
|
| 30 |
|
| 31 |
-
|
| 32 |
|
| 33 |
|
| 34 |
class PopulateCKYChart:
|
|
@@ -54,20 +54,20 @@ class PopulateCKYChart:
|
|
| 54 |
|
| 55 |
if predict_type == "inside":
|
| 56 |
|
| 57 |
-
if data.shape[0] > chunks:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
else:
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
|
| 72 |
data["inside_scores"] = inside_scores
|
| 73 |
data.loc[
|
|
|
|
| 26 |
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
|
| 27 |
ptb_most_common_first_token = "the"
|
| 28 |
|
| 29 |
+
from pytorch_lightning import Trainer
|
| 30 |
|
| 31 |
+
trainer = Trainer(accelerator="auto", enable_progress_bar=False, max_epochs=-1)
|
| 32 |
|
| 33 |
|
| 34 |
class PopulateCKYChart:
|
|
|
|
| 54 |
|
| 55 |
if predict_type == "inside":
|
| 56 |
|
| 57 |
+
# if data.shape[0] > chunks:
|
| 58 |
+
# data_chunks = np.array_split(data, data.shape[0] // chunks)
|
| 59 |
+
# for data_chunk in data_chunks:
|
| 60 |
+
# inside_scores.extend(model.predict_proba(spans=data_chunk.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
| 61 |
+
# scale_axis=scale_axis,
|
| 62 |
+
# predict_batch_size=predict_batch_size)[:, 1])
|
| 63 |
+
# else:
|
| 64 |
+
# inside_scores.extend(model.predict_proba(spans=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
| 65 |
+
# scale_axis=scale_axis,
|
| 66 |
+
# predict_batch_size=predict_batch_size)[:, 1])
|
| 67 |
|
| 68 |
+
test_dataloader = DataModule(model_name_or_path="roberta-base", train_df=None, eval_df=None,
|
| 69 |
+
test_df=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]])
|
| 70 |
+
inside_scores.extend(trainer.predict(model, dataloaders=test_dataloader)[0])
|
| 71 |
|
| 72 |
data["inside_scores"] = inside_scores
|
| 73 |
data.loc[
|