heerjtdev commited on
Commit
4fc6352
Β·
verified Β·
1 Parent(s): bb64eee

Update HF_LayoutLM_with_Passage.py

Browse files
Files changed (1) hide show
  1. HF_LayoutLM_with_Passage.py +0 -763
HF_LayoutLM_with_Passage.py CHANGED
@@ -1,766 +1,3 @@
1
- #
2
- # import json
3
- # import argparse
4
- # import os
5
- # import random
6
- # import torch
7
- # import torch.nn as nn
8
- # from torch.utils.data import Dataset, DataLoader, random_split
9
- # from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
10
- # from TorchCRF import CRF
11
- # from torch.optim import AdamW
12
- # from tqdm import tqdm
13
- # from sklearn.metrics import precision_recall_fscore_support
14
- #
15
- #
16
- # # --- Configuration for Augmentation ---
17
- # MAX_BBOX_DIMENSION = 999
18
- # MAX_SHIFT = 30
19
- # AUGMENTATION_FACTOR = 1
20
- #
21
- #
22
- # # -------------------------------------
23
- #
24
- #
25
- # # -------------------------
26
- # # Step 1: Preprocessing (Label Studio β†’ BIO + bboxes)
27
- # # -------------------------
28
- # def preprocess_labelstudio(input_path, output_path):
29
- # with open(input_path, "r", encoding="utf-8") as f:
30
- # data = json.load(f)
31
- #
32
- # processed = []
33
- # total_items = len(data) # Added for potential verbose logging
34
- # print(f"πŸ”„ Starting preprocessing of {total_items} documents...")
35
- #
36
- # for item in data:
37
- # words = item["data"]["original_words"]
38
- # bboxes = item["data"]["original_bboxes"]
39
- # labels = ["O"] * len(words)
40
- #
41
- # if "annotations" in item:
42
- # for ann in item["annotations"]:
43
- # for res in ann["result"]:
44
- # # Check if the result item is a span annotation
45
- # if "value" in res and "labels" in res["value"]:
46
- # text = res["value"]["text"]
47
- # tag = res["value"]["labels"][0]
48
- # # Some tokenizers may split words, so we must find a consecutive word match.
49
- # text_tokens = text.split()
50
- #
51
- # for i in range(len(words) - len(text_tokens) + 1):
52
- # if words[i:i + len(text_tokens)] == text_tokens:
53
- # labels[i] = f"B-{tag}"
54
- # for j in range(1, len(text_tokens)):
55
- # labels[i + j] = f"I-{tag}"
56
- # break # Move to next annotation if a match is found
57
- #
58
- # processed.append({"tokens": words, "labels": labels, "bboxes": bboxes})
59
- #
60
- # with open(output_path, "w", encoding="utf-8") as f:
61
- # json.dump(processed, f, indent=2, ensure_ascii=False)
62
- #
63
- # print(f"βœ… Preprocessed data saved to {output_path}")
64
- # return output_path
65
- #
66
- #
67
- # # -------------------------
68
- # # Step 1.5: Bounding Box Augmentation
69
- # # -------------------------
70
- #
71
- # def translate_bbox(bbox, shift_x, shift_y):
72
- # """
73
- # Translates a single bounding box [x_min, y_min, x_max, y_max] by (shift_x, shift_y)
74
- # and clamps the coordinates to the valid range [0, MAX_BBOX_DIMENSION].
75
- # """
76
- # x_min, y_min, x_max, y_max = bbox
77
- #
78
- # new_x_min = x_min + shift_x
79
- # new_y_min = y_min + shift_y
80
- # new_x_max = x_max + shift_x
81
- # new_y_max = y_max + shift_y
82
- #
83
- # # Clamp the new coordinates
84
- # new_x_min = max(0, min(new_x_min, MAX_BBOX_DIMENSION))
85
- # new_y_min = max(0, min(new_y_min, MAX_BBOX_DIMENSION))
86
- # new_x_max = max(0, min(new_x_max, MAX_BBOX_DIMENSION))
87
- # new_y_max = max(0, min(new_y_max, MAX_BBOX_DIMENSION))
88
- #
89
- # # Safety check
90
- # if new_x_min > new_x_max: new_x_min = new_x_max
91
- # if new_y_min > new_y_max: new_y_min = new_y_max
92
- #
93
- # return [new_x_min, new_y_min, new_x_max, new_y_max]
94
- #
95
- #
96
- # def augment_sample(sample):
97
- # """
98
- # Generates a new sample by translating all bounding boxes.
99
- # """
100
- # shift_x = random.randint(-MAX_SHIFT, MAX_SHIFT)
101
- # shift_y = random.randint(-MAX_SHIFT, MAX_SHIFT)
102
- #
103
- # new_sample = sample.copy()
104
- #
105
- # # Ensure tokens and labels are copied (they remain unchanged)
106
- # new_sample["tokens"] = sample["tokens"]
107
- # new_sample["labels"] = sample["labels"]
108
- #
109
- # # Translate all bounding boxes
110
- # new_bboxes = [translate_bbox(bbox, shift_x, shift_y) for bbox in sample["bboxes"]]
111
- # new_sample["bboxes"] = new_bboxes
112
- #
113
- # return new_sample
114
- #
115
- #
116
- # def augment_and_save_dataset(input_json_path, output_json_path):
117
- # """
118
- # Loads preprocessed data, performs augmentation, and saves the result.
119
- # """
120
- # print(f"πŸ”„ Loading preprocessed data from {input_json_path} for augmentation...")
121
- # with open(input_json_path, 'r', encoding="utf-8") as f:
122
- # training_data = json.load(f)
123
- #
124
- # augmented_data = []
125
- # original_count = len(training_data)
126
- #
127
- # print(f"πŸ”„ Starting augmentation (Factor: {AUGMENTATION_FACTOR}, {original_count} documents)...")
128
- #
129
- # for i, original_sample in enumerate(training_data):
130
- # # 1. Add the original sample
131
- # augmented_data.append(original_sample)
132
- #
133
- # # 2. Generate augmented samples
134
- # for _ in range(AUGMENTATION_FACTOR):
135
- # if "tokens" in original_sample and "labels" in original_sample and "bboxes" in original_sample:
136
- # augmented_data.append(augment_sample(original_sample))
137
- # else:
138
- # print(f"Warning: Skipping augmentation for sample {i} due to missing keys.")
139
- #
140
- # augmented_count = len(augmented_data)
141
- # print(f"Dataset Augmentation: Original samples: {original_count}, Total samples: {augmented_count}")
142
- #
143
- # # Save the augmented dataset
144
- # with open(output_json_path, 'w', encoding="utf-8") as f:
145
- # json.dump(augmented_data, f, indent=2, ensure_ascii=False)
146
- #
147
- # print(f"βœ… Augmented data saved to {output_json_path}")
148
- # return output_json_path
149
- #
150
- #
151
- # # -------------------------
152
- # # Step 2: Dataset Class (Unchanged)
153
- # # -------------------------
154
- # class LayoutDataset(Dataset):
155
- # def __init__(self, json_path, tokenizer, label2id, max_len=512):
156
- # with open(json_path, "r", encoding="utf-8") as f:
157
- # self.data = json.load(f)
158
- # self.tokenizer = tokenizer
159
- # self.label2id = label2id
160
- # self.max_len = max_len
161
- #
162
- # def __len__(self):
163
- # return len(self.data)
164
- #
165
- # def __getitem__(self, idx):
166
- # item = self.data[idx]
167
- # words, bboxes, labels = item["tokens"], item["bboxes"], item["labels"]
168
- #
169
- # # Tokenize
170
- # encodings = self.tokenizer(
171
- # words,
172
- # boxes=bboxes,
173
- # padding="max_length",
174
- # truncation=True,
175
- # max_length=self.max_len,
176
- # return_offsets_mapping=True,
177
- # return_tensors="pt"
178
- # )
179
- #
180
- # # Align labels to word pieces
181
- # word_ids = encodings.word_ids(batch_index=0)
182
- # label_ids = []
183
- # for word_id in word_ids:
184
- # if word_id is None:
185
- # label_ids.append(self.label2id["O"]) # [CLS], [SEP], padding
186
- # else:
187
- # label_ids.append(self.label2id.get(labels[word_id], self.label2id["O"]))
188
- #
189
- # encodings.pop("offset_mapping")
190
- # encodings["labels"] = torch.tensor(label_ids)
191
- #
192
- # return {key: val.squeeze(0) for key, val in encodings.items()}
193
- #
194
- #
195
- # # -------------------------
196
- # # Step 3: Model Architecture (Unchanged)
197
- # # -------------------------
198
- # class LayoutLMv3CRF(nn.Module):
199
- # def __init__(self, model_name, num_labels):
200
- # super().__init__()
201
- # self.layoutlm = LayoutLMv3Model.from_pretrained(model_name)
202
- # # self.layoutlm = LayoutLMv3Model.from_pretrained("heerjtdev/edugenius")
203
- # self.dropout = nn.Dropout(0.1)
204
- # self.classifier = nn.Linear(self.layoutlm.config.hidden_size, num_labels)
205
- # self.crf = CRF(num_labels)
206
- #
207
- # def forward(self, input_ids, bbox, attention_mask, labels=None):
208
- # outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
209
- # sequence_output = self.dropout(outputs.last_hidden_state)
210
- # emissions = self.classifier(sequence_output)
211
- #
212
- # if labels is not None:
213
- # # Training mode: calculate loss
214
- # log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
215
- # return -log_likelihood.mean()
216
- # else:
217
- # # Inference mode: decode best path
218
- # best_paths = self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
219
- # return best_paths
220
- #
221
- #
222
- # # -------------------------
223
- # # Step 4: Training + Evaluation (Unchanged)
224
- # # -------------------------
225
- # def train_one_epoch(model, dataloader, optimizer, device):
226
- # model.train()
227
- # total_loss = 0
228
- # for batch in tqdm(dataloader, desc="Training"):
229
- # batch = {k: v.to(device) for k, v in batch.items()}
230
- # labels = batch.pop("labels")
231
- # optimizer.zero_grad()
232
- # loss = model(**batch, labels=labels)
233
- # loss.backward()
234
- # optimizer.step()
235
- # total_loss += loss.item()
236
- # return total_loss / len(dataloader)
237
- #
238
- #
239
- # def evaluate(model, dataloader, device, id2label):
240
- # model.eval()
241
- # all_preds, all_labels = [], []
242
- # with torch.no_grad():
243
- # for batch in tqdm(dataloader, desc="Evaluating"):
244
- # batch = {k: v.to(device) for k, v in batch.items()}
245
- # labels = batch.pop("labels").cpu().numpy()
246
- # preds = model(**batch)
247
- # for p, l, mask in zip(preds, labels, batch["attention_mask"].cpu().numpy()):
248
- # valid = mask == 1
249
- # l = l[valid].tolist()
250
- # all_labels.extend(l)
251
- # all_preds.extend(p[:len(l)])
252
- #
253
- # # Exclude the "O" label and other special tokens if necessary, but using 'micro' average
254
- # # on all valid tokens is typically fine for the initial evaluation.
255
- # precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro", zero_division=0)
256
- # return precision, recall, f1
257
- #
258
- #
259
- # # -------------------------
260
- # # Step 5: Main Pipeline (Training) - MODIFIED LABELS + FILE PATH FIX
261
- # # -------------------------
262
- # def main(args):
263
- # # LABELS UPDATED: Added SECTION_HEADING and PASSAGE
264
- # labels = [
265
- # "O",
266
- # "B-QUESTION", "I-QUESTION",
267
- # "B-OPTION", "I-OPTION",
268
- # "B-ANSWER", "I-ANSWER",
269
- # "B-SECTION_HEADING", "I-SECTION_HEADING",
270
- # "B-PASSAGE", "I-PASSAGE"
271
- # ]
272
- # label2id = {l: i for i, l in enumerate(labels)}
273
- # id2label = {i: l for l, i in label2id.items()}
274
- #
275
- # # --- FIX for FileNotFoundError: Use a temporary directory for intermediate files ---
276
- # TEMP_DIR = "temp_intermediate_files"
277
- # os.makedirs(TEMP_DIR, exist_ok=True)
278
- # print(f"\n--- SETUP PHASE: Created temp directory: {TEMP_DIR} ---")
279
- #
280
- # # 1. Preprocess and save the initial training data
281
- # print("\n--- START PHASE: PREPROCESSING ---")
282
- #
283
- # # FIX: Prepend the directory path to the file name
284
- # initial_bio_json = os.path.join(TEMP_DIR, "training_data_bio_bboxes.json")
285
- # preprocess_labelstudio(args.input, initial_bio_json)
286
- #
287
- # # 2. Augment the dataset with translated bboxes
288
- # print("\n--- START PHASE: AUGMENTATION ---")
289
- #
290
- # # FIX: Prepend the directory path to the file name
291
- # augmented_bio_json = os.path.join(TEMP_DIR, "augmented_training_data_bio_bboxes.json")
292
- # final_data_path = augment_and_save_dataset(initial_bio_json, augmented_bio_json)
293
- #
294
- # # Clean up the intermediary file (optional)
295
- # # import shutil
296
- # # shutil.rmtree(TEMP_DIR)
297
- #
298
- # # 3. Load and split augmented dataset
299
- # print("\n--- START PHASE: MODEL/DATASET SETUP ---")
300
- # #MODEL_ID = "heerjtdev/edugenius"
301
- # tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
302
- # #tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MODEL_ID)
303
- #
304
- # dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
305
- # val_size = int(0.2 * len(dataset))
306
- # train_size = len(dataset) - val_size
307
- #
308
- # # Use a fixed seed for reproducibility in split
309
- # torch.manual_seed(42)
310
- # train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
311
- # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
312
- # val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
313
- #
314
- # # 4. Initialize and load model
315
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
316
- # print(f"Using device: {device}")
317
- # # Num_labels is based on the updated 'labels' list
318
- # model = LayoutLMv3CRF("microsoft/layoutlmv3-base", num_labels=len(labels)).to(device)
319
- # # model = LayoutLMv3CRF(MODEL_ID, num_labels=len(labels)).to(device)
320
- # ckpt_path = "checkpoints/layoutlmv3_crf_passage.pth"
321
- # os.makedirs("checkpoints", exist_ok=True)
322
- # if os.path.exists(ckpt_path):
323
- # # NOTE: Loading an old checkpoint will likely fail now because num_labels has changed,
324
- # # unless the old checkpoint had the *exact* same number of labels.
325
- # # It is recommended to start training from scratch.
326
- # # print(f"πŸ”„ Loading checkpoint from {ckpt_path}")
327
- # # model.load_state_dict(torch.load(ckpt_path, map_location=device))
328
- # print(f"⚠️ Starting fresh training. Old checkpoint {ckpt_path} may be incompatible with new label count.")
329
- #
330
- # optimizer = AdamW(model.parameters(), lr=args.lr)
331
- #
332
- # # 5. Training loop
333
- # for epoch in range(args.epochs):
334
- # print(f"\n--- START PHASE: EPOCH {epoch + 1}/{args.epochs} TRAINING ---")
335
- # avg_loss = train_one_epoch(model, train_loader, optimizer, device)
336
- #
337
- # print(f"\n--- START PHASE: EPOCH {epoch + 1}/{args.epochs} EVALUATION ---")
338
- # precision, recall, f1 = evaluate(model, val_loader, device, id2label)
339
- #
340
- # print(
341
- # f"Epoch {epoch + 1}/{args.epochs} | Loss: {avg_loss:.4f} | P: {precision:.3f} R: {recall:.3f} F1: {f1:.3f}")
342
- # torch.save(model.state_dict(), ckpt_path)
343
- # print(f"πŸ’Ύ Model saved at {ckpt_path}")
344
- #
345
- #
346
- #
347
- #
348
- # # -------------------------
349
- # # Step 7: Main Execution (Unchanged)
350
- # # -------------------------
351
- # if __name__ == "__main__":
352
- # parser = argparse.ArgumentParser(description="LayoutLMv3 Fine-tuning and Inference Script.")
353
- # parser.add_argument("--mode", type=str, required=True, choices=["train", "infer"],
354
- # help="Select mode: 'train' or 'infer'")
355
- # parser.add_argument("--input", type=str, help="Path to input file (Label Studio JSON for train, PDF for infer).")
356
- # parser.add_argument("--batch_size", type=int, default=4)
357
- # parser.add_argument("--epochs", type=int, default=5)
358
- # parser.add_argument("--lr", type=float, default=5e-5)
359
- # parser.add_argument("--max_len", type=int, default=512)
360
- # args = parser.parse_args()
361
- #
362
- # if args.mode == "train":
363
- # if not args.input:
364
- # parser.error("--input is required for 'train' mode.")
365
- # main(args)
366
-
367
-
368
- # import json
369
- # import argparse
370
- # import os
371
- # import random
372
- # import torch
373
- # import torch.nn as nn
374
- # from torch.utils.data import Dataset, DataLoader, random_split
375
- # # Using LayoutLMv3TokenizerFast, LayoutLMv3Model
376
- # from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
377
- # from transformers.utils import cached_file
378
- # from safetensors.torch import load_file
379
- # from TorchCRF import CRF
380
- # from torch.optim import AdamW
381
- # from tqdm import tqdm
382
- # from sklearn.metrics import precision_recall_fscore_support
383
- #
384
- # # --- Configuration for Augmentation ---
385
- # MAX_BBOX_DIMENSION = 1000
386
- # MAX_SHIFT = 30
387
- # AUGMENTATION_FACTOR = 1
388
- #
389
- # # -------------------------------------
390
- #
391
- # # --- Hugging Face Model ID ---
392
- # HF_MODEL_ID = "heerjtdev/edugenius"
393
- #
394
- #
395
- # # -----------------------------
396
- #
397
- #
398
- # # -------------------------
399
- # # Step 1: Preprocessing (Label Studio β†’ BIO + bboxes)
400
- # # -------------------------
401
- # def preprocess_labelstudio(input_path, output_path):
402
- # with open(input_path, "r", encoding="utf-8") as f:
403
- # data = json.load(f)
404
- #
405
- # processed = []
406
- # total_items = len(data) # Added for potential verbose logging
407
- # print(f"πŸ”„ Starting preprocessing of {total_items} documents...")
408
- #
409
- # for item in data:
410
- # words = item["data"]["original_words"]
411
- # bboxes = item["data"]["original_bboxes"]
412
- # labels = ["O"] * len(words)
413
- #
414
- # if "annotations" in item:
415
- # for ann in item["annotations"]:
416
- # for res in ann["result"]:
417
- # # Check if the result item is a span annotation
418
- # if "value" in res and "labels" in res["value"]:
419
- # text = res["value"]["text"]
420
- # tag = res["value"]["labels"][0]
421
- # # Some tokenizers may split words, so we must find a consecutive word match.
422
- # text_tokens = text.split()
423
- #
424
- # for i in range(len(words) - len(text_tokens) + 1):
425
- # if words[i:i + len(text_tokens)] == text_tokens:
426
- # labels[i] = f"B-{tag}"
427
- # for j in range(1, len(text_tokens)):
428
- # labels[i + j] = f"I-{tag}"
429
- # break # Move to next annotation if a match is found
430
- #
431
- # processed.append({"tokens": words, "labels": labels, "bboxes": bboxes})
432
- #
433
- # with open(output_path, "w", encoding="utf-8") as f:
434
- # json.dump(processed, f, indent=2, ensure_ascii=False)
435
- #
436
- # print(f"βœ… Preprocessed data saved to {output_path}")
437
- # return output_path
438
- #
439
- #
440
- # # -------------------------
441
- # # Step 1.5: Bounding Box Augmentation
442
- # # -------------------------
443
- #
444
- # def translate_bbox(bbox, shift_x, shift_y):
445
- # """
446
- # Translates a single bounding box [x_min, y_min, x_max, y_max] by (shift_x, shift_y)
447
- # and clamps the coordinates to the valid range [0, MAX_BBOX_DIMENSION].
448
- # """
449
- # x_min, y_min, x_max, y_max = bbox
450
- #
451
- # new_x_min = x_min + shift_x
452
- # new_y_min = y_min + shift_y
453
- # new_x_max = x_max + shift_x
454
- # new_y_max = y_max + shift_y
455
- #
456
- # # Clamp the new coordinates
457
- # new_x_min = max(0, min(new_x_min, MAX_BBOX_DIMENSION))
458
- # new_y_min = max(0, min(new_y_min, MAX_BBOX_DIMENSION))
459
- # new_x_max = max(0, min(new_x_max, MAX_BBOX_DIMENSION))
460
- # new_y_max = max(0, min(new_y_max, MAX_BBOX_DIMENSION))
461
- #
462
- # # Safety check
463
- # if new_x_min > new_x_max: new_x_min = new_x_max
464
- # if new_y_min > new_y_max: new_y_min = new_y_max
465
- #
466
- # return [new_x_min, new_y_min, new_x_max, new_y_max]
467
- #
468
- #
469
- # def augment_sample(sample):
470
- # """
471
- # Generates a new sample by translating all bounding boxes.
472
- # """
473
- # shift_x = random.randint(-MAX_SHIFT, MAX_SHIFT)
474
- # shift_y = random.randint(-MAX_SHIFT, MAX_SHIFT)
475
- #
476
- # new_sample = sample.copy()
477
- #
478
- # # Ensure tokens and labels are copied (they remain unchanged)
479
- # new_sample["tokens"] = sample["tokens"]
480
- # new_sample["labels"] = sample["labels"]
481
- #
482
- # # Translate all bounding boxes
483
- # new_bboxes = [translate_bbox(bbox, shift_x, shift_y) for bbox in sample["bboxes"]]
484
- # new_sample["bboxes"] = new_bboxes
485
- #
486
- # return new_sample
487
- #
488
- #
489
- # def augment_and_save_dataset(input_json_path, output_json_path):
490
- # """
491
- # Loads preprocessed data, performs augmentation, and saves the result.
492
- # """
493
- # print(f"πŸ”„ Loading preprocessed data from {input_json_path} for augmentation...")
494
- # with open(input_json_path, 'r', encoding="utf-8") as f:
495
- # training_data = json.load(f)
496
- #
497
- # augmented_data = []
498
- # original_count = len(training_data)
499
- #
500
- # print(f"πŸ”„ Starting augmentation (Factor: {AUGMENTATION_FACTOR}, {original_count} documents)...")
501
- #
502
- # for i, original_sample in enumerate(training_data):
503
- # # 1. Add the original sample
504
- # augmented_data.append(original_sample)
505
- #
506
- # # 2. Generate augmented samples
507
- # for _ in range(AUGMENTATION_FACTOR):
508
- # if "tokens" in original_sample and "labels" in original_sample and "bboxes" in original_sample:
509
- # augmented_data.append(augment_sample(original_sample))
510
- # else:
511
- # print(f"Warning: Skipping augmentation for sample {i} due to missing keys.")
512
- #
513
- # augmented_count = len(augmented_data)
514
- # print(f"Dataset Augmentation: Original samples: {original_count}, Total samples: {augmented_count}")
515
- #
516
- # # Save the augmented dataset
517
- # with open(output_json_path, 'w', encoding="utf-8") as f:
518
- # json.dump(augmented_data, f, indent=2, ensure_ascii=False)
519
- #
520
- # print(f"βœ… Augmented data saved to {output_json_path}")
521
- # return output_json_path
522
- #
523
- #
524
- # # -------------------------
525
- # # Step 2: Dataset Class
526
- # # -------------------------
527
- # class LayoutDataset(Dataset):
528
- # def __init__(self, json_path, tokenizer, label2id, max_len=512):
529
- # with open(json_path, "r", encoding="utf-8") as f:
530
- # self.data = json.load(f)
531
- # self.tokenizer = tokenizer
532
- # self.label2id = label2id
533
- # self.max_len = max_len
534
- #
535
- # def __len__(self):
536
- # return len(self.data)
537
- #
538
- # def __getitem__(self, idx):
539
- # item = self.data[idx]
540
- # words, bboxes, labels = item["tokens"], item["bboxes"], item["labels"]
541
- #
542
- # # Tokenize
543
- # encodings = self.tokenizer(
544
- # words,
545
- # boxes=bboxes,
546
- # padding="max_length",
547
- # truncation=True,
548
- # max_length=self.max_len,
549
- # return_offsets_mapping=True,
550
- # return_tensors="pt"
551
- # )
552
- #
553
- # # Align labels to word pieces
554
- # word_ids = encodings.word_ids(batch_index=0)
555
- # label_ids = []
556
- # for word_id in word_ids:
557
- # if word_id is None:
558
- # label_ids.append(self.label2id["O"]) # [CLS], [SEP], padding
559
- # else:
560
- # label_ids.append(self.label2id.get(labels[word_id], self.label2id["O"]))
561
- #
562
- # encodings.pop("offset_mapping")
563
- # encodings["labels"] = torch.tensor(label_ids)
564
- #
565
- # return {key: val.squeeze(0) for key, val in encodings.items()}
566
- #
567
- #
568
- # # -------------------------
569
- # # Step 3: Model Architecture (PATCHED TO LOAD WEIGHTS CORRECTLY)
570
- # # -------------------------
571
- # class LayoutLMv3CRF(nn.Module):
572
- # def __init__(self, model_name, num_labels, device):
573
- # super().__init__()
574
- #
575
- # # 1. Initialize the LayoutLMv3 model using the base class
576
- # # We start by initializing from the base configuration to ensure all weights are present
577
- # self.layoutlm = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base")
578
- #
579
- # # 2. Try to load the fine-tuned weights from the Hugging Face Hub/Cache
580
- # try:
581
- # # This resolves the path to the downloaded model.safetensors in the cache
582
- # # Assumes you have renamed your file on the Hugging Face Hub to 'model.safetensors'
583
- # weights_path = cached_file(model_name, "model.safetensors")
584
- # fine_tuned_weights = load_file(weights_path)
585
- #
586
- # # 3. Strip the Mismatching Prefix (Assuming 'layoutlm.' prefix from a previous wrapper)
587
- # new_state_dict = {}
588
- # prefix_to_strip = "layoutlm."
589
- #
590
- # for key, value in fine_tuned_weights.items():
591
- # if key.startswith(prefix_to_strip):
592
- # new_key = key[len(prefix_to_strip):]
593
- # new_state_dict[new_key] = value
594
- # else:
595
- # new_state_dict[key] = value
596
- #
597
- # # 4. Load the fixed state dictionary into the LayoutLMv3Model
598
- # # strict=False allows us to ignore classifier/CRF weights not in LayoutLMv3Model
599
- # print("πŸ”„ Successfully loaded and stripped keys. Loading base LayoutLMv3 weights...")
600
- #
601
- # # Load only the weights for the transformer body
602
- # missing_keys, unexpected_keys = self.layoutlm.load_state_dict(new_state_dict, strict=False)
603
- #
604
- # print(f"Weights loading done: {len(missing_keys)} missing, {len(unexpected_keys)} unexpected keys.")
605
- #
606
- # except Exception as e:
607
- # print(f"❌ Fine-tuned weights could not be loaded directly and mapped. Starting with random weights.")
608
- # print(f"Error: {e}")
609
- # # Fallback: Load the LayoutLMv3 component directly from the Hub ID (will result in random weights for layers)
610
- # self.layoutlm = LayoutLMv3Model.from_pretrained(model_name)
611
- #
612
- # # 5. Initialize the new heads (CRF layer and Classifier)
613
- # self.dropout = nn.Dropout(0.1)
614
- # self.classifier = nn.Linear(self.layoutlm.config.hidden_size, num_labels)
615
- # self.crf = CRF(num_labels)
616
- #
617
- # def forward(self, input_ids, bbox, attention_mask, labels=None):
618
- # outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
619
- # sequence_output = self.dropout(outputs.last_hidden_state)
620
- # emissions = self.classifier(sequence_output)
621
- #
622
- # if labels is not None:
623
- # # Training mode: calculate loss
624
- # log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
625
- # return -log_likelihood.mean()
626
- # else:
627
- # # Inference mode: decode best path
628
- # best_paths = self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
629
- # return best_paths
630
- #
631
- #
632
- # # -------------------------
633
- # # Step 4: Training + Evaluation
634
- # # -------------------------
635
- # def train_one_epoch(model, dataloader, optimizer, device):
636
- # model.train()
637
- # total_loss = 0
638
- # for batch in tqdm(dataloader, desc="Training"):
639
- # batch = {k: v.to(device) for k, v in batch.items()}
640
- # labels = batch.pop("labels")
641
- # optimizer.zero_grad()
642
- # loss = model(**batch, labels=labels)
643
- # loss.backward()
644
- # optimizer.step()
645
- # total_loss += loss.item()
646
- # return total_loss / len(dataloader)
647
- #
648
- #
649
- # def evaluate(model, dataloader, device, id2label):
650
- # model.eval()
651
- # all_preds, all_labels = [], []
652
- # with torch.no_grad():
653
- # for batch in tqdm(dataloader, desc="Evaluating"):
654
- # batch = {k: v.to(device) for k, v in batch.items()}
655
- # labels = batch.pop("labels").cpu().numpy()
656
- # # The model returns a list of lists of predicted labels in inference mode
657
- # preds = model(**batch)
658
- # for p, l, mask in zip(preds, labels, batch["attention_mask"].cpu().numpy()):
659
- # valid = mask == 1
660
- # l = l[valid].tolist()
661
- # all_labels.extend(l)
662
- # # Ensure pred length matches label length for the unmasked tokens
663
- # all_preds.extend(p[:len(l)])
664
- #
665
- # # Exclude the "O" label and other special tokens if necessary, but using 'micro' average
666
- # precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro", zero_division=0)
667
- # return precision, recall, f1
668
- #
669
- #
670
- # # -------------------------
671
- # # Step 5: Main Pipeline (Training) - MODIFIED MODEL/TOKENIZER LOADING
672
- # # -------------------------
673
- # def main(args):
674
- # # LABELS UPDATED: Added SECTION_HEADING and PASSAGE
675
- # labels = [
676
- # "O",
677
- # "B-QUESTION", "I-QUESTION",
678
- # "B-OPTION", "I-OPTION",
679
- # "B-ANSWER", "I-ANSWER",
680
- # "B-SECTION_HEADING", "I-SECTION_HEADING",
681
- # "B-PASSAGE", "I-PASSAGE"
682
- # ]
683
- # label2id = {l: i for i, l in enumerate(labels)}
684
- # id2label = {i: l for l, i in label2id.items()}
685
- #
686
- # # --- SETUP: Use a temporary directory for intermediate files ---
687
- # TEMP_DIR = "temp_intermediate_files"
688
- # os.makedirs(TEMP_DIR, exist_ok=True)
689
- # print(f"\n--- SETUP PHASE: Created temp directory: {TEMP_DIR} ---")
690
- #
691
- # # 1. Preprocess
692
- # print("\n--- START PHASE: PREPROCESSING ---")
693
- # initial_bio_json = os.path.join(TEMP_DIR, "training_data_bio_bboxes.json")
694
- # preprocess_labelstudio(args.input, initial_bio_json)
695
- #
696
- # # 2. Augment
697
- # print("\n--- START PHASE: AUGMENTATION ---")
698
- # augmented_bio_json = os.path.join(TEMP_DIR, "augmented_training_data_bio_bboxes.json")
699
- # final_data_path = augment_and_save_dataset(initial_bio_json, augmented_bio_json)
700
- #
701
- # # 3. Load and split augmented dataset
702
- # print("\n--- START PHASE: MODEL/DATASET SETUP ---")
703
- #
704
- # # Load tokenizer from the specified Hugging Face ID
705
- # tokenizer = LayoutLMv3TokenizerFast.from_pretrained(HF_MODEL_ID)
706
- #
707
- # dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
708
- # val_size = int(0.2 * len(dataset))
709
- # train_size = len(dataset) - val_size
710
- #
711
- # # Use a fixed seed for reproducibility in split
712
- # torch.manual_seed(42)
713
- # train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
714
- # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
715
- # val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
716
- #
717
- # # 4. Initialize and load model
718
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
719
- # print(f"Using device: {device}")
720
- #
721
- # # Pass the Hugging Face ID and device to the custom model wrapper
722
- # model = LayoutLMv3CRF(HF_MODEL_ID, num_labels=len(labels), device=device).to(device)
723
- #
724
- # ckpt_path = "checkpoints/layoutlmv3_crf_passage.pth"
725
- # os.makedirs("checkpoints", exist_ok=True)
726
- # if os.path.exists(ckpt_path):
727
- # print(f"⚠️ Starting fresh training. Old checkpoint {ckpt_path} may be incompatible with new label count.")
728
- #
729
- # optimizer = AdamW(model.parameters(), lr=args.lr)
730
- #
731
- # # 5. Training loop
732
- # for epoch in range(args.epochs):
733
- # print(f"\n--- START PHASE: EPOCH {epoch + 1}/{args.epochs} TRAINING ---")
734
- # avg_loss = train_one_epoch(model, train_loader, optimizer, device)
735
- #
736
- # print(f"\n--- START PHASE: EPOCH {epoch + 1}/{args.epochs} EVALUATION ---")
737
- # precision, recall, f1 = evaluate(model, val_loader, device, id2label)
738
- #
739
- # print(
740
- # f"Epoch {epoch + 1}/{args.epochs} | Loss: {avg_loss:.4f} | P: {precision:.3f} R: {recall:.3f} F1: {f1:.3f}")
741
- # torch.save(model.state_dict(), ckpt_path)
742
- # print(f"πŸ’Ύ Model saved at {ckpt_path}")
743
- #
744
- #
745
- # # -------------------------
746
- # # Step 7: Main Execution
747
- # # -------------------------
748
- # if __name__ == "__main__":
749
- # parser = argparse.ArgumentParser(description="LayoutLMv3 Fine-tuning and Inference Script.")
750
- # parser.add_argument("--mode", type=str, required=True, choices=["train", "infer"],
751
- # help="Select mode: 'train' or 'infer'")
752
- # parser.add_argument("--input", type=str, help="Path to input file (Label Studio JSON for train, PDF for infer).")
753
- # parser.add_argument("--batch_size", type=int, default=4)
754
- # parser.add_argument("--epochs", type=int, default=5)
755
- # parser.add_argument("--lr", type=float, default=5e-5)
756
- # parser.add_argument("--max_len", type=int, default=512)
757
- # args = parser.parse_args()
758
- #
759
- # if args.mode == "train":
760
- # if not args.input:
761
- # parser.error("--input is required for 'train' mode.")
762
- # main(args)
763
-
764
 
765
  import json
766
  import argparse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import json
3
  import argparse