Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import io | |
| import base64 | |
| import gc | |
| from huggingface_hub.utils import HfHubHTTPError | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| import io, base64 | |
| from PIL import Image | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| import numpy as np | |
| import pandas as pd | |
| import pymupdf | |
| from PIL import Image | |
| from pypdf import PdfReader | |
| from dotenv import load_dotenv | |
| import shutil | |
| from chromadb.config import Settings, DEFAULT_TENANT, DEFAULT_DATABASE | |
| from welcome_text import WELCOME_INTRO | |
| from doctr.io import DocumentFile | |
| from doctr.models import ocr_predictor | |
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from chromadb.utils.data_loaders import ImageLoader | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from utils import extract_pdfs, extract_images, clean_text, image_to_bytes | |
| from utils import * | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Load .env | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
| processor = None | |
| vision_model = None | |
| # OCR + multimodal image description setup | |
| ocr_model = ocr_predictor( | |
| "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True | |
| ) | |
| processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
| vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| "llava-hf/llava-v1.6-mistral-7b-hf", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ).to("cuda") | |
| # Add at the top of your module, alongside your other globals | |
| PERSIST_DIR = "./chroma_db" | |
| if os.path.exists(PERSIST_DIR): | |
| shutil.rmtree(PERSIST_DIR) | |
| def get_image_description(image: Image.Image) -> str: | |
| """ | |
| Lazy-loads the Llava processor + model inside the GPU worker, | |
| runs captioning, and returns a one-sentence description. | |
| """ | |
| global processor, vision_model | |
| # On first call, instantiate + move to CUDA | |
| if processor is None or vision_model is None: | |
| processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
| vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| "llava-hf/llava-v1.6-mistral-7b-hf", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ).to("cuda") | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
| inputs = processor(prompt, image, return_tensors="pt").to("cuda") | |
| output = vision_model.generate(**inputs, max_new_tokens=100) | |
| return processor.decode(output[0], skip_special_tokens=True) | |
| # Vector DB setup | |
| # at top of file, alongside your other imports | |
| from chromadb.utils import embedding_functions | |
| from chromadb.utils.data_loaders import ImageLoader | |
| import chromadb | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from utils import image_to_bytes # your helper | |
| # 1) Create one shared embedding function (defaulting to All-MiniLM-L6-v2, 384-dim) | |
| SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="all-MiniLM-L6-v2" | |
| ) | |
| def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]): | |
| """ | |
| Build a *persistent* ChromaDB instance on disk, with two collections: | |
| • text_db (chunks of the PDF text) | |
| • image_db (image descriptions + raw image bytes) | |
| """ | |
| # 1) Make or clean the on-disk folder | |
| shutil.rmtree(PERSIST_DIR, ignore_errors=True) | |
| os.makedirs(PERSIST_DIR, exist_ok=True) | |
| client = chromadb.PersistentClient( | |
| path=PERSIST_DIR, | |
| settings=Settings(), | |
| tenant=DEFAULT_TENANT, | |
| database=DEFAULT_DATABASE | |
| ) | |
| # 3) Create / wipe collections | |
| for col in ("text_db", "image_db"): | |
| if col in [c.name for c in client.list_collections()]: | |
| client.delete_collection(col) | |
| text_col = client.get_or_create_collection( | |
| name="text_db", | |
| embedding_function=SHARED_EMB_FN | |
| ) | |
| img_col = client.get_or_create_collection( | |
| name="image_db", | |
| embedding_function=SHARED_EMB_FN, | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| # 4) Add images | |
| if images: | |
| descs, metas = [], [] | |
| for idx, img in enumerate(images): | |
| try: | |
| cap = get_image_description(img) | |
| except: | |
| cap = "⚠️ could not describe image" | |
| descs.append(f"{img_names[idx]}: {cap}") | |
| metas.append({"image": image_to_bytes(img)}) | |
| img_col.add(ids=[str(i) for i in range(len(images))], | |
| documents=descs, | |
| metadatas=metas) | |
| # 5) Chunk & add text | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
| docs = splitter.create_documents([text]) | |
| text_col.add(ids=[str(i) for i in range(len(docs))], | |
| documents=[d.page_content for d in docs]) | |
| return client | |
| # Text extraction | |
| def result_to_text(result, as_text=False): | |
| pages = [] | |
| for pg in result.pages: | |
| txt = " ".join(w.value for block in pg.blocks for line in block.lines for w in line.words) | |
| pages.append(clean_text(txt)) | |
| return "\n\n".join(pages) if as_text else pages | |
| OCR_CHOICES = { | |
| "db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"), | |
| "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"), | |
| } | |
| def extract_data_from_pdfs( | |
| docs: list[str], | |
| session: dict, | |
| include_images: str, | |
| do_ocr: str, | |
| ocr_choice: str, | |
| vlm_choice: str, | |
| progress=gr.Progress() | |
| ): | |
| if not docs: | |
| raise gr.Error("No documents to process") | |
| # 1) OCR pipeline if requested | |
| if do_ocr == "Get Text With OCR": | |
| db_m, crnn_m = OCR_CHOICES[ocr_choice] | |
| local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True) | |
| else: | |
| local_ocr = None | |
| # 2) Vision–language model | |
| proc = LlavaNextProcessor.from_pretrained(vlm_choice) | |
| vis = (LlavaNextForConditionalGeneration | |
| .from_pretrained(vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True) | |
| .to("cuda")) | |
| # 3) Monkey-patch caption fn | |
| def describe(img): | |
| torch.cuda.empty_cache(); gc.collect() | |
| prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
| inp = proc(prompt, img, return_tensors="pt").to("cuda") | |
| out = vis.generate(**inp, max_new_tokens=100) | |
| return proc.decode(out[0], skip_special_tokens=True) | |
| global get_image_description | |
| get_image_description = describe | |
| # 4) Extract text & images | |
| progress(0.2, "Extracting text and images…") | |
| all_text = "" | |
| images, names = [], [] | |
| for path in docs: | |
| if local_ocr: | |
| pdf = DocumentFile.from_pdf(path) | |
| res = local_ocr(pdf) | |
| all_text += result_to_text(res, as_text=True) + "\n\n" | |
| else: | |
| all_text += (PdfReader(path).pages[0].extract_text() or "") + "\n\n" | |
| if include_images == "Include Images": | |
| imgs = extract_images([path]) | |
| images.extend(imgs) | |
| names.extend([os.path.basename(path)] * len(imgs)) | |
| # 5) Build + persist the vectordb | |
| progress(0.6, "Indexing in vector DB…") | |
| client = get_vectordb(all_text, images, names) | |
| # 6) Mark session and return UI outputs | |
| session["processed"] = True | |
| session["persist_directory"] = PERSIST_DIR | |
| sample_imgs = images[:4] if include_images == "Include Images" else [] | |
| return ( | |
| session, # gr.State | |
| all_text[:2000] + "...", | |
| sample_imgs, | |
| "<h3>Done!</h3>" | |
| ) | |
| # Chat function | |
| def conversation( | |
| session: dict, | |
| question: str, | |
| num_ctx: int, | |
| img_ctx: int, | |
| history: list, | |
| temp: float, | |
| max_tok: int, | |
| model_id: str | |
| ): | |
| pd = session.get("persist_directory") | |
| if not session.get("processed") or not pd: | |
| raise gr.Error("Please extract data first") | |
| # 1) Reopen the same persistent client (new API) | |
| client = chromadb.PersistentClient( | |
| path=pd, | |
| settings=Settings(), | |
| tenant=DEFAULT_TENANT, | |
| database=DEFAULT_DATABASE | |
| ) | |
| # 2) Text retrieval | |
| text_col = client.get_collection("text_db") | |
| docs = text_col.query(query_texts=[question], | |
| n_results=int(num_ctx), | |
| include=["documents"])["documents"][0] | |
| # 3) Image retrieval | |
| img_col = client.get_collection("image_db") | |
| img_q = img_col.query(query_texts=[question], | |
| n_results=int(img_ctx), | |
| include=["metadatas","documents"]) | |
| img_descs = img_q["documents"][0] or ["No images found"] | |
| images = [] | |
| for meta in img_q["metadatas"][0]: | |
| b64 = meta.get("image","") | |
| try: | |
| images.append(Image.open(io.BytesIO(base64.b64decode(b64)))) | |
| except: | |
| pass | |
| img_desc = "\n".join(img_descs) | |
| # 4) Build prompt & call LLM | |
| llm = HuggingFaceEndpoint( | |
| repo_id=model_id, | |
| task="text-generation", | |
| temperature=temp, | |
| max_new_tokens=max_tok, | |
| huggingfacehub_api_token=HF_TOKEN | |
| ) | |
| prompt = PromptTemplate( | |
| template=""" | |
| Context: | |
| {text} | |
| Included Images: | |
| {img_desc} | |
| Question: | |
| {q} | |
| Answer: | |
| """, input_variables=["text","img_desc","q"] | |
| ) | |
| inp = prompt.format(text="\n\n".join(docs), img_desc=img_desc, q=question) | |
| try: | |
| answer = llm.invoke(inp) | |
| except HfHubHTTPError as e: | |
| answer = "❌ Model not hosted" if e.response.status_code==404 else f"⚠️ HF error: {e}" | |
| except Exception as e: | |
| answer = f"⚠️ Unexpected error: {e}" | |
| new_history = history + [ | |
| {"role":"user", "content":question}, | |
| {"role":"assistant","content":answer} | |
| ] | |
| return new_history, docs, images | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Gradio UI | |
| CSS = """ | |
| footer {visibility:hidden;} | |
| """ | |
| MODEL_OPTIONS = [ | |
| "HuggingFaceH4/zephyr-7b-beta", | |
| "mistralai/Mistral-7B-Instruct-v0.2", | |
| "openchat/openchat-3.5-0106", | |
| "google/gemma-7b-it", | |
| "deepseek-ai/deepseek-llm-7b-chat", | |
| "microsoft/Phi-3-mini-4k-instruct", | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "Qwen/Qwen1.5-7B-Chat", | |
| "tiiuae/falcon-7b-instruct", # Falcon 7B Instruct | |
| "bigscience/bloomz-7b1", # BLOOMZ 7B | |
| "facebook/opt-2.7b", | |
| ] | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| session_state = gr.State({}) | |
| with gr.Column(visible=True) as welcome_col: | |
| gr.Markdown(f"<div style='text-align:center'>{WELCOME_INTRO}</div>") | |
| start_btn = gr.Button("🚀 Start") | |
| with gr.Column(visible=False) as app_col: | |
| gr.Markdown("## 📚 Multimodal Chat-PDF Playground") | |
| extract_event = None | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("1. Upload & Extract"): | |
| docs = gr.File(file_count="multiple", file_types=[".pdf"], label="Upload PDFs") | |
| include_dd = gr.Radio(["Include Images","Exclude Images"],"Exclude Images","Images") | |
| ocr_radio = gr.Radio(["Get Text With OCR","Get Available Text Only"],"Get Available Text Only","OCR") | |
| ocr_dd = gr.Dropdown(list(OCR_CHOICES.keys()), list(OCR_CHOICES.keys())[0], "OCR Model") | |
| vlm_dd = gr.Dropdown(["llava-hf/llava-v1.6-mistral-7b-hf","llava-hf/llava-v1.5-mistral-7b"], "llava-hf/llava-v1.6-mistral-7b-hf", "Vision-Language Model") | |
| extract_btn = gr.Button("Extract") | |
| preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False) | |
| preview_img = gr.Gallery(label="Sample Images", rows=2, value=[]) | |
| preview_html = gr.HTML() | |
| extract_event = extract_btn.click( | |
| fn=extract_data_from_pdfs, | |
| inputs=[docs, session_state, include_dd, ocr_radio, ocr_dd, vlm_dd], | |
| outputs=[session_state, preview_text, preview_img, preview_html] | |
| ) | |
| with gr.TabItem("2. Chat", visible=False) as chat_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chat = gr.Chatbot(type="messages", label="Chat") | |
| msg = gr.Textbox(placeholder="Ask about your PDF...", label="Your question") | |
| send = gr.Button("Send") | |
| with gr.Column(scale=1): | |
| model_dd = gr.Dropdown(MODEL_OPTIONS, MODEL_OPTIONS[0], "Choose Chat Model") | |
| num_ctx = gr.Slider(1,20, value=3, label="Text Contexts") | |
| img_ctx = gr.Slider(1,10, value=2, label="Image Contexts") | |
| temp = gr.Slider(0.1,1.0, step=0.1, value=0.4, label="Temperature") | |
| max_tok = gr.Slider(10,1000, step=10, value=200, label="Max Tokens") | |
| send.click( | |
| fn=conversation, | |
| inputs=[session_state, msg, num_ctx, img_ctx, chat, temp, max_tok, model_dd], | |
| outputs=[chat, gr.Dataframe(), gr.Gallery(label="Relevant Images", rows=2, value=[])] | |
| ) | |
| # Unhide the Chat tab once extraction completes | |
| extract_event.then( | |
| fn=lambda: gr.update(visible=True), | |
| inputs=[], | |
| outputs=[chat_tab] | |
| ) | |
| gr.HTML("<center>Made with ❤️ by Zamal</center>") | |
| start_btn.click( | |
| fn=lambda: (gr.update(visible=False), gr.update(visible=True)), | |
| outputs=[welcome_col, app_col] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |