File size: 10,400 Bytes
94b55f0
a246b15
5697c10
a246b15
5697c10
94b55f0
602d806
527f685
602d806
89cecf3
3649694
602d806
 
 
5dfd724
9b1e831
9c66171
5697c10
 
 
 
 
 
 
b9c715a
5697c10
 
 
 
9b1e831
 
5697c10
 
 
 
 
e7f8afe
9b1e831
9c66171
5697c10
 
 
 
6efb913
 
 
 
 
73f30e5
5697c10
 
73f30e5
a129662
 
5697c10
 
a129662
 
5697c10
 
 
 
 
 
 
 
 
 
 
cd6d6e9
d76e71e
511619b
a129662
5697c10
 
cd6d6e9
 
f24b36e
5697c10
 
 
 
8777d11
a129662
8e6343b
40ef153
 
5697c10
 
0d01d71
6efb913
5697c10
 
 
 
 
ec28a2a
 
5697c10
 
 
 
 
 
 
f1d7f41
5697c10
 
0d01d71
5697c10
 
 
 
ec28a2a
5697c10
 
 
9b1e831
602d806
 
068f2e8
602d806
9b1e831
602d806
a2d6d06
5697c10
602d806
9c66171
602d806
 
5697c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603af32
5697c10
f24b36e
5697c10
 
f24b36e
 
5697c10
 
 
 
603af32
5697c10
 
 
 
 
 
 
f24b36e
5697c10
 
 
 
 
 
 
 
 
4fd529e
 
5697c10
 
 
 
 
 
602d806
5697c10
 
 
0d01d71
4fd529e
 
5697c10
 
 
 
 
602d806
5697c10
 
4fd529e
a315ebf
5697c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75f63ac
59741c0
75f63ac
5697c10
 
 
dad1e49
d546c80
5697c10
 
9c9913c
5697c10
9357d80
0d01d71
75f63ac
 
 
 
 
 
 
 
 
 
 
 
602d806
0d01d71
 
 
59741c0
5697c10
 
fa73ad0
5697c10
75f63ac
 
a315ebf
602d806
 
5697c10
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import os
import base64
import tempfile
from io import BytesIO
from urllib.request import urlretrieve

import gradio as gr
from gradio_pdf import PDF
import torch

from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm

from colpali_engine.models import ColQwen2, ColQwen2Processor

# -----------------------------
# Globals
# -----------------------------
api_key = os.getenv("OPENAI_API_KEY", "")  # <- use env var
ds = []          # list of document embeddings (torch tensors)
images = []      # list of PIL images (page-order)
current_pdf_path = None  # last (indexed) pdf path for preview

# -----------------------------
# Model & processor
# -----------------------------
device_map = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")

model = ColQwen2.from_pretrained(
    "vidore/colqwen2-v1.0",
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    attn_implementation="flash_attention_2"
).eval()
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")


# -----------------------------
# Utilities
# -----------------------------
def encode_image_to_base64(image: Image.Image) -> str:
    """Encodes a PIL image to a base64 string."""
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def query_gpt(query: str, retrieved_images: list[tuple[Image.Image, str]]) -> str:
    """Calls OpenAI's GPT model with the query and image data."""
    if api_key and api_key.startswith("sk"):
        try:
            from openai import OpenAI

            base64_images = [encode_image_to_base64(im_caption[0]) for im_caption in retrieved_images]
            client = OpenAI(api_key=api_key.strip())
            PROMPT = """
You are a smart assistant designed to answer questions about a PDF document.
You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc).
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
Give detailed and extensive answers, only containing info in the pages you are given.
You can answer using information contained in plots and figures if necessary.
Answer in the same language as the query.

Query: {query}
PDF pages:
""".strip()

            response = client.responses.create(
                model="gpt-5-mini",
                input=[
                    {
                        "role": "user",
                        "content": (
                            [{"type": "input_text", "text": PROMPT.format(query=query)}] +
                            [{"type": "input_image",
                              "image_url": f"data:image/jpeg;base64,{im}"}
                             for im in base64_images]
                        )
                    }
                ],
                # max_tokens=500,
            )
            return response.output_text
        except Exception as e:
            print(e)
            return "OpenAI API connection failure. Verify that OPENAI_API_KEY is set and valid (sk-***)."
    return "Set OPENAI_API_KEY in your environment to get a custom response."


def _ensure_model_device():
    dev = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
    if str(model.device) != dev:
        model.to(dev)
    return dev


# -----------------------------
# Indexing helpers
# -----------------------------
def convert_files(pdf_path: str) -> list[Image.Image]:
    """Convert a single PDF path into a list of PIL Images (pages)."""
    imgs = convert_from_path(pdf_path, thread_count=4)
    if len(imgs) >= 500:
        raise gr.Error("The number of images in the dataset should be less than 500.")
    return imgs


def index_gpu(imgs: list[Image.Image]) -> str:
    """Embed a list of images (pages) with ColPali and store in globals."""
    global ds, images
    device = _ensure_model_device()

    # reset previous dataset
    ds = []
    images = imgs

    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: processor.process_images(x).to(model.device),
    )

    for batch_doc in tqdm(dataloader, desc="Indexing pages"):
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    return f"Indexed {len(images)} pages successfully."


def index_from_path(pdf_path: str) -> str:
    """Public: index a local PDF file path."""
    imgs = convert_files(pdf_path)
    return index_gpu(imgs)


def index_from_url(url: str) -> tuple[str, str]:
    """
    Download a PDF from URL and index it.

    Returns:
        status message, saved pdf path
    """
    tmp_dir = tempfile.mkdtemp(prefix="colpali_")
    local_path = os.path.join(tmp_dir, "document.pdf")
    urlretrieve(url, local_path)
    status = index_from_path(local_path)
    return status, local_path


# -----------------------------
# Search (MCP tool-friendly)
# -----------------------------
def search(query: str, k: int = 5):
    """
    Search within a PDF document for the most relevant pages to answer a query and synthetizes a short grounded answer using only those pages.

    MCP tool description:
      - name: mcp_test_search
      - description: Search within a PDF document for the most relevant pages to answer a query and synthetizes a short grounded answer using only those pages.
      - input_schema:
          type: object
          properties:
            query: {type: string, description: "User query in natural language."}
            k: {type: integer, minimum: 1, maximum: 20, default: 5. description: "Number of top pages to retrieve."}
          required: ["query"]

    Args:
        query (str): Natural-language question to search for.
        k (int): Number of top results to return (1โ€“10).

    Returns:
        ai_response (str): Text answer to the query grounded in content from the PDF, with citations (page numbers).
    """
    global ds, images

    if not images or not ds:
        return [], "No document indexed yet. Upload a PDF or load the sample, then run Search."

    k = max(1, min(int(k), len(images)))
    device = _ensure_model_device()

    print(query)

    # Encode query
    qs = []
    with torch.no_grad():
        batch_query = processor.process_queries([query]).to(model.device)
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    # Score and select top-k
    scores = processor.score(qs, ds, device=device)
    top_k_indices = scores[0].topk(k).indices.tolist()

    print(top_k_indices)

    # Build gallery results with 1-based page numbering
    results = []
    for idx in top_k_indices:
        page_num = idx + 1
        results.append((images[idx], f"Page {page_num}"))

    # Generate grounded response
    ai_response = query_gpt(query, results)
    print(ai_response)
    return ai_response


# -----------------------------
# Gradio UI callbacks
# -----------------------------
def handle_upload(file) -> tuple[str, str | None]:
    """Index a user-uploaded PDF file."""
    global current_pdf_path
    if file is None:
        return "Please upload a PDF.", None
    path = getattr(file, "name", file)
    status = index_from_path(path)
    current_pdf_path = path
    return status, path


def handle_url(url: str) -> tuple[str, str | None]:
    """Index a PDF from URL (e.g., a sample)."""
    global current_pdf_path
    if not url or not url.lower().endswith(".pdf"):
        return "Please provide a direct PDF URL.", None
    status, path = index_from_url(url)
    current_pdf_path = path
    return status, path


print("Uploading")
print(handle_url("https://www.ipcc.ch/report/ar6/syr/downloads/report/IPCC_AR6_SYR_SPM.pdf"))

# -----------------------------
# Gradio App
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) ๐Ÿ“š")
    gr.Markdown(
        """Demo to test ColQwen2 (ColPali) on PDF documents.  
ColPali is implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449)."""
    )

    with gr.Row():
        # with gr.Column(scale=2):
        #     gr.Markdown("## 1๏ธโƒฃ Load a PDF")
        #     pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
        #     index_btn = gr.Button("๐Ÿ“ฅ Index Uploaded PDF", variant="secondary")
        #     url_box = gr.Textbox(
        #         label="Or index from URL",
        #         placeholder="https://example.com/file.pdf",
        #         value="https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf",
        #     )
        #     index_url_btn = gr.Button("๐ŸŒ Load Sample / From URL", variant="secondary")
        #     status_box = gr.Textbox(label="Status", interactive=False)
        #     pdf_view = PDF(label="PDF Preview")

        with gr.Column(scale=3):
            gr.Markdown("## 2๏ธโƒฃ Search")
            query = gr.Textbox(placeholder="Enter your query here", label="Query")
            k_slider = gr.Slider(minimum=1, maximum=20, step=1, label="Number of results", value=5)
            search_button = gr.Button("๐Ÿ” Search", variant="primary")
            output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")

    # Wiring
    # index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view])
    # index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view])
    search_button.click(search, inputs=[query, k_slider], outputs=[output_text])

if __name__ == "__main__":
    # Optional: pre-load the default sample at startup.
    # Comment these two lines if you prefer a "cold" start.
    # msg, path = index_from_url("https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf")
    # print(msg, "->", path)

    demo.queue(max_size=5).launch(debug=True, mcp_server=True)