Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import fitz # PyMuPDF | |
| import torch | |
| import cv2 | |
| import os | |
| import tempfile | |
| import shutil | |
| import logging | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Check CUDA | |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"GPU: {torch.cuda.get_device_name(0)}") | |
| # BitsAndBytes config for quantized model loading | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| # Load Qwen model | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Omni-3B", trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen2.5-Omni-3B", | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| trust_remote_code=True | |
| ).eval() | |
| logger.info("Qwen model loaded.") | |
| except Exception as e: | |
| logger.error(f"Failed to load Qwen: {e}") | |
| model, tokenizer = None, None | |
| # Load SentenceTransformer for RAG | |
| try: | |
| embed_model = SentenceTransformer('paraphrase-MiniLM-L3-v2') | |
| logger.info("Embedding model loaded.") | |
| except Exception as e: | |
| logger.error(f"Failed to load embedding model: {e}") | |
| embed_model = None | |
| # Global index state | |
| chunks = [] | |
| index = None | |
| # PDF text chunking | |
| def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200): | |
| try: | |
| doc = fitz.open(pdf_path) | |
| text = "".join([page.get_text() for page in doc]) | |
| return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)] | |
| except Exception as e: | |
| logger.error(f"PDF error: {e}") | |
| return ["Error extracting content."] | |
| # Build FAISS index | |
| def build_faiss_index(chunks): | |
| try: | |
| embeddings = embed_model.encode(chunks, convert_to_numpy=True) | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings) | |
| return index | |
| except Exception as e: | |
| logger.error(f"FAISS index error: {e}") | |
| return None | |
| # RAG retrieval | |
| def rag_query(query, chunks, index, top_k=3): | |
| try: | |
| q_emb = embed_model.encode([query], convert_to_numpy=True) | |
| D, I = index.search(q_emb, top_k) | |
| return "\n\n".join([chunks[i] for i in I[0]]) | |
| except Exception as e: | |
| logger.error(f"RAG query error: {e}") | |
| return "Error retrieving context." | |
| # Qwen chat | |
| def chat_with_qwen(text, image=None): | |
| if not model or not tokenizer: | |
| return "Model not loaded." | |
| try: | |
| messages = [{"role": "user", "content": text}] | |
| if image: | |
| messages[0]["content"] = [{"image": image}, {"text": text}] | |
| response, _ = model.chat(tokenizer, messages, history=None) | |
| return response | |
| except Exception as e: | |
| logger.error(f"Chat error: {e}") | |
| return f"Chat error: {e}" | |
| # Extract representative frames | |
| def extract_video_frames(video_path, max_frames=2): | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_indices = [int(i * total_frames / max_frames) for i in range(max_frames)] | |
| frames = [] | |
| for idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| success, frame = cap.read() | |
| if success: | |
| frames.append(frame) | |
| cap.release() | |
| return frames | |
| except Exception as e: | |
| logger.error(f"Frame extraction error: {e}") | |
| return [] | |
| # Multimodal chat logic | |
| def multimodal_chat(message, history, image=None, video=None, pdf=None): | |
| global chunks, index | |
| if not model: | |
| return "Model not available." | |
| try: | |
| # PDF + question | |
| if pdf and message: | |
| pdf_path = pdf.name if hasattr(pdf, 'name') else None | |
| if not pdf_path: | |
| return "Invalid PDF input." | |
| chunks = extract_chunks_from_pdf(pdf_path) | |
| index = build_faiss_index(chunks) | |
| if index: | |
| context = rag_query(message, chunks, index) | |
| user_prompt = f"Context:\n{context}\n\nQuestion: {message}" | |
| return chat_with_qwen(user_prompt) | |
| else: | |
| return "Failed to process PDF." | |
| # Image + question | |
| if image and message: | |
| return chat_with_qwen(message, image) | |
| # Video + question | |
| if video and message: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| video_path = os.path.join(temp_dir, "video.mp4") | |
| shutil.copy(video.name if hasattr(video, 'name') else video, video_path) | |
| frames = extract_video_frames(video_path) | |
| if not frames: | |
| return "Could not extract video frames." | |
| temp_img_path = os.path.join(temp_dir, "frame.jpg") | |
| cv2.imwrite(temp_img_path, cv2.cvtColor(frames[0], cv2.COLOR_BGR2RGB)) | |
| return chat_with_qwen(message, temp_img_path) | |
| # Text only | |
| if message: | |
| return chat_with_qwen(message) | |
| return "Please enter a question and optionally upload a file." | |
| except Exception as e: | |
| logger.error(f"Chat error: {e}") | |
| return f"Error: {e}" | |
| # Gradio UI | |
| with gr.Blocks(css=""" | |
| body { background-color: #f3f6fc; } | |
| .gradio-container { font-family: 'Segoe UI', sans-serif; } | |
| h1 { | |
| background: linear-gradient(to right, #667eea, #764ba2); | |
| color: white !important; | |
| padding: 1rem; border-radius: 12px; margin-bottom: 0.5rem; | |
| } | |
| .gr-box { | |
| background-color: white; border-radius: 12px; | |
| box-shadow: 0 0 10px rgba(0,0,0,0.05); padding: 16px; | |
| } | |
| footer { display: none !important; } | |
| """) as demo: | |
| gr.Markdown(""" | |
| <h1 style='text-align: center;'>Multimodal Chatbot powered by Qwen-2.5-Omni-3B</h1> | |
| <p style='text-align: center;'>Ask your own questions with optional image, video, or PDF context.</p> | |
| """) | |
| chatbot = gr.Chatbot(show_label=False, height=450) | |
| state = gr.State([]) | |
| with gr.Row(): | |
| txt = gr.Textbox(show_label=False, placeholder="Type your question...", scale=5) | |
| send_btn = gr.Button("🚀 Send", scale=1) | |
| with gr.Row(): | |
| image_input = gr.Image(type="filepath", label="Upload Image") | |
| video_input = gr.Video(label="Upload Video") | |
| pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF") | |
| def user_send(message, history, image, video, pdf): | |
| if not message and not image and not video and not pdf: | |
| return "", history, history | |
| response = multimodal_chat(message, history, image, video, pdf) | |
| history.append((message, response)) | |
| return "", history, history | |
| send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot, state]) | |
| txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot, state]) | |
| logger.info("Launching Gradio app") | |
| demo.launch() | |