Spaces:
Build error
Build error
| import argparse | |
| import json | |
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from tqdm.auto import tqdm | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder | |
| from common import articles_to_paragraphs, kilt_wikipedia_columns | |
| from common import kilt_wikipedia_paragraph_columns as columns | |
| def eval_generate(args): | |
| device = ("cuda" if torch.cuda.is_available() else "cpu") | |
| question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name) | |
| question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device) | |
| _ = question_model.eval() | |
| eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5') | |
| eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device) | |
| _ = eli5_model.eval() | |
| min_snippet_length = 20 | |
| topk = 21 | |
| min_chars_per_passage = 200 | |
| kilt_wikipedia = load_dataset("kilt_wikipedia", split="full") | |
| kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True, | |
| remove_columns=kilt_wikipedia_columns, | |
| batch_size=256, | |
| cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow", | |
| desc="Expanding wiki articles into paragraphs") | |
| # use paragraphs that are not simple fragments or very short sentences | |
| kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter( | |
| lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage) | |
| kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0) | |
| def embed_questions_for_retrieval(questions): | |
| query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| q_reps = question_model(query["input_ids"].to(device), | |
| query["attention_mask"].to(device)).pooler_output | |
| return q_reps.cpu().numpy() | |
| def query_index(question): | |
| question_embedding = embed_questions_for_retrieval([question]) | |
| scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk) | |
| retrieved_examples = [] | |
| r = list(zip(wiki_passages[k] for k in columns)) | |
| for i in range(topk): | |
| retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])}) | |
| return retrieved_examples | |
| def create_kilt_datapoint(q_id, query, answer, res_list): | |
| # make a KILT data point | |
| # see https://github.com/facebookresearch/KILT#kilt-data-format | |
| provenance = [{ | |
| "wikipedia_id": r["wikipedia_id"], # *mandatory* | |
| "title": r["title"], | |
| "section": r["section"], | |
| "start_paragraph_id": r["start_paragraph_id"], | |
| "start_character": r["start_character"], | |
| "end_paragraph_id": r["end_paragraph_id"], | |
| "end_character": r["end_character"], | |
| "text": r["text"], | |
| "bleu_score": None, # wrt original evidence | |
| "meta": None # dataset/task specific | |
| } for r in res_list] | |
| output = [{"answer": answer, "provenance": provenance}] | |
| return {"id": q_id, | |
| "input": query, | |
| "output": output, # each element is an answer or provenance (can have multiple of each) | |
| "meta": None # dataset/task specific | |
| } | |
| kilt_output = [] | |
| with open(args.kilt_input_file, "r") as f: | |
| kilt_items = [json.loads(x) for x in f.read().strip().split("\n")] | |
| progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document") | |
| for idx, item in enumerate(kilt_items): | |
| query = item["input"] | |
| res_list = query_index(query) | |
| res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)] | |
| documents = [res["text"] for res in res_list] | |
| conditioned_doc = "<P> " + " <P> ".join([d for d in documents]) | |
| query_and_docs = "question: {} context: {}".format(query, conditioned_doc) | |
| model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt") | |
| generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device), | |
| attention_mask=model_input["attention_mask"].to(device), | |
| min_length=50, | |
| max_length=250, | |
| do_sample=False, | |
| early_stopping=True, | |
| num_beams=8, | |
| temperature=1.0, | |
| top_k=None, | |
| top_p=None, | |
| no_repeat_ngram_size=3, | |
| num_return_sequences=1) | |
| answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True) | |
| kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list) | |
| kilt_output.append(kilt_example) | |
| progress_bar.update(1) | |
| with open(args.kilt_output_file, "w") as fp: | |
| for kilt_example in kilt_output: | |
| json.dump(kilt_example, fp) | |
| fp.write("\n") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str) | |
| parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str) | |
| parser.add_argument( | |
| "--question_encoder_name", | |
| default="vblagoje/dpr-question_encoder-single-lfqa-base", | |
| help="Question encoder to use", | |
| ) | |
| parser.add_argument( | |
| "--index_file_name", | |
| default="../data/kilt_dpr_wikipedia_first.faiss", | |
| help="Faiss index with passage embeddings", | |
| ) | |
| args = parser.parse_args() | |
| assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded" | |
| eval_generate(args) | |