Spaces:
Sleeping
Sleeping
File size: 10,400 Bytes
94b55f0 a246b15 5697c10 a246b15 5697c10 94b55f0 602d806 527f685 602d806 89cecf3 3649694 602d806 5dfd724 9b1e831 9c66171 5697c10 b9c715a 5697c10 9b1e831 5697c10 e7f8afe 9b1e831 9c66171 5697c10 6efb913 73f30e5 5697c10 73f30e5 a129662 5697c10 a129662 5697c10 cd6d6e9 d76e71e 511619b a129662 5697c10 cd6d6e9 f24b36e 5697c10 8777d11 a129662 8e6343b 40ef153 5697c10 0d01d71 6efb913 5697c10 ec28a2a 5697c10 f1d7f41 5697c10 0d01d71 5697c10 ec28a2a 5697c10 9b1e831 602d806 068f2e8 602d806 9b1e831 602d806 a2d6d06 5697c10 602d806 9c66171 602d806 5697c10 603af32 5697c10 f24b36e 5697c10 f24b36e 5697c10 603af32 5697c10 f24b36e 5697c10 4fd529e 5697c10 602d806 5697c10 0d01d71 4fd529e 5697c10 602d806 5697c10 4fd529e a315ebf 5697c10 75f63ac 59741c0 75f63ac 5697c10 dad1e49 d546c80 5697c10 9c9913c 5697c10 9357d80 0d01d71 75f63ac 602d806 0d01d71 59741c0 5697c10 fa73ad0 5697c10 75f63ac a315ebf 602d806 5697c10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 |
import os
import base64
import tempfile
from io import BytesIO
from urllib.request import urlretrieve
import gradio as gr
from gradio_pdf import PDF
import torch
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from colpali_engine.models import ColQwen2, ColQwen2Processor
# -----------------------------
# Globals
# -----------------------------
api_key = os.getenv("OPENAI_API_KEY", "") # <- use env var
ds = [] # list of document embeddings (torch tensors)
images = [] # list of PIL images (page-order)
current_pdf_path = None # last (indexed) pdf path for preview
# -----------------------------
# Model & processor
# -----------------------------
device_map = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
model = ColQwen2.from_pretrained(
"vidore/colqwen2-v1.0",
torch_dtype=torch.bfloat16,
device_map=device_map,
attn_implementation="flash_attention_2"
).eval()
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")
# -----------------------------
# Utilities
# -----------------------------
def encode_image_to_base64(image: Image.Image) -> str:
"""Encodes a PIL image to a base64 string."""
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def query_gpt(query: str, retrieved_images: list[tuple[Image.Image, str]]) -> str:
"""Calls OpenAI's GPT model with the query and image data."""
if api_key and api_key.startswith("sk"):
try:
from openai import OpenAI
base64_images = [encode_image_to_base64(im_caption[0]) for im_caption in retrieved_images]
client = OpenAI(api_key=api_key.strip())
PROMPT = """
You are a smart assistant designed to answer questions about a PDF document.
You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc).
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
Give detailed and extensive answers, only containing info in the pages you are given.
You can answer using information contained in plots and figures if necessary.
Answer in the same language as the query.
Query: {query}
PDF pages:
""".strip()
response = client.responses.create(
model="gpt-5-mini",
input=[
{
"role": "user",
"content": (
[{"type": "input_text", "text": PROMPT.format(query=query)}] +
[{"type": "input_image",
"image_url": f"data:image/jpeg;base64,{im}"}
for im in base64_images]
)
}
],
# max_tokens=500,
)
return response.output_text
except Exception as e:
print(e)
return "OpenAI API connection failure. Verify that OPENAI_API_KEY is set and valid (sk-***)."
return "Set OPENAI_API_KEY in your environment to get a custom response."
def _ensure_model_device():
dev = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
if str(model.device) != dev:
model.to(dev)
return dev
# -----------------------------
# Indexing helpers
# -----------------------------
def convert_files(pdf_path: str) -> list[Image.Image]:
"""Convert a single PDF path into a list of PIL Images (pages)."""
imgs = convert_from_path(pdf_path, thread_count=4)
if len(imgs) >= 500:
raise gr.Error("The number of images in the dataset should be less than 500.")
return imgs
def index_gpu(imgs: list[Image.Image]) -> str:
"""Embed a list of images (pages) with ColPali and store in globals."""
global ds, images
device = _ensure_model_device()
# reset previous dataset
ds = []
images = imgs
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: processor.process_images(x).to(model.device),
)
for batch_doc in tqdm(dataloader, desc="Indexing pages"):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return f"Indexed {len(images)} pages successfully."
def index_from_path(pdf_path: str) -> str:
"""Public: index a local PDF file path."""
imgs = convert_files(pdf_path)
return index_gpu(imgs)
def index_from_url(url: str) -> tuple[str, str]:
"""
Download a PDF from URL and index it.
Returns:
status message, saved pdf path
"""
tmp_dir = tempfile.mkdtemp(prefix="colpali_")
local_path = os.path.join(tmp_dir, "document.pdf")
urlretrieve(url, local_path)
status = index_from_path(local_path)
return status, local_path
# -----------------------------
# Search (MCP tool-friendly)
# -----------------------------
def search(query: str, k: int = 5):
"""
Search within a PDF document for the most relevant pages to answer a query and synthetizes a short grounded answer using only those pages.
MCP tool description:
- name: mcp_test_search
- description: Search within a PDF document for the most relevant pages to answer a query and synthetizes a short grounded answer using only those pages.
- input_schema:
type: object
properties:
query: {type: string, description: "User query in natural language."}
k: {type: integer, minimum: 1, maximum: 20, default: 5. description: "Number of top pages to retrieve."}
required: ["query"]
Args:
query (str): Natural-language question to search for.
k (int): Number of top results to return (1โ10).
Returns:
ai_response (str): Text answer to the query grounded in content from the PDF, with citations (page numbers).
"""
global ds, images
if not images or not ds:
return [], "No document indexed yet. Upload a PDF or load the sample, then run Search."
k = max(1, min(int(k), len(images)))
device = _ensure_model_device()
print(query)
# Encode query
qs = []
with torch.no_grad():
batch_query = processor.process_queries([query]).to(model.device)
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# Score and select top-k
scores = processor.score(qs, ds, device=device)
top_k_indices = scores[0].topk(k).indices.tolist()
print(top_k_indices)
# Build gallery results with 1-based page numbering
results = []
for idx in top_k_indices:
page_num = idx + 1
results.append((images[idx], f"Page {page_num}"))
# Generate grounded response
ai_response = query_gpt(query, results)
print(ai_response)
return ai_response
# -----------------------------
# Gradio UI callbacks
# -----------------------------
def handle_upload(file) -> tuple[str, str | None]:
"""Index a user-uploaded PDF file."""
global current_pdf_path
if file is None:
return "Please upload a PDF.", None
path = getattr(file, "name", file)
status = index_from_path(path)
current_pdf_path = path
return status, path
def handle_url(url: str) -> tuple[str, str | None]:
"""Index a PDF from URL (e.g., a sample)."""
global current_pdf_path
if not url or not url.lower().endswith(".pdf"):
return "Please provide a direct PDF URL.", None
status, path = index_from_url(url)
current_pdf_path = path
return status, path
print("Uploading")
print(handle_url("https://www.ipcc.ch/report/ar6/syr/downloads/report/IPCC_AR6_SYR_SPM.pdf"))
# -----------------------------
# Gradio App
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) ๐")
gr.Markdown(
"""Demo to test ColQwen2 (ColPali) on PDF documents.
ColPali is implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449)."""
)
with gr.Row():
# with gr.Column(scale=2):
# gr.Markdown("## 1๏ธโฃ Load a PDF")
# pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
# index_btn = gr.Button("๐ฅ Index Uploaded PDF", variant="secondary")
# url_box = gr.Textbox(
# label="Or index from URL",
# placeholder="https://example.com/file.pdf",
# value="https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf",
# )
# index_url_btn = gr.Button("๐ Load Sample / From URL", variant="secondary")
# status_box = gr.Textbox(label="Status", interactive=False)
# pdf_view = PDF(label="PDF Preview")
with gr.Column(scale=3):
gr.Markdown("## 2๏ธโฃ Search")
query = gr.Textbox(placeholder="Enter your query here", label="Query")
k_slider = gr.Slider(minimum=1, maximum=20, step=1, label="Number of results", value=5)
search_button = gr.Button("๐ Search", variant="primary")
output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
# Wiring
# index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view])
# index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view])
search_button.click(search, inputs=[query, k_slider], outputs=[output_text])
if __name__ == "__main__":
# Optional: pre-load the default sample at startup.
# Comment these two lines if you prefer a "cold" start.
# msg, path = index_from_url("https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf")
# print(msg, "->", path)
demo.queue(max_size=5).launch(debug=True, mcp_server=True)
|