| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration | 
					
					
						
						| 
							 | 
						from datasets import load_dataset | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						dataset = load_dataset("wiki_dpr", "psgs_w100.nq.exact", trust_remote_code=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						retriever = RagRetriever.from_pretrained( | 
					
					
						
						| 
							 | 
						    "facebook/rag-token-base",  | 
					
					
						
						| 
							 | 
						    use_dummy_dataset=True,  | 
					
					
						
						| 
							 | 
						    trust_remote_code=True | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def generate_answer(question): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    inputs = tokenizer(question, return_tensors="pt") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    input_ids = inputs["input_ids"] | 
					
					
						
						| 
							 | 
						    retrieved_doc_ids = retriever.retrieve(input_ids) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    generated_ids = model.generate(input_ids, context_input_ids=retrieved_doc_ids["context_input_ids"]) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return answer | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    question = "Who was the first president of the United States?" | 
					
					
						
						| 
							 | 
						    print(f"Question: {question}") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    answer = generate_answer(question) | 
					
					
						
						| 
							 | 
						    print(f"Answer: {answer}") | 
					
					
						
						| 
							 | 
						
 |