deepsodha's picture
Upload 25 files
beb5479 verified
import json
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from shared.metrics import compute_rouge, compute_bleu, factuality_score
from shared.utils import print_banner
import torch
def run_eval_for_model(model_name, dataset):
print_banner(f"Evaluating {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
preds, refs = [], []
for row in dataset:
inputs = tokenizer(row["question"], return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=128)
preds.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
refs.append(row["answer"])
r = compute_rouge(preds, refs)
b = compute_bleu(preds, refs)
f = factuality_score(preds, refs)
return {"model": model_name, **r, **b, **f}
def evaluate_all():
from shared.utils import load_yaml_config
cfg = load_yaml_config("config.yaml")
dataset = load_dataset("json", data_files="datasets/retail_sample.jsonl", split="train[:50]")
results = [run_eval_for_model(m, dataset) for m in cfg["models"]]
json.dump(results, open("models/retail_eval_results.json", "w"), indent=2)
print("βœ… Saved results to models/retail_eval_results.json")
return results
if __name__ == "__main__":
evaluate_all()