Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import json | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| from loguru import logger | |
| from sklearn.neighbors import NearestNeighbors | |
| from sentence_transformers import SentenceTransformer | |
| from app.utils.config import settings | |
| from app.utils.helpers import normalize_gender, clean_diagnosis | |
| class SimilarCase: | |
| summary_id: str | |
| diagnosis: Optional[str] | |
| age: Optional[float] | |
| gender: Optional[str] | |
| summary_snippet: str | |
| similarity_score: float | |
| class CameroonMedicalData: | |
| """ | |
| Load, clean, analyze and search medical summaries specialized for the Cameroonian context. | |
| Designed for ~45k rows. Caches embeddings and lightweight stats. | |
| """ | |
| def __init__(self, csv_path: Optional[str] = None): | |
| self.csv_path = csv_path or settings.CAMEROON_DATA_CSV | |
| if not self.csv_path or not os.path.exists(self.csv_path): | |
| logger.warning("CameroonMedicalData: CSV path missing or not found. Set CAMEROON_DATA_CSV in .env") | |
| self.df = pd.DataFrame() | |
| else: | |
| self.df = self._load_csv(self.csv_path, settings.CAMEROON_MAX_ROWS) | |
| self._cleaned: bool = False | |
| self._model: Optional[SentenceTransformer] = None | |
| self._embeddings: Optional[np.ndarray] = None | |
| self._nn: Optional[NearestNeighbors] = None | |
| self._cache_dir = settings.CAMEROON_CACHE_DIR | |
| os.makedirs(self._cache_dir, exist_ok=True) | |
| # ----------------------- Data Loading & Cleaning ----------------------- | |
| def _load_csv(self, path: str, limit: Optional[int]) -> pd.DataFrame: | |
| df = pd.read_csv(path) | |
| if limit and limit > 0: | |
| df = df.head(limit) | |
| return df | |
| def clean(self) -> None: | |
| if self.df.empty: | |
| self._cleaned = True | |
| return | |
| df = self.df.copy() | |
| # Standardize column names | |
| expected_cols = [ | |
| "summary_id","patient_id","patient_age","patient_gender","diagnosis", | |
| "body_temp_c","blood_pressure_systolic","heart_rate","summary_text","date_recorded" | |
| ] | |
| missing = [c for c in expected_cols if c not in df.columns] | |
| if missing: | |
| raise ValueError(f"Missing required columns: {missing}") | |
| # Parse dates | |
| df["date_recorded"] = pd.to_datetime(df["date_recorded"], errors="coerce") | |
| # Handle missing values | |
| df["patient_gender"] = df["patient_gender"].fillna("") | |
| df["diagnosis"] = df["diagnosis"].fillna("") | |
| df["summary_text"] = df["summary_text"].fillna("") | |
| # Normalize gender and diagnosis | |
| df["patient_gender_norm"] = df["patient_gender"].apply(lambda v: normalize_gender(str(v))) | |
| df["diagnosis_norm"] = df["diagnosis"].apply(lambda v: clean_diagnosis(str(v))) | |
| # Coerce numeric vitals | |
| for col in ["patient_age","body_temp_c","blood_pressure_systolic","heart_rate"]: | |
| df[col] = pd.to_numeric(df[col], errors="coerce") | |
| # Drop rows with no summary text and no diagnosis | |
| df = df[~((df["summary_text"].str.len() == 0) & (df["diagnosis_norm"].isna()))] | |
| self.df = df.reset_index(drop=True) | |
| self._cleaned = True | |
| # ----------------------------- Statistics ----------------------------- | |
| def stats_overview(self) -> Dict[str, Any]: | |
| if not self._cleaned: | |
| self.clean() | |
| if self.df.empty: | |
| return {"total_rows": 0} | |
| df = self.df | |
| top_diagnoses = ( | |
| df["diagnosis_norm"].value_counts(dropna=True).head(20).dropna().to_dict() | |
| ) | |
| age_desc = df["patient_age"].describe().fillna(0).to_dict() | |
| return { | |
| "total_rows": int(len(df)), | |
| "top_diagnoses": top_diagnoses, | |
| "age_stats": age_desc, | |
| "gender_distribution": df["patient_gender_norm"].value_counts(dropna=True).to_dict(), | |
| } | |
| def stats_disease(self, disease_name: str) -> Dict[str, Any]: | |
| if not self._cleaned: | |
| self.clean() | |
| if self.df.empty: | |
| return {"disease": disease_name, "total_cases": 0} | |
| df = self.df | |
| mask = df["diagnosis_norm"] == disease_name.lower() | |
| subset = df[mask] | |
| total = int(len(subset)) | |
| # Age buckets | |
| bins = [-1, 18, 35, 60, 200] | |
| labels = ["0-18", "19-35", "36-60", "60+"] | |
| ages = pd.cut(subset["patient_age"], bins=bins, labels=labels) | |
| age_dist = ages.value_counts().reindex(labels, fill_value=0).to_dict() | |
| gender_dist = subset["patient_gender_norm"].value_counts().to_dict() | |
| # Common symptom terms (very simple proxy: frequent tokens in summary_text) | |
| common_symptoms = self._extract_common_terms(subset["summary_text"].tolist(), top_k=15) | |
| return { | |
| "disease": disease_name, | |
| "total_cases": total, | |
| "age_distribution": age_dist, | |
| "gender_distribution": gender_dist, | |
| "common_symptoms": common_symptoms, | |
| } | |
| def seasonal_patterns(self) -> Dict[str, int]: | |
| if not self._cleaned: | |
| self.clean() | |
| if self.df.empty: | |
| return {} | |
| df = self.df.dropna(subset=["date_recorded"]).copy() | |
| df["month"] = df["date_recorded"].dt.month | |
| counts = df["month"].value_counts().sort_index() | |
| # map month numbers to english lowercase names for consistency | |
| months = ["january","february","march","april","may","june","july","august","september","october","november","december"] | |
| return {months[i-1]: int(counts.get(i, 0)) for i in range(1, 13)} | |
| def age_gender_distribution(self) -> Dict[str, Any]: | |
| if not self._cleaned: | |
| self.clean() | |
| if self.df.empty: | |
| return {"age_buckets": {}, "gender_distribution": {}} | |
| df = self.df | |
| bins = [-1, 18, 35, 60, 200] | |
| labels = ["0-18", "19-35", "36-60", "60+"] | |
| ages = pd.cut(df["patient_age"], bins=bins, labels=labels) | |
| age_dist = ages.value_counts().reindex(labels, fill_value=0).to_dict() | |
| gender_dist = df["patient_gender_norm"].value_counts().to_dict() | |
| return {"age_buckets": age_dist, "gender_distribution": gender_dist} | |
| # --------------------------- Semantic Similarity --------------------------- | |
| def _ensure_embeddings(self) -> None: | |
| if self._embeddings is not None and self._nn is not None: | |
| return | |
| if not self._cleaned: | |
| self.clean() | |
| if self.df.empty: | |
| self._embeddings = np.zeros((0, 384), dtype=np.float32) | |
| self._nn = None | |
| return | |
| # Load model lazily | |
| if self._model is None: | |
| model_name = settings.CAMEROON_EMBEDDINGS_MODEL | |
| logger.info(f"Loading sentence-transformers model: {model_name}") | |
| self._model = SentenceTransformer(model_name) | |
| cache_file = os.path.join(self._cache_dir, "embeddings.npy") | |
| if os.path.exists(cache_file): | |
| try: | |
| self._embeddings = np.load(cache_file) | |
| except Exception: | |
| self._embeddings = None | |
| if self._embeddings is None or len(self._embeddings) != len(self.df): | |
| texts = self.df["summary_text"].astype(str).tolist() | |
| self._embeddings = self._model.encode(texts, batch_size=64, show_progress_bar=False, normalize_embeddings=True) | |
| np.save(cache_file, self._embeddings) | |
| # Build NN index | |
| self._nn = NearestNeighbors(n_neighbors=10, metric="cosine") | |
| self._nn.fit(self._embeddings) | |
| def search_similar_cases(self, query_text: str, top_k: int = 10) -> List[SimilarCase]: | |
| if not query_text or query_text.strip() == "": | |
| return [] | |
| self._ensure_embeddings() | |
| if self._model is None or self._nn is None or self._embeddings is None or self.df.empty: | |
| return [] | |
| q = self._model.encode([query_text], normalize_embeddings=True) | |
| distances, indices = self._nn.kneighbors(q, n_neighbors=min(top_k, len(self.df))) | |
| distances = distances[0] | |
| indices = indices[0] | |
| results: List[SimilarCase] = [] | |
| for dist, idx in zip(distances, indices): | |
| row = self.df.iloc[int(idx)] | |
| # similarity = 1 - cosine distance | |
| sim = float(1.0 - dist) | |
| snippet = str(row.get("summary_text", ""))[:140] + ("..." if len(str(row.get("summary_text", ""))) > 140 else "") | |
| results.append(SimilarCase( | |
| summary_id=str(row.get("summary_id", "")), | |
| diagnosis=row.get("diagnosis_norm"), | |
| age=float(row.get("patient_age")) if pd.notna(row.get("patient_age")) else None, | |
| gender=row.get("patient_gender_norm"), | |
| summary_snippet=snippet, | |
| similarity_score=sim, | |
| )) | |
| return results | |
| # ----------------------------- Utils ----------------------------- | |
| def _extract_common_terms(self, texts: List[str], top_k: int = 20) -> List[str]: | |
| # Very naive bag-of-words; in production consider medical entity extraction. | |
| from collections import Counter | |
| tokens: List[str] = [] | |
| for t in texts: | |
| for w in str(t).lower().replace(",", " ").replace(".", " ").split(): | |
| if len(w) >= 3 and w.isalpha(): | |
| tokens.append(w) | |
| return [w for w, _ in Counter(tokens).most_common(top_k)] | |
| # Singleton accessor | |
| _singleton: Optional[CameroonMedicalData] = None | |
| def get_cameroon_data() -> CameroonMedicalData: | |
| global _singleton | |
| if _singleton is None: | |
| _singleton = CameroonMedicalData() | |
| return _singleton | |