DanielGallagherIRE commited on
Commit
e2a0b30
·
verified ·
1 Parent(s): 34dc7a3

Upload 13 files

Browse files
grewtse/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pipeline import GrewTSEPipe
2
+
3
+ __all__ = ["GrewTSEPipe"]
grewtse/evaluators/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .evaluator import GrewTSEvaluator
2
+ from .evaluator import Evaluator
3
+
4
+ __all__ = ["GrewTSEvaluator", "Evaluator"]
grewtse/evaluators/evaluator.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+
3
+ from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer
4
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
5
+ from typing import Tuple, NamedTuple, List, Any
6
+ import torch.nn.functional as F
7
+ from tqdm import tqdm
8
+ import pandas as pd
9
+ import itertools
10
+ import logging
11
+ import torch
12
+ import math
13
+
14
+ from grewtse.utils.validation import load_and_validate_mp_dataset
15
+ from grewtse.evaluators.metrics import (
16
+ compute_normalised_surprisal_difference,
17
+ compute_average_surprisal_difference,
18
+ compute_entropy,
19
+ compute_surprisal,
20
+ compute_mean,
21
+ calculate_all_metrics
22
+ )
23
+
24
+ EVAL_TEMPLATE = {
25
+ "sentence_id": None,
26
+ "match_id": None,
27
+ "original_text": None,
28
+ "prompt_text": None,
29
+ "form_grammatical": None,
30
+ "p_grammatical": None,
31
+ "I_grammatical": None,
32
+ "form_ungrammatical": None,
33
+ "p_ungrammatical": None,
34
+ "I_ungrammatical": None,
35
+ "entropy": None,
36
+ "entropy_norm": None,
37
+ }
38
+
39
+
40
+ class TooManyMasksException(Exception):
41
+ def __init__(self, message: str):
42
+ self.message = message
43
+ super().__init__(f"TMM Exception: {message}")
44
+
45
+
46
+ class Prediction(NamedTuple):
47
+ token: str
48
+ prob: float
49
+ surprisal: float
50
+
51
+
52
+ class GrewTSEvaluator:
53
+ """
54
+ An evaluation class designed specifically for rapid syntactic evaluation of models available on the Hugging Face platform.
55
+ """
56
+
57
+ def __init__(self):
58
+ self.evaluator = Evaluator()
59
+ self.evaluation_dataset = None
60
+
61
+ def evaluate_model(
62
+ self,
63
+ mp_dataset: pd.DataFrame,
64
+ model_repo: str,
65
+ model_type: str, # can be 'encoder' or 'decoder'
66
+ entropy_topk: int = 100,
67
+ row_limit: int = None,
68
+ ) -> pd.DataFrame:
69
+ """
70
+ Generic evaluation function for encoder or decoder models.
71
+ """
72
+
73
+ # --- Prepare dataset ---
74
+ mp_dataset_iter = mp_dataset.itertuples()
75
+ if row_limit:
76
+ mp_dataset_iter = itertools.islice(mp_dataset_iter, row_limit)
77
+ n = len(mp_dataset) if not row_limit else row_limit
78
+
79
+ # --- Load model & tokenizer ---
80
+ is_encoder = model_type == "encoder"
81
+ model, tokenizer = self.evaluator.setup_parameters(model_repo, is_encoder)
82
+ results = []
83
+
84
+ # --- Evaluate each row ---
85
+ for row in tqdm(mp_dataset_iter, ncols=n):
86
+ row_results = self._init_row_results(row)
87
+
88
+ try:
89
+ if is_encoder:
90
+ self._evaluate_encoder_row(row, row_results)
91
+ else:
92
+ self._evaluate_decoder_row(row, row_results)
93
+
94
+ except TooManyMasksException:
95
+ logging.error(f"Too many masks in {row.sentence_id}")
96
+ continue
97
+ except Exception as e:
98
+ raise RuntimeError(f"Model/tokeniser issue: {e}") from e
99
+
100
+ # --- Entropy ---
101
+ entropy, entropy_norm = self.evaluator.get_entropy(entropy_topk, True)
102
+ row_results["entropy"] = entropy
103
+ row_results["entropy_norm"] = entropy_norm
104
+
105
+ results.append(row_results)
106
+
107
+ results_df = pd.DataFrame(results, columns=EVAL_TEMPLATE.keys())
108
+ self.evaluation_dataset = results_df
109
+ return results_df
110
+
111
+ def evaluate_from_minimal_pairs(
112
+ self,
113
+ mp_dataset_filepath: str,
114
+ model_repo: str,
115
+ model_type: str,
116
+ entropy_topk: int = 100,
117
+ row_limit: int = None,
118
+ ) -> pd.DataFrame:
119
+ mp_dataset = load_and_validate_mp_dataset(mp_dataset_filepath)
120
+ self.mp_dataset = mp_dataset
121
+ return self.evaluate_model(model_repo, model_type, entropy_topk, row_limit)
122
+
123
+ # --- Helper functions ---
124
+ def _init_row_results(self, row):
125
+ row_results = EVAL_TEMPLATE.copy()
126
+ row_results.update(row._asdict())
127
+ return row_results
128
+
129
+ def _evaluate_encoder_row(self, row, row_results):
130
+ prob_gram, prob_ungram = self.evaluator.run_masked_prediction(
131
+ row.masked_text,
132
+ row.form_grammatical,
133
+ row.form_ungrammatical,
134
+ )
135
+
136
+ row_results["p_grammatical"] = prob_gram
137
+ row_results["p_ungrammatical"] = prob_ungram
138
+ row_results["I_grammatical"] = compute_surprisal(prob_gram)
139
+ row_results["I_ungrammatical"] = compute_surprisal(prob_ungram)
140
+
141
+ if "ood_minimal_pairs" in row:
142
+ ood_pairs_str = row.ood_pairs
143
+ ood_pairs = ast.literal_eval(ood_pairs_str)
144
+ all_ood_probs_gram = []
145
+ all_ood_probs_ungram = []
146
+ for pair in ood_pairs:
147
+ ood_prob_gram, ood_prob_ungram = self.evaluator.run_masked_prediction(
148
+ row.masked_text, pair[0], pair[1]
149
+ )
150
+ all_ood_probs_gram.append(ood_prob_gram)
151
+ all_ood_probs_ungram.append(ood_prob_ungram)
152
+
153
+ avg_ood_prob_gram = compute_mean(all_ood_probs_gram)
154
+ avg_ood_prob_ungram = compute_mean(all_ood_probs_ungram)
155
+
156
+ row_results["ood_p_grammatical"] = avg_ood_prob_gram
157
+ row_results["ood_p_ungrammatical"] = avg_ood_prob_ungram
158
+ row_results["ood_I_grammatical"] = compute_surprisal(avg_ood_prob_gram)
159
+ row_results["ood_I_ungrammatical"] = compute_surprisal(avg_ood_prob_ungram)
160
+
161
+ def _evaluate_decoder_row(self, row, row_results):
162
+ prob_gram, prob_ungram = self.evaluator.run_next_word_prediction(
163
+ row.prompt_text, row.form_grammatical, row.form_ungrammatical
164
+ )
165
+ row_results["p_grammatical"] = prob_gram
166
+ row_results["p_ungrammatical"] = prob_ungram
167
+ row_results["I_grammatical"] = compute_surprisal(prob_gram)
168
+ row_results["I_ungrammatical"] = compute_surprisal(prob_ungram)
169
+
170
+ if "ood_minimal_pairs" in row:
171
+ ood_pairs_str = row.ood_pairs
172
+ ood_pairs = ast.literal_eval(ood_pairs_str)
173
+ all_ood_probs_gram = []
174
+ all_ood_probs_ungram = []
175
+ for pair in ood_pairs:
176
+ ood_prob_gram, ood_prob_ungram = (
177
+ self.evaluator.run_next_word_prediction(
178
+ row.masked_text, pair[0], pair[1]
179
+ )
180
+ )
181
+ all_ood_probs_gram.append(ood_prob_gram)
182
+ all_ood_probs_ungram.append(ood_prob_ungram)
183
+
184
+ avg_ood_prob_gram = compute_mean(all_ood_probs_gram)
185
+ avg_ood_prob_ungram = compute_mean(all_ood_probs_ungram)
186
+
187
+ row_results["ood_p_grammatical"] = avg_ood_prob_gram
188
+ row_results["ood_p_ungrammatical"] = avg_ood_prob_ungram
189
+ row_results["ood_I_grammatical"] = compute_surprisal(avg_ood_prob_gram)
190
+ row_results["ood_I_ungrammatical"] = compute_surprisal(avg_ood_prob_ungram)
191
+
192
+ def get_norm_avg_surprisal_difference(self) -> float:
193
+ if not self.is_model_evaluated():
194
+ raise KeyError("Please evaluate a model first.")
195
+ return compute_normalised_surprisal_difference(
196
+ self.evaluation_dataset["p_grammatical"],
197
+ self.evaluation_dataset["p_ungrammatical"],
198
+ )
199
+
200
+ def get_avg_surprisal_difference(self, is_ood: bool = False) -> float:
201
+ p_grammatical_col = "p_grammatical" if not is_ood else "ood_p_grammatical"
202
+ p_ungrammatical_col = "p_ungrammatical" if not is_ood else "ood_p_ungrammatical"
203
+ if not self.is_model_evaluated():
204
+ raise KeyError("Please evaluate a model first.")
205
+ return compute_average_surprisal_difference(
206
+ self.evaluation_dataset[p_grammatical_col],
207
+ self.evaluation_dataset[p_ungrammatical_col],
208
+ )
209
+
210
+ def get_all_metrics(self):
211
+ if self.evaluation_dataset is not None:
212
+ print(self.evaluation_dataset.columns)
213
+ return calculate_all_metrics(self.evaluation_dataset)
214
+ else:
215
+ raise ValueError("Please evaluate a model first.")
216
+
217
+
218
+ class Evaluator:
219
+ def __init__(self):
220
+ self.tokeniser: PreTrainedTokenizerBase = None
221
+ self.model: PreTrainedModel = None
222
+
223
+ self.mask_token_index: int = -1
224
+ self.mask_probs: torch.Tensor | None = None
225
+ self.logits: torch.Tensor = None
226
+
227
+ def setup_parameters(
228
+ self, model_name: str, is_mlm: bool = True
229
+ ) -> Tuple[PreTrainedTokenizerBase, PreTrainedModel]:
230
+ if is_mlm:
231
+ self.tokeniser = AutoTokenizer.from_pretrained(model_name)
232
+ self.model = AutoModelForMaskedLM.from_pretrained(model_name)
233
+ else:
234
+ self.tokeniser = AutoTokenizer.from_pretrained(model_name)
235
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
236
+
237
+ # set to eval mode, disabling things like dropout
238
+ self.model.eval()
239
+
240
+ return self.model, self.tokeniser
241
+
242
+ def run_masked_prediction(
243
+ self, sentence: str, grammatical_word: str, ungrammatical_word: str
244
+ ) -> Tuple[float, float]:
245
+ if not self.model or not self.tokeniser:
246
+ raise RuntimeError("Model and tokenizer must be loaded before prediction.")
247
+
248
+ mask_token = self.tokeniser.mask_token
249
+ sentence_masked = sentence.replace("[MASK]", mask_token)
250
+
251
+ if sentence_masked.count(mask_token) != 1:
252
+ raise TooManyMasksException("Only single-mask sentences are supported.")
253
+
254
+ masked_ids = self.tokeniser.encode(sentence_masked, add_special_tokens=False)
255
+ mask_index = masked_ids.index(self.tokeniser.mask_token_id)
256
+
257
+ device = next(self.model.parameters()).device
258
+ g_ids = self.tokeniser.encode(grammatical_word, add_special_tokens=False)
259
+ u_ids = self.tokeniser.encode(ungrammatical_word, add_special_tokens=False)
260
+
261
+ g_prob = self._compute_masked_joint_probability(
262
+ masked_ids, mask_index, g_ids, device
263
+ )
264
+ u_prob = self._compute_masked_joint_probability(
265
+ masked_ids, mask_index, u_ids, device
266
+ )
267
+
268
+ return g_prob, u_prob
269
+
270
+ def _compute_masked_joint_probability(
271
+ self, input_ids: List[int], mask_index: int, word_ids: List[int], device
272
+ ) -> float:
273
+ input_ids_tensor = torch.tensor([input_ids], device=device)
274
+ log_prob = 0.0
275
+ index = mask_index
276
+
277
+ for i, tid in enumerate(word_ids):
278
+ with torch.no_grad():
279
+ logits = self.model(input_ids_tensor).logits
280
+
281
+ probs = F.softmax(logits[:, index, :], dim=-1)
282
+ token_prob = probs[0, tid].item()
283
+ log_prob += math.log(token_prob + 1e-12)
284
+
285
+ if i == 0:
286
+ self.mask_probs = probs
287
+
288
+ # Replace mask with predicted token
289
+ input_ids_tensor[0, index] = tid
290
+
291
+ # Insert new mask if more tokens remain
292
+ if i < len(word_ids) - 1:
293
+ input_ids_tensor = torch.cat(
294
+ [
295
+ input_ids_tensor[:, : index + 1],
296
+ torch.tensor([[self.tokeniser.mask_token_id]], device=device),
297
+ input_ids_tensor[:, index + 1 :],
298
+ ],
299
+ dim=1,
300
+ )
301
+
302
+ # debugging
303
+ tokens_after_insertion = self.tokeniser.convert_ids_to_tokens(
304
+ input_ids_tensor[0].tolist()
305
+ )
306
+
307
+ index += 1
308
+
309
+ return math.exp(log_prob)
310
+
311
+ def run_next_word_prediction(
312
+ self, context: str, grammatical_word: str, ungrammatical_word: str
313
+ ) -> Tuple[float, float]:
314
+ if not self.model or not self.tokeniser:
315
+ raise RuntimeError("Model and tokenizer must be loaded before prediction.")
316
+
317
+ context_ids = self.tokeniser.encode(context, add_special_tokens=False)
318
+ device = next(self.model.parameters()).device
319
+
320
+ g_ids = self.tokeniser.encode(grammatical_word, add_special_tokens=False)
321
+ u_ids = self.tokeniser.encode(ungrammatical_word, add_special_tokens=False)
322
+
323
+ g_prob = self._compute_next_word_joint_probability(context_ids, g_ids, device)
324
+ u_prob = self._compute_next_word_joint_probability(context_ids, u_ids, device)
325
+
326
+ return g_prob, u_prob
327
+
328
+ def _compute_next_word_joint_probability(
329
+ self, input_ids: List[int], word_ids: List[int], device
330
+ ) -> float:
331
+ input_ids_tensor = torch.tensor([input_ids], device=device)
332
+ # debugging
333
+ tokens_after_insertion = self.tokeniser.convert_ids_to_tokens(
334
+ input_ids_tensor[0].tolist()
335
+ )
336
+ log_prob = 0.0
337
+
338
+ for i, tid in enumerate(word_ids):
339
+ with torch.no_grad():
340
+ logits = self.model(input_ids_tensor).logits
341
+
342
+ index = input_ids_tensor.shape[1] - 1 # last token position
343
+ probs = F.softmax(logits[:, index, :], dim=-1)
344
+ token_prob = probs[0, tid].item()
345
+ log_prob += math.log(token_prob + 1e-12)
346
+
347
+ if i == 0:
348
+ self.mask_probs = probs
349
+
350
+ # Append predicted token to context
351
+ input_ids_tensor = torch.cat(
352
+ [input_ids_tensor, torch.tensor([[tid]], device=device)], dim=1
353
+ )
354
+
355
+ # debugging
356
+ tokens_after_insertion = self.tokeniser.convert_ids_to_tokens(
357
+ input_ids_tensor[0].tolist()
358
+ )
359
+
360
+ return math.exp(log_prob)
361
+
362
+ def get_entropy(self, k: int = 100, normalise: bool = False) -> float:
363
+ """Compute entropy over the prediction distribution.
364
+
365
+ Args:
366
+ k: Number of top tokens to consider.
367
+ normalise: Whether to normalise entropy.
368
+
369
+ Returns:
370
+ Entropy value.
371
+ Raises:
372
+ ValueError: If no probabilities are available.
373
+ """
374
+ if self.mask_probs is None:
375
+ raise ValueError("No output probabilities available. Run evaluation first.")
376
+ return compute_entropy(self.mask_probs, k, normalise)
377
+
378
+ def _get_mask_index(self, inputs: Any) -> int:
379
+ if "input_ids" not in inputs:
380
+ raise ValueError("Missing 'input_ids' in inputs.")
381
+ elif self.tokeniser.mask_token_id is None:
382
+ raise ValueError("The tokeniser does not have a defined mask_token_id.")
383
+
384
+ input_ids = inputs["input_ids"]
385
+ mask_positions = torch.where(input_ids == self.tokeniser.mask_token_id)
386
+
387
+ if len(mask_positions[0]) == 0:
388
+ raise ValueError("No mask token found in input_ids.")
389
+ elif len(mask_positions[0]) > 1:
390
+ raise ValueError("Multiple mask tokens found; expected only one.")
391
+
392
+ return (
393
+ mask_positions[1].item()
394
+ if len(mask_positions) > 1
395
+ else mask_positions[0].item()
396
+ )
397
+
398
+ def _get_mask_probabilities(
399
+ self, mask_token_index: int, logits: Any
400
+ ) -> torch.Tensor:
401
+ mask_logits = logits[0, mask_token_index, :]
402
+ probs = F.softmax(mask_logits, dim=-1) # shape: (vocab_size, )
403
+ return probs
grewtse/evaluators/metrics.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import math
4
+ from typing import List
5
+
6
+
7
+ def compute_mean(list_of_values: List[float]) -> float:
8
+ return sum(list_of_values) / len(list_of_values)
9
+
10
+
11
+ def compute_surprisal(p: float) -> float:
12
+ return -math.log2(p) if p and p > 0 else float("inf")
13
+
14
+ def compute_avg_surprisal(probs: pd.Series) -> float:
15
+ as_surprisal = probs.apply(compute_surprisal)
16
+ return as_surprisal.mean()
17
+
18
+
19
+ def compute_average_surprisal_difference(
20
+ correct_form_probs: pd.Series, wrong_form_probs: pd.Series
21
+ ) -> float:
22
+ correct_form_avg_surp = compute_avg_surprisal(correct_form_probs)
23
+ wrong_form_avg_surp = compute_avg_surprisal(wrong_form_probs)
24
+ return wrong_form_avg_surp - correct_form_avg_surp
25
+
26
+
27
+ def compute_normalised_surprisal_difference(
28
+ correct_form_probs: pd.Series, wrong_form_probs: pd.Series
29
+ ) -> float:
30
+ correct_form_avg_surp = compute_avg_surprisal(correct_form_probs)
31
+ wrong_form_avg_surp = compute_avg_surprisal(wrong_form_probs)
32
+ return (wrong_form_avg_surp - correct_form_avg_surp) / correct_form_avg_surp
33
+
34
+
35
+ def compute_entropy(probs, k=None, normalise=False):
36
+ probs = np.array(probs, dtype=np.float64)
37
+
38
+ # remove zeros to avoid log(0)
39
+ probs = probs[probs > 0]
40
+
41
+ # get top-k probabilities
42
+ if k is not None:
43
+ probs = np.sort(probs)[::-1][:k]
44
+ probs = probs / probs.sum() # renormalize to sum to 1
45
+
46
+ H = -np.sum(probs * np.log(probs))
47
+
48
+ if normalise:
49
+ n = len(probs)
50
+ return H, 1 - H / np.log(n)
51
+ else:
52
+ return H
53
+
54
+ def get_predictions(df: pd.DataFrame) -> np.ndarray:
55
+ """
56
+ Convert probabilities to binary predictions.
57
+ Predicts grammatical (1) if p_grammatical > p_ungrammatical, else ungrammatical (0).
58
+ """
59
+ predictions = (df['p_grammatical'] > df['p_ungrammatical']).astype(int)
60
+ return predictions.values
61
+
62
+
63
+ def calculate_accuracy(df: pd.DataFrame) -> float:
64
+ """
65
+ Calculate accuracy: proportion of correct predictions.
66
+ Assumes the model should always predict grammatical form (label = 1).
67
+ """
68
+ predictions = get_predictions(df)
69
+ # True labels: all should be grammatical (1)
70
+ true_labels = np.ones(len(df), dtype=int)
71
+
72
+ correct = np.sum(predictions == true_labels)
73
+ total = len(predictions)
74
+
75
+ return correct / total if total > 0 else 0.0
76
+
77
+ def calculate_all_metrics(df: pd.DataFrame) -> dict:
78
+ predictions = get_predictions(df)
79
+ true_labels = np.ones(len(df), dtype=int)
80
+
81
+ # Calculate confusion matrix components
82
+ tp = np.sum((predictions == 1) & (true_labels == 1))
83
+ fp = np.sum((predictions == 1) & (true_labels == 0))
84
+ fn = np.sum((predictions == 0) & (true_labels == 1))
85
+ tn = np.sum((predictions == 0) & (true_labels == 0))
86
+
87
+ total = len(predictions)
88
+
89
+ # Calculate metrics
90
+ accuracy = (tp + tn) / total if total > 0 else 0.0
91
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
92
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
93
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
94
+
95
+ return {
96
+ 'accuracy': round(accuracy,2),
97
+ 'precision': round(precision, 2),
98
+ 'recall': round(recall, 2),
99
+ 'f1': round(f1, 2),
100
+ 'true_positives': int(tp),
101
+ 'false_positives': int(fp),
102
+ 'false_negatives': int(fn),
103
+ 'true_negatives': int(tn)
104
+ }
grewtse/pipeline.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import random
4
+ import logging
5
+ from grewtse.preprocessing import ConlluParser
6
+
7
+ logging.basicConfig(
8
+ level=logging.INFO,
9
+ format="%(asctime)s [%(levelname)s] %(message)s",
10
+ handlers=[logging.FileHandler("app.log"), logging.StreamHandler()],
11
+ )
12
+
13
+
14
+ class GrewTSEPipe:
15
+ """
16
+ Main pipeline controller for generating prompt- or masked-based minimal-pair datasets derived from UD treebanks.
17
+
18
+ This class acts as a high-level interface for the Grew-TSE workflow:
19
+
20
+ 1. Parse treebanks to build lexical item datasets.
21
+ 2. Generate masked or prompt-based datasets using GREW.
22
+ 3. Create minimal pairs for syntactic evaluation.
23
+ """
24
+
25
+ def __init__(self):
26
+ self.parser = ConlluParser()
27
+
28
+ self.treebank_paths: list[str] = []
29
+ self.lexical_items: pd.DataFrame | None = None
30
+ self.grew_generated_dataset: pd.DataFrame | None = None
31
+ self.mp_dataset: pd.DataFrame | None = None
32
+ self.exception_dataset: pd.DataFrame | None = None
33
+ self.evaluation_results: pd.DataFrame | None = None
34
+
35
+ # 1. Initial step, parse a treebank from a .conllu file
36
+ def parse_treebank(
37
+ self, filepaths: str | list[str], reset: bool = False
38
+ ) -> pd.DataFrame:
39
+ """
40
+ Parse one or more treebanks and create a lexical item set.
41
+ A lexical item set is a dataset of words and their features.
42
+
43
+ Args:
44
+ filepaths: Path or list of paths to treebank files.
45
+ reset: If True, clears existing lexical_items before parsing.
46
+
47
+ """
48
+
49
+ if isinstance(filepaths, str):
50
+ filepaths = [filepaths] # wrap single path in list
51
+
52
+ try:
53
+ if reset or self.lexical_items is None:
54
+ self.lexical_items = pd.DataFrame()
55
+ self.treebank_paths = []
56
+
57
+ self.lexical_items = self.parser.build_lexicon(filepaths)
58
+ self.treebank_paths = filepaths
59
+
60
+ return self.lexical_items
61
+ except Exception as e:
62
+ raise Exception(f"Issue parsing treebank: {e}")
63
+
64
+ def load_lexicon(self, filepath: str, treebank_paths: list[str]) -> None:
65
+ """
66
+ Load a previously generated lexicon (typically returned from the parse_treebank function) from disk and attach it to the pipeline.
67
+
68
+ This method is used when you want to resume processing using an existing LI_set that was
69
+ generated earlier and saved as a CSV. It loads the LI_set, validates the required columns,
70
+ sets the appropriate index, and updates the pipeline and parser with the loaded data.
71
+
72
+ Args:
73
+ filepath (str):
74
+ Path to the CSV file containing the LI_set to load. The file must contain the
75
+ columns ``"sentence_id"`` and ``"token_id"``.
76
+ treebank_paths (list[str]):
77
+ A list of paths to the treebanks associated with the LI_set. These paths are stored
78
+ so the pipeline can later reference the corresponding treebanks when generating or
79
+ analyzing data.
80
+
81
+ Raises:
82
+ FileNotFoundError:
83
+ If the CSV file cannot be found at the given ``filepath``.
84
+ ValueError:
85
+ If the required index columns (``"sentence_id"``, ``"token_id"``) are missing.
86
+
87
+ Example:
88
+ >>> pipe = GrewTSEPipe()
89
+ >>> pipe.load_lexicon("output/li_set.csv", ["treebank1.conllu", "treebank2.conllu"])
90
+
91
+ """
92
+ if not os.path.exists(filepath):
93
+ raise FileNotFoundError(f"LI_set file not found: {filepath}")
94
+
95
+ li_df = pd.read_csv(filepath)
96
+
97
+ required_cols = {"sentence_id", "token_id"}
98
+ missing = required_cols - set(li_df.columns)
99
+ if missing:
100
+ raise ValueError(
101
+ f"Missing required columns in LI_set: {', '.join(missing)}"
102
+ )
103
+
104
+ li_df.set_index(["sentence_id", "token_id"], inplace=True)
105
+
106
+ self.lexical_items = li_df
107
+ self.parser.lexicon = li_df
108
+ self.treebank_paths = treebank_paths
109
+
110
+ def generate_masked_dataset(
111
+ self, query: str, target_node: str, mask_token: str = "[MASK]"
112
+ ) -> pd.DataFrame:
113
+ """
114
+ Once a treebank has been parsed, if testing models on the task of masked language modelling (MLM) e.g. for encoder models, then you can generate a masked dataset with default token [MASK] by providing
115
+ a GREW query that isolates a particular construction and a target node that identifies the element
116
+ in that construction that you want to test.
117
+
118
+ :param query: the GREW query that specifies a construction. Test them over at https://universal.grew.fr/
119
+ :param target_node: the particular variable that you defined in your GREW query representing the target word
120
+ :return: a DataFrame consisting of the sentence ID in the given treebank, the index of the token to be masked in the set of tokens, the list of all tokens, the matched token itself, the original text, and lastly the masked text.
121
+
122
+ """
123
+
124
+ if not self.is_treebank_parsed():
125
+ raise ValueError(
126
+ "Cannot create masked dataset: no treebank or invalid treebank filepath provided."
127
+ )
128
+
129
+ results = self.parser.build_masked_dataset(
130
+ self.treebank_paths, query, target_node, mask_token
131
+ )
132
+ self.grew_generated_dataset = results["masked"]
133
+ self.exception_dataset = results["exception"]
134
+ return self.grew_generated_dataset
135
+
136
+ def generate_prompt_dataset(self, query: str, target_node: str) -> pd.DataFrame:
137
+ """
138
+ Once a treebank has been parsed, if testing models on the task of next-token prediction (NTP) e.g. for decoder models, then you can use this function to generate a prompt dataset by providing
139
+ a GREW query that isolates a particular construction and a target node that identifies the element
140
+ in that construction that you want to test.
141
+
142
+ :param query: the GREW query that specifies a construction. Test them over at https://universal.grew.fr/
143
+ :param target_node: the particular variable that you defined in your GREW query representing the target word
144
+ :return: a DataFrame consisting of the sentence ID in the given treebank, the index of the target token, the list of all tokens, the matched token itself, the original text, and the created prompt.
145
+
146
+ """
147
+
148
+ if not self.is_treebank_parsed():
149
+ raise ValueError(
150
+ "Cannot create prompt dataset: no treebank or invalid treebank filepath provided."
151
+ )
152
+
153
+ prompt_dataset = self.parser.build_prompt_dataset(
154
+ self.treebank_paths, query, target_node
155
+ )
156
+ self.grew_generated_dataset = prompt_dataset
157
+ return prompt_dataset
158
+
159
+ def generate_minimal_pair_dataset(
160
+ self,
161
+ morph_features: dict,
162
+ upos_features: dict | None,
163
+ ood_pairs: int | None = None,
164
+ has_leading_whitespace: bool = True,
165
+ ) -> pd.DataFrame:
166
+ """
167
+ After generating a masked or prompt dataset, that same dataset with minimal pairs can be created using this function by specifying the feature that you would like to change. You can also specify whether you want additional 'OOD' pairs to be created, as well as whether there should be a leading whitespace at the start of each minimal pair item.
168
+
169
+ NOTE: morph_features and upos_features expects lowercase keys, values remain as in the treebank.
170
+
171
+ :param morph_features: the morphological features from the UD treebank that you want to adjust for the second element of the minimal pair e.g. { 'case': 'Dat' } may convert the original target item e.g. German 'Hunde' (dog.PLUR.NOM / dog.PLUR.ACC) to the dative case e.g. 'Hunden' (dog.PLUR.DAT) to form the minimal pair (Hunde, Hunden). The exact keys and values will depend on the treebank that you're working with.
172
+ :param upos_features: the universal part-of-speech tags from the UD treebank that you want to adjust for the second element of the minimal pair e.g. { 'upos': 'VERB' } will only search for verbs
173
+ :param ood_pairs: a boolean argument that specifies whether you want alternative (likely semantically implausible) minimal pairs to be provided for each example. These may help in evaluating generalisation performance.
174
+ :param has_trailing_whitespace: a boolean argument that specifies whether an additional whitespace is included at the beginning of each element in the minimal pair e.g. (' is', ' are')
175
+ :return: a DataFrame containing the masked sentences or prompts as well as the minimal pairs
176
+
177
+ """
178
+
179
+ if self.grew_generated_dataset is None:
180
+ raise ValueError(
181
+ "Cannot generate minimal pairs: treebank must be parsed and masked first."
182
+ )
183
+
184
+ def convert_row_to_feature(row):
185
+ return self.parser.to_syntactic_feature(
186
+ row["sentence_id"],
187
+ row["match_id"] - 1,
188
+ row["match_token"],
189
+ morph_features,
190
+ upos_features
191
+ )
192
+
193
+ alternative_row = self.grew_generated_dataset.apply(
194
+ convert_row_to_feature, axis=1
195
+ )
196
+ self.mp_dataset = self.grew_generated_dataset
197
+ self.mp_dataset["form_ungrammatical"] = alternative_row
198
+
199
+ self.mp_dataset = self.mp_dataset.rename(
200
+ columns={"match_token": "form_grammatical"}
201
+ )
202
+
203
+ # rule 1: drop any rows where we don't find a minimal pair (henceforth MP)
204
+ self.mp_dataset = self.mp_dataset.dropna(subset=["form_ungrammatical"])
205
+
206
+ # rule 2: don't include MPs where the minimal pairs are the same string
207
+ self.mp_dataset = self.mp_dataset[
208
+ self.mp_dataset["form_grammatical"] != self.mp_dataset["form_ungrammatical"]
209
+ ]
210
+
211
+ # add leading whitespace if requested.
212
+ # this is useful for models that expect whitespace at the end such as many decoder models
213
+ if has_leading_whitespace:
214
+ self.mp_dataset["form_grammatical"] = (
215
+ " " + self.mp_dataset["form_grammatical"]
216
+ )
217
+ self.mp_dataset["form_ungrammatical"] = (
218
+ " " + self.mp_dataset["form_ungrammatical"]
219
+ )
220
+
221
+ # handle the assigning of the out-of-distribution pairs
222
+ if ood_pairs:
223
+ # assign additional pairs for OOD data
224
+ all_grammatical = self.mp_dataset["form_grammatical"].to_list()
225
+ all_ungrammatical = self.mp_dataset["form_ungrammatical"].to_list()
226
+
227
+ # combine both into one vocabulary
228
+ words = set(zip(all_grammatical, all_ungrammatical))
229
+
230
+ def pick_words(row):
231
+ excluded = (row["form_grammatical"], row["form_ungrammatical"])
232
+ available = list(words - {excluded})
233
+ return random.sample(list(available), ood_pairs)
234
+
235
+ # Apply function to each row
236
+ self.mp_dataset["ood_minimal_pairs"] = self.mp_dataset.apply(
237
+ pick_words, axis=1
238
+ )
239
+
240
+ return self.mp_dataset
241
+
242
+ def get_morphological_features(self) -> list:
243
+ """
244
+ Get a list of all available morphological features in a given treebank.
245
+ Similarly, you can go to the treebank's respective webpage to find this information.
246
+ A treebank must first be parsed in order to use this function.
247
+
248
+ :return: a list of strings with each morphological feature in the treebank.
249
+
250
+ """
251
+
252
+ if not self.is_treebank_parsed():
253
+ raise ValueError("Cannot get features: You must parse a treebank first.")
254
+
255
+ morph_df = self.lexical_items.copy()
256
+ morph_df.columns = [
257
+ col.replace("feats__", "") if col.startswith("feats__") else col
258
+ for col in morph_df.columns
259
+ ]
260
+
261
+ return morph_df
262
+
263
+ def is_treebank_parsed(self) -> bool:
264
+ return self.lexical_items is not None
265
+
266
+ def is_dataset_masked(self) -> bool:
267
+ return self.grew_generated_dataset is not None
268
+
269
+ def is_model_evaluated(self) -> bool:
270
+ return self.evaluation_dataset is not None
271
+
272
+ def get_lexical_items(self) -> pd.DataFrame:
273
+ return self.lexical_items
274
+
275
+ def get_masked_dataset(self) -> pd.DataFrame:
276
+ return self.grew_generated_dataset
277
+
278
+ def get_minimal_pair_dataset(self) -> pd.DataFrame:
279
+ return self.mp_dataset
280
+
281
+ def get_exceptions_dataset(self):
282
+ return self.exception_dataset
283
+
284
+ def get_num_exceptions(self):
285
+ if self.exception_dataset is not None:
286
+ return self.exception_dataset.shape[0]
287
+ else:
288
+ return -1
289
+
290
+ def are_minimal_pairs_generated(self) -> bool:
291
+ return (
292
+ self.is_treebank_parsed()
293
+ and self.is_dataset_masked()
294
+ and ("form_ungrammatical" in self.mp_dataset.columns)
295
+ )
grewtse/preprocessing/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .conllu_parser import ConlluParser
2
+
3
+ __all__ = ["ConlluParser"]
grewtse/preprocessing/conllu_parser.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from grewtse.preprocessing.grew_dependencies import match_dependencies
2
+ from grewtse.preprocessing.reconstruction import (
3
+ perform_token_surgery,
4
+ recursive_match_token,
5
+ )
6
+ from conllu import parse_incr, Token
7
+ from typing import Any
8
+ import pandas as pd
9
+ import numpy as np
10
+ import logging
11
+
12
+
13
+ class ConlluParser:
14
+ """
15
+ A class designed to parse .conllu files for Grew-TSE, that is, the standard format for UD treebanks.
16
+
17
+ """
18
+
19
+ def __init__(self) -> None:
20
+ self.lexicon: pd.DataFrame = None
21
+
22
+ def build_lexicon(self, filepaths: list[str] | str) -> pd.DataFrame:
23
+ """
24
+ Create a DataFrame that contains the set of all words with their features as generated from a UD treebank.
25
+ This is essential for the subsequent generation of minimal pairs.
26
+ This was not designed to handle treebanks that assign differing names to features, so please ensure multiple treebank files are all from the same treebank or treebank schema.
27
+
28
+ :param filepaths: a list of strings corresponding to the UD treebank files e.g. ["german_treebank_part_A.conllu", "german_treebank_part_B.conllu"].
29
+ :return: a DataFrame with all words and their features.
30
+
31
+ """
32
+ rows = []
33
+
34
+ if isinstance(filepaths, str):
35
+ filepaths = [filepaths] # wrap single path in list
36
+
37
+ for conllu_path in filepaths:
38
+ with open(conllu_path, "r", encoding="utf-8") as f:
39
+ for tokenlist in parse_incr(f):
40
+ # get the sentence ID in the dataset
41
+ sent_id = tokenlist.metadata["sent_id"]
42
+ logging.info(f"Parsing Sentence: {sent_id}")
43
+
44
+ # iterate over each token
45
+ for token in tokenlist:
46
+ # check if it's worth saving to our lexical item dataset
47
+ is_valid_token = is_valid_for_lexicon(token)
48
+ if not is_valid_token:
49
+ continue
50
+
51
+ # from the token object create a dict and append
52
+ row = build_token_row(token, sent_id)
53
+ rows.append(row)
54
+
55
+ lexicon_df = pd.DataFrame(rows)
56
+
57
+ # make sure our nan values are interpreted as such
58
+ lexicon_df.replace("nan", np.nan, inplace=True)
59
+
60
+ # create the (Sentence ID, Token ID) primary key
61
+ lexicon_df.set_index(["sentence_id", "token_id"], inplace=True)
62
+
63
+ self.lexicon = lexicon_df
64
+
65
+ return lexicon_df
66
+
67
+ def to_syntactic_feature(
68
+ self,
69
+ sentence_id: str,
70
+ token_id: str,
71
+ token: str,
72
+ alt_morph_constraints: dict,
73
+ alt_universal_constraints: dict,
74
+ ) -> str | None:
75
+ """
76
+ The most important function for the finding of minimal pairs. Converts a given lexical item taken from a UD treebank sentence
77
+ to another lexical item of the same lemma but with the specified differing feature(s).
78
+
79
+ :param sentence_id: the ID in the treebank of the sentence.
80
+ :param token_id: the token index in the list of tokens corresponding to the isolated target word.
81
+ :param token: the token string itself that is the isolated target word.
82
+ :param alt_morph_constraints: the alternative morphological feature(s) for the target word.
83
+ :param alt_universal_constraints: the alternative UPOS feature(s) for the target word.
84
+ :return: a string representing the converted target word.
85
+ """
86
+
87
+ # distinguish morphological from universal features
88
+ # todo: find a better way to do this
89
+ prefix = "feats__"
90
+ # prefix = ''
91
+ alt_morph_constraints = {
92
+ prefix + key: value for key, value in alt_morph_constraints.items()
93
+ }
94
+
95
+ token_features = self.get_features(sentence_id, token_id)
96
+
97
+ token_features.update(alt_morph_constraints)
98
+ token_features.update(alt_universal_constraints)
99
+ lexical_items = self.lexicon
100
+
101
+ # get only those items which are the same lemma
102
+ lemma = self.get_lemma(sentence_id, token_id)
103
+ lemma_mask = lexical_items["lemma"] == lemma
104
+ lexical_items = lexical_items[lemma_mask]
105
+ logging.info(f"Looking for form {lemma}")
106
+
107
+ lexical_items = construct_candidate_set(lexical_items, token_features)
108
+ # ensure that it doesn't allow minimal pairs with different start cases e.g business, Business
109
+ filtered = lexical_items[
110
+ lexical_items["form"].apply(lambda w: is_same_start_case(w, token))
111
+ ]
112
+ if not filtered.empty:
113
+ return filtered["form"].iloc[0]
114
+ else:
115
+ return None
116
+
117
+ def get_lexicon(self) -> pd.DataFrame:
118
+ return self.lexicon
119
+
120
+ # this shouldn't be hard coded
121
+ def get_feature_names(self) -> list:
122
+ return self.lexicon.columns[4:].to_list()
123
+
124
+ # todo: add more safety
125
+ def get_features(self, sentence_id: str, token_id: int) -> dict:
126
+ return self.lexicon.loc[(sentence_id, token_id)][
127
+ self.get_feature_names()
128
+ ].to_dict()
129
+
130
+ def get_lemma(self, sentence_id: str, token_id: str) -> str:
131
+ return self.lexicon.loc[(sentence_id, token_id)]["lemma"]
132
+
133
+ def get_candidate_set(
134
+ self, universal_constraints: dict, morph_constraints: dict
135
+ ) -> pd.DataFrame:
136
+ has_parsed_conllu = self.lexicon is not None
137
+ if not has_parsed_conllu:
138
+ raise ValueError("Please parse a ConLLU file first.")
139
+
140
+ morph_constraints = {f"feats__{k}": v for k, v in morph_constraints.items()}
141
+ are_morph_features_valid = all(
142
+ f in self.lexicon.columns for f in morph_constraints.keys()
143
+ )
144
+ are_universal_features_valid = all(
145
+ f in self.lexicon.columns for f in universal_constraints.keys()
146
+ )
147
+ if not are_morph_features_valid or not are_universal_features_valid:
148
+ raise KeyError(
149
+ "Features provided for candidate set are not valid features in the dataset."
150
+ )
151
+
152
+ all_constraints = {**universal_constraints, **morph_constraints}
153
+ candidate_set = construct_candidate_set(self.lexicon, all_constraints)
154
+ return candidate_set
155
+
156
+ def build_prompt_dataset(
157
+ self,
158
+ filepaths: list[str],
159
+ grew_query: str,
160
+ dependency_node: str,
161
+ encoding: str = "utf-8",
162
+ ):
163
+ prompt_cutoff_token = "[PROMPT_CUTOFF]"
164
+ results = self.build_masked_dataset(
165
+ filepaths, grew_query, dependency_node, prompt_cutoff_token, encoding
166
+ )
167
+ prompt_dataset = results["masked"]
168
+
169
+ def substring_up_to_token(s: str, token: str) -> str:
170
+ idx = s.find(token)
171
+ return s[:idx].strip() if idx != -1 else s.strip()
172
+
173
+ prompt_dataset["prompt_text"] = prompt_dataset["masked_text"].apply(
174
+ lambda x: substring_up_to_token(x, prompt_cutoff_token)
175
+ )
176
+ prompt_dataset = prompt_dataset.drop(["masked_text"], axis=1)
177
+ return prompt_dataset
178
+
179
+ def build_masked_dataset(
180
+ self,
181
+ filepaths: list[str],
182
+ grew_query: str,
183
+ dependency_node: str,
184
+ mask_token: str,
185
+ encoding: str = "utf-8",
186
+ ):
187
+ masked_dataset = []
188
+ exception_dataset = []
189
+
190
+ try:
191
+ for filepath in filepaths:
192
+ get_tokens_to_mask = match_dependencies(
193
+ filepath, grew_query, dependency_node
194
+ )
195
+
196
+ with open(filepath, "r", encoding=encoding) as data_file:
197
+ for sentence in parse_incr(data_file):
198
+
199
+ sentence_id = sentence.metadata["sent_id"]
200
+ sentence_text = sentence.metadata["text"]
201
+
202
+ if sentence_id in get_tokens_to_mask:
203
+ for i in range(len(sentence)):
204
+ sentence[i]["index"] = i
205
+
206
+ token_to_mask_id = get_tokens_to_mask[sentence_id]
207
+
208
+ try:
209
+ t_match = [
210
+ tok
211
+ for tok in sentence
212
+ if tok.get("id") == token_to_mask_id
213
+ ][0]
214
+ t_match_form = t_match["form"]
215
+ t_match_index = t_match["index"]
216
+ sentence_as_str_list = [t["form"] for t in sentence]
217
+ except KeyError:
218
+ logging.info(
219
+ "There was a mismatch for the GREW-based ID and the Conllu ID."
220
+ )
221
+ exception_dataset.append(
222
+ {
223
+ "sentence_id": sentence_id,
224
+ "match_id": None,
225
+ "all_tokens": None,
226
+ "match_token": None,
227
+ "original_text": sentence_text,
228
+ }
229
+ )
230
+ continue
231
+
232
+ try:
233
+ matched_token_start_index = recursive_match_token(
234
+ sentence_text, # the original string
235
+ sentence_as_str_list.copy(), # the string as a list of tokens
236
+ t_match_index, # the index of the token to be replaced
237
+ [
238
+ "_",
239
+ " ",
240
+ ], # todo: skip lines where we don't encounter accounted for tokens
241
+ )
242
+ except ValueError:
243
+ exception_dataset.append(
244
+ {
245
+ "sentence_id": sentence_id,
246
+ "match_id": token_to_mask_id,
247
+ "all_tokens": sentence_as_str_list,
248
+ "match_token": t_match_form,
249
+ "original_text": sentence_text,
250
+ }
251
+ )
252
+ continue
253
+
254
+ # let's replace the matched token with a MASK token
255
+ masked_sentence = perform_token_surgery(
256
+ sentence_text,
257
+ t_match_form,
258
+ mask_token,
259
+ matched_token_start_index,
260
+ )
261
+
262
+ # the sentence ID and match ID are together a primary key
263
+ masked_dataset.append(
264
+ {
265
+ "sentence_id": sentence_id,
266
+ "match_id": token_to_mask_id,
267
+ "match_token": t_match_form,
268
+ "original_text": sentence_text,
269
+ "masked_text": masked_sentence,
270
+ }
271
+ )
272
+ except Exception as e:
273
+ print(f"Issue building dataset: {e}")
274
+
275
+ masked_dataset_df = pd.DataFrame(masked_dataset)
276
+ exception_dataset_df = pd.DataFrame(exception_dataset)
277
+
278
+ return {"masked": masked_dataset_df, "exception": exception_dataset_df}
279
+
280
+
281
+ def construct_candidate_set(
282
+ lexicon: pd.DataFrame, target_features: dict
283
+ ) -> pd.DataFrame:
284
+ """
285
+ This constructs a list of words which have the same feature set as the
286
+ target features which are passed as an argument. These resulting words are termed 'candidates'.
287
+
288
+ :param lexicon: the DataFrame consisting of all lexical items and their features
289
+ :param target_features: the differing features of the candidates.
290
+ :return: a DataFrame containing the candidate subset of the lexicon.
291
+
292
+ """
293
+
294
+ # optionally restrict search to a certain type of lexical item
295
+ subset = lexicon
296
+
297
+ # continuously filter the dataframe so as to be left
298
+ # only with those lexical items which match the target
299
+ # features
300
+ # this includes cases
301
+ for feat, value in target_features.items():
302
+ # ensure feature is a valid feature in feature set
303
+ if feat not in subset.columns:
304
+ print(subset.columns)
305
+ raise KeyError("Invalid feature provided to confound set: {}".format(feat))
306
+
307
+ # slim the mask down using each feature
308
+ # interesting edge case: np.nan == np.nan returns false!
309
+ mask = (subset[feat] == value) | (subset[feat].isna() & pd.isna(value))
310
+ subset = subset[mask]
311
+
312
+ return subset
313
+
314
+
315
+ def is_same_start_case(s1, s2):
316
+ if not s1 or not s2:
317
+ return False
318
+ return s1[0].isupper() == s2[0].isupper()
319
+
320
+
321
+ def is_valid_for_lexicon(token: Token) -> bool:
322
+ punctuation = [".", ",", "!", "?", "*"]
323
+
324
+ # skip multiword tokens, malformed entries and punctuation
325
+ is_punctuation = token.get("form") in punctuation
326
+ is_valid_type = isinstance(token, dict)
327
+ has_valid_id = isinstance(token.get("id"), int)
328
+ return is_valid_type and has_valid_id and not is_punctuation
329
+
330
+
331
+ def build_token_row(token: Token, sentence_id: str) -> dict[str, Any]:
332
+ # get all token features such as Person, Mood, etc
333
+ feats = token.get("feats") or {}
334
+
335
+ row = {
336
+ "sentence_id": sentence_id,
337
+ "token_id": token.get("id") - 1, # IDs are reduced by one to start at 0
338
+ "form": token.get("form"),
339
+ "lemma": token.get("lemma"),
340
+ "upos": token.get("upos"),
341
+ "xpos": token.get("xpos"),
342
+ }
343
+
344
+ # add each morphological feature as a column
345
+ for feat_name, feat_value in feats.items():
346
+ row["feats__" + feat_name.lower()] = feat_value
347
+
348
+ return row
grewtse/preprocessing/grew_dependencies.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from grewpy import Corpus, Request, set_config
2
+
3
+
4
+ """
5
+ def match_dependencies(
6
+ filepaths: list[str] | str, grew_query: str, dependency_node: str
7
+ ) -> dict:
8
+ set_config("sud") # ud or basic
9
+ dep_matches = {}
10
+
11
+ if isinstance(filepaths, str):
12
+ filepaths = [filepaths] # wrap single path in list
13
+
14
+ try:
15
+ for corpus_path in filepaths:
16
+ # run the GREW request on the corpus
17
+ print("Corpus Path ", corpus_path)
18
+ corpus = Corpus(str(corpus_path))
19
+ request = Request(grew_query)
20
+ occurrences = corpus.search(request)
21
+
22
+ # step 2
23
+ for occ in occurrences:
24
+ sent_id = occ["sent_id"]
25
+
26
+ object_node_id = int(occ["matching"]["nodes"][dependency_node])
27
+
28
+ # one match per sentence
29
+ # todo: handle multiple matches per sentence
30
+ dep_matches[sent_id] = object_node_id
31
+ except KeyError:
32
+ raise KeyError(
33
+ "You must provide a dependency node name which exists in your GREW pattern."
34
+ )
35
+ except Exception as e:
36
+ raise ValueError(f"Invalid GREW query: {e}")
37
+
38
+ return dep_matches
39
+ """
40
+
41
+ def match_dependencies(
42
+ filepaths: list[str] | str, grew_query: str, dependency_node: str
43
+ ) -> dict:
44
+ set_config("sud") # ud or basic
45
+ dep_matches = {}
46
+
47
+ if isinstance(filepaths, str):
48
+ filepaths = [filepaths] # wrap single path in list
49
+
50
+ try:
51
+ for corpus_path in filepaths:
52
+ # run the GREW request on the corpus
53
+ corpus = Corpus(str(corpus_path))
54
+ request = Request(grew_query)
55
+ occurrences = corpus.search(request)
56
+
57
+ # step 2
58
+ for occ in occurrences:
59
+ sent_id = occ["sent_id"]
60
+
61
+ # Handle both integer and float string representations
62
+ node_id_str = occ["matching"]["nodes"][dependency_node]
63
+ try:
64
+ # Try converting to float first, then to int
65
+ object_node_id = float(node_id_str)
66
+ except ValueError:
67
+ raise ValueError("Cannot convert.")
68
+
69
+ # one match per sentence
70
+ # todo: handle multiple matches per sentence
71
+ dep_matches[sent_id] = object_node_id
72
+ except KeyError as e:
73
+ if dependency_node not in str(e):
74
+ raise KeyError(
75
+ f"You must provide a dependency node name which exists in your GREW pattern. Missing key: {e}"
76
+ )
77
+ raise KeyError(
78
+ "You must provide a dependency node name which exists in your GREW pattern."
79
+ )
80
+ except Exception as e:
81
+ raise ValueError(f"Invalid GREW query: {e}")
82
+
83
+ return dep_matches
grewtse/preprocessing/reconstruction.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def perform_token_surgery(
2
+ sentence: str,
3
+ original_token: str,
4
+ replacement_token: str,
5
+ start_index: int,
6
+ ) -> str:
7
+ t_len = len(original_token)
8
+
9
+ return sentence[:start_index] + replacement_token + sentence[start_index + t_len :]
10
+
11
+
12
+ def recursive_match_token(
13
+ full_sentence: str,
14
+ token_list: list[str],
15
+ token_list_mask_index: int,
16
+ skippable_tokens: list[str],
17
+ ) -> int:
18
+ # ensure we can retrieve another token
19
+ n_remaining_tokens = len(token_list)
20
+ if n_remaining_tokens == 0:
21
+ raise ValueError(
22
+ "Mask index not reached but token list has been iterated for sentence: {}".format(
23
+ full_sentence
24
+ )
25
+ )
26
+ t = token_list[0]
27
+
28
+ # returns the index of the first occurrence
29
+ # of the token t
30
+ match_index = full_sentence.find(t)
31
+ is_match_found = match_index != -1
32
+ has_reached_mask_token = token_list_mask_index == 0
33
+
34
+ # BASE CASE
35
+ if has_reached_mask_token and is_match_found:
36
+ # we're at the end
37
+ return match_index
38
+ # RECURSIVE CASE
39
+ elif is_match_found:
40
+ sliced_sentence = full_sentence[match_index + len(t) :]
41
+ token_list.pop(0)
42
+
43
+ return (
44
+ match_index
45
+ + len(t)
46
+ + recursive_match_token(
47
+ sliced_sentence,
48
+ token_list,
49
+ token_list_mask_index - 1,
50
+ skippable_tokens,
51
+ )
52
+ )
53
+ else:
54
+ # no match found, is t irrelevant?
55
+ if t in skippable_tokens:
56
+ # need to watch out with the slicing here
57
+ # tests are important
58
+ sliced_sentence = full_sentence[len(t) - 1 :]
59
+ token_list.pop(0)
60
+ return recursive_match_token(
61
+ sliced_sentence,
62
+ token_list,
63
+ token_list_mask_index - 1,
64
+ skippable_tokens,
65
+ )
66
+ else:
67
+ raise ValueError(
68
+ "Token not found in string nor has it been specified as skippable: {}".format(
69
+ t
70
+ )
71
+ )
grewtse/utils/__init__.py ADDED
File without changes
grewtse/utils/validation.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+
4
+ def load_and_validate_mp_dataset(filepath: str):
5
+ required_columns = {
6
+ "sentence_id",
7
+ "match_id",
8
+ "original_text",
9
+ }
10
+
11
+ try:
12
+ df = pd.read_csv(filepath)
13
+
14
+ missing = required_columns - set(df.columns)
15
+ if missing:
16
+ raise ValueError(f"Missing required column(s): {', '.join(missing)}")
17
+
18
+ return df
19
+
20
+ except FileNotFoundError:
21
+ raise FileNotFoundError(f"File not found: {filepath}")
22
+
23
+ except pd.errors.ParserError as e:
24
+ raise pd.errors.ParserError(f"Error parsing CSV file: {e}")
grewtse/visualise/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .visualiser import GrewTSEVisualiser
2
+
3
+ __all__ = ["GrewTSEVisualiser"]
grewtse/visualise/visualiser.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from plotnine import (
3
+ labs,
4
+ theme,
5
+ theme_bw,
6
+ guides,
7
+ position_nudge,
8
+ aes,
9
+ geom_violin,
10
+ geom_line,
11
+ geom_jitter,
12
+ scale_x_discrete,
13
+ ggplot,
14
+ )
15
+ from pathlib import Path
16
+ import math
17
+
18
+
19
+ class GrewTSEVisualiser:
20
+ """
21
+ A basic visualisation class that creates a violin plot based on a syntactic evaluation.
22
+ """
23
+
24
+ def __init__(self) -> None:
25
+ self.data = None
26
+
27
+ def visualise_syntactic_performance(
28
+ self,
29
+ results: pd.DataFrame,
30
+ title: str,
31
+ target_x_label: str,
32
+ alt_x_label: str,
33
+ x_axis_label: str,
34
+ y_axis_label: str,
35
+ filename: str,
36
+ ) -> None:
37
+ """
38
+ Visualise a syntactic performance evaluation result.
39
+
40
+ :param results: pass the results DataFrame created by the GrewTSEEvaluator.
41
+ :param title: Give the diagram a main title.
42
+ :param target_x_label: Give the original target word and hence first word in the minimal pair a label e.g. 'Accusative'.
43
+ :param alt_x_label: Give the second element in the minimal pair a label e.g. 'Dative'.
44
+ :param x_axis_label: Give the X Axis a title.
45
+ :param y_axis_label: Give the Y Axis a title.
46
+ :param filename: A filename to save the visualisation.
47
+ :return:
48
+ """
49
+
50
+ visualise_slope(
51
+ filename,
52
+ results,
53
+ target_x_label,
54
+ alt_x_label,
55
+ x_axis_label,
56
+ y_axis_label,
57
+ title,
58
+ )
59
+
60
+
61
+ def visualise_slope(
62
+ path: Path,
63
+ results: pd.DataFrame,
64
+ target_x_label: str,
65
+ alt_x_label: str,
66
+ x_axis_label: str,
67
+ y_axis_label: str,
68
+ title: str,
69
+ ):
70
+ lsize = 0.65
71
+ fill_alpha = 0.7
72
+
73
+ # X-axis: Acc, Gen
74
+ # Y-axis: surprisal
75
+ filtered_df = results[
76
+ results["form_ungrammatical"].notna()
77
+ & (results["form_ungrammatical"].str.strip() != "")
78
+ ]
79
+
80
+ filtered_df["subject_id"] = filtered_df.index
81
+
82
+ # Melt the dataframe
83
+ df_long = pd.melt(
84
+ filtered_df,
85
+ id_vars=["subject_id"],
86
+ value_vars=["p_grammatical", "p_ungrammatical"],
87
+ var_name="source",
88
+ value_name="log_prob",
89
+ )
90
+
91
+ # Map source to fixed x-axis labels
92
+ df_long["x_label"] = df_long["source"].map(
93
+ {"p_grammatical": target_x_label, "p_ungrammatical": alt_x_label}
94
+ )
95
+
96
+ def surprisal(p: float) -> float:
97
+ return -math.log2(p)
98
+
99
+ def confidence(p: float) -> float:
100
+ return math.log2(p)
101
+
102
+ df_long["surprisal"] = df_long["log_prob"].apply(surprisal)
103
+
104
+ p = (
105
+ ggplot(df_long, aes(x="x_label", y="surprisal", fill="x_label"))
106
+ + scale_x_discrete(limits=[target_x_label, alt_x_label])
107
+ + geom_jitter(aes(color="x_label"), width=0.01, alpha=0.7)
108
+ +
109
+ # geom_text(aes(label='label'), nudge_y=0.1) +
110
+ geom_line(aes(group="subject_id"), color="gray", alpha=0.7, size=0.2)
111
+ + geom_violin(
112
+ df_long[df_long["x_label"] == target_x_label],
113
+ aes(x="x_label", y="surprisal", group="x_label"),
114
+ position=position_nudge(x=-0.2),
115
+ style="left-right",
116
+ alpha=fill_alpha,
117
+ size=lsize,
118
+ )
119
+ + geom_violin(
120
+ df_long[df_long["x_label"] == alt_x_label],
121
+ aes(x="x_label", y="surprisal", group="x_label"),
122
+ position=position_nudge(x=0.2),
123
+ style="right-left",
124
+ alpha=fill_alpha,
125
+ size=lsize,
126
+ )
127
+ + guides(fill=False)
128
+ + theme_bw()
129
+ + theme(figure_size=(8, 4), legend_position="none")
130
+ + labs(x=x_axis_label, y=y_axis_label, title=title)
131
+ )
132
+ p.save(path, width=14, height=8, dpi=300)