|
|
|
|
|
|
|
|
from fastapi import HTTPException, status |
|
|
|
|
|
from src.core import prompt_builder |
|
|
from src.core.state import AppState |
|
|
from src.data.medical_kb import search_medical_kb |
|
|
from src.models.account import Account |
|
|
from src.services import local_llm_service |
|
|
from src.services.gemini import gemini_chat |
|
|
from src.services.guard import SafetyGuard |
|
|
from src.utils.logger import logger |
|
|
from src.utils.rotator import APIKeyRotator |
|
|
|
|
|
|
|
|
|
|
|
def _validate_user_query(message: str, safety_guard: SafetyGuard | None): |
|
|
""" |
|
|
Checks the user's query against the safety guard. |
|
|
Raises an HTTPException if the query is unsafe. |
|
|
""" |
|
|
if not safety_guard: return |
|
|
try: |
|
|
is_safe, reason = safety_guard.check_user_query(message) |
|
|
if not is_safe: |
|
|
logger().warning(f"Safety guard blocked user query: {reason}") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
|
detail=f"Query blocked for safety reasons: {reason}" |
|
|
) |
|
|
logger().info(f"User query passed safety validation: {reason}") |
|
|
except Exception as e: |
|
|
logger().error(f"Safety guard failed on user query: {e}") |
|
|
|
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail="Failed to validate user query safety." |
|
|
) from e |
|
|
|
|
|
def _validate_model_response( |
|
|
query: str, |
|
|
response: str, |
|
|
safety_guard: SafetyGuard | None |
|
|
) -> str: |
|
|
""" |
|
|
Checks the generated model response against the safety guard. |
|
|
Returns a safe fallback message if the response is deemed unsafe. |
|
|
""" |
|
|
if not safety_guard: return response |
|
|
safe_fallback = "I apologize, but I cannot provide a response to that query as it may contain unsafe content. Please consult with a qualified healthcare professional for medical advice." |
|
|
try: |
|
|
is_safe, reason = safety_guard.check_model_answer(query, response) |
|
|
if not is_safe: |
|
|
logger().warning(f"Safety guard blocked AI response: {reason}") |
|
|
return safe_fallback |
|
|
logger().info(f"AI response passed safety validation: {reason}") |
|
|
return response |
|
|
except Exception as e: |
|
|
logger().error(f"Safety guard failed on model response: {e}") |
|
|
logger().warning("Safety guard failed, allowing response through (fail-open)") |
|
|
|
|
|
return response |
|
|
|
|
|
async def _retrieve_context( |
|
|
state: AppState, |
|
|
session_id: str, |
|
|
patient_id: str, |
|
|
message: str |
|
|
) -> str: |
|
|
"""Retrieves enhanced medical context. This is the entry point for RAG.""" |
|
|
try: |
|
|
return await state.memory_manager.get_enhanced_context( |
|
|
session_id=session_id, |
|
|
patient_id=patient_id, |
|
|
question=message, |
|
|
nvidia_rotator=state.nvidia_rotator |
|
|
) |
|
|
except Exception as e: |
|
|
logger().error(f"Error getting medical context: {e}") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail="Failed to build medical context." |
|
|
) from e |
|
|
|
|
|
def _add_disclaimer(response_text: str) -> str: |
|
|
"""Adds a standard medical disclaimer if one is not already present.""" |
|
|
if "disclaimer" not in response_text.lower() and "consult" not in response_text.lower(): |
|
|
disclaimer = "\n\n⚠️ **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare professionals." |
|
|
return response_text + disclaimer |
|
|
return response_text |
|
|
|
|
|
async def _persist_exchange( |
|
|
state: AppState, |
|
|
session_id: str, |
|
|
patient_id: str, |
|
|
account_id: str, |
|
|
question: str, |
|
|
answer: str |
|
|
): |
|
|
"""Processes and stores the full conversation exchange.""" |
|
|
summary = await state.memory_manager.process_medical_exchange( |
|
|
session_id=session_id, |
|
|
patient_id=patient_id, |
|
|
doctor_id=account_id, |
|
|
question=question, |
|
|
answer=answer, |
|
|
gemini_rotator=state.gemini_rotator, |
|
|
nvidia_rotator=state.nvidia_rotator |
|
|
) |
|
|
if not summary: |
|
|
logger().warning(f"Failed to process and store medical exchange for session {session_id}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def generate_llm_response( |
|
|
account: Account, |
|
|
message: str, |
|
|
rotator: APIKeyRotator, |
|
|
medical_context: str = "" |
|
|
) -> str | None: |
|
|
""" |
|
|
Generates an intelligent medical response using the LLM, adding a disclaimer. |
|
|
This function is now purely for generation, with safety checks handled elsewhere. |
|
|
""" |
|
|
prompt = prompt_builder.medical_response_prompt( |
|
|
account=account, |
|
|
user_message=message, |
|
|
medical_context=medical_context |
|
|
) |
|
|
|
|
|
if local_llm_service.model_loaded: |
|
|
response_text = local_llm_service.get_inference(prompt=prompt) |
|
|
else: |
|
|
response_text = await gemini_chat(prompt, rotator) |
|
|
|
|
|
if not response_text: |
|
|
return None |
|
|
|
|
|
response_with_disclaimer = _add_disclaimer(response_text) |
|
|
logger().info(f"Gemini response generated, length: {len(response_with_disclaimer)} chars") |
|
|
return response_with_disclaimer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def generate_chat_response( |
|
|
state: AppState, |
|
|
message: str, |
|
|
session_id: str, |
|
|
patient_id: str, |
|
|
account_id: str |
|
|
) -> str: |
|
|
""" |
|
|
Orchestrates the pipeline for generating a chat response. |
|
|
""" |
|
|
logger().info(f"Starting response pipeline for session {session_id}") |
|
|
safety_guard: SafetyGuard | None = None |
|
|
|
|
|
try: |
|
|
safety_guard = SafetyGuard(state.nvidia_rotator) |
|
|
except Exception as e: |
|
|
logger().warning("Safety guard failed to be created, ignoring") |
|
|
|
|
|
|
|
|
_validate_user_query(message, safety_guard) |
|
|
|
|
|
|
|
|
medical_context = await _retrieve_context(state, session_id, patient_id, message) |
|
|
|
|
|
|
|
|
account = state.memory_manager.get_account(account_id) |
|
|
if not account: |
|
|
logger().error(f"Account not found for account_id: {account_id}") |
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found") |
|
|
|
|
|
|
|
|
try: |
|
|
response_text = await generate_llm_response( |
|
|
message=message, |
|
|
account=account, |
|
|
rotator=state.gemini_rotator, |
|
|
medical_context=medical_context |
|
|
) |
|
|
|
|
|
if not response_text: |
|
|
logger().warning("LLM response failed, using fallback.") |
|
|
response_text = _generate_fallback_response(message=message, account=account) |
|
|
|
|
|
except Exception as e: |
|
|
logger().error(f"Error generating medical response: {e}") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail="Failed to generate AI response." |
|
|
) from e |
|
|
|
|
|
|
|
|
final_response = _validate_model_response(message, response_text, safety_guard) |
|
|
|
|
|
|
|
|
|
|
|
await _persist_exchange( |
|
|
state=state, |
|
|
session_id=session_id, |
|
|
patient_id=patient_id, |
|
|
account_id=account_id, |
|
|
question=message, |
|
|
answer=final_response |
|
|
) |
|
|
|
|
|
return final_response |
|
|
|
|
|
def _generate_fallback_response( |
|
|
message: str, |
|
|
account: Account |
|
|
) -> str: |
|
|
"""Generates a fallback response using a local knowledge base.""" |
|
|
kb_info = search_medical_kb(message) |
|
|
|
|
|
logger().info("Generating backup response") |
|
|
|
|
|
|
|
|
response_parts = [] |
|
|
|
|
|
|
|
|
question_lower = message.lower() |
|
|
|
|
|
if kb_info: |
|
|
response_parts.append(f"Based on your question about medical topics, here's what I found:\n\n{kb_info}") |
|
|
|
|
|
|
|
|
if any(word in question_lower for word in ["fever", "temperature", "hot"]): |
|
|
response_parts.append("\n\n**Key Points about Fever:**") |
|
|
response_parts.append("• Normal body temperature is around 98.6°F (37°C)") |
|
|
response_parts.append("• Fever is often a sign of infection or inflammation") |
|
|
response_parts.append("• Monitor for other symptoms that accompany fever") |
|
|
response_parts.append("• Seek medical attention for high fevers (>103°F/39.4°C) or persistent fevers") |
|
|
|
|
|
elif any(word in question_lower for word in ["headache", "head pain", "migraine"]): |
|
|
response_parts.append("\n\n**Key Points about Headaches:**") |
|
|
response_parts.append("• Tension headaches are the most common type") |
|
|
response_parts.append("• Migraines often have specific triggers and symptoms") |
|
|
response_parts.append("• Sudden, severe headaches require immediate medical attention") |
|
|
response_parts.append("• Keep a headache diary to identify patterns") |
|
|
|
|
|
elif any(word in question_lower for word in ["cough", "cold", "respiratory"]): |
|
|
response_parts.append("\n\n**Key Points about Respiratory Symptoms:**") |
|
|
response_parts.append("• Dry vs. productive cough have different implications") |
|
|
response_parts.append("• Most colds resolve within 7-10 days") |
|
|
response_parts.append("• Persistent cough may indicate underlying conditions") |
|
|
response_parts.append("• Monitor for difficulty breathing or chest pain") |
|
|
|
|
|
elif any(word in question_lower for word in ["hypertension", "blood pressure", "high bp"]): |
|
|
response_parts.append("\n\n**Key Points about Hypertension:**") |
|
|
response_parts.append("• Often called the 'silent killer' due to lack of symptoms") |
|
|
response_parts.append("• Regular monitoring is essential") |
|
|
response_parts.append("• Lifestyle modifications can help control blood pressure") |
|
|
response_parts.append("• Medication may be necessary for some individuals") |
|
|
|
|
|
elif any(word in question_lower for word in ["diabetes", "blood sugar", "glucose"]): |
|
|
response_parts.append("\n\n**Key Points about Diabetes:**") |
|
|
response_parts.append("• Type 1: Autoimmune, requires insulin") |
|
|
response_parts.append("• Type 2: Often lifestyle-related, may be managed with diet/exercise") |
|
|
response_parts.append("• Regular blood sugar monitoring is crucial") |
|
|
response_parts.append("• Complications can affect multiple organ systems") |
|
|
|
|
|
else: |
|
|
|
|
|
if "what is" in question_lower or "define" in question_lower: |
|
|
response_parts.append("I understand you're asking about a medical topic. While I don't have specific information about this particular condition or symptom, I can provide some general guidance.") |
|
|
elif "how to" in question_lower or "treatment" in question_lower: |
|
|
response_parts.append("I understand you're asking about treatment or management of a medical condition. This is an area where professional medical advice is particularly important.") |
|
|
elif "symptom" in question_lower or "sign" in question_lower: |
|
|
response_parts.append("I understand you're asking about symptoms or signs of a medical condition. Remember that symptoms can vary between individuals and may indicate different conditions.") |
|
|
else: |
|
|
response_parts.append("Thank you for your medical question. While I can provide general information, it's important to consult with healthcare professionals for personalized medical advice.") |
|
|
|
|
|
|
|
|
if account.role.lower() in ["physician", "doctor", "nurse"]: |
|
|
response_parts.append("\n\n**Professional Context:** As a healthcare professional, you're likely familiar with these concepts. Remember to always follow your institution's protocols and guidelines, and consider the latest clinical evidence in your practice.") |
|
|
elif account.role.lower() in ["medical student", "student"]: |
|
|
response_parts.append("\n\n**Educational Context:** As a medical student, this information can help with your studies. Always verify information with your professors and clinical supervisors, and use this as a starting point for further research.") |
|
|
elif account.role.lower() in ["patient"]: |
|
|
response_parts.append("\n\n**Patient Context:** As a patient, this information is for educational purposes only. Please discuss any concerns with your healthcare provider, and don't make treatment decisions based solely on this information.") |
|
|
else: |
|
|
response_parts.append("\n\n**General Context:** This information is provided for educational purposes. Always consult with qualified healthcare professionals for medical advice.") |
|
|
|
|
|
|
|
|
if account.specialty and account.specialty.lower() in ["cardiology", "cardiac"]: |
|
|
response_parts.append("\n\n**Cardiology Perspective:** Given your interest in cardiology, consider how this information relates to cardiovascular health and patient care. Many conditions can have cardiac implications.") |
|
|
elif account.specialty and account.specialty.lower() in ["pediatrics", "pediatric"]: |
|
|
response_parts.append("\n\n**Pediatric Perspective:** In pediatric care, remember that children may present differently than adults and may require specialized approaches. Consider age-appropriate considerations.") |
|
|
elif account.specialty and account.specialty.lower() in ["emergency", "er"]: |
|
|
response_parts.append("\n\n**Emergency Medicine Perspective:** In emergency settings, rapid assessment and intervention are crucial. Consider the urgency and severity of presenting symptoms.") |
|
|
|
|
|
|
|
|
response_parts.append("\n\n⚠️ **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare professionals.") |
|
|
|
|
|
return "\n".join(response_parts) |
|
|
|