Spaces:
Running
Running
| from __future__ import annotations | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Tuple | |
| import gradio as gr | |
| import requests | |
| import xmltodict | |
| from PyPDF2 import PdfReader | |
| from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline | |
| from transformers.pipelines.question_answering import QuestionAnsweringPipeline | |
| QA_MODEL_NAME = "ixa-ehu/SciBERT-SQuAD-QuAC" | |
| TEMP_PDF_PATH = "/tmp/arxiv_paper.pdf" | |
| ARXIV_URL_PATTERN = r"(http|https)://(arxiv.org/pdf/)+([0-9]+\.[0-9]+)\.pdf" | |
| def is_valid_url(url: str) -> bool: | |
| return re.fullmatch(ARXIV_URL_PATTERN, url) is not None | |
| class PaperMetaData: | |
| arxiv_id: str | |
| title: str | |
| summary: str | |
| text: str | |
| def _clean_field(text: str) -> str: | |
| text = re.sub(r"\n", " ", text) | |
| text = re.sub(r"\s+", " ", text) | |
| return text | |
| def from_api(cls, arxiv_id: str, text: str) -> PaperMetaData: | |
| paper_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}" | |
| response = requests.get(paper_url) | |
| paper_dict = xmltodict.parse(response.content)["feed"]["entry"] | |
| return PaperMetaData( | |
| arxiv_id=arxiv_id, | |
| title=cls._clean_field(paper_dict["title"]), | |
| summary=cls._clean_field(paper_dict["summary"]), | |
| text=text, | |
| ) | |
| def clean_text(text: str) -> str: | |
| text = re.sub(r"\x03|\x02", "", text) | |
| text = re.sub(r"-\s+", "", text) | |
| text = re.sub(r"\n", " ", text) | |
| return text | |
| class PDFPaper: | |
| def __init__(self, url: str): | |
| if not is_valid_url(url): | |
| raise ValueError("The URL provided is not a valid arxiv PDF url.") | |
| self.url = url | |
| self.arxiv_id = re.fullmatch(ARXIV_URL_PATTERN, url).group(3) | |
| def _download(self, download_path: str = TEMP_PDF_PATH) -> None: | |
| pdf_r = requests.get(self.url) | |
| pdf_r.raise_for_status() | |
| with open(download_path, "wb") as pdf_file: | |
| pdf_file.write(pdf_r.content) | |
| def read_text(self, pdf_path: str = TEMP_PDF_PATH) -> str: | |
| self._download(pdf_path) | |
| reader = PdfReader(pdf_path) | |
| pdf_text = " ".join([page.extract_text() for page in reader.pages]) | |
| return clean_text(pdf_text) | |
| def get_paper_full_data(self) -> PaperMetaData: | |
| return PaperMetaData.from_api(arxiv_id=self.arxiv_id, text=self.read_text()) | |
| def get_paper_data(url: str) -> Tuple[str, str, str]: | |
| paper_data = PDFPaper(url=url).get_paper_full_data() | |
| return paper_data.title, paper_data.summary, paper_data.text | |
| def get_qa_pipeline(qa_model_name: str = QA_MODEL_NAME) -> QuestionAnsweringPipeline: | |
| tokenizer = AutoTokenizer.from_pretrained(qa_model_name) | |
| model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name) | |
| qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer) | |
| return qa_pipeline | |
| def get_answer(question: str, context: str) -> str: | |
| qa_pipeline = get_qa_pipeline() | |
| prediction = qa_pipeline(question=question, context=context) | |
| return prediction["answer"] | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("# arXiv Paper Q&A\nImport an arXiv paper and ask questions about it!") | |
| gr.Markdown("## π Import the paper on arXiv") | |
| arxiv_url = gr.Textbox( | |
| label="arXiv Paper URL", placeholder="Insert here the URL of a paper on arXiv" | |
| ) | |
| fetch_document_button = gr.Button("Import Paper") | |
| paper_title = gr.Textbox(label="Paper Title") | |
| paper_summary = gr.Textbox(label="Paper Summary") | |
| paper_text = gr.Textbox(label="Paper Text") | |
| fetch_document_button.click( | |
| fn=get_paper_data, | |
| inputs=arxiv_url, | |
| outputs=[paper_title, paper_summary, paper_text], | |
| ) | |
| gr.Markdown("## π€¨ Ask a question about the paper") | |
| question = gr.Textbox(label="Ask a question about the paper:") | |
| ask_button = gr.Button("Ask me π€") | |
| answer = gr.Textbox(label="Answer:") | |
| ask_button.click(fn=get_answer, inputs=[question, paper_summary], outputs=answer) | |
| demo.launch() | |