Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import torch | |
| from transformers import AutoProcessor, AutoModel | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain_community.llms import HuggingFacePipeline | |
| from PIL import Image | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| class MultimodalRAG: | |
| def __init__(self, pdf_path=None): | |
| self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32") | |
| self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| self.pdf_path = pdf_path | |
| self.documents = [] | |
| self.vector_store = None | |
| self.retriever = None | |
| self.qa_chain = None | |
| try: | |
| self.llm = HuggingFacePipeline.from_model_id( | |
| model_id="google/flan-t5-large", | |
| task="text2text-generation", | |
| model_kwargs={"temperature": 0.7, "max_length": 512} | |
| ) | |
| except Exception as e: | |
| print(f"Error loading flan-t5 model: {e}") | |
| from langchain.llms import OpenAI | |
| self.llm = OpenAI(temperature=0.7) | |
| if pdf_path and os.path.exists(pdf_path): | |
| self.load_pdf(pdf_path) | |
| def load_pdf(self, pdf_path): | |
| if not os.path.exists(pdf_path): | |
| raise FileNotFoundError(f"PDF file not found: {pdf_path}") | |
| loader = PyPDFLoader(pdf_path) | |
| self.documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200 | |
| ) | |
| self.documents = text_splitter.split_documents(self.documents) | |
| self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings) | |
| self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2}) | |
| self.qa_chain = RetrievalQA.from_chain_type( | |
| llm=self.llm, | |
| chain_type="stuff", | |
| retriever=self.retriever, | |
| return_source_documents=True | |
| ) | |
| return f"Successfully loaded and processed PDF: {pdf_path}" | |
| def process_image(self, image_path): | |
| if not os.path.exists(image_path): | |
| print(f"Warning: Image path {image_path} does not exist") | |
| return None | |
| image = Image.open(image_path) | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| image_features = self.vision_model.get_image_features(**inputs) | |
| return image_features | |
| def generate_image_description(self, image_features): | |
| return "an image" | |
| def retrieve_related_documents(self, query_text, image_path=None): | |
| if image_path: | |
| image_features = self.process_image(image_path) | |
| if image_features is not None: | |
| image_query = self.generate_image_description(image_features) | |
| enhanced_query = f"{query_text} {image_query}" | |
| else: | |
| enhanced_query = query_text | |
| else: | |
| enhanced_query = query_text | |
| docs = self.retriever.get_relevant_documents(enhanced_query) | |
| return docs | |
| def answer_query(self, query_text, image_path=None): | |
| if not self.vector_store or not self.qa_chain: | |
| return "Please upload a PDF document first.", [] | |
| if image_path: | |
| docs = self.retrieve_related_documents(query_text, image_path) | |
| else: | |
| docs = self.retrieve_related_documents(query_text) | |
| result = self.qa_chain({"query": query_text}) | |
| answer = result["result"] | |
| sources = [doc.page_content[:1000] + "..." for doc in result["source_documents"]] | |
| return answer, sources | |
| rag_system = MultimodalRAG() | |
| def upload_pdf(pdf_file): | |
| if pdf_file is None: | |
| return "No file uploaded" | |
| file_path = pdf_file.name | |
| try: | |
| result = rag_system.load_pdf(file_path) | |
| return result | |
| except Exception as e: | |
| return f"Error processing PDF: {str(e)}" | |
| def save_image(image): | |
| if image is None: | |
| return None | |
| temp_path = "temp_image.jpg" | |
| image.save(temp_path) | |
| return temp_path | |
| def process_query(query, pdf_file, image=None): | |
| if not query.strip(): | |
| return "Please enter a question", [] | |
| if pdf_file is None: | |
| return "Please upload a PDF document first", [] | |
| image_path = None | |
| if image is not None: | |
| image_path = save_image(image) | |
| try: | |
| answer, sources = rag_system.answer_query(query, image_path) | |
| if image_path and os.path.exists(image_path): | |
| os.remove(image_path) | |
| return answer, sources | |
| except Exception as e: | |
| if image_path and os.path.exists(image_path): | |
| os.remove(image_path) | |
| return f"Error processing query: {str(e)}", [] | |
| with gr.Blocks(title="Multimodal RAG System") as demo: | |
| gr.Markdown("# Multimodal RAG System") | |
| gr.Markdown("Upload a PDF document and ask questions about it. You can also add an image for multimodal context.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pdf_input = gr.File(label="Upload PDF Document") | |
| upload_button = gr.Button("Process PDF") | |
| status_output = gr.Textbox(label="Status") | |
| upload_button.click( | |
| fn=upload_pdf, | |
| inputs=[pdf_input], | |
| outputs=[status_output] | |
| ) | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(label="Optional: Upload an Image", type="pil") | |
| query_input = gr.Textbox(label="Ask a question") | |
| submit_button = gr.Button("Submit Question") | |
| answer_output = gr.Textbox(label="Answer") | |
| sources_output = gr.JSON(label="Sources") | |
| submit_button.click( | |
| fn=process_query, | |
| inputs=[query_input, pdf_input, image_input], | |
| outputs=[answer_output, sources_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, server_name="0.0.0.0") |