Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import tempfile | |
| from io import BytesIO | |
| from urllib.request import urlretrieve | |
| import gradio as gr | |
| from gradio_pdf import PDF | |
| import torch | |
| from pdf2image import convert_from_path | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from colpali_engine.models import ColQwen2, ColQwen2Processor | |
| # ----------------------------- | |
| # Globals | |
| # ----------------------------- | |
| api_key = os.getenv("OPENAI_API_KEY", "") # <- use env var | |
| ds = [] # list of document embeddings (torch tensors) | |
| images = [] # list of PIL images (page-order) | |
| current_pdf_path = None # last (indexed) pdf path for preview | |
| # ----------------------------- | |
| # Model & processor | |
| # ----------------------------- | |
| device_map = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") | |
| model = ColQwen2.from_pretrained( | |
| "vidore/colqwen2-v1.0", | |
| torch_dtype=torch.bfloat16, | |
| device_map=device_map, | |
| attn_implementation="flash_attention_2" | |
| ).eval() | |
| processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0") | |
| # ----------------------------- | |
| # Utilities | |
| # ----------------------------- | |
| def encode_image_to_base64(image: Image.Image) -> str: | |
| """Encodes a PIL image to a base64 string.""" | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def query_gpt(query: str, retrieved_images: list[tuple[Image.Image, str]]) -> str: | |
| """Calls OpenAI's GPT model with the query and image data.""" | |
| if api_key and api_key.startswith("sk"): | |
| try: | |
| from openai import OpenAI | |
| base64_images = [encode_image_to_base64(im_caption[0]) for im_caption in retrieved_images] | |
| client = OpenAI(api_key=api_key.strip()) | |
| PROMPT = """ | |
| You are a smart assistant designed to answer questions about a PDF document. | |
| You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc). | |
| If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents. | |
| Give detailed and extensive answers, only containing info in the pages you are given. | |
| You can answer using information contained in plots and figures if necessary. | |
| Answer in the same language as the query. | |
| Query: {query} | |
| PDF pages: | |
| """.strip() | |
| response = client.responses.create( | |
| model="gpt-5-mini", | |
| input=[ | |
| { | |
| "role": "user", | |
| "content": ( | |
| [{"type": "input_text", "text": PROMPT.format(query=query)}] + | |
| [{"type": "input_image", | |
| "image_url": f"data:image/jpeg;base64,{im}"} | |
| for im in base64_images] | |
| ) | |
| } | |
| ], | |
| # max_tokens=500, | |
| ) | |
| return response.output_text | |
| except Exception as e: | |
| print(e) | |
| return "OpenAI API connection failure. Verify that OPENAI_API_KEY is set and valid (sk-***)." | |
| return "Set OPENAI_API_KEY in your environment to get a custom response." | |
| def _ensure_model_device(): | |
| dev = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") | |
| if str(model.device) != dev: | |
| model.to(dev) | |
| return dev | |
| # ----------------------------- | |
| # Indexing helpers | |
| # ----------------------------- | |
| def convert_files(pdf_path: str) -> list[Image.Image]: | |
| """Convert a single PDF path into a list of PIL Images (pages).""" | |
| imgs = convert_from_path(pdf_path, thread_count=4) | |
| if len(imgs) >= 500: | |
| raise gr.Error("The number of images in the dataset should be less than 500.") | |
| return imgs | |
| def index_gpu(imgs: list[Image.Image]) -> str: | |
| """Embed a list of images (pages) with ColPali and store in globals.""" | |
| global ds, images | |
| device = _ensure_model_device() | |
| # reset previous dataset | |
| ds = [] | |
| images = imgs | |
| dataloader = DataLoader( | |
| images, | |
| batch_size=4, | |
| shuffle=False, | |
| collate_fn=lambda x: processor.process_images(x).to(model.device), | |
| ) | |
| for batch_doc in tqdm(dataloader, desc="Indexing pages"): | |
| with torch.no_grad(): | |
| batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
| embeddings_doc = model(**batch_doc) | |
| ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
| return f"Indexed {len(images)} pages successfully." | |
| def index_from_path(pdf_path: str) -> str: | |
| """Public: index a local PDF file path.""" | |
| imgs = convert_files(pdf_path) | |
| return index_gpu(imgs) | |
| def index_from_url(url: str) -> tuple[str, str]: | |
| """ | |
| Download a PDF from URL and index it. | |
| Returns: | |
| status message, saved pdf path | |
| """ | |
| tmp_dir = tempfile.mkdtemp(prefix="colpali_") | |
| local_path = os.path.join(tmp_dir, "document.pdf") | |
| urlretrieve(url, local_path) | |
| status = index_from_path(local_path) | |
| return status, local_path | |
| # ----------------------------- | |
| # Search (MCP tool-friendly) | |
| # ----------------------------- | |
| def search(query: str, k: int = 5): | |
| """ | |
| Search within a PDF document for the most relevant pages to answer a query and synthetizes a short grounded answer using only those pages. | |
| MCP tool description: | |
| - name: mcp_test_search | |
| - description: Search within a PDF document for the most relevant pages to answer a query and synthetizes a short grounded answer using only those pages. | |
| - input_schema: | |
| type: object | |
| properties: | |
| query: {type: string, description: "User query in natural language."} | |
| k: {type: integer, minimum: 1, maximum: 20, default: 5. description: "Number of top pages to retrieve."} | |
| required: ["query"] | |
| Args: | |
| query (str): Natural-language question to search for. | |
| k (int): Number of top results to return (1โ10). | |
| Returns: | |
| ai_response (str): Text answer to the query grounded in content from the PDF, with citations (page numbers). | |
| """ | |
| global ds, images | |
| if not images or not ds: | |
| return [], "No document indexed yet. Upload a PDF or load the sample, then run Search." | |
| k = max(1, min(int(k), len(images))) | |
| device = _ensure_model_device() | |
| print(query) | |
| # Encode query | |
| qs = [] | |
| with torch.no_grad(): | |
| batch_query = processor.process_queries([query]).to(model.device) | |
| embeddings_query = model(**batch_query) | |
| qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
| # Score and select top-k | |
| scores = processor.score(qs, ds, device=device) | |
| top_k_indices = scores[0].topk(k).indices.tolist() | |
| print(top_k_indices) | |
| # Build gallery results with 1-based page numbering | |
| results = [] | |
| for idx in top_k_indices: | |
| page_num = idx + 1 | |
| results.append((images[idx], f"Page {page_num}")) | |
| # Generate grounded response | |
| ai_response = query_gpt(query, results) | |
| print(ai_response) | |
| return ai_response | |
| # ----------------------------- | |
| # Gradio UI callbacks | |
| # ----------------------------- | |
| def handle_upload(file) -> tuple[str, str | None]: | |
| """Index a user-uploaded PDF file.""" | |
| global current_pdf_path | |
| if file is None: | |
| return "Please upload a PDF.", None | |
| path = getattr(file, "name", file) | |
| status = index_from_path(path) | |
| current_pdf_path = path | |
| return status, path | |
| def handle_url(url: str) -> tuple[str, str | None]: | |
| """Index a PDF from URL (e.g., a sample).""" | |
| global current_pdf_path | |
| if not url or not url.lower().endswith(".pdf"): | |
| return "Please provide a direct PDF URL.", None | |
| status, path = index_from_url(url) | |
| current_pdf_path = path | |
| return status, path | |
| print("Uploading") | |
| print(handle_url("https://www.ipcc.ch/report/ar6/syr/downloads/report/IPCC_AR6_SYR_SPM.pdf")) | |
| # ----------------------------- | |
| # Gradio App | |
| # ----------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) ๐") | |
| gr.Markdown( | |
| """Demo to test ColQwen2 (ColPali) on PDF documents. | |
| ColPali is implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).""" | |
| ) | |
| with gr.Row(): | |
| # with gr.Column(scale=2): | |
| # gr.Markdown("## 1๏ธโฃ Load a PDF") | |
| # pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| # index_btn = gr.Button("๐ฅ Index Uploaded PDF", variant="secondary") | |
| # url_box = gr.Textbox( | |
| # label="Or index from URL", | |
| # placeholder="https://example.com/file.pdf", | |
| # value="https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf", | |
| # ) | |
| # index_url_btn = gr.Button("๐ Load Sample / From URL", variant="secondary") | |
| # status_box = gr.Textbox(label="Status", interactive=False) | |
| # pdf_view = PDF(label="PDF Preview") | |
| with gr.Column(scale=3): | |
| gr.Markdown("## 2๏ธโฃ Search") | |
| query = gr.Textbox(placeholder="Enter your query here", label="Query") | |
| k_slider = gr.Slider(minimum=1, maximum=20, step=1, label="Number of results", value=5) | |
| search_button = gr.Button("๐ Search", variant="primary") | |
| output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents") | |
| # Wiring | |
| # index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view]) | |
| # index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view]) | |
| search_button.click(search, inputs=[query, k_slider], outputs=[output_text]) | |
| if __name__ == "__main__": | |
| # Optional: pre-load the default sample at startup. | |
| # Comment these two lines if you prefer a "cold" start. | |
| # msg, path = index_from_url("https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf") | |
| # print(msg, "->", path) | |
| demo.queue(max_size=5).launch(debug=True, mcp_server=True) | |