Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| OncoLife Symptom & Triage Assistant | |
| A medical chatbot that performs both symptom assessment and clinical triage for chemotherapy patients. | |
| Updated: Using BioMistral-7B base model for medical conversations. | |
| REBUILD: Simplified to use only base model, no adapters. | |
| RAG: Added document retrieval capabilities for PDFs and other reference materials (optional). | |
| """ | |
| import gradio as gr | |
| import os | |
| import json | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, MistralForCausalLM | |
| import torch | |
| from spaces import GPU | |
| # RAG imports (optional) | |
| try: | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import PyPDF2 | |
| import pdfplumber | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| import fitz # PyMuPDF for better PDF handling | |
| RAG_AVAILABLE = True | |
| except ImportError: | |
| print("β οΈ RAG libraries not available, running in instruction-only mode") | |
| RAG_AVAILABLE = False | |
| # Force GPU detection for HF Spaces | |
| def force_gpu_detection(): | |
| """Force GPU detection for Hugging Face Spaces""" | |
| return torch.cuda.is_available() | |
| class OncoLifeAssistant: | |
| def __init__(self): | |
| # BioMistral base model configuration | |
| BASE = "BioMistral/BioMistral-7B" | |
| print("π Initializing OncoLife Symptom & Triage Assistant") | |
| print(f"π¦ Loading base model: {BASE}") | |
| # Force GPU detection first | |
| try: | |
| gpu_available = force_gpu_detection() | |
| print(f"π₯οΈ GPU Detection: {gpu_available}") | |
| except Exception as e: | |
| print(f"β οΈ GPU detection error: {e}") | |
| gpu_available = torch.cuda.is_available() | |
| self._load_model(BASE, gpu_available) | |
| # Load the OncoLife instructions | |
| self._load_instructions() | |
| # Initialize RAG system (optional) | |
| self.rag_enabled = False | |
| if RAG_AVAILABLE: | |
| try: | |
| self._initialize_rag() | |
| self.rag_enabled = True | |
| print("β RAG system initialized successfully") | |
| except Exception as e: | |
| print(f"β οΈ RAG initialization failed: {e}") | |
| print("π Continuing with instruction-only mode") | |
| else: | |
| print("π Running in instruction-only mode (no RAG)") | |
| def _load_instructions(self): | |
| """Load the OncoLife instructions from the text file""" | |
| try: | |
| instructions_file = Path(__file__).parent / "oncolifebot_instructions.txt" | |
| if instructions_file.exists(): | |
| with open(instructions_file, 'r') as f: | |
| self.instructions = f.read() | |
| print("β Loaded oncolifebot_instructions.txt") | |
| else: | |
| print("β οΈ oncolifebot_instructions.txt not found") | |
| self.instructions = "" | |
| except Exception as e: | |
| print(f"β Error loading instructions: {e}") | |
| self.instructions = "" | |
| def _initialize_rag(self): | |
| """Initialize the RAG system with document embeddings (lightweight version)""" | |
| try: | |
| print("π Initializing lightweight RAG system...") | |
| # Use a smaller embedding model | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| print("β Loaded embedding model") | |
| # Initialize ChromaDB with persistence disabled for memory efficiency | |
| self.chroma_client = chromadb.Client() | |
| self.collection = self.chroma_client.create_collection( | |
| name="oncolife_documents", | |
| metadata={"description": "OncoLife reference documents"} | |
| ) | |
| print("β Initialized ChromaDB collection") | |
| # Load and process documents (limited to essential files) | |
| self._load_documents_lightweight() | |
| except Exception as e: | |
| print(f"β Error initializing RAG: {e}") | |
| self.embedding_model = None | |
| self.collection = None | |
| raise e | |
| def _load_documents_lightweight(self): | |
| """Load only essential documents to save memory""" | |
| try: | |
| docs_path = Path(__file__).parent / "guideline-docs" | |
| print(f"π Loading essential documents from: {docs_path}") | |
| if not docs_path.exists(): | |
| print("β οΈ guideline-docs directory not found") | |
| return | |
| # Text splitter for chunking documents | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, # Smaller chunks to save memory | |
| chunk_overlap=100, | |
| separators=["\n\n", "\n", ". ", " ", ""] | |
| ) | |
| documents_loaded = 0 | |
| # Process PDF files (essential medical guidelines) | |
| for pdf_file in docs_path.glob("*.pdf"): | |
| try: | |
| print(f"π Processing PDF: {pdf_file.name}") | |
| text = self._extract_pdf_text(pdf_file) | |
| if text: | |
| chunks = text_splitter.split_text(text) | |
| self._add_chunks_to_db(chunks, pdf_file.name) | |
| documents_loaded += 1 | |
| print(f"β Added {len(chunks)} chunks from {pdf_file.name}") | |
| else: | |
| print(f"β οΈ No text extracted from {pdf_file.name}") | |
| except Exception as e: | |
| print(f"β Error processing {pdf_file.name}: {e}") | |
| # Process JSON files (lightweight) | |
| for json_file in docs_path.glob("*.json"): | |
| try: | |
| print(f"π Processing JSON: {json_file.name}") | |
| with open(json_file, 'r') as f: | |
| data = json.load(f) | |
| # Convert JSON to text representation | |
| text = json.dumps(data, indent=2) | |
| chunks = text_splitter.split_text(text) | |
| self._add_chunks_to_db(chunks, json_file.name) | |
| documents_loaded += 1 | |
| print(f"β Added {len(chunks)} chunks from {json_file.name}") | |
| except Exception as e: | |
| print(f"β Error processing {json_file.name}: {e}") | |
| # Process text files (lightweight) | |
| for txt_file in docs_path.glob("*.txt"): | |
| try: | |
| print(f"π Processing TXT: {txt_file.name}") | |
| with open(txt_file, 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| chunks = text_splitter.split_text(text) | |
| self._add_chunks_to_db(chunks, txt_file.name) | |
| documents_loaded += 1 | |
| print(f"β Added {len(chunks)} chunks from {txt_file.name}") | |
| except Exception as e: | |
| print(f"β Error processing {txt_file.name}: {e}") | |
| print(f"β RAG system initialized with {documents_loaded} documents") | |
| except Exception as e: | |
| print(f"β Error loading documents: {e}") | |
| def _extract_pdf_text(self, pdf_path): | |
| """Extract text from PDF using multiple methods""" | |
| try: | |
| # Try PyMuPDF first (better for complex PDFs) | |
| try: | |
| doc = fitz.open(pdf_path) | |
| text = "" | |
| for page in doc: | |
| text += page.get_text() | |
| doc.close() | |
| if text.strip(): | |
| return text | |
| except Exception as e: | |
| print(f"PyMuPDF failed for {pdf_path.name}: {e}") | |
| # Fallback to pdfplumber | |
| try: | |
| with pdfplumber.open(pdf_path) as pdf: | |
| text = "" | |
| for page in pdf.pages: | |
| if page.extract_text(): | |
| text += page.extract_text() + "\n" | |
| return text | |
| except Exception as e: | |
| print(f"pdfplumber failed for {pdf_path.name}: {e}") | |
| # Final fallback to PyPDF2 | |
| try: | |
| with open(pdf_path, 'rb') as file: | |
| reader = PyPDF2.PdfReader(file) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| except Exception as e: | |
| print(f"PyPDF2 failed for {pdf_path.name}: {e}") | |
| return None | |
| except Exception as e: | |
| print(f"β Error extracting text from {pdf_path.name}: {e}") | |
| return None | |
| def _add_chunks_to_db(self, chunks, source_name): | |
| """Add document chunks to the vector database""" | |
| try: | |
| if not chunks or not self.collection: | |
| return | |
| # Generate embeddings | |
| embeddings = self.embedding_model.encode(chunks) | |
| # Add to ChromaDB | |
| self.collection.add( | |
| embeddings=embeddings.tolist(), | |
| documents=chunks, | |
| metadatas=[{"source": source_name, "chunk_id": i} for i in range(len(chunks))], | |
| ids=[f"{source_name}_chunk_{i}" for i in range(len(chunks))] | |
| ) | |
| except Exception as e: | |
| print(f"β Error adding chunks to database: {e}") | |
| def _retrieve_relevant_documents(self, query, top_k=3): | |
| """Retrieve relevant document chunks for a query""" | |
| try: | |
| if not self.collection or not self.embedding_model or not self.rag_enabled: | |
| return [] | |
| # Generate query embedding | |
| query_embedding = self.embedding_model.encode([query]) | |
| # Search for similar documents | |
| results = self.collection.query( | |
| query_embeddings=query_embedding.tolist(), | |
| n_results=top_k | |
| ) | |
| # Format results | |
| relevant_docs = [] | |
| if results['documents']: | |
| for i, doc in enumerate(results['documents'][0]): | |
| relevant_docs.append({ | |
| 'content': doc, | |
| 'source': results['metadatas'][0][i]['source'], | |
| 'similarity': results['distances'][0][i] if 'distances' in results else None | |
| }) | |
| return relevant_docs | |
| except Exception as e: | |
| print(f"β Error retrieving documents: {e}") | |
| return [] | |
| def _load_model(self, model_id, gpu_available): | |
| """Load the BioMistral base model with memory optimization""" | |
| try: | |
| print("π Loading BioMistral base model...") | |
| # Determine device strategy | |
| if gpu_available and torch.cuda.is_available(): | |
| device = "cuda" | |
| dtype = torch.float16 | |
| print("π₯οΈ Loading BioMistral model on GPU...") | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 | |
| print("π» Loading BioMistral model on CPU...") | |
| # Load tokenizer | |
| print(f"π Loading tokenizer: {model_id}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True | |
| ) | |
| # Load the model with memory optimization | |
| print(f"π¦ Loading model: {model_id}") | |
| self.model = MistralForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| # Add memory optimization | |
| max_memory={0: "8GB", "cpu": "16GB"} if gpu_available else {"cpu": "8GB"} | |
| ) | |
| # Add pad token if not present | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| print(f"β BioMistral base model loaded successfully on {device.upper()}!") | |
| except Exception as e: | |
| print(f"β Error loading BioMistral model: {e}") | |
| self.model = None | |
| self.tokenizer = None | |
| def generate_oncolife_response(self, user_input, conversation_history): | |
| """Generate response using OncoLife instructions and optional RAG""" | |
| try: | |
| if self.model is None or self.tokenizer is None: | |
| return """β **Model Loading Error** | |
| The OncoLife assistant model failed to load. This could be due to: | |
| 1. Model not available | |
| 2. Memory constraints | |
| 3. Network issues | |
| Please check the Space logs for details.""" | |
| print(f"π Generating OncoLife response for: {user_input}") | |
| # Retrieve relevant documents using RAG (if available) | |
| context_text = "" | |
| if self.rag_enabled: | |
| try: | |
| relevant_docs = self._retrieve_relevant_documents(user_input, top_k=2) | |
| if relevant_docs: | |
| context_text = "\n\n**Relevant Reference Information:**\n" | |
| for i, doc in enumerate(relevant_docs): | |
| context_text += f"\n--- Source: {doc['source']} ---\n{doc['content'][:300]}...\n" | |
| except Exception as e: | |
| print(f"β οΈ RAG retrieval failed: {e}") | |
| # Create prompt using the loaded instructions and retrieved context | |
| system_prompt = f"""You are the OncoLife Symptom & Triage Assistant. Follow these instructions exactly: | |
| {self.instructions} | |
| {context_text} | |
| Current user input: {user_input}""" | |
| # Format conversation history | |
| history_text = "" | |
| if conversation_history: | |
| for entry in conversation_history: | |
| history_text += f"User: {entry['user']}\nAssistant: {entry['assistant']}\n\n" | |
| # Create full prompt | |
| prompt = f"{system_prompt}\n\nConversation History:\n{history_text}\nUser: {user_input}\nAssistant:" | |
| # Tokenize | |
| inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) | |
| # Get the device the model is actually on | |
| model_device = next(self.model.parameters()).device | |
| print(f"π§ Model device: {model_device}") | |
| # Move inputs to the same device as the model | |
| for key in inputs: | |
| if isinstance(inputs[key], torch.Tensor): | |
| inputs[key] = inputs[key].to(model_device) | |
| print(f"π¦ Moved {key} to {model_device}") | |
| # Ensure model is in eval mode | |
| self.model.eval() | |
| # Generate with proper device handling | |
| with torch.no_grad(): | |
| try: | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=512, # Longer responses for detailed medical assessment | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| except RuntimeError as e: | |
| if "device" in str(e).lower(): | |
| print("π Device error detected, trying CPU fallback...") | |
| # Move everything to CPU and try again | |
| self.model = self.model.to("cpu") | |
| for key in inputs: | |
| if isinstance(inputs[key], torch.Tensor): | |
| inputs[key] = inputs[key].to("cpu") | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| else: | |
| raise e | |
| # Decode response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract just the assistant's response | |
| if "Assistant:" in response: | |
| answer = response.split("Assistant:")[-1].strip() | |
| else: | |
| answer = response.strip() | |
| print("β OncoLife response generated successfully") | |
| return answer | |
| except Exception as e: | |
| print(f"β Error generating OncoLife response: {e}") | |
| return f"""β **Generation Error** | |
| Error: {str(e)} | |
| This could be due to: | |
| 1. Model compatibility issues | |
| 2. Memory constraints | |
| 3. Input format problems | |
| Please try a simpler question or check the logs for more details.""" | |
| def chat(self, message, history): | |
| """Main chat interface for OncoLife Assistant""" | |
| if not message.strip(): | |
| return "Please describe your symptoms or concerns." | |
| # Convert history to the format expected by generate_oncolife_response | |
| conversation_history = [] | |
| for user_msg, assistant_msg in history: | |
| conversation_history.append({ | |
| "user": user_msg, | |
| "assistant": assistant_msg | |
| }) | |
| # Generate response using OncoLife instructions and optional RAG | |
| response = self.generate_oncolife_response(message, conversation_history) | |
| return response | |
| # Create interface | |
| assistant = OncoLifeAssistant() | |
| interface = gr.ChatInterface( | |
| fn=assistant.chat, | |
| title="π₯ OncoLife Symptom & Triage Assistant", | |
| description="I'm here to help assess your symptoms and determine if you need to contact your care team. I can access your medical guidelines and reference documents to provide accurate information.", | |
| examples=[ | |
| ["I'm feeling nauseous and tired"], | |
| ["I have a fever of 101"], | |
| ["My neuropathy is getting worse"], | |
| ["I'm having trouble eating"], | |
| ["I feel dizzy and lightheaded"] | |
| ], | |
| theme=gr.themes.Soft() | |
| ) | |
| if __name__ == "__main__": | |
| print("=" * 60) | |
| print("OncoLife Symptom & Triage Assistant") | |
| print("=" * 60) | |
| interface.launch(server_name="0.0.0.0", server_port=7860) | |