Spaces:
Sleeping
Sleeping
| # src/core/memory_manager.py | |
| from src.data.connection import ActionFailed | |
| from src.data.repositories import account as account_repo | |
| from src.data.repositories import information as info_repo | |
| from src.data.repositories import medical_memory as memory_repo | |
| from src.data.repositories import patient as patient_repo | |
| from src.data.repositories import session as session_repo | |
| from src.models.account import Account | |
| from src.models.patient import Patient | |
| from src.models.session import Message, Session | |
| from src.services import reranker, summariser | |
| from src.services.nvidia import nvidia_chat | |
| from src.utils.embeddings import EmbeddingClient | |
| from src.utils.logger import logger | |
| from src.utils.rotator import APIKeyRotator | |
| class MemoryManager: | |
| """ | |
| A service layer that orchestrates data access and business logic for managing | |
| accounts, chat sessions, and long-term medical memory. | |
| """ | |
| def __init__(self, embedder: EmbeddingClient, max_sessions_per_user: int = 10): | |
| self.embedder = embedder | |
| self.max_sessions_per_user = max_sessions_per_user | |
| # --- Account Management Facade --- | |
| def create_account( | |
| self, | |
| name: str = "Anonymous", | |
| role: str = "Other", | |
| specialty: str | None = None | |
| ) -> str | None: | |
| """Creates a new user account.""" | |
| try: | |
| return account_repo.create_account(name=name, role=role, specialty=specialty) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to create account in MemoryManager: {e}") | |
| return None | |
| def get_account(self, user_id: str) -> Account | None: | |
| """Retrieves a user account by its ID.""" | |
| try: | |
| return account_repo.get_account(user_id) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to get account '{user_id}' in MemoryManager: {e}") | |
| return None | |
| def get_all_accounts(self, limit: int = 50) -> list[Account]: | |
| """Retrieves a list of all accounts.""" | |
| try: | |
| return account_repo.get_all_accounts(limit=limit) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to get all accounts in MemoryManager: {e}") | |
| return [] | |
| def search_accounts(self, query: str, limit: int = 10) -> list[Account]: | |
| """Searches for accounts by name.""" | |
| try: | |
| return account_repo.search_accounts(query, limit=limit) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to search accounts in MemoryManager: {e}") | |
| return [] | |
| # --- Patient Management Facade --- | |
| def create_patient(self, **kwargs) -> str | None: | |
| """Creates a new patient record.""" | |
| try: | |
| return patient_repo.create_patient(**kwargs) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to create patient in MemoryManager: {e}") | |
| return None | |
| def get_patient_by_id(self, patient_id: str) -> Patient | None: | |
| """Retrieves a patient by their unique ID.""" | |
| try: | |
| return patient_repo.get_patient_by_id(patient_id) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to get patient '{patient_id}' in MemoryManager: {e}") | |
| return None | |
| def update_patient_profile(self, patient_id: str, updates: dict) -> int: | |
| """Updates a patient's profile.""" | |
| try: | |
| return patient_repo.update_patient_profile(patient_id, updates) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to update patient '{patient_id}' in MemoryManager: {e}") | |
| return 0 | |
| def search_patients(self, query: str, limit: int = 10) -> list[Patient]: | |
| """Searches for patients by name.""" | |
| try: | |
| return patient_repo.search_patients(query, limit=limit) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to search patients in MemoryManager: {e}") | |
| return [] | |
| # --- Session Management Facade --- | |
| def create_session(self, user_id: str, patient_id: str, title: str = "New Chat") -> Session | None: | |
| """Creates a new chat session for a user.""" | |
| try: | |
| return session_repo.create_session(user_id, patient_id, title) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to create session in MemoryManager: {e}") | |
| return None | |
| def get_session(self, session_id: str) -> Session | None: | |
| """Retrieves a single chat session by its ID.""" | |
| try: | |
| return session_repo.get_session(session_id) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to get session '{session_id}' in MemoryManager: {e}") | |
| return None | |
| def get_user_sessions(self, user_id: str) -> list[Session]: | |
| """Retrieves all sessions for a specific user.""" | |
| try: | |
| return session_repo.get_user_sessions(user_id, limit=self.max_sessions_per_user) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to get user sessions for '{user_id}': {e}") | |
| return [] | |
| def update_session_title(self, session_id: str, title: str) -> bool: | |
| """Updates the title of a session.""" | |
| try: | |
| return session_repo.update_session_title(session_id, title) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to update title for session '{session_id}': {e}") | |
| return False | |
| def list_patient_sessions(self, patient_id: str) -> list[Session]: | |
| """Retrieves all sessions for a specific patient.""" | |
| try: | |
| return session_repo.list_patient_sessions(patient_id, limit=self.max_sessions_per_user) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to get sessions for patient '{patient_id}': {e}") | |
| return [] | |
| def delete_session(self, session_id: str) -> bool: | |
| """Deletes a chat session.""" | |
| try: | |
| return session_repo.delete_session(session_id) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to delete session '{session_id}' in MemoryManager: {e}") | |
| return False | |
| def get_session_messages(self, session_id: str, limit: int | None = None) -> list[Message]: | |
| """Gets messages from a specific chat session.""" | |
| try: | |
| return session_repo.get_session_messages(session_id, limit) | |
| except ActionFailed as e: | |
| logger().error(f"Failed to get messages for session '{session_id}': {e}") | |
| return [] | |
| # --- Core Business Logic --- | |
| async def process_medical_exchange( | |
| self, | |
| session_id: str, | |
| patient_id: str, | |
| doctor_id: str, | |
| question: str, | |
| answer: str, | |
| gemini_rotator: APIKeyRotator, | |
| nvidia_rotator: APIKeyRotator | |
| ) -> str | None: | |
| """ | |
| Processes a medical Q&A exchange: adds messages to the session, generates | |
| a summary, creates an embedding, and saves it to long-term memory. | |
| """ | |
| try: | |
| # 1. Add messages to the current session | |
| session_repo.add_message(session_id, question, sent_by_user=True) | |
| session_repo.add_message(session_id, answer, sent_by_user=False) | |
| # 2. Generate a concise summary of the exchange | |
| summary = await self._generate_summary( | |
| question=question, | |
| answer=answer, | |
| gemini_rotator=gemini_rotator, | |
| nvidia_rotator=nvidia_rotator | |
| ) | |
| # 3. Generate an embedding for the summary for semantic search | |
| embedding = None | |
| if self.embedder: | |
| try: | |
| embedding = self.embedder.embed([summary])[0] | |
| except Exception as e: | |
| logger().warning(f"Failed to generate embedding for summary: {e}") | |
| # 4. Save the summary and embedding to long-term medical memory | |
| memory_repo.create_memory( | |
| patient_id=patient_id, | |
| doctor_id=doctor_id, | |
| session_id=session_id, | |
| summary=summary, | |
| embedding=embedding | |
| ) | |
| # 5. Update the session title if this was the first exchange | |
| await self._update_session_title_if_first_message( | |
| session_id=session_id, | |
| question=question, | |
| nvidia_rotator=nvidia_rotator | |
| ) | |
| return summary | |
| except ActionFailed as e: | |
| logger().error(f"Database error processing medical exchange for session '{session_id}': {e}") | |
| return None | |
| except Exception as e: | |
| logger().error(f"Unexpected error processing medical exchange: {e}") | |
| return None | |
| async def get_enhanced_context( | |
| self, | |
| session_id: str, | |
| patient_id: str, | |
| question: str, | |
| nvidia_rotator: APIKeyRotator | |
| ) -> str: | |
| """ | |
| Builds a rich, multi-source context string for a new question, combining | |
| short-term memory, long-term semantic memory, information from the knowledge base, and current conversation. | |
| """ | |
| context_parts = [] | |
| # 1. Get recent summaries (Short-Term Memory) | |
| try: | |
| recent_memories = memory_repo.get_recent_memories(patient_id, limit=3) | |
| if recent_memories: | |
| # Use NVIDIA to reason about relevance | |
| relevant_stm = await self._filter_summaries_for_relevance( | |
| question=question, | |
| summaries=[mem.summary for mem in recent_memories], | |
| nvidia_rotator=nvidia_rotator | |
| ) | |
| if relevant_stm: | |
| context_parts.append("Recent relevant medical context (STM):\n" + "\n".join(relevant_stm)) | |
| except ActionFailed as e: | |
| logger().warning(f"Could not retrieve recent memories for enhanced context: {e}") | |
| # 2. Get semantically similar summaries (Long-Term Memory) | |
| if self.embedder and self.embedder.is_available(): | |
| try: | |
| query_embedding = self.embedder.embed([question])[0] | |
| if query_embedding: | |
| ltm_results = memory_repo.search_memories_semantic( | |
| patient_id=patient_id, | |
| query_embedding=query_embedding, | |
| limit=2 | |
| ) | |
| if ltm_results: | |
| ltm_summaries = [result.summary for result in ltm_results] | |
| context_parts.append("Semantically relevant medical history (LTM):\n" + "\n".join(ltm_summaries)) | |
| except (ActionFailed, Exception) as e: | |
| logger().warning(f"Failed to perform LTM semantic search: {e}") | |
| # 3. Consult knowledge base | |
| info = await self._consult_knowledge_base( | |
| question=question, | |
| nvidia_rotator=nvidia_rotator | |
| ) | |
| if info: | |
| context_parts.append(info) | |
| # 4. Get current conversation context | |
| try: | |
| session = session_repo.get_session(session_id) | |
| if session and session.messages: | |
| session_context = "\n".join([ | |
| f"{'User' if msg.sent_by_user else 'Assistant'}: {msg.content}" | |
| for msg in session.messages[-10:] # Get last 10 messages | |
| ]) | |
| context_parts.append("Current conversation:\n" + session_context) | |
| except ActionFailed as e: | |
| logger().warning(f"Could not retrieve current session context: {e}") | |
| return "\n\n".join(filter(None, context_parts)) | |
| # --- Private Helper Methods --- | |
| async def _consult_knowledge_base( | |
| self, | |
| question: str, | |
| nvidia_rotator: APIKeyRotator | |
| ) -> str: | |
| """ | |
| Embeds a question, queries the knowledge base for relevant chunks, | |
| reranks them, and formats them into a context string. | |
| """ | |
| if not self.embedder or not self.embedder.is_available(): | |
| logger().warning("Embedder not available, skipping knowledge base consultation.") | |
| return "" | |
| try: | |
| # 1. Embed the user's question | |
| query_embedding = self.embedder.embed([question])[0] | |
| if not query_embedding: | |
| logger().warning("Failed to generate query embedding.") | |
| return "" | |
| # 2. Retrieve initial candidates from MongoDB | |
| initial_chunks = info_repo.search_chunks_semantic( | |
| query_embedding=query_embedding, | |
| limit=10 # Retrieve more candidates for the reranker to process | |
| ) | |
| if not initial_chunks: | |
| logger().info("No relevant chunks found in the knowledge base.") | |
| return "" | |
| # 3. Rerank the results for semantic relevance | |
| reranked_chunks = await reranker.rerank_documents( | |
| query=question, | |
| documents=initial_chunks, | |
| rotator=nvidia_rotator, | |
| top_k=3 # Keep the top 3 most relevant results | |
| ) | |
| if not reranked_chunks: | |
| logger().warning("Reranking failed to return any chunks.") | |
| return "" | |
| # 4. Format the final response | |
| context_header = "Consulted Knowledge Base for context:" | |
| formatted_chunks = [] | |
| for chunk in reranked_chunks: | |
| source = chunk.metadata.source | |
| content = chunk.content.strip() | |
| formatted_chunks.append(f"[Source: {source}]\n{content}") | |
| return f"{context_header}\n\n" + "\n\n".join(formatted_chunks) | |
| except ActionFailed as e: | |
| logger().error(f"A database error occurred while consulting the knowledge base: {e}") | |
| except Exception as e: | |
| logger().error(f"An unexpected error occurred during knowledge base consultation: {e}") | |
| return "" | |
| async def _update_session_title_if_first_message( | |
| self, | |
| session_id: str, | |
| question: str, | |
| nvidia_rotator: APIKeyRotator | |
| ) -> None: | |
| """Updates the session title if it contains only the first Q&A pair.""" | |
| try: | |
| session = self.get_session(session_id) | |
| # Check if it's the first user message and first assistant response | |
| if session and len(session.messages) == 2: | |
| title = await summariser.summarise_title_with_nvidia(text=question, rotator=nvidia_rotator, max_words=5) | |
| if not title: | |
| title = question[:80] # Fallback to first 80 chars | |
| self.update_session_title(session_id=session_id, title=title) | |
| except Exception as e: | |
| logger().warning(f"Failed to auto-update session title for session '{session_id}': {e}") | |
| async def _generate_summary( | |
| self, | |
| question: str, | |
| answer: str, | |
| gemini_rotator: APIKeyRotator, | |
| nvidia_rotator: APIKeyRotator | |
| ) -> str: | |
| """Generates a summary of a Q&A exchange, falling back to a basic format if AI fails.""" | |
| try: | |
| summary = await summariser.summarise_qa_with_gemini( | |
| question=question, | |
| answer=answer, | |
| rotator=gemini_rotator | |
| ) | |
| if summary: return summary | |
| # Fallback to NVIDIA if Gemini fails | |
| summary = await summariser.summarise_qa_with_nvidia( | |
| question=question, | |
| answer=answer, | |
| rotator=nvidia_rotator | |
| ) | |
| if summary: return summary | |
| except Exception as e: | |
| logger().warning(f"Failed to generate AI summary: {e}") | |
| # Fallback for both exceptions and cases where services return None | |
| return summariser.summarise_fallback(question=question, answer=answer) | |
| async def _filter_summaries_for_relevance( | |
| self, | |
| question: str, | |
| summaries: list[str], | |
| nvidia_rotator: APIKeyRotator | |
| ) -> list[str]: | |
| """Uses an AI model to select only the most relevant summaries for a given question.""" | |
| if not summaries: | |
| return [] | |
| try: | |
| sys_prompt = "You are a medical AI assistant. Select only the most relevant recent medical context that directly relates to the new question. Return the selected items verbatim, separated by a newline. If none are relevant, return nothing." | |
| user_prompt = f"Question: {question}\n\nSelect relevant items from recent medical context:\n" + "\n".join(summaries) | |
| relevant_text = await nvidia_chat(sys_prompt, user_prompt, nvidia_rotator) | |
| return relevant_text.strip().split('\n') if relevant_text and relevant_text.strip() else [] | |
| except Exception as e: | |
| logger().warning(f"Failed to get AI reasoning for STM relevance: {e}") | |
| return summaries # Fallback to returning all summaries | |