Spaces:
Running
Running
| # train.py | |
| # -------------------------------------------------- | |
| # Training script with: | |
| # * bugβfree evaluation (no IndexError) | |
| # * faster throughput (preβtokenised data, DataLoader workers) | |
| # * higher GPU utilisation (larger batch, torch.compile, TF32) | |
| # -------------------------------------------------- | |
| import os, math, time, torch, torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from transformers import ( | |
| AutoTokenizer, | |
| get_linear_schedule_with_warmup, | |
| ) | |
| from mcqa_dataset import MCQADataset | |
| from mcqa_bert import MCQABERT | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_split(model, loader): | |
| """ | |
| Returns: (binary CLS accuracy, 4βway accuracy) | |
| Bug fixed β no more IndexError. | |
| """ | |
| model.eval() | |
| bin_correct = tot = 0 | |
| per_qid = {} # {qid: [(cid, logit, gold_flag), β¦]} | |
| with torch.no_grad(): | |
| for batch in loader: | |
| ids = batch["input_ids"].cuda() | |
| mask = batch["attention_mask"].cuda() | |
| label = batch["label"].cuda() | |
| logits = model(ids, mask) # (B) | |
| preds = (torch.sigmoid(logits) > 0.5).long() | |
| bin_correct += (preds == label).sum().item() | |
| tot += len(label) | |
| # stash logits for 4βway metric | |
| for qid, cid, logit, gold_flag in zip( | |
| batch["qid"], batch["cid"], logits.cpu(), batch["label"] | |
| ): | |
| per_qid.setdefault(qid, []).append( | |
| (cid, logit.item(), gold_flag.item()) | |
| ) | |
| correct4 = 0 | |
| for qid, opts in per_qid.items(): | |
| pred_cid = max(opts, key=lambda x: x[1])[0] # highest logit | |
| gold_cid = [cid for cid, _, flag in opts if flag == 1][0] | |
| if pred_cid == gold_cid: | |
| correct4 += 1 | |
| return bin_correct / tot, correct4 / len(per_qid) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| # ---------------- data ----------------- | |
| tok = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| train_ds = MCQADataset("train_complete.jsonl", tok) | |
| val_ds = MCQADataset("valid_complete.jsonl", tok) | |
| test_ds = MCQADataset("test_complete.jsonl", tok) | |
| train_loader = DataLoader( | |
| train_ds, batch_size=64, shuffle=True, | |
| num_workers=4, pin_memory=True, persistent_workers=True | |
| ) | |
| val_loader = DataLoader( | |
| val_ds, batch_size=128, num_workers=4, | |
| pin_memory=True, persistent_workers=True | |
| ) | |
| test_loader = DataLoader( | |
| test_ds, batch_size=128, num_workers=4, | |
| pin_memory=True, persistent_workers=True | |
| ) | |
| # ---------------- model ----------------- | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.set_float32_matmul_precision("high") | |
| model = MCQABERT().cuda() | |
| # Optional: compile for extra speed (PyTorchΒ β₯β―2.1) | |
| if hasattr(torch, "compile"): | |
| model = torch.compile(model) | |
| # AdamWΒ (fused=True on PyTorchΒ β₯β―2.2, else falls back) | |
| fused_ok = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), lr=3e-5, fused=fused_ok | |
| ) | |
| total_steps = len(train_loader) * 3 # 3 epochs | |
| scheduler = get_linear_schedule_with_warmup( | |
| optimizer, int(0.1 * total_steps), total_steps | |
| ) | |
| # ---------------- training ----------------- | |
| for epoch in range(1, 4): | |
| model.train() | |
| t0 = time.time() | |
| for batch in train_loader: | |
| ids = batch["input_ids"].cuda(non_blocking=True) | |
| mask = batch["attention_mask"].cuda(non_blocking=True) | |
| label = batch["label"].float().cuda(non_blocking=True) | |
| logits = model(ids, mask) | |
| loss = F.binary_cross_entropy_with_logits(logits, label) | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| dur = time.time() - t0 | |
| bin_acc, mc_acc = run_split(model, val_loader) | |
| print(f"Epoch {epoch}: " | |
| f"valβCLS={bin_acc:.3f} | valβ4way={mc_acc:.3f} " | |
| f"| time={dur/60:.1f}β―min") | |
| # ---------------- test ----------------- | |
| _, test_acc = run_split(model, test_loader) | |
| mem = torch.cuda.max_memory_allocated() / (1024 ** 3) | |
| print(f"Test 4βway accuracy = {test_acc:.3f}") | |
| print(f"Peak GPU memory = {mem:.1f}β―GB") | |
| if __name__ == "__main__": | |
| main() | |