Spaces:
Runtime error
Runtime error
| import os | |
| import glob | |
| import re | |
| import json | |
| import torch | |
| import torch.utils.data | |
| from transformers import AutoTokenizer, AutoModel | |
| from tqdm import tqdm | |
| tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) | |
| model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).bfloat16().cuda() | |
| choices = ["A", "B", "C", "D"] | |
| choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in choices] | |
| def build_prompt(text): | |
| return "[Round {}]\n\n问:{}\n\n答:".format(1, text) | |
| extraction_prompt = '综上所述,ABCD中正确的选项是:' | |
| accuracy_dict, count_dict = {}, {} | |
| with torch.no_grad(): | |
| for entry in glob.glob("./CEval/val/**/*.jsonl", recursive=True): | |
| dataset = [] | |
| with open(entry, encoding='utf-8') as file: | |
| for line in file: | |
| dataset.append(json.loads(line)) | |
| correct = 0 | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) | |
| for batch in tqdm(dataloader): | |
| texts = batch["inputs_pretokenized"] | |
| queries = [build_prompt(query) for query in texts] | |
| inputs = tokenizer(queries, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda') | |
| outputs = model.generate(**inputs, do_sample=False, max_new_tokens=512) | |
| intermediate_outputs = [] | |
| for idx in range(len(outputs)): | |
| output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):] | |
| response = tokenizer.decode(output) | |
| intermediate_outputs.append(response) | |
| answer_texts = [text + intermediate + "\n" + extraction_prompt for text, intermediate in | |
| zip(texts, intermediate_outputs)] | |
| input_tokens = [build_prompt(answer_text) for answer_text in answer_texts] | |
| inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda') | |
| outputs = model(**inputs, return_last_logit=True) | |
| logits = outputs.logits[:, -1] | |
| logits = logits[:, choice_tokens] | |
| preds = logits.argmax(dim=-1) | |
| correct += (preds.cpu() == batch["label"]).sum().item() | |
| accuracy = correct / len(dataset) | |
| print(entry, accuracy) | |
| accuracy_dict[entry] = accuracy | |
| count_dict[entry] = len(dataset) | |
| acc_total, count_total = 0.0, 0 | |
| for key in accuracy_dict: | |
| acc_total += accuracy_dict[key] * count_dict[key] | |
| count_total += count_dict[key] | |
| print(acc_total / count_total) |