Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import base64 | |
| import gc | |
| from huggingface_hub.utils import HfHubHTTPError | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| import io, base64 | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import pymupdf | |
| from PIL import Image | |
| from pypdf import PdfReader | |
| from dotenv import load_dotenv | |
| from welcome_text import WELCOME_INTRO | |
| from doctr.io import DocumentFile | |
| from doctr.models import ocr_predictor | |
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from chromadb.utils.data_loaders import ImageLoader | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from utils import extract_pdfs, extract_images, clean_text, image_to_bytes | |
| from utils import * | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Load .env | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
| # OCR + multimodal image description setup | |
| ocr_model = ocr_predictor( | |
| "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True | |
| ) | |
| processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
| vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| "llava-hf/llava-v1.6-mistral-7b-hf", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ).to("cpu") | |
| def get_image_description(image: Image.Image) -> str: | |
| """Generate a one-sentence description via LlavaNext.""" | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
| inputs = processor(prompt, image, return_tensors="pt").to("cpu") | |
| output = vision_model.generate(**inputs, max_new_tokens=100) | |
| return processor.decode(output[0], skip_special_tokens=True) | |
| # Vector DB setup | |
| # at top of file, alongside your other imports | |
| from chromadb.utils import embedding_functions | |
| from chromadb.utils.data_loaders import ImageLoader | |
| import chromadb | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from utils import image_to_bytes # your helper | |
| # 1) Create one shared embedding function (defaulting to All-MiniLM-L6-v2, 384-dim) | |
| SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="all-MiniLM-L6-v2" | |
| ) | |
| def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]): | |
| """ | |
| Build an in-memory ChromaDB instance with two collections: | |
| • text_db (chunks of the PDF text) | |
| • image_db (image descriptions + raw image bytes) | |
| Returns the Chroma client for later querying. | |
| """ | |
| # ——— 1) Init & wipe old ———————————————— | |
| client = chromadb.EphemeralClient() | |
| for col in ("text_db", "image_db"): | |
| if col in [c.name for c in client.list_collections()]: | |
| client.delete_collection(col) | |
| # ——— 2) Create fresh collections ————————— | |
| text_col = client.get_or_create_collection( | |
| name="text_db", | |
| embedding_function=SHARED_EMB_FN, | |
| data_loader=ImageLoader(), # loader only matters for images, benign here | |
| ) | |
| img_col = client.get_or_create_collection( | |
| name="image_db", | |
| embedding_function=SHARED_EMB_FN, | |
| metadata={"hnsw:space": "cosine"}, | |
| data_loader=ImageLoader(), | |
| ) | |
| # ——— 3) Add images if any ——————————————— | |
| if images: | |
| descs = [] | |
| metas = [] | |
| for idx, img in enumerate(images): | |
| # build one-line caption (or fallback) | |
| try: | |
| caption = get_image_description(img) | |
| except Exception: | |
| caption = "⚠️ could not describe image" | |
| descs.append(f"{img_names[idx]}: {caption}") | |
| metas.append({"image": image_to_bytes(img)}) | |
| img_col.add( | |
| ids=[str(i) for i in range(len(images))], | |
| documents=descs, | |
| metadatas=metas, | |
| ) | |
| # ——— 4) Chunk & add text ——————————————— | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
| docs = splitter.create_documents([text]) | |
| text_col.add( | |
| ids=[str(i) for i in range(len(docs))], | |
| documents=[d.page_content for d in docs], | |
| ) | |
| return client | |
| # Text extraction | |
| def result_to_text(result, as_text=False): | |
| pages = [] | |
| for pg in result.pages: | |
| txt = " ".join(w.value for block in pg.blocks for line in block.lines for w in line.words) | |
| pages.append(clean_text(txt)) | |
| return "\n\n".join(pages) if as_text else pages | |
| OCR_CHOICES = { | |
| "db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"), | |
| "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"), | |
| } | |
| def extract_data_from_pdfs( | |
| docs, | |
| session, | |
| include_images, # "Include Images" or "Exclude Images" | |
| do_ocr, # "Get Text With OCR" or "Get Available Text Only" | |
| ocr_choice, # key into OCR_CHOICES | |
| vlm_choice, # HF repo ID for LlavaNext | |
| progress=gr.Progress() | |
| ): | |
| """ | |
| 1) Dynamically instantiate the chosen OCR pipeline (if any) | |
| 2) Dynamically instantiate the chosen vision‐language model | |
| 3) Override the global get_image_description to use that model for captions | |
| 4) Extract text & images, index into ChromaDB | |
| """ | |
| if not docs: | |
| raise gr.Error("No documents to process") | |
| # ——— 1) Set up OCR if requested ———————————————— | |
| if do_ocr == "Get Text With OCR": | |
| db_m, crnn_m = OCR_CHOICES[ocr_choice] | |
| local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True) | |
| else: | |
| local_ocr = None | |
| # ——— 2) Set up vision‐language model ————————————— | |
| proc = LlavaNextProcessor.from_pretrained(vlm_choice) | |
| vis = LlavaNextForConditionalGeneration.from_pretrained( | |
| vlm_choice, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ).to("cpu") | |
| # ——— 3) Monkey‐patch global get_image_description ———— | |
| def describe(img: Image.Image) -> str: | |
| torch.cuda.empty_cache(); gc.collect() | |
| prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
| inputs = proc(prompt, img, return_tensors="pt").to("cpu") | |
| output = vis.generate(**inputs, max_new_tokens=100) | |
| return proc.decode(output[0], skip_special_tokens=True) | |
| global get_image_description | |
| get_image_description = describe | |
| # ——— 4) Extract text & images ———————————————— | |
| progress(0.2, "Extracting text and images…") | |
| all_text, images, names = "", [], [] | |
| for path in docs: | |
| if local_ocr: | |
| pdf = DocumentFile.from_pdf(path) | |
| res = local_ocr(pdf) | |
| all_text += result_to_text(res, as_text=True) + "\n\n" | |
| else: | |
| txt = PdfReader(path).pages[0].extract_text() or "" | |
| all_text += "\n\n" + txt + "\n\n" | |
| if include_images == "Include Images": | |
| imgs = extract_images([path]) | |
| images.extend(imgs) | |
| names.extend([os.path.basename(path)] * len(imgs)) | |
| # ——— 5) Index into vector DB ———————————————— | |
| progress(0.6, "Indexing in vector DB…") | |
| vdb = get_vectordb(all_text, images, names) | |
| session["processed"] = True | |
| sample_imgs = images[:4] if include_images == "Include Images" else [] | |
| return ( | |
| vdb, | |
| session, | |
| gr.Row(visible=True), | |
| all_text[:2000] + "...", | |
| sample_imgs, | |
| "<h3>Done!</h3>" | |
| ) | |
| # Chat function | |
| def conversation( | |
| vdb, question: str, num_ctx, img_ctx, | |
| history: list, temp: float, max_tok: int, model_id: str | |
| ): | |
| # 0) Cast the context sliders to ints | |
| num_ctx = int(num_ctx) | |
| img_ctx = int(img_ctx) | |
| # 1) Guard: must have extracted first | |
| if vdb is None: | |
| raise gr.Error("Please extract data first") | |
| # 2) Instantiate the chosen HF endpoint | |
| llm = HuggingFaceEndpoint( | |
| repo_id=model_id, | |
| temperature=temp, | |
| max_new_tokens=max_tok, | |
| huggingfacehub_api_token=HF_TOKEN | |
| ) | |
| # 3) Query text collection | |
| text_col = vdb.get_collection("text_db") | |
| docs = text_col.query( | |
| query_texts=[question], | |
| n_results=num_ctx, # now an int | |
| include=["documents"] | |
| )["documents"][0] | |
| # 4) Query image collection | |
| img_col = vdb.get_collection("image_db") | |
| img_q = img_col.query( | |
| query_texts=[question], | |
| n_results=img_ctx, # now an int | |
| include=["metadatas", "documents"] | |
| ) | |
| # … rest unchanged … | |
| images, img_descs = [], img_q["documents"][0] or ["No images found"] | |
| for meta in img_q["metadatas"][0]: | |
| b64 = meta.get("image", "") | |
| try: | |
| images.append(Image.open(io.BytesIO(base64.b64decode(b64)))) | |
| except: | |
| pass | |
| img_desc = "\n".join(img_descs) | |
| # 5) Build prompt | |
| prompt = PromptTemplate( | |
| template=""" | |
| Context: | |
| {text} | |
| Included Images: | |
| {img_desc} | |
| Question: | |
| {q} | |
| Answer: | |
| """, | |
| input_variables=["text", "img_desc", "q"], | |
| ) | |
| context = "\n\n".join(docs) | |
| user_input = prompt.format(text=context, img_desc=img_desc, q=question) | |
| # 6) Call the model with error handling | |
| try: | |
| answer = llm.invoke(user_input) | |
| except HfHubHTTPError as e: | |
| if e.response.status_code == 404: | |
| answer = f"❌ Model `{model_id}` not hosted on HF Inference API." | |
| else: | |
| answer = f"⚠️ HF API error: {e}" | |
| except Exception as e: | |
| answer = f"⚠️ Unexpected error: {e}" | |
| # 7) Append to history | |
| new_history = history + [ | |
| {"role":"user", "content": question}, | |
| {"role":"assistant","content": answer} | |
| ] | |
| # 8) Return updated history, docs, images | |
| return new_history, docs, images | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Gradio UI | |
| CSS = """ | |
| footer {visibility:hidden;} | |
| """ | |
| MODEL_OPTIONS = [ | |
| "HuggingFaceH4/zephyr-7b-beta", | |
| "mistralai/Mistral-7B-Instruct-v0.2", | |
| "openchat/openchat-3.5-0106", | |
| "google/gemma-7b-it", | |
| "deepseek-ai/deepseek-llm-7b-chat", | |
| "microsoft/Phi-3-mini-4k-instruct", | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "Qwen/Qwen1.5-7B-Chat", | |
| "tiiuae/falcon-7b-instruct", # Falcon 7B Instruct | |
| "bigscience/bloomz-7b1", # BLOOMZ 7B | |
| "facebook/opt-2.7b", | |
| ] | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| vdb_state = gr.State() | |
| session_state = gr.State({}) | |
| # ─── Welcome Screen ───────────────────────────────────────────── | |
| with gr.Column(visible=True) as welcome_col: | |
| gr.Markdown( | |
| f"<div style='text-align: center'>\n{WELCOME_INTRO}\n</div>", | |
| elem_id="welcome_md" | |
| ) | |
| start_btn = gr.Button("🚀 Start") | |
| # ─── Main App (hidden until Start is clicked) ─────────────────── | |
| with gr.Column(visible=False) as app_col: | |
| gr.Markdown("## 📚 Multimodal Chat-PDF Playground") | |
| with gr.Tabs(): | |
| # Tab 1: Upload & Extract | |
| with gr.TabItem("1. Upload & Extract"): | |
| docs = gr.File( | |
| file_count="multiple", | |
| file_types=[".pdf"], | |
| label="Upload PDFs" | |
| ) | |
| include_dd = gr.Radio( | |
| ["Include Images", "Exclude Images"], | |
| value="Exclude Images", | |
| label="Images" | |
| ) | |
| ocr_dd = gr.Dropdown( | |
| choices=[ | |
| "db_resnet50 + crnn_mobilenet_v3_large", | |
| "db_resnet50 + crnn_resnet31" | |
| ], | |
| value="db_resnet50 + crnn_mobilenet_v3_large", | |
| label="OCR Model" | |
| ) | |
| vlm_dd = gr.Dropdown( | |
| choices=[ | |
| "llava-hf/llava-v1.6-mistral-7b-hf", | |
| "llava-hf/llava-v1.5-mistral-7b" | |
| ], | |
| value="llava-hf/llava-v1.6-mistral-7b-hf", | |
| label="Vision-Language Model" | |
| ) | |
| extract_btn = gr.Button("Extract") | |
| preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False) | |
| preview_img = gr.Gallery(label="Sample Images", rows=2, value=[]) | |
| extract_btn.click( | |
| extract_data_from_pdfs, | |
| inputs=[ | |
| docs, | |
| session_state, | |
| include_dd, | |
| gr.Radio( | |
| ["Get Text With OCR", "Get Available Text Only"], | |
| value="Get Available Text Only", | |
| label="OCR" | |
| ), | |
| ocr_dd, | |
| vlm_dd | |
| ], | |
| outputs=[ | |
| vdb_state, | |
| session_state, | |
| gr.Row(visible=False), | |
| preview_text, | |
| preview_img, | |
| gr.HTML() | |
| ] | |
| ) | |
| # Tab 2: Chat | |
| with gr.TabItem("2. Chat"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chat = gr.Chatbot(type="messages", label="Chat") | |
| msg = gr.Textbox( | |
| placeholder="Ask about your PDF...", | |
| label="Your question" | |
| ) | |
| send = gr.Button("Send") | |
| with gr.Column(scale=1): | |
| model_dd = gr.Dropdown( | |
| MODEL_OPTIONS, | |
| value=MODEL_OPTIONS[0], | |
| label="Choose Chat Model" | |
| ) | |
| num_ctx = gr.Slider(1,20,value=3,label="Text Contexts") | |
| img_ctx = gr.Slider(1,10,value=2,label="Image Contexts") | |
| temp = gr.Slider(0.1,1.0,step=0.1,value=0.4,label="Temperature") | |
| max_tok = gr.Slider(10,1000,step=10,value=200,label="Max Tokens") | |
| send.click( | |
| conversation, | |
| inputs=[ | |
| vdb_state, | |
| msg, | |
| num_ctx, | |
| img_ctx, | |
| chat, | |
| temp, | |
| max_tok, | |
| model_dd | |
| ], | |
| outputs=[ | |
| chat, | |
| gr.Dataframe(), | |
| gr.Gallery(label="Relevant Images", rows=2, value=[]) | |
| ] | |
| ) | |
| # Footer inside app_col | |
| gr.HTML("<center>Made with ❤️ by Zamal</center>") | |
| # ─── Wire the Start button ─────────────────────────────────────── | |
| start_btn.click( | |
| fn=lambda: (gr.update(visible=False), gr.update(visible=True)), | |
| inputs=[], outputs=[welcome_col, app_col] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |