#!/usr/bin/env python3 """ OncoLife Symptom & Triage Assistant A medical chatbot that performs both symptom assessment and clinical triage for chemotherapy patients. Updated: Using BioMistral-7B base model for medical conversations. REBUILD: Simplified to use only base model, no adapters. RAG: Added document retrieval capabilities for PDFs and other reference materials (optional). """ import gradio as gr import os import json from pathlib import Path from transformers import AutoTokenizer, MistralForCausalLM import torch from spaces import GPU # RAG imports (optional) try: import chromadb from sentence_transformers import SentenceTransformer import PyPDF2 import pdfplumber from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings import HuggingFaceEmbeddings import fitz # PyMuPDF for better PDF handling RAG_AVAILABLE = True except ImportError: print("⚠️ RAG libraries not available, running in instruction-only mode") RAG_AVAILABLE = False # Force GPU detection for HF Spaces @GPU def force_gpu_detection(): """Force GPU detection for Hugging Face Spaces""" return torch.cuda.is_available() class OncoLifeAssistant: def __init__(self): # BioMistral base model configuration BASE = "BioMistral/BioMistral-7B" print("🔄 Initializing OncoLife Symptom & Triage Assistant") print(f"📦 Loading base model: {BASE}") # Force GPU detection first try: gpu_available = force_gpu_detection() print(f"🖥️ GPU Detection: {gpu_available}") except Exception as e: print(f"⚠️ GPU detection error: {e}") gpu_available = torch.cuda.is_available() self._load_model(BASE, gpu_available) # Load the OncoLife instructions self._load_instructions() # Initialize RAG system (optional) self.rag_enabled = False if RAG_AVAILABLE: try: self._initialize_rag() self.rag_enabled = True print("✅ RAG system initialized successfully") except Exception as e: print(f"⚠️ RAG initialization failed: {e}") print("🔄 Continuing with instruction-only mode") else: print("🔄 Running in instruction-only mode (no RAG)") def _load_instructions(self): """Load the OncoLife instructions from the text file""" try: instructions_file = Path(__file__).parent / "oncolifebot_instructions.txt" if instructions_file.exists(): with open(instructions_file, 'r') as f: self.instructions = f.read() print("✅ Loaded oncolifebot_instructions.txt") else: print("⚠️ oncolifebot_instructions.txt not found") self.instructions = "" except Exception as e: print(f"❌ Error loading instructions: {e}") self.instructions = "" def _initialize_rag(self): """Initialize the RAG system with document embeddings (lightweight version)""" try: print("🔍 Initializing lightweight RAG system...") # Use a smaller embedding model self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') print("✅ Loaded embedding model") # Initialize ChromaDB with persistence disabled for memory efficiency self.chroma_client = chromadb.Client() self.collection = self.chroma_client.create_collection( name="oncolife_documents", metadata={"description": "OncoLife reference documents"} ) print("✅ Initialized ChromaDB collection") # Load and process documents (limited to essential files) self._load_documents_lightweight() except Exception as e: print(f"❌ Error initializing RAG: {e}") self.embedding_model = None self.collection = None raise e def _load_documents_lightweight(self): """Load only essential documents to save memory""" try: docs_path = Path(__file__).parent / "guideline-docs" print(f"📚 Loading essential documents from: {docs_path}") if not docs_path.exists(): print("⚠️ guideline-docs directory not found") return # Text splitter for chunking documents text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, # Smaller chunks to save memory chunk_overlap=100, separators=["\n\n", "\n", ". ", " ", ""] ) documents_loaded = 0 # Process PDF files (essential medical guidelines) for pdf_file in docs_path.glob("*.pdf"): try: print(f"📄 Processing PDF: {pdf_file.name}") text = self._extract_pdf_text(pdf_file) if text: chunks = text_splitter.split_text(text) self._add_chunks_to_db(chunks, pdf_file.name) documents_loaded += 1 print(f"✅ Added {len(chunks)} chunks from {pdf_file.name}") else: print(f"⚠️ No text extracted from {pdf_file.name}") except Exception as e: print(f"❌ Error processing {pdf_file.name}: {e}") # Process JSON files (lightweight) for json_file in docs_path.glob("*.json"): try: print(f"📄 Processing JSON: {json_file.name}") with open(json_file, 'r') as f: data = json.load(f) # Convert JSON to text representation text = json.dumps(data, indent=2) chunks = text_splitter.split_text(text) self._add_chunks_to_db(chunks, json_file.name) documents_loaded += 1 print(f"✅ Added {len(chunks)} chunks from {json_file.name}") except Exception as e: print(f"❌ Error processing {json_file.name}: {e}") # Process text files (lightweight) for txt_file in docs_path.glob("*.txt"): try: print(f"📄 Processing TXT: {txt_file.name}") with open(txt_file, 'r', encoding='utf-8') as f: text = f.read() chunks = text_splitter.split_text(text) self._add_chunks_to_db(chunks, txt_file.name) documents_loaded += 1 print(f"✅ Added {len(chunks)} chunks from {txt_file.name}") except Exception as e: print(f"❌ Error processing {txt_file.name}: {e}") print(f"✅ RAG system initialized with {documents_loaded} documents") except Exception as e: print(f"❌ Error loading documents: {e}") def _extract_pdf_text(self, pdf_path): """Extract text from PDF using multiple methods""" try: # Try PyMuPDF first (better for complex PDFs) try: doc = fitz.open(pdf_path) text = "" for page in doc: text += page.get_text() doc.close() if text.strip(): return text except Exception as e: print(f"PyMuPDF failed for {pdf_path.name}: {e}") # Fallback to pdfplumber try: with pdfplumber.open(pdf_path) as pdf: text = "" for page in pdf.pages: if page.extract_text(): text += page.extract_text() + "\n" return text except Exception as e: print(f"pdfplumber failed for {pdf_path.name}: {e}") # Final fallback to PyPDF2 try: with open(pdf_path, 'rb') as file: reader = PyPDF2.PdfReader(file) text = "" for page in reader.pages: text += page.extract_text() + "\n" return text except Exception as e: print(f"PyPDF2 failed for {pdf_path.name}: {e}") return None except Exception as e: print(f"❌ Error extracting text from {pdf_path.name}: {e}") return None def _add_chunks_to_db(self, chunks, source_name): """Add document chunks to the vector database""" try: if not chunks or not self.collection: return # Generate embeddings embeddings = self.embedding_model.encode(chunks) # Add to ChromaDB self.collection.add( embeddings=embeddings.tolist(), documents=chunks, metadatas=[{"source": source_name, "chunk_id": i} for i in range(len(chunks))], ids=[f"{source_name}_chunk_{i}" for i in range(len(chunks))] ) except Exception as e: print(f"❌ Error adding chunks to database: {e}") def _retrieve_relevant_documents(self, query, top_k=3): """Retrieve relevant document chunks for a query""" try: if not self.collection or not self.embedding_model or not self.rag_enabled: return [] # Generate query embedding query_embedding = self.embedding_model.encode([query]) # Search for similar documents results = self.collection.query( query_embeddings=query_embedding.tolist(), n_results=top_k ) # Format results relevant_docs = [] if results['documents']: for i, doc in enumerate(results['documents'][0]): relevant_docs.append({ 'content': doc, 'source': results['metadatas'][0][i]['source'], 'similarity': results['distances'][0][i] if 'distances' in results else None }) return relevant_docs except Exception as e: print(f"❌ Error retrieving documents: {e}") return [] def _load_model(self, model_id, gpu_available): """Load the BioMistral base model with memory optimization""" try: print("🔄 Loading BioMistral base model...") # Determine device strategy if gpu_available and torch.cuda.is_available(): device = "cuda" dtype = torch.float16 print("🖥️ Loading BioMistral model on GPU...") else: device = "cpu" dtype = torch.float32 print("💻 Loading BioMistral model on CPU...") # Load tokenizer print(f"📝 Loading tokenizer: {model_id}") self.tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) # Load the model with memory optimization print(f"📦 Loading model: {model_id}") self.model = MistralForCausalLM.from_pretrained( model_id, trust_remote_code=True, device_map="auto", torch_dtype=dtype, low_cpu_mem_usage=True, # Add memory optimization max_memory={0: "8GB", "cpu": "16GB"} if gpu_available else {"cpu": "8GB"} ) # Add pad token if not present if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(f"✅ BioMistral base model loaded successfully on {device.upper()}!") except Exception as e: print(f"❌ Error loading BioMistral model: {e}") self.model = None self.tokenizer = None def generate_oncolife_response(self, user_input, conversation_history): """Generate response using OncoLife instructions and optional RAG""" try: if self.model is None or self.tokenizer is None: return """❌ **Model Loading Error** The OncoLife assistant model failed to load. This could be due to: 1. Model not available 2. Memory constraints 3. Network issues Please check the Space logs for details.""" print(f"🔄 Generating OncoLife response for: {user_input}") # Retrieve relevant documents using RAG (if available) context_text = "" if self.rag_enabled: try: relevant_docs = self._retrieve_relevant_documents(user_input, top_k=2) if relevant_docs: context_text = "\n\n**Relevant Reference Information:**\n" for i, doc in enumerate(relevant_docs): context_text += f"\n--- Source: {doc['source']} ---\n{doc['content'][:300]}...\n" except Exception as e: print(f"⚠️ RAG retrieval failed: {e}") # Create prompt using the loaded instructions and retrieved context system_prompt = f"""You are the OncoLife Symptom & Triage Assistant. Follow these instructions exactly: {self.instructions} {context_text} Current user input: {user_input}""" # Format conversation history history_text = "" if conversation_history: for entry in conversation_history: history_text += f"User: {entry['user']}\nAssistant: {entry['assistant']}\n\n" # Create full prompt prompt = f"{system_prompt}\n\nConversation History:\n{history_text}\nUser: {user_input}\nAssistant:" # Tokenize inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) # Get the device the model is actually on model_device = next(self.model.parameters()).device print(f"🔧 Model device: {model_device}") # Move inputs to the same device as the model for key in inputs: if isinstance(inputs[key], torch.Tensor): inputs[key] = inputs[key].to(model_device) print(f"📦 Moved {key} to {model_device}") # Ensure model is in eval mode self.model.eval() # Generate with proper device handling with torch.no_grad(): try: outputs = self.model.generate( **inputs, max_new_tokens=512, # Longer responses for detailed medical assessment temperature=0.7, do_sample=True, top_p=0.9, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id ) except RuntimeError as e: if "device" in str(e).lower(): print("🔄 Device error detected, trying CPU fallback...") # Move everything to CPU and try again self.model = self.model.to("cpu") for key in inputs: if isinstance(inputs[key], torch.Tensor): inputs[key] = inputs[key].to("cpu") outputs = self.model.generate( **inputs, max_new_tokens=512, temperature=0.7, do_sample=True, top_p=0.9, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id ) else: raise e # Decode response response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract just the assistant's response if "Assistant:" in response: answer = response.split("Assistant:")[-1].strip() else: answer = response.strip() print("✅ OncoLife response generated successfully") return answer except Exception as e: print(f"❌ Error generating OncoLife response: {e}") return f"""❌ **Generation Error** Error: {str(e)} This could be due to: 1. Model compatibility issues 2. Memory constraints 3. Input format problems Please try a simpler question or check the logs for more details.""" def chat(self, message, history): """Main chat interface for OncoLife Assistant""" if not message.strip(): return "Please describe your symptoms or concerns." # Convert history to the format expected by generate_oncolife_response conversation_history = [] for user_msg, assistant_msg in history: conversation_history.append({ "user": user_msg, "assistant": assistant_msg }) # Generate response using OncoLife instructions and optional RAG response = self.generate_oncolife_response(message, conversation_history) return response # Create interface assistant = OncoLifeAssistant() interface = gr.ChatInterface( fn=assistant.chat, title="🏥 OncoLife Symptom & Triage Assistant", description="I'm here to help assess your symptoms and determine if you need to contact your care team. I can access your medical guidelines and reference documents to provide accurate information.", examples=[ ["I'm feeling nauseous and tired"], ["I have a fever of 101"], ["My neuropathy is getting worse"], ["I'm having trouble eating"], ["I feel dizzy and lightheaded"] ], theme=gr.themes.Soft() ) if __name__ == "__main__": print("=" * 60) print("OncoLife Symptom & Triage Assistant") print("=" * 60) interface.launch(server_name="0.0.0.0", server_port=7860)