Spaces:
Sleeping
Sleeping
Upload 13 files
Browse files- grewtse/__init__.py +3 -0
- grewtse/evaluators/__init__.py +4 -0
- grewtse/evaluators/evaluator.py +403 -0
- grewtse/evaluators/metrics.py +104 -0
- grewtse/pipeline.py +295 -0
- grewtse/preprocessing/__init__.py +3 -0
- grewtse/preprocessing/conllu_parser.py +348 -0
- grewtse/preprocessing/grew_dependencies.py +83 -0
- grewtse/preprocessing/reconstruction.py +71 -0
- grewtse/utils/__init__.py +0 -0
- grewtse/utils/validation.py +24 -0
- grewtse/visualise/__init__.py +3 -0
- grewtse/visualise/visualiser.py +132 -0
grewtse/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipeline import GrewTSEPipe
|
| 2 |
+
|
| 3 |
+
__all__ = ["GrewTSEPipe"]
|
grewtse/evaluators/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluator import GrewTSEvaluator
|
| 2 |
+
from .evaluator import Evaluator
|
| 3 |
+
|
| 4 |
+
__all__ = ["GrewTSEvaluator", "Evaluator"]
|
grewtse/evaluators/evaluator.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
|
| 3 |
+
from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
| 5 |
+
from typing import Tuple, NamedTuple, List, Any
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import itertools
|
| 10 |
+
import logging
|
| 11 |
+
import torch
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
from grewtse.utils.validation import load_and_validate_mp_dataset
|
| 15 |
+
from grewtse.evaluators.metrics import (
|
| 16 |
+
compute_normalised_surprisal_difference,
|
| 17 |
+
compute_average_surprisal_difference,
|
| 18 |
+
compute_entropy,
|
| 19 |
+
compute_surprisal,
|
| 20 |
+
compute_mean,
|
| 21 |
+
calculate_all_metrics
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
EVAL_TEMPLATE = {
|
| 25 |
+
"sentence_id": None,
|
| 26 |
+
"match_id": None,
|
| 27 |
+
"original_text": None,
|
| 28 |
+
"prompt_text": None,
|
| 29 |
+
"form_grammatical": None,
|
| 30 |
+
"p_grammatical": None,
|
| 31 |
+
"I_grammatical": None,
|
| 32 |
+
"form_ungrammatical": None,
|
| 33 |
+
"p_ungrammatical": None,
|
| 34 |
+
"I_ungrammatical": None,
|
| 35 |
+
"entropy": None,
|
| 36 |
+
"entropy_norm": None,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TooManyMasksException(Exception):
|
| 41 |
+
def __init__(self, message: str):
|
| 42 |
+
self.message = message
|
| 43 |
+
super().__init__(f"TMM Exception: {message}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Prediction(NamedTuple):
|
| 47 |
+
token: str
|
| 48 |
+
prob: float
|
| 49 |
+
surprisal: float
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class GrewTSEvaluator:
|
| 53 |
+
"""
|
| 54 |
+
An evaluation class designed specifically for rapid syntactic evaluation of models available on the Hugging Face platform.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self):
|
| 58 |
+
self.evaluator = Evaluator()
|
| 59 |
+
self.evaluation_dataset = None
|
| 60 |
+
|
| 61 |
+
def evaluate_model(
|
| 62 |
+
self,
|
| 63 |
+
mp_dataset: pd.DataFrame,
|
| 64 |
+
model_repo: str,
|
| 65 |
+
model_type: str, # can be 'encoder' or 'decoder'
|
| 66 |
+
entropy_topk: int = 100,
|
| 67 |
+
row_limit: int = None,
|
| 68 |
+
) -> pd.DataFrame:
|
| 69 |
+
"""
|
| 70 |
+
Generic evaluation function for encoder or decoder models.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
# --- Prepare dataset ---
|
| 74 |
+
mp_dataset_iter = mp_dataset.itertuples()
|
| 75 |
+
if row_limit:
|
| 76 |
+
mp_dataset_iter = itertools.islice(mp_dataset_iter, row_limit)
|
| 77 |
+
n = len(mp_dataset) if not row_limit else row_limit
|
| 78 |
+
|
| 79 |
+
# --- Load model & tokenizer ---
|
| 80 |
+
is_encoder = model_type == "encoder"
|
| 81 |
+
model, tokenizer = self.evaluator.setup_parameters(model_repo, is_encoder)
|
| 82 |
+
results = []
|
| 83 |
+
|
| 84 |
+
# --- Evaluate each row ---
|
| 85 |
+
for row in tqdm(mp_dataset_iter, ncols=n):
|
| 86 |
+
row_results = self._init_row_results(row)
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
if is_encoder:
|
| 90 |
+
self._evaluate_encoder_row(row, row_results)
|
| 91 |
+
else:
|
| 92 |
+
self._evaluate_decoder_row(row, row_results)
|
| 93 |
+
|
| 94 |
+
except TooManyMasksException:
|
| 95 |
+
logging.error(f"Too many masks in {row.sentence_id}")
|
| 96 |
+
continue
|
| 97 |
+
except Exception as e:
|
| 98 |
+
raise RuntimeError(f"Model/tokeniser issue: {e}") from e
|
| 99 |
+
|
| 100 |
+
# --- Entropy ---
|
| 101 |
+
entropy, entropy_norm = self.evaluator.get_entropy(entropy_topk, True)
|
| 102 |
+
row_results["entropy"] = entropy
|
| 103 |
+
row_results["entropy_norm"] = entropy_norm
|
| 104 |
+
|
| 105 |
+
results.append(row_results)
|
| 106 |
+
|
| 107 |
+
results_df = pd.DataFrame(results, columns=EVAL_TEMPLATE.keys())
|
| 108 |
+
self.evaluation_dataset = results_df
|
| 109 |
+
return results_df
|
| 110 |
+
|
| 111 |
+
def evaluate_from_minimal_pairs(
|
| 112 |
+
self,
|
| 113 |
+
mp_dataset_filepath: str,
|
| 114 |
+
model_repo: str,
|
| 115 |
+
model_type: str,
|
| 116 |
+
entropy_topk: int = 100,
|
| 117 |
+
row_limit: int = None,
|
| 118 |
+
) -> pd.DataFrame:
|
| 119 |
+
mp_dataset = load_and_validate_mp_dataset(mp_dataset_filepath)
|
| 120 |
+
self.mp_dataset = mp_dataset
|
| 121 |
+
return self.evaluate_model(model_repo, model_type, entropy_topk, row_limit)
|
| 122 |
+
|
| 123 |
+
# --- Helper functions ---
|
| 124 |
+
def _init_row_results(self, row):
|
| 125 |
+
row_results = EVAL_TEMPLATE.copy()
|
| 126 |
+
row_results.update(row._asdict())
|
| 127 |
+
return row_results
|
| 128 |
+
|
| 129 |
+
def _evaluate_encoder_row(self, row, row_results):
|
| 130 |
+
prob_gram, prob_ungram = self.evaluator.run_masked_prediction(
|
| 131 |
+
row.masked_text,
|
| 132 |
+
row.form_grammatical,
|
| 133 |
+
row.form_ungrammatical,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
row_results["p_grammatical"] = prob_gram
|
| 137 |
+
row_results["p_ungrammatical"] = prob_ungram
|
| 138 |
+
row_results["I_grammatical"] = compute_surprisal(prob_gram)
|
| 139 |
+
row_results["I_ungrammatical"] = compute_surprisal(prob_ungram)
|
| 140 |
+
|
| 141 |
+
if "ood_minimal_pairs" in row:
|
| 142 |
+
ood_pairs_str = row.ood_pairs
|
| 143 |
+
ood_pairs = ast.literal_eval(ood_pairs_str)
|
| 144 |
+
all_ood_probs_gram = []
|
| 145 |
+
all_ood_probs_ungram = []
|
| 146 |
+
for pair in ood_pairs:
|
| 147 |
+
ood_prob_gram, ood_prob_ungram = self.evaluator.run_masked_prediction(
|
| 148 |
+
row.masked_text, pair[0], pair[1]
|
| 149 |
+
)
|
| 150 |
+
all_ood_probs_gram.append(ood_prob_gram)
|
| 151 |
+
all_ood_probs_ungram.append(ood_prob_ungram)
|
| 152 |
+
|
| 153 |
+
avg_ood_prob_gram = compute_mean(all_ood_probs_gram)
|
| 154 |
+
avg_ood_prob_ungram = compute_mean(all_ood_probs_ungram)
|
| 155 |
+
|
| 156 |
+
row_results["ood_p_grammatical"] = avg_ood_prob_gram
|
| 157 |
+
row_results["ood_p_ungrammatical"] = avg_ood_prob_ungram
|
| 158 |
+
row_results["ood_I_grammatical"] = compute_surprisal(avg_ood_prob_gram)
|
| 159 |
+
row_results["ood_I_ungrammatical"] = compute_surprisal(avg_ood_prob_ungram)
|
| 160 |
+
|
| 161 |
+
def _evaluate_decoder_row(self, row, row_results):
|
| 162 |
+
prob_gram, prob_ungram = self.evaluator.run_next_word_prediction(
|
| 163 |
+
row.prompt_text, row.form_grammatical, row.form_ungrammatical
|
| 164 |
+
)
|
| 165 |
+
row_results["p_grammatical"] = prob_gram
|
| 166 |
+
row_results["p_ungrammatical"] = prob_ungram
|
| 167 |
+
row_results["I_grammatical"] = compute_surprisal(prob_gram)
|
| 168 |
+
row_results["I_ungrammatical"] = compute_surprisal(prob_ungram)
|
| 169 |
+
|
| 170 |
+
if "ood_minimal_pairs" in row:
|
| 171 |
+
ood_pairs_str = row.ood_pairs
|
| 172 |
+
ood_pairs = ast.literal_eval(ood_pairs_str)
|
| 173 |
+
all_ood_probs_gram = []
|
| 174 |
+
all_ood_probs_ungram = []
|
| 175 |
+
for pair in ood_pairs:
|
| 176 |
+
ood_prob_gram, ood_prob_ungram = (
|
| 177 |
+
self.evaluator.run_next_word_prediction(
|
| 178 |
+
row.masked_text, pair[0], pair[1]
|
| 179 |
+
)
|
| 180 |
+
)
|
| 181 |
+
all_ood_probs_gram.append(ood_prob_gram)
|
| 182 |
+
all_ood_probs_ungram.append(ood_prob_ungram)
|
| 183 |
+
|
| 184 |
+
avg_ood_prob_gram = compute_mean(all_ood_probs_gram)
|
| 185 |
+
avg_ood_prob_ungram = compute_mean(all_ood_probs_ungram)
|
| 186 |
+
|
| 187 |
+
row_results["ood_p_grammatical"] = avg_ood_prob_gram
|
| 188 |
+
row_results["ood_p_ungrammatical"] = avg_ood_prob_ungram
|
| 189 |
+
row_results["ood_I_grammatical"] = compute_surprisal(avg_ood_prob_gram)
|
| 190 |
+
row_results["ood_I_ungrammatical"] = compute_surprisal(avg_ood_prob_ungram)
|
| 191 |
+
|
| 192 |
+
def get_norm_avg_surprisal_difference(self) -> float:
|
| 193 |
+
if not self.is_model_evaluated():
|
| 194 |
+
raise KeyError("Please evaluate a model first.")
|
| 195 |
+
return compute_normalised_surprisal_difference(
|
| 196 |
+
self.evaluation_dataset["p_grammatical"],
|
| 197 |
+
self.evaluation_dataset["p_ungrammatical"],
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def get_avg_surprisal_difference(self, is_ood: bool = False) -> float:
|
| 201 |
+
p_grammatical_col = "p_grammatical" if not is_ood else "ood_p_grammatical"
|
| 202 |
+
p_ungrammatical_col = "p_ungrammatical" if not is_ood else "ood_p_ungrammatical"
|
| 203 |
+
if not self.is_model_evaluated():
|
| 204 |
+
raise KeyError("Please evaluate a model first.")
|
| 205 |
+
return compute_average_surprisal_difference(
|
| 206 |
+
self.evaluation_dataset[p_grammatical_col],
|
| 207 |
+
self.evaluation_dataset[p_ungrammatical_col],
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def get_all_metrics(self):
|
| 211 |
+
if self.evaluation_dataset is not None:
|
| 212 |
+
print(self.evaluation_dataset.columns)
|
| 213 |
+
return calculate_all_metrics(self.evaluation_dataset)
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError("Please evaluate a model first.")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class Evaluator:
|
| 219 |
+
def __init__(self):
|
| 220 |
+
self.tokeniser: PreTrainedTokenizerBase = None
|
| 221 |
+
self.model: PreTrainedModel = None
|
| 222 |
+
|
| 223 |
+
self.mask_token_index: int = -1
|
| 224 |
+
self.mask_probs: torch.Tensor | None = None
|
| 225 |
+
self.logits: torch.Tensor = None
|
| 226 |
+
|
| 227 |
+
def setup_parameters(
|
| 228 |
+
self, model_name: str, is_mlm: bool = True
|
| 229 |
+
) -> Tuple[PreTrainedTokenizerBase, PreTrainedModel]:
|
| 230 |
+
if is_mlm:
|
| 231 |
+
self.tokeniser = AutoTokenizer.from_pretrained(model_name)
|
| 232 |
+
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
| 233 |
+
else:
|
| 234 |
+
self.tokeniser = AutoTokenizer.from_pretrained(model_name)
|
| 235 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 236 |
+
|
| 237 |
+
# set to eval mode, disabling things like dropout
|
| 238 |
+
self.model.eval()
|
| 239 |
+
|
| 240 |
+
return self.model, self.tokeniser
|
| 241 |
+
|
| 242 |
+
def run_masked_prediction(
|
| 243 |
+
self, sentence: str, grammatical_word: str, ungrammatical_word: str
|
| 244 |
+
) -> Tuple[float, float]:
|
| 245 |
+
if not self.model or not self.tokeniser:
|
| 246 |
+
raise RuntimeError("Model and tokenizer must be loaded before prediction.")
|
| 247 |
+
|
| 248 |
+
mask_token = self.tokeniser.mask_token
|
| 249 |
+
sentence_masked = sentence.replace("[MASK]", mask_token)
|
| 250 |
+
|
| 251 |
+
if sentence_masked.count(mask_token) != 1:
|
| 252 |
+
raise TooManyMasksException("Only single-mask sentences are supported.")
|
| 253 |
+
|
| 254 |
+
masked_ids = self.tokeniser.encode(sentence_masked, add_special_tokens=False)
|
| 255 |
+
mask_index = masked_ids.index(self.tokeniser.mask_token_id)
|
| 256 |
+
|
| 257 |
+
device = next(self.model.parameters()).device
|
| 258 |
+
g_ids = self.tokeniser.encode(grammatical_word, add_special_tokens=False)
|
| 259 |
+
u_ids = self.tokeniser.encode(ungrammatical_word, add_special_tokens=False)
|
| 260 |
+
|
| 261 |
+
g_prob = self._compute_masked_joint_probability(
|
| 262 |
+
masked_ids, mask_index, g_ids, device
|
| 263 |
+
)
|
| 264 |
+
u_prob = self._compute_masked_joint_probability(
|
| 265 |
+
masked_ids, mask_index, u_ids, device
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
return g_prob, u_prob
|
| 269 |
+
|
| 270 |
+
def _compute_masked_joint_probability(
|
| 271 |
+
self, input_ids: List[int], mask_index: int, word_ids: List[int], device
|
| 272 |
+
) -> float:
|
| 273 |
+
input_ids_tensor = torch.tensor([input_ids], device=device)
|
| 274 |
+
log_prob = 0.0
|
| 275 |
+
index = mask_index
|
| 276 |
+
|
| 277 |
+
for i, tid in enumerate(word_ids):
|
| 278 |
+
with torch.no_grad():
|
| 279 |
+
logits = self.model(input_ids_tensor).logits
|
| 280 |
+
|
| 281 |
+
probs = F.softmax(logits[:, index, :], dim=-1)
|
| 282 |
+
token_prob = probs[0, tid].item()
|
| 283 |
+
log_prob += math.log(token_prob + 1e-12)
|
| 284 |
+
|
| 285 |
+
if i == 0:
|
| 286 |
+
self.mask_probs = probs
|
| 287 |
+
|
| 288 |
+
# Replace mask with predicted token
|
| 289 |
+
input_ids_tensor[0, index] = tid
|
| 290 |
+
|
| 291 |
+
# Insert new mask if more tokens remain
|
| 292 |
+
if i < len(word_ids) - 1:
|
| 293 |
+
input_ids_tensor = torch.cat(
|
| 294 |
+
[
|
| 295 |
+
input_ids_tensor[:, : index + 1],
|
| 296 |
+
torch.tensor([[self.tokeniser.mask_token_id]], device=device),
|
| 297 |
+
input_ids_tensor[:, index + 1 :],
|
| 298 |
+
],
|
| 299 |
+
dim=1,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# debugging
|
| 303 |
+
tokens_after_insertion = self.tokeniser.convert_ids_to_tokens(
|
| 304 |
+
input_ids_tensor[0].tolist()
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
index += 1
|
| 308 |
+
|
| 309 |
+
return math.exp(log_prob)
|
| 310 |
+
|
| 311 |
+
def run_next_word_prediction(
|
| 312 |
+
self, context: str, grammatical_word: str, ungrammatical_word: str
|
| 313 |
+
) -> Tuple[float, float]:
|
| 314 |
+
if not self.model or not self.tokeniser:
|
| 315 |
+
raise RuntimeError("Model and tokenizer must be loaded before prediction.")
|
| 316 |
+
|
| 317 |
+
context_ids = self.tokeniser.encode(context, add_special_tokens=False)
|
| 318 |
+
device = next(self.model.parameters()).device
|
| 319 |
+
|
| 320 |
+
g_ids = self.tokeniser.encode(grammatical_word, add_special_tokens=False)
|
| 321 |
+
u_ids = self.tokeniser.encode(ungrammatical_word, add_special_tokens=False)
|
| 322 |
+
|
| 323 |
+
g_prob = self._compute_next_word_joint_probability(context_ids, g_ids, device)
|
| 324 |
+
u_prob = self._compute_next_word_joint_probability(context_ids, u_ids, device)
|
| 325 |
+
|
| 326 |
+
return g_prob, u_prob
|
| 327 |
+
|
| 328 |
+
def _compute_next_word_joint_probability(
|
| 329 |
+
self, input_ids: List[int], word_ids: List[int], device
|
| 330 |
+
) -> float:
|
| 331 |
+
input_ids_tensor = torch.tensor([input_ids], device=device)
|
| 332 |
+
# debugging
|
| 333 |
+
tokens_after_insertion = self.tokeniser.convert_ids_to_tokens(
|
| 334 |
+
input_ids_tensor[0].tolist()
|
| 335 |
+
)
|
| 336 |
+
log_prob = 0.0
|
| 337 |
+
|
| 338 |
+
for i, tid in enumerate(word_ids):
|
| 339 |
+
with torch.no_grad():
|
| 340 |
+
logits = self.model(input_ids_tensor).logits
|
| 341 |
+
|
| 342 |
+
index = input_ids_tensor.shape[1] - 1 # last token position
|
| 343 |
+
probs = F.softmax(logits[:, index, :], dim=-1)
|
| 344 |
+
token_prob = probs[0, tid].item()
|
| 345 |
+
log_prob += math.log(token_prob + 1e-12)
|
| 346 |
+
|
| 347 |
+
if i == 0:
|
| 348 |
+
self.mask_probs = probs
|
| 349 |
+
|
| 350 |
+
# Append predicted token to context
|
| 351 |
+
input_ids_tensor = torch.cat(
|
| 352 |
+
[input_ids_tensor, torch.tensor([[tid]], device=device)], dim=1
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# debugging
|
| 356 |
+
tokens_after_insertion = self.tokeniser.convert_ids_to_tokens(
|
| 357 |
+
input_ids_tensor[0].tolist()
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
return math.exp(log_prob)
|
| 361 |
+
|
| 362 |
+
def get_entropy(self, k: int = 100, normalise: bool = False) -> float:
|
| 363 |
+
"""Compute entropy over the prediction distribution.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
k: Number of top tokens to consider.
|
| 367 |
+
normalise: Whether to normalise entropy.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
Entropy value.
|
| 371 |
+
Raises:
|
| 372 |
+
ValueError: If no probabilities are available.
|
| 373 |
+
"""
|
| 374 |
+
if self.mask_probs is None:
|
| 375 |
+
raise ValueError("No output probabilities available. Run evaluation first.")
|
| 376 |
+
return compute_entropy(self.mask_probs, k, normalise)
|
| 377 |
+
|
| 378 |
+
def _get_mask_index(self, inputs: Any) -> int:
|
| 379 |
+
if "input_ids" not in inputs:
|
| 380 |
+
raise ValueError("Missing 'input_ids' in inputs.")
|
| 381 |
+
elif self.tokeniser.mask_token_id is None:
|
| 382 |
+
raise ValueError("The tokeniser does not have a defined mask_token_id.")
|
| 383 |
+
|
| 384 |
+
input_ids = inputs["input_ids"]
|
| 385 |
+
mask_positions = torch.where(input_ids == self.tokeniser.mask_token_id)
|
| 386 |
+
|
| 387 |
+
if len(mask_positions[0]) == 0:
|
| 388 |
+
raise ValueError("No mask token found in input_ids.")
|
| 389 |
+
elif len(mask_positions[0]) > 1:
|
| 390 |
+
raise ValueError("Multiple mask tokens found; expected only one.")
|
| 391 |
+
|
| 392 |
+
return (
|
| 393 |
+
mask_positions[1].item()
|
| 394 |
+
if len(mask_positions) > 1
|
| 395 |
+
else mask_positions[0].item()
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
def _get_mask_probabilities(
|
| 399 |
+
self, mask_token_index: int, logits: Any
|
| 400 |
+
) -> torch.Tensor:
|
| 401 |
+
mask_logits = logits[0, mask_token_index, :]
|
| 402 |
+
probs = F.softmax(mask_logits, dim=-1) # shape: (vocab_size, )
|
| 403 |
+
return probs
|
grewtse/evaluators/metrics.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compute_mean(list_of_values: List[float]) -> float:
|
| 8 |
+
return sum(list_of_values) / len(list_of_values)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def compute_surprisal(p: float) -> float:
|
| 12 |
+
return -math.log2(p) if p and p > 0 else float("inf")
|
| 13 |
+
|
| 14 |
+
def compute_avg_surprisal(probs: pd.Series) -> float:
|
| 15 |
+
as_surprisal = probs.apply(compute_surprisal)
|
| 16 |
+
return as_surprisal.mean()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def compute_average_surprisal_difference(
|
| 20 |
+
correct_form_probs: pd.Series, wrong_form_probs: pd.Series
|
| 21 |
+
) -> float:
|
| 22 |
+
correct_form_avg_surp = compute_avg_surprisal(correct_form_probs)
|
| 23 |
+
wrong_form_avg_surp = compute_avg_surprisal(wrong_form_probs)
|
| 24 |
+
return wrong_form_avg_surp - correct_form_avg_surp
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def compute_normalised_surprisal_difference(
|
| 28 |
+
correct_form_probs: pd.Series, wrong_form_probs: pd.Series
|
| 29 |
+
) -> float:
|
| 30 |
+
correct_form_avg_surp = compute_avg_surprisal(correct_form_probs)
|
| 31 |
+
wrong_form_avg_surp = compute_avg_surprisal(wrong_form_probs)
|
| 32 |
+
return (wrong_form_avg_surp - correct_form_avg_surp) / correct_form_avg_surp
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def compute_entropy(probs, k=None, normalise=False):
|
| 36 |
+
probs = np.array(probs, dtype=np.float64)
|
| 37 |
+
|
| 38 |
+
# remove zeros to avoid log(0)
|
| 39 |
+
probs = probs[probs > 0]
|
| 40 |
+
|
| 41 |
+
# get top-k probabilities
|
| 42 |
+
if k is not None:
|
| 43 |
+
probs = np.sort(probs)[::-1][:k]
|
| 44 |
+
probs = probs / probs.sum() # renormalize to sum to 1
|
| 45 |
+
|
| 46 |
+
H = -np.sum(probs * np.log(probs))
|
| 47 |
+
|
| 48 |
+
if normalise:
|
| 49 |
+
n = len(probs)
|
| 50 |
+
return H, 1 - H / np.log(n)
|
| 51 |
+
else:
|
| 52 |
+
return H
|
| 53 |
+
|
| 54 |
+
def get_predictions(df: pd.DataFrame) -> np.ndarray:
|
| 55 |
+
"""
|
| 56 |
+
Convert probabilities to binary predictions.
|
| 57 |
+
Predicts grammatical (1) if p_grammatical > p_ungrammatical, else ungrammatical (0).
|
| 58 |
+
"""
|
| 59 |
+
predictions = (df['p_grammatical'] > df['p_ungrammatical']).astype(int)
|
| 60 |
+
return predictions.values
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def calculate_accuracy(df: pd.DataFrame) -> float:
|
| 64 |
+
"""
|
| 65 |
+
Calculate accuracy: proportion of correct predictions.
|
| 66 |
+
Assumes the model should always predict grammatical form (label = 1).
|
| 67 |
+
"""
|
| 68 |
+
predictions = get_predictions(df)
|
| 69 |
+
# True labels: all should be grammatical (1)
|
| 70 |
+
true_labels = np.ones(len(df), dtype=int)
|
| 71 |
+
|
| 72 |
+
correct = np.sum(predictions == true_labels)
|
| 73 |
+
total = len(predictions)
|
| 74 |
+
|
| 75 |
+
return correct / total if total > 0 else 0.0
|
| 76 |
+
|
| 77 |
+
def calculate_all_metrics(df: pd.DataFrame) -> dict:
|
| 78 |
+
predictions = get_predictions(df)
|
| 79 |
+
true_labels = np.ones(len(df), dtype=int)
|
| 80 |
+
|
| 81 |
+
# Calculate confusion matrix components
|
| 82 |
+
tp = np.sum((predictions == 1) & (true_labels == 1))
|
| 83 |
+
fp = np.sum((predictions == 1) & (true_labels == 0))
|
| 84 |
+
fn = np.sum((predictions == 0) & (true_labels == 1))
|
| 85 |
+
tn = np.sum((predictions == 0) & (true_labels == 0))
|
| 86 |
+
|
| 87 |
+
total = len(predictions)
|
| 88 |
+
|
| 89 |
+
# Calculate metrics
|
| 90 |
+
accuracy = (tp + tn) / total if total > 0 else 0.0
|
| 91 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 92 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 93 |
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
'accuracy': round(accuracy,2),
|
| 97 |
+
'precision': round(precision, 2),
|
| 98 |
+
'recall': round(recall, 2),
|
| 99 |
+
'f1': round(f1, 2),
|
| 100 |
+
'true_positives': int(tp),
|
| 101 |
+
'false_positives': int(fp),
|
| 102 |
+
'false_negatives': int(fn),
|
| 103 |
+
'true_negatives': int(tn)
|
| 104 |
+
}
|
grewtse/pipeline.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import random
|
| 4 |
+
import logging
|
| 5 |
+
from grewtse.preprocessing import ConlluParser
|
| 6 |
+
|
| 7 |
+
logging.basicConfig(
|
| 8 |
+
level=logging.INFO,
|
| 9 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 10 |
+
handlers=[logging.FileHandler("app.log"), logging.StreamHandler()],
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GrewTSEPipe:
|
| 15 |
+
"""
|
| 16 |
+
Main pipeline controller for generating prompt- or masked-based minimal-pair datasets derived from UD treebanks.
|
| 17 |
+
|
| 18 |
+
This class acts as a high-level interface for the Grew-TSE workflow:
|
| 19 |
+
|
| 20 |
+
1. Parse treebanks to build lexical item datasets.
|
| 21 |
+
2. Generate masked or prompt-based datasets using GREW.
|
| 22 |
+
3. Create minimal pairs for syntactic evaluation.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.parser = ConlluParser()
|
| 27 |
+
|
| 28 |
+
self.treebank_paths: list[str] = []
|
| 29 |
+
self.lexical_items: pd.DataFrame | None = None
|
| 30 |
+
self.grew_generated_dataset: pd.DataFrame | None = None
|
| 31 |
+
self.mp_dataset: pd.DataFrame | None = None
|
| 32 |
+
self.exception_dataset: pd.DataFrame | None = None
|
| 33 |
+
self.evaluation_results: pd.DataFrame | None = None
|
| 34 |
+
|
| 35 |
+
# 1. Initial step, parse a treebank from a .conllu file
|
| 36 |
+
def parse_treebank(
|
| 37 |
+
self, filepaths: str | list[str], reset: bool = False
|
| 38 |
+
) -> pd.DataFrame:
|
| 39 |
+
"""
|
| 40 |
+
Parse one or more treebanks and create a lexical item set.
|
| 41 |
+
A lexical item set is a dataset of words and their features.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
filepaths: Path or list of paths to treebank files.
|
| 45 |
+
reset: If True, clears existing lexical_items before parsing.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
if isinstance(filepaths, str):
|
| 50 |
+
filepaths = [filepaths] # wrap single path in list
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
if reset or self.lexical_items is None:
|
| 54 |
+
self.lexical_items = pd.DataFrame()
|
| 55 |
+
self.treebank_paths = []
|
| 56 |
+
|
| 57 |
+
self.lexical_items = self.parser.build_lexicon(filepaths)
|
| 58 |
+
self.treebank_paths = filepaths
|
| 59 |
+
|
| 60 |
+
return self.lexical_items
|
| 61 |
+
except Exception as e:
|
| 62 |
+
raise Exception(f"Issue parsing treebank: {e}")
|
| 63 |
+
|
| 64 |
+
def load_lexicon(self, filepath: str, treebank_paths: list[str]) -> None:
|
| 65 |
+
"""
|
| 66 |
+
Load a previously generated lexicon (typically returned from the parse_treebank function) from disk and attach it to the pipeline.
|
| 67 |
+
|
| 68 |
+
This method is used when you want to resume processing using an existing LI_set that was
|
| 69 |
+
generated earlier and saved as a CSV. It loads the LI_set, validates the required columns,
|
| 70 |
+
sets the appropriate index, and updates the pipeline and parser with the loaded data.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
filepath (str):
|
| 74 |
+
Path to the CSV file containing the LI_set to load. The file must contain the
|
| 75 |
+
columns ``"sentence_id"`` and ``"token_id"``.
|
| 76 |
+
treebank_paths (list[str]):
|
| 77 |
+
A list of paths to the treebanks associated with the LI_set. These paths are stored
|
| 78 |
+
so the pipeline can later reference the corresponding treebanks when generating or
|
| 79 |
+
analyzing data.
|
| 80 |
+
|
| 81 |
+
Raises:
|
| 82 |
+
FileNotFoundError:
|
| 83 |
+
If the CSV file cannot be found at the given ``filepath``.
|
| 84 |
+
ValueError:
|
| 85 |
+
If the required index columns (``"sentence_id"``, ``"token_id"``) are missing.
|
| 86 |
+
|
| 87 |
+
Example:
|
| 88 |
+
>>> pipe = GrewTSEPipe()
|
| 89 |
+
>>> pipe.load_lexicon("output/li_set.csv", ["treebank1.conllu", "treebank2.conllu"])
|
| 90 |
+
|
| 91 |
+
"""
|
| 92 |
+
if not os.path.exists(filepath):
|
| 93 |
+
raise FileNotFoundError(f"LI_set file not found: {filepath}")
|
| 94 |
+
|
| 95 |
+
li_df = pd.read_csv(filepath)
|
| 96 |
+
|
| 97 |
+
required_cols = {"sentence_id", "token_id"}
|
| 98 |
+
missing = required_cols - set(li_df.columns)
|
| 99 |
+
if missing:
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"Missing required columns in LI_set: {', '.join(missing)}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
li_df.set_index(["sentence_id", "token_id"], inplace=True)
|
| 105 |
+
|
| 106 |
+
self.lexical_items = li_df
|
| 107 |
+
self.parser.lexicon = li_df
|
| 108 |
+
self.treebank_paths = treebank_paths
|
| 109 |
+
|
| 110 |
+
def generate_masked_dataset(
|
| 111 |
+
self, query: str, target_node: str, mask_token: str = "[MASK]"
|
| 112 |
+
) -> pd.DataFrame:
|
| 113 |
+
"""
|
| 114 |
+
Once a treebank has been parsed, if testing models on the task of masked language modelling (MLM) e.g. for encoder models, then you can generate a masked dataset with default token [MASK] by providing
|
| 115 |
+
a GREW query that isolates a particular construction and a target node that identifies the element
|
| 116 |
+
in that construction that you want to test.
|
| 117 |
+
|
| 118 |
+
:param query: the GREW query that specifies a construction. Test them over at https://universal.grew.fr/
|
| 119 |
+
:param target_node: the particular variable that you defined in your GREW query representing the target word
|
| 120 |
+
:return: a DataFrame consisting of the sentence ID in the given treebank, the index of the token to be masked in the set of tokens, the list of all tokens, the matched token itself, the original text, and lastly the masked text.
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
if not self.is_treebank_parsed():
|
| 125 |
+
raise ValueError(
|
| 126 |
+
"Cannot create masked dataset: no treebank or invalid treebank filepath provided."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
results = self.parser.build_masked_dataset(
|
| 130 |
+
self.treebank_paths, query, target_node, mask_token
|
| 131 |
+
)
|
| 132 |
+
self.grew_generated_dataset = results["masked"]
|
| 133 |
+
self.exception_dataset = results["exception"]
|
| 134 |
+
return self.grew_generated_dataset
|
| 135 |
+
|
| 136 |
+
def generate_prompt_dataset(self, query: str, target_node: str) -> pd.DataFrame:
|
| 137 |
+
"""
|
| 138 |
+
Once a treebank has been parsed, if testing models on the task of next-token prediction (NTP) e.g. for decoder models, then you can use this function to generate a prompt dataset by providing
|
| 139 |
+
a GREW query that isolates a particular construction and a target node that identifies the element
|
| 140 |
+
in that construction that you want to test.
|
| 141 |
+
|
| 142 |
+
:param query: the GREW query that specifies a construction. Test them over at https://universal.grew.fr/
|
| 143 |
+
:param target_node: the particular variable that you defined in your GREW query representing the target word
|
| 144 |
+
:return: a DataFrame consisting of the sentence ID in the given treebank, the index of the target token, the list of all tokens, the matched token itself, the original text, and the created prompt.
|
| 145 |
+
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
if not self.is_treebank_parsed():
|
| 149 |
+
raise ValueError(
|
| 150 |
+
"Cannot create prompt dataset: no treebank or invalid treebank filepath provided."
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
prompt_dataset = self.parser.build_prompt_dataset(
|
| 154 |
+
self.treebank_paths, query, target_node
|
| 155 |
+
)
|
| 156 |
+
self.grew_generated_dataset = prompt_dataset
|
| 157 |
+
return prompt_dataset
|
| 158 |
+
|
| 159 |
+
def generate_minimal_pair_dataset(
|
| 160 |
+
self,
|
| 161 |
+
morph_features: dict,
|
| 162 |
+
upos_features: dict | None,
|
| 163 |
+
ood_pairs: int | None = None,
|
| 164 |
+
has_leading_whitespace: bool = True,
|
| 165 |
+
) -> pd.DataFrame:
|
| 166 |
+
"""
|
| 167 |
+
After generating a masked or prompt dataset, that same dataset with minimal pairs can be created using this function by specifying the feature that you would like to change. You can also specify whether you want additional 'OOD' pairs to be created, as well as whether there should be a leading whitespace at the start of each minimal pair item.
|
| 168 |
+
|
| 169 |
+
NOTE: morph_features and upos_features expects lowercase keys, values remain as in the treebank.
|
| 170 |
+
|
| 171 |
+
:param morph_features: the morphological features from the UD treebank that you want to adjust for the second element of the minimal pair e.g. { 'case': 'Dat' } may convert the original target item e.g. German 'Hunde' (dog.PLUR.NOM / dog.PLUR.ACC) to the dative case e.g. 'Hunden' (dog.PLUR.DAT) to form the minimal pair (Hunde, Hunden). The exact keys and values will depend on the treebank that you're working with.
|
| 172 |
+
:param upos_features: the universal part-of-speech tags from the UD treebank that you want to adjust for the second element of the minimal pair e.g. { 'upos': 'VERB' } will only search for verbs
|
| 173 |
+
:param ood_pairs: a boolean argument that specifies whether you want alternative (likely semantically implausible) minimal pairs to be provided for each example. These may help in evaluating generalisation performance.
|
| 174 |
+
:param has_trailing_whitespace: a boolean argument that specifies whether an additional whitespace is included at the beginning of each element in the minimal pair e.g. (' is', ' are')
|
| 175 |
+
:return: a DataFrame containing the masked sentences or prompts as well as the minimal pairs
|
| 176 |
+
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
if self.grew_generated_dataset is None:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
"Cannot generate minimal pairs: treebank must be parsed and masked first."
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def convert_row_to_feature(row):
|
| 185 |
+
return self.parser.to_syntactic_feature(
|
| 186 |
+
row["sentence_id"],
|
| 187 |
+
row["match_id"] - 1,
|
| 188 |
+
row["match_token"],
|
| 189 |
+
morph_features,
|
| 190 |
+
upos_features
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
alternative_row = self.grew_generated_dataset.apply(
|
| 194 |
+
convert_row_to_feature, axis=1
|
| 195 |
+
)
|
| 196 |
+
self.mp_dataset = self.grew_generated_dataset
|
| 197 |
+
self.mp_dataset["form_ungrammatical"] = alternative_row
|
| 198 |
+
|
| 199 |
+
self.mp_dataset = self.mp_dataset.rename(
|
| 200 |
+
columns={"match_token": "form_grammatical"}
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# rule 1: drop any rows where we don't find a minimal pair (henceforth MP)
|
| 204 |
+
self.mp_dataset = self.mp_dataset.dropna(subset=["form_ungrammatical"])
|
| 205 |
+
|
| 206 |
+
# rule 2: don't include MPs where the minimal pairs are the same string
|
| 207 |
+
self.mp_dataset = self.mp_dataset[
|
| 208 |
+
self.mp_dataset["form_grammatical"] != self.mp_dataset["form_ungrammatical"]
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
# add leading whitespace if requested.
|
| 212 |
+
# this is useful for models that expect whitespace at the end such as many decoder models
|
| 213 |
+
if has_leading_whitespace:
|
| 214 |
+
self.mp_dataset["form_grammatical"] = (
|
| 215 |
+
" " + self.mp_dataset["form_grammatical"]
|
| 216 |
+
)
|
| 217 |
+
self.mp_dataset["form_ungrammatical"] = (
|
| 218 |
+
" " + self.mp_dataset["form_ungrammatical"]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# handle the assigning of the out-of-distribution pairs
|
| 222 |
+
if ood_pairs:
|
| 223 |
+
# assign additional pairs for OOD data
|
| 224 |
+
all_grammatical = self.mp_dataset["form_grammatical"].to_list()
|
| 225 |
+
all_ungrammatical = self.mp_dataset["form_ungrammatical"].to_list()
|
| 226 |
+
|
| 227 |
+
# combine both into one vocabulary
|
| 228 |
+
words = set(zip(all_grammatical, all_ungrammatical))
|
| 229 |
+
|
| 230 |
+
def pick_words(row):
|
| 231 |
+
excluded = (row["form_grammatical"], row["form_ungrammatical"])
|
| 232 |
+
available = list(words - {excluded})
|
| 233 |
+
return random.sample(list(available), ood_pairs)
|
| 234 |
+
|
| 235 |
+
# Apply function to each row
|
| 236 |
+
self.mp_dataset["ood_minimal_pairs"] = self.mp_dataset.apply(
|
| 237 |
+
pick_words, axis=1
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
return self.mp_dataset
|
| 241 |
+
|
| 242 |
+
def get_morphological_features(self) -> list:
|
| 243 |
+
"""
|
| 244 |
+
Get a list of all available morphological features in a given treebank.
|
| 245 |
+
Similarly, you can go to the treebank's respective webpage to find this information.
|
| 246 |
+
A treebank must first be parsed in order to use this function.
|
| 247 |
+
|
| 248 |
+
:return: a list of strings with each morphological feature in the treebank.
|
| 249 |
+
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
if not self.is_treebank_parsed():
|
| 253 |
+
raise ValueError("Cannot get features: You must parse a treebank first.")
|
| 254 |
+
|
| 255 |
+
morph_df = self.lexical_items.copy()
|
| 256 |
+
morph_df.columns = [
|
| 257 |
+
col.replace("feats__", "") if col.startswith("feats__") else col
|
| 258 |
+
for col in morph_df.columns
|
| 259 |
+
]
|
| 260 |
+
|
| 261 |
+
return morph_df
|
| 262 |
+
|
| 263 |
+
def is_treebank_parsed(self) -> bool:
|
| 264 |
+
return self.lexical_items is not None
|
| 265 |
+
|
| 266 |
+
def is_dataset_masked(self) -> bool:
|
| 267 |
+
return self.grew_generated_dataset is not None
|
| 268 |
+
|
| 269 |
+
def is_model_evaluated(self) -> bool:
|
| 270 |
+
return self.evaluation_dataset is not None
|
| 271 |
+
|
| 272 |
+
def get_lexical_items(self) -> pd.DataFrame:
|
| 273 |
+
return self.lexical_items
|
| 274 |
+
|
| 275 |
+
def get_masked_dataset(self) -> pd.DataFrame:
|
| 276 |
+
return self.grew_generated_dataset
|
| 277 |
+
|
| 278 |
+
def get_minimal_pair_dataset(self) -> pd.DataFrame:
|
| 279 |
+
return self.mp_dataset
|
| 280 |
+
|
| 281 |
+
def get_exceptions_dataset(self):
|
| 282 |
+
return self.exception_dataset
|
| 283 |
+
|
| 284 |
+
def get_num_exceptions(self):
|
| 285 |
+
if self.exception_dataset is not None:
|
| 286 |
+
return self.exception_dataset.shape[0]
|
| 287 |
+
else:
|
| 288 |
+
return -1
|
| 289 |
+
|
| 290 |
+
def are_minimal_pairs_generated(self) -> bool:
|
| 291 |
+
return (
|
| 292 |
+
self.is_treebank_parsed()
|
| 293 |
+
and self.is_dataset_masked()
|
| 294 |
+
and ("form_ungrammatical" in self.mp_dataset.columns)
|
| 295 |
+
)
|
grewtse/preprocessing/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .conllu_parser import ConlluParser
|
| 2 |
+
|
| 3 |
+
__all__ = ["ConlluParser"]
|
grewtse/preprocessing/conllu_parser.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from grewtse.preprocessing.grew_dependencies import match_dependencies
|
| 2 |
+
from grewtse.preprocessing.reconstruction import (
|
| 3 |
+
perform_token_surgery,
|
| 4 |
+
recursive_match_token,
|
| 5 |
+
)
|
| 6 |
+
from conllu import parse_incr, Token
|
| 7 |
+
from typing import Any
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ConlluParser:
|
| 14 |
+
"""
|
| 15 |
+
A class designed to parse .conllu files for Grew-TSE, that is, the standard format for UD treebanks.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.lexicon: pd.DataFrame = None
|
| 21 |
+
|
| 22 |
+
def build_lexicon(self, filepaths: list[str] | str) -> pd.DataFrame:
|
| 23 |
+
"""
|
| 24 |
+
Create a DataFrame that contains the set of all words with their features as generated from a UD treebank.
|
| 25 |
+
This is essential for the subsequent generation of minimal pairs.
|
| 26 |
+
This was not designed to handle treebanks that assign differing names to features, so please ensure multiple treebank files are all from the same treebank or treebank schema.
|
| 27 |
+
|
| 28 |
+
:param filepaths: a list of strings corresponding to the UD treebank files e.g. ["german_treebank_part_A.conllu", "german_treebank_part_B.conllu"].
|
| 29 |
+
:return: a DataFrame with all words and their features.
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
rows = []
|
| 33 |
+
|
| 34 |
+
if isinstance(filepaths, str):
|
| 35 |
+
filepaths = [filepaths] # wrap single path in list
|
| 36 |
+
|
| 37 |
+
for conllu_path in filepaths:
|
| 38 |
+
with open(conllu_path, "r", encoding="utf-8") as f:
|
| 39 |
+
for tokenlist in parse_incr(f):
|
| 40 |
+
# get the sentence ID in the dataset
|
| 41 |
+
sent_id = tokenlist.metadata["sent_id"]
|
| 42 |
+
logging.info(f"Parsing Sentence: {sent_id}")
|
| 43 |
+
|
| 44 |
+
# iterate over each token
|
| 45 |
+
for token in tokenlist:
|
| 46 |
+
# check if it's worth saving to our lexical item dataset
|
| 47 |
+
is_valid_token = is_valid_for_lexicon(token)
|
| 48 |
+
if not is_valid_token:
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
# from the token object create a dict and append
|
| 52 |
+
row = build_token_row(token, sent_id)
|
| 53 |
+
rows.append(row)
|
| 54 |
+
|
| 55 |
+
lexicon_df = pd.DataFrame(rows)
|
| 56 |
+
|
| 57 |
+
# make sure our nan values are interpreted as such
|
| 58 |
+
lexicon_df.replace("nan", np.nan, inplace=True)
|
| 59 |
+
|
| 60 |
+
# create the (Sentence ID, Token ID) primary key
|
| 61 |
+
lexicon_df.set_index(["sentence_id", "token_id"], inplace=True)
|
| 62 |
+
|
| 63 |
+
self.lexicon = lexicon_df
|
| 64 |
+
|
| 65 |
+
return lexicon_df
|
| 66 |
+
|
| 67 |
+
def to_syntactic_feature(
|
| 68 |
+
self,
|
| 69 |
+
sentence_id: str,
|
| 70 |
+
token_id: str,
|
| 71 |
+
token: str,
|
| 72 |
+
alt_morph_constraints: dict,
|
| 73 |
+
alt_universal_constraints: dict,
|
| 74 |
+
) -> str | None:
|
| 75 |
+
"""
|
| 76 |
+
The most important function for the finding of minimal pairs. Converts a given lexical item taken from a UD treebank sentence
|
| 77 |
+
to another lexical item of the same lemma but with the specified differing feature(s).
|
| 78 |
+
|
| 79 |
+
:param sentence_id: the ID in the treebank of the sentence.
|
| 80 |
+
:param token_id: the token index in the list of tokens corresponding to the isolated target word.
|
| 81 |
+
:param token: the token string itself that is the isolated target word.
|
| 82 |
+
:param alt_morph_constraints: the alternative morphological feature(s) for the target word.
|
| 83 |
+
:param alt_universal_constraints: the alternative UPOS feature(s) for the target word.
|
| 84 |
+
:return: a string representing the converted target word.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
# distinguish morphological from universal features
|
| 88 |
+
# todo: find a better way to do this
|
| 89 |
+
prefix = "feats__"
|
| 90 |
+
# prefix = ''
|
| 91 |
+
alt_morph_constraints = {
|
| 92 |
+
prefix + key: value for key, value in alt_morph_constraints.items()
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
token_features = self.get_features(sentence_id, token_id)
|
| 96 |
+
|
| 97 |
+
token_features.update(alt_morph_constraints)
|
| 98 |
+
token_features.update(alt_universal_constraints)
|
| 99 |
+
lexical_items = self.lexicon
|
| 100 |
+
|
| 101 |
+
# get only those items which are the same lemma
|
| 102 |
+
lemma = self.get_lemma(sentence_id, token_id)
|
| 103 |
+
lemma_mask = lexical_items["lemma"] == lemma
|
| 104 |
+
lexical_items = lexical_items[lemma_mask]
|
| 105 |
+
logging.info(f"Looking for form {lemma}")
|
| 106 |
+
|
| 107 |
+
lexical_items = construct_candidate_set(lexical_items, token_features)
|
| 108 |
+
# ensure that it doesn't allow minimal pairs with different start cases e.g business, Business
|
| 109 |
+
filtered = lexical_items[
|
| 110 |
+
lexical_items["form"].apply(lambda w: is_same_start_case(w, token))
|
| 111 |
+
]
|
| 112 |
+
if not filtered.empty:
|
| 113 |
+
return filtered["form"].iloc[0]
|
| 114 |
+
else:
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
def get_lexicon(self) -> pd.DataFrame:
|
| 118 |
+
return self.lexicon
|
| 119 |
+
|
| 120 |
+
# this shouldn't be hard coded
|
| 121 |
+
def get_feature_names(self) -> list:
|
| 122 |
+
return self.lexicon.columns[4:].to_list()
|
| 123 |
+
|
| 124 |
+
# todo: add more safety
|
| 125 |
+
def get_features(self, sentence_id: str, token_id: int) -> dict:
|
| 126 |
+
return self.lexicon.loc[(sentence_id, token_id)][
|
| 127 |
+
self.get_feature_names()
|
| 128 |
+
].to_dict()
|
| 129 |
+
|
| 130 |
+
def get_lemma(self, sentence_id: str, token_id: str) -> str:
|
| 131 |
+
return self.lexicon.loc[(sentence_id, token_id)]["lemma"]
|
| 132 |
+
|
| 133 |
+
def get_candidate_set(
|
| 134 |
+
self, universal_constraints: dict, morph_constraints: dict
|
| 135 |
+
) -> pd.DataFrame:
|
| 136 |
+
has_parsed_conllu = self.lexicon is not None
|
| 137 |
+
if not has_parsed_conllu:
|
| 138 |
+
raise ValueError("Please parse a ConLLU file first.")
|
| 139 |
+
|
| 140 |
+
morph_constraints = {f"feats__{k}": v for k, v in morph_constraints.items()}
|
| 141 |
+
are_morph_features_valid = all(
|
| 142 |
+
f in self.lexicon.columns for f in morph_constraints.keys()
|
| 143 |
+
)
|
| 144 |
+
are_universal_features_valid = all(
|
| 145 |
+
f in self.lexicon.columns for f in universal_constraints.keys()
|
| 146 |
+
)
|
| 147 |
+
if not are_morph_features_valid or not are_universal_features_valid:
|
| 148 |
+
raise KeyError(
|
| 149 |
+
"Features provided for candidate set are not valid features in the dataset."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
all_constraints = {**universal_constraints, **morph_constraints}
|
| 153 |
+
candidate_set = construct_candidate_set(self.lexicon, all_constraints)
|
| 154 |
+
return candidate_set
|
| 155 |
+
|
| 156 |
+
def build_prompt_dataset(
|
| 157 |
+
self,
|
| 158 |
+
filepaths: list[str],
|
| 159 |
+
grew_query: str,
|
| 160 |
+
dependency_node: str,
|
| 161 |
+
encoding: str = "utf-8",
|
| 162 |
+
):
|
| 163 |
+
prompt_cutoff_token = "[PROMPT_CUTOFF]"
|
| 164 |
+
results = self.build_masked_dataset(
|
| 165 |
+
filepaths, grew_query, dependency_node, prompt_cutoff_token, encoding
|
| 166 |
+
)
|
| 167 |
+
prompt_dataset = results["masked"]
|
| 168 |
+
|
| 169 |
+
def substring_up_to_token(s: str, token: str) -> str:
|
| 170 |
+
idx = s.find(token)
|
| 171 |
+
return s[:idx].strip() if idx != -1 else s.strip()
|
| 172 |
+
|
| 173 |
+
prompt_dataset["prompt_text"] = prompt_dataset["masked_text"].apply(
|
| 174 |
+
lambda x: substring_up_to_token(x, prompt_cutoff_token)
|
| 175 |
+
)
|
| 176 |
+
prompt_dataset = prompt_dataset.drop(["masked_text"], axis=1)
|
| 177 |
+
return prompt_dataset
|
| 178 |
+
|
| 179 |
+
def build_masked_dataset(
|
| 180 |
+
self,
|
| 181 |
+
filepaths: list[str],
|
| 182 |
+
grew_query: str,
|
| 183 |
+
dependency_node: str,
|
| 184 |
+
mask_token: str,
|
| 185 |
+
encoding: str = "utf-8",
|
| 186 |
+
):
|
| 187 |
+
masked_dataset = []
|
| 188 |
+
exception_dataset = []
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
for filepath in filepaths:
|
| 192 |
+
get_tokens_to_mask = match_dependencies(
|
| 193 |
+
filepath, grew_query, dependency_node
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
with open(filepath, "r", encoding=encoding) as data_file:
|
| 197 |
+
for sentence in parse_incr(data_file):
|
| 198 |
+
|
| 199 |
+
sentence_id = sentence.metadata["sent_id"]
|
| 200 |
+
sentence_text = sentence.metadata["text"]
|
| 201 |
+
|
| 202 |
+
if sentence_id in get_tokens_to_mask:
|
| 203 |
+
for i in range(len(sentence)):
|
| 204 |
+
sentence[i]["index"] = i
|
| 205 |
+
|
| 206 |
+
token_to_mask_id = get_tokens_to_mask[sentence_id]
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
t_match = [
|
| 210 |
+
tok
|
| 211 |
+
for tok in sentence
|
| 212 |
+
if tok.get("id") == token_to_mask_id
|
| 213 |
+
][0]
|
| 214 |
+
t_match_form = t_match["form"]
|
| 215 |
+
t_match_index = t_match["index"]
|
| 216 |
+
sentence_as_str_list = [t["form"] for t in sentence]
|
| 217 |
+
except KeyError:
|
| 218 |
+
logging.info(
|
| 219 |
+
"There was a mismatch for the GREW-based ID and the Conllu ID."
|
| 220 |
+
)
|
| 221 |
+
exception_dataset.append(
|
| 222 |
+
{
|
| 223 |
+
"sentence_id": sentence_id,
|
| 224 |
+
"match_id": None,
|
| 225 |
+
"all_tokens": None,
|
| 226 |
+
"match_token": None,
|
| 227 |
+
"original_text": sentence_text,
|
| 228 |
+
}
|
| 229 |
+
)
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
matched_token_start_index = recursive_match_token(
|
| 234 |
+
sentence_text, # the original string
|
| 235 |
+
sentence_as_str_list.copy(), # the string as a list of tokens
|
| 236 |
+
t_match_index, # the index of the token to be replaced
|
| 237 |
+
[
|
| 238 |
+
"_",
|
| 239 |
+
" ",
|
| 240 |
+
], # todo: skip lines where we don't encounter accounted for tokens
|
| 241 |
+
)
|
| 242 |
+
except ValueError:
|
| 243 |
+
exception_dataset.append(
|
| 244 |
+
{
|
| 245 |
+
"sentence_id": sentence_id,
|
| 246 |
+
"match_id": token_to_mask_id,
|
| 247 |
+
"all_tokens": sentence_as_str_list,
|
| 248 |
+
"match_token": t_match_form,
|
| 249 |
+
"original_text": sentence_text,
|
| 250 |
+
}
|
| 251 |
+
)
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
# let's replace the matched token with a MASK token
|
| 255 |
+
masked_sentence = perform_token_surgery(
|
| 256 |
+
sentence_text,
|
| 257 |
+
t_match_form,
|
| 258 |
+
mask_token,
|
| 259 |
+
matched_token_start_index,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# the sentence ID and match ID are together a primary key
|
| 263 |
+
masked_dataset.append(
|
| 264 |
+
{
|
| 265 |
+
"sentence_id": sentence_id,
|
| 266 |
+
"match_id": token_to_mask_id,
|
| 267 |
+
"match_token": t_match_form,
|
| 268 |
+
"original_text": sentence_text,
|
| 269 |
+
"masked_text": masked_sentence,
|
| 270 |
+
}
|
| 271 |
+
)
|
| 272 |
+
except Exception as e:
|
| 273 |
+
print(f"Issue building dataset: {e}")
|
| 274 |
+
|
| 275 |
+
masked_dataset_df = pd.DataFrame(masked_dataset)
|
| 276 |
+
exception_dataset_df = pd.DataFrame(exception_dataset)
|
| 277 |
+
|
| 278 |
+
return {"masked": masked_dataset_df, "exception": exception_dataset_df}
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def construct_candidate_set(
|
| 282 |
+
lexicon: pd.DataFrame, target_features: dict
|
| 283 |
+
) -> pd.DataFrame:
|
| 284 |
+
"""
|
| 285 |
+
This constructs a list of words which have the same feature set as the
|
| 286 |
+
target features which are passed as an argument. These resulting words are termed 'candidates'.
|
| 287 |
+
|
| 288 |
+
:param lexicon: the DataFrame consisting of all lexical items and their features
|
| 289 |
+
:param target_features: the differing features of the candidates.
|
| 290 |
+
:return: a DataFrame containing the candidate subset of the lexicon.
|
| 291 |
+
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
# optionally restrict search to a certain type of lexical item
|
| 295 |
+
subset = lexicon
|
| 296 |
+
|
| 297 |
+
# continuously filter the dataframe so as to be left
|
| 298 |
+
# only with those lexical items which match the target
|
| 299 |
+
# features
|
| 300 |
+
# this includes cases
|
| 301 |
+
for feat, value in target_features.items():
|
| 302 |
+
# ensure feature is a valid feature in feature set
|
| 303 |
+
if feat not in subset.columns:
|
| 304 |
+
print(subset.columns)
|
| 305 |
+
raise KeyError("Invalid feature provided to confound set: {}".format(feat))
|
| 306 |
+
|
| 307 |
+
# slim the mask down using each feature
|
| 308 |
+
# interesting edge case: np.nan == np.nan returns false!
|
| 309 |
+
mask = (subset[feat] == value) | (subset[feat].isna() & pd.isna(value))
|
| 310 |
+
subset = subset[mask]
|
| 311 |
+
|
| 312 |
+
return subset
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def is_same_start_case(s1, s2):
|
| 316 |
+
if not s1 or not s2:
|
| 317 |
+
return False
|
| 318 |
+
return s1[0].isupper() == s2[0].isupper()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def is_valid_for_lexicon(token: Token) -> bool:
|
| 322 |
+
punctuation = [".", ",", "!", "?", "*"]
|
| 323 |
+
|
| 324 |
+
# skip multiword tokens, malformed entries and punctuation
|
| 325 |
+
is_punctuation = token.get("form") in punctuation
|
| 326 |
+
is_valid_type = isinstance(token, dict)
|
| 327 |
+
has_valid_id = isinstance(token.get("id"), int)
|
| 328 |
+
return is_valid_type and has_valid_id and not is_punctuation
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def build_token_row(token: Token, sentence_id: str) -> dict[str, Any]:
|
| 332 |
+
# get all token features such as Person, Mood, etc
|
| 333 |
+
feats = token.get("feats") or {}
|
| 334 |
+
|
| 335 |
+
row = {
|
| 336 |
+
"sentence_id": sentence_id,
|
| 337 |
+
"token_id": token.get("id") - 1, # IDs are reduced by one to start at 0
|
| 338 |
+
"form": token.get("form"),
|
| 339 |
+
"lemma": token.get("lemma"),
|
| 340 |
+
"upos": token.get("upos"),
|
| 341 |
+
"xpos": token.get("xpos"),
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
# add each morphological feature as a column
|
| 345 |
+
for feat_name, feat_value in feats.items():
|
| 346 |
+
row["feats__" + feat_name.lower()] = feat_value
|
| 347 |
+
|
| 348 |
+
return row
|
grewtse/preprocessing/grew_dependencies.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from grewpy import Corpus, Request, set_config
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
def match_dependencies(
|
| 6 |
+
filepaths: list[str] | str, grew_query: str, dependency_node: str
|
| 7 |
+
) -> dict:
|
| 8 |
+
set_config("sud") # ud or basic
|
| 9 |
+
dep_matches = {}
|
| 10 |
+
|
| 11 |
+
if isinstance(filepaths, str):
|
| 12 |
+
filepaths = [filepaths] # wrap single path in list
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
for corpus_path in filepaths:
|
| 16 |
+
# run the GREW request on the corpus
|
| 17 |
+
print("Corpus Path ", corpus_path)
|
| 18 |
+
corpus = Corpus(str(corpus_path))
|
| 19 |
+
request = Request(grew_query)
|
| 20 |
+
occurrences = corpus.search(request)
|
| 21 |
+
|
| 22 |
+
# step 2
|
| 23 |
+
for occ in occurrences:
|
| 24 |
+
sent_id = occ["sent_id"]
|
| 25 |
+
|
| 26 |
+
object_node_id = int(occ["matching"]["nodes"][dependency_node])
|
| 27 |
+
|
| 28 |
+
# one match per sentence
|
| 29 |
+
# todo: handle multiple matches per sentence
|
| 30 |
+
dep_matches[sent_id] = object_node_id
|
| 31 |
+
except KeyError:
|
| 32 |
+
raise KeyError(
|
| 33 |
+
"You must provide a dependency node name which exists in your GREW pattern."
|
| 34 |
+
)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
raise ValueError(f"Invalid GREW query: {e}")
|
| 37 |
+
|
| 38 |
+
return dep_matches
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def match_dependencies(
|
| 42 |
+
filepaths: list[str] | str, grew_query: str, dependency_node: str
|
| 43 |
+
) -> dict:
|
| 44 |
+
set_config("sud") # ud or basic
|
| 45 |
+
dep_matches = {}
|
| 46 |
+
|
| 47 |
+
if isinstance(filepaths, str):
|
| 48 |
+
filepaths = [filepaths] # wrap single path in list
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
for corpus_path in filepaths:
|
| 52 |
+
# run the GREW request on the corpus
|
| 53 |
+
corpus = Corpus(str(corpus_path))
|
| 54 |
+
request = Request(grew_query)
|
| 55 |
+
occurrences = corpus.search(request)
|
| 56 |
+
|
| 57 |
+
# step 2
|
| 58 |
+
for occ in occurrences:
|
| 59 |
+
sent_id = occ["sent_id"]
|
| 60 |
+
|
| 61 |
+
# Handle both integer and float string representations
|
| 62 |
+
node_id_str = occ["matching"]["nodes"][dependency_node]
|
| 63 |
+
try:
|
| 64 |
+
# Try converting to float first, then to int
|
| 65 |
+
object_node_id = float(node_id_str)
|
| 66 |
+
except ValueError:
|
| 67 |
+
raise ValueError("Cannot convert.")
|
| 68 |
+
|
| 69 |
+
# one match per sentence
|
| 70 |
+
# todo: handle multiple matches per sentence
|
| 71 |
+
dep_matches[sent_id] = object_node_id
|
| 72 |
+
except KeyError as e:
|
| 73 |
+
if dependency_node not in str(e):
|
| 74 |
+
raise KeyError(
|
| 75 |
+
f"You must provide a dependency node name which exists in your GREW pattern. Missing key: {e}"
|
| 76 |
+
)
|
| 77 |
+
raise KeyError(
|
| 78 |
+
"You must provide a dependency node name which exists in your GREW pattern."
|
| 79 |
+
)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
raise ValueError(f"Invalid GREW query: {e}")
|
| 82 |
+
|
| 83 |
+
return dep_matches
|
grewtse/preprocessing/reconstruction.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def perform_token_surgery(
|
| 2 |
+
sentence: str,
|
| 3 |
+
original_token: str,
|
| 4 |
+
replacement_token: str,
|
| 5 |
+
start_index: int,
|
| 6 |
+
) -> str:
|
| 7 |
+
t_len = len(original_token)
|
| 8 |
+
|
| 9 |
+
return sentence[:start_index] + replacement_token + sentence[start_index + t_len :]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def recursive_match_token(
|
| 13 |
+
full_sentence: str,
|
| 14 |
+
token_list: list[str],
|
| 15 |
+
token_list_mask_index: int,
|
| 16 |
+
skippable_tokens: list[str],
|
| 17 |
+
) -> int:
|
| 18 |
+
# ensure we can retrieve another token
|
| 19 |
+
n_remaining_tokens = len(token_list)
|
| 20 |
+
if n_remaining_tokens == 0:
|
| 21 |
+
raise ValueError(
|
| 22 |
+
"Mask index not reached but token list has been iterated for sentence: {}".format(
|
| 23 |
+
full_sentence
|
| 24 |
+
)
|
| 25 |
+
)
|
| 26 |
+
t = token_list[0]
|
| 27 |
+
|
| 28 |
+
# returns the index of the first occurrence
|
| 29 |
+
# of the token t
|
| 30 |
+
match_index = full_sentence.find(t)
|
| 31 |
+
is_match_found = match_index != -1
|
| 32 |
+
has_reached_mask_token = token_list_mask_index == 0
|
| 33 |
+
|
| 34 |
+
# BASE CASE
|
| 35 |
+
if has_reached_mask_token and is_match_found:
|
| 36 |
+
# we're at the end
|
| 37 |
+
return match_index
|
| 38 |
+
# RECURSIVE CASE
|
| 39 |
+
elif is_match_found:
|
| 40 |
+
sliced_sentence = full_sentence[match_index + len(t) :]
|
| 41 |
+
token_list.pop(0)
|
| 42 |
+
|
| 43 |
+
return (
|
| 44 |
+
match_index
|
| 45 |
+
+ len(t)
|
| 46 |
+
+ recursive_match_token(
|
| 47 |
+
sliced_sentence,
|
| 48 |
+
token_list,
|
| 49 |
+
token_list_mask_index - 1,
|
| 50 |
+
skippable_tokens,
|
| 51 |
+
)
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
# no match found, is t irrelevant?
|
| 55 |
+
if t in skippable_tokens:
|
| 56 |
+
# need to watch out with the slicing here
|
| 57 |
+
# tests are important
|
| 58 |
+
sliced_sentence = full_sentence[len(t) - 1 :]
|
| 59 |
+
token_list.pop(0)
|
| 60 |
+
return recursive_match_token(
|
| 61 |
+
sliced_sentence,
|
| 62 |
+
token_list,
|
| 63 |
+
token_list_mask_index - 1,
|
| 64 |
+
skippable_tokens,
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
"Token not found in string nor has it been specified as skippable: {}".format(
|
| 69 |
+
t
|
| 70 |
+
)
|
| 71 |
+
)
|
grewtse/utils/__init__.py
ADDED
|
File without changes
|
grewtse/utils/validation.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def load_and_validate_mp_dataset(filepath: str):
|
| 5 |
+
required_columns = {
|
| 6 |
+
"sentence_id",
|
| 7 |
+
"match_id",
|
| 8 |
+
"original_text",
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
df = pd.read_csv(filepath)
|
| 13 |
+
|
| 14 |
+
missing = required_columns - set(df.columns)
|
| 15 |
+
if missing:
|
| 16 |
+
raise ValueError(f"Missing required column(s): {', '.join(missing)}")
|
| 17 |
+
|
| 18 |
+
return df
|
| 19 |
+
|
| 20 |
+
except FileNotFoundError:
|
| 21 |
+
raise FileNotFoundError(f"File not found: {filepath}")
|
| 22 |
+
|
| 23 |
+
except pd.errors.ParserError as e:
|
| 24 |
+
raise pd.errors.ParserError(f"Error parsing CSV file: {e}")
|
grewtse/visualise/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .visualiser import GrewTSEVisualiser
|
| 2 |
+
|
| 3 |
+
__all__ = ["GrewTSEVisualiser"]
|
grewtse/visualise/visualiser.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from plotnine import (
|
| 3 |
+
labs,
|
| 4 |
+
theme,
|
| 5 |
+
theme_bw,
|
| 6 |
+
guides,
|
| 7 |
+
position_nudge,
|
| 8 |
+
aes,
|
| 9 |
+
geom_violin,
|
| 10 |
+
geom_line,
|
| 11 |
+
geom_jitter,
|
| 12 |
+
scale_x_discrete,
|
| 13 |
+
ggplot,
|
| 14 |
+
)
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class GrewTSEVisualiser:
|
| 20 |
+
"""
|
| 21 |
+
A basic visualisation class that creates a violin plot based on a syntactic evaluation.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self) -> None:
|
| 25 |
+
self.data = None
|
| 26 |
+
|
| 27 |
+
def visualise_syntactic_performance(
|
| 28 |
+
self,
|
| 29 |
+
results: pd.DataFrame,
|
| 30 |
+
title: str,
|
| 31 |
+
target_x_label: str,
|
| 32 |
+
alt_x_label: str,
|
| 33 |
+
x_axis_label: str,
|
| 34 |
+
y_axis_label: str,
|
| 35 |
+
filename: str,
|
| 36 |
+
) -> None:
|
| 37 |
+
"""
|
| 38 |
+
Visualise a syntactic performance evaluation result.
|
| 39 |
+
|
| 40 |
+
:param results: pass the results DataFrame created by the GrewTSEEvaluator.
|
| 41 |
+
:param title: Give the diagram a main title.
|
| 42 |
+
:param target_x_label: Give the original target word and hence first word in the minimal pair a label e.g. 'Accusative'.
|
| 43 |
+
:param alt_x_label: Give the second element in the minimal pair a label e.g. 'Dative'.
|
| 44 |
+
:param x_axis_label: Give the X Axis a title.
|
| 45 |
+
:param y_axis_label: Give the Y Axis a title.
|
| 46 |
+
:param filename: A filename to save the visualisation.
|
| 47 |
+
:return:
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
visualise_slope(
|
| 51 |
+
filename,
|
| 52 |
+
results,
|
| 53 |
+
target_x_label,
|
| 54 |
+
alt_x_label,
|
| 55 |
+
x_axis_label,
|
| 56 |
+
y_axis_label,
|
| 57 |
+
title,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def visualise_slope(
|
| 62 |
+
path: Path,
|
| 63 |
+
results: pd.DataFrame,
|
| 64 |
+
target_x_label: str,
|
| 65 |
+
alt_x_label: str,
|
| 66 |
+
x_axis_label: str,
|
| 67 |
+
y_axis_label: str,
|
| 68 |
+
title: str,
|
| 69 |
+
):
|
| 70 |
+
lsize = 0.65
|
| 71 |
+
fill_alpha = 0.7
|
| 72 |
+
|
| 73 |
+
# X-axis: Acc, Gen
|
| 74 |
+
# Y-axis: surprisal
|
| 75 |
+
filtered_df = results[
|
| 76 |
+
results["form_ungrammatical"].notna()
|
| 77 |
+
& (results["form_ungrammatical"].str.strip() != "")
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
filtered_df["subject_id"] = filtered_df.index
|
| 81 |
+
|
| 82 |
+
# Melt the dataframe
|
| 83 |
+
df_long = pd.melt(
|
| 84 |
+
filtered_df,
|
| 85 |
+
id_vars=["subject_id"],
|
| 86 |
+
value_vars=["p_grammatical", "p_ungrammatical"],
|
| 87 |
+
var_name="source",
|
| 88 |
+
value_name="log_prob",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Map source to fixed x-axis labels
|
| 92 |
+
df_long["x_label"] = df_long["source"].map(
|
| 93 |
+
{"p_grammatical": target_x_label, "p_ungrammatical": alt_x_label}
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def surprisal(p: float) -> float:
|
| 97 |
+
return -math.log2(p)
|
| 98 |
+
|
| 99 |
+
def confidence(p: float) -> float:
|
| 100 |
+
return math.log2(p)
|
| 101 |
+
|
| 102 |
+
df_long["surprisal"] = df_long["log_prob"].apply(surprisal)
|
| 103 |
+
|
| 104 |
+
p = (
|
| 105 |
+
ggplot(df_long, aes(x="x_label", y="surprisal", fill="x_label"))
|
| 106 |
+
+ scale_x_discrete(limits=[target_x_label, alt_x_label])
|
| 107 |
+
+ geom_jitter(aes(color="x_label"), width=0.01, alpha=0.7)
|
| 108 |
+
+
|
| 109 |
+
# geom_text(aes(label='label'), nudge_y=0.1) +
|
| 110 |
+
geom_line(aes(group="subject_id"), color="gray", alpha=0.7, size=0.2)
|
| 111 |
+
+ geom_violin(
|
| 112 |
+
df_long[df_long["x_label"] == target_x_label],
|
| 113 |
+
aes(x="x_label", y="surprisal", group="x_label"),
|
| 114 |
+
position=position_nudge(x=-0.2),
|
| 115 |
+
style="left-right",
|
| 116 |
+
alpha=fill_alpha,
|
| 117 |
+
size=lsize,
|
| 118 |
+
)
|
| 119 |
+
+ geom_violin(
|
| 120 |
+
df_long[df_long["x_label"] == alt_x_label],
|
| 121 |
+
aes(x="x_label", y="surprisal", group="x_label"),
|
| 122 |
+
position=position_nudge(x=0.2),
|
| 123 |
+
style="right-left",
|
| 124 |
+
alpha=fill_alpha,
|
| 125 |
+
size=lsize,
|
| 126 |
+
)
|
| 127 |
+
+ guides(fill=False)
|
| 128 |
+
+ theme_bw()
|
| 129 |
+
+ theme(figure_size=(8, 4), legend_position="none")
|
| 130 |
+
+ labs(x=x_axis_label, y=y_axis_label, title=title)
|
| 131 |
+
)
|
| 132 |
+
p.save(path, width=14, height=8, dpi=300)
|