Spaces:
Sleeping
Sleeping
File size: 9,805 Bytes
411a994 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
from __future__ import annotations
import os
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from loguru import logger
from sklearn.neighbors import NearestNeighbors
from sentence_transformers import SentenceTransformer
from app.utils.config import settings
from app.utils.helpers import normalize_gender, clean_diagnosis
@dataclass
class SimilarCase:
summary_id: str
diagnosis: Optional[str]
age: Optional[float]
gender: Optional[str]
summary_snippet: str
similarity_score: float
class CameroonMedicalData:
"""
Load, clean, analyze and search medical summaries specialized for the Cameroonian context.
Designed for ~45k rows. Caches embeddings and lightweight stats.
"""
def __init__(self, csv_path: Optional[str] = None):
self.csv_path = csv_path or settings.CAMEROON_DATA_CSV
if not self.csv_path or not os.path.exists(self.csv_path):
logger.warning("CameroonMedicalData: CSV path missing or not found. Set CAMEROON_DATA_CSV in .env")
self.df = pd.DataFrame()
else:
self.df = self._load_csv(self.csv_path, settings.CAMEROON_MAX_ROWS)
self._cleaned: bool = False
self._model: Optional[SentenceTransformer] = None
self._embeddings: Optional[np.ndarray] = None
self._nn: Optional[NearestNeighbors] = None
self._cache_dir = settings.CAMEROON_CACHE_DIR
os.makedirs(self._cache_dir, exist_ok=True)
# ----------------------- Data Loading & Cleaning -----------------------
def _load_csv(self, path: str, limit: Optional[int]) -> pd.DataFrame:
df = pd.read_csv(path)
if limit and limit > 0:
df = df.head(limit)
return df
def clean(self) -> None:
if self.df.empty:
self._cleaned = True
return
df = self.df.copy()
# Standardize column names
expected_cols = [
"summary_id","patient_id","patient_age","patient_gender","diagnosis",
"body_temp_c","blood_pressure_systolic","heart_rate","summary_text","date_recorded"
]
missing = [c for c in expected_cols if c not in df.columns]
if missing:
raise ValueError(f"Missing required columns: {missing}")
# Parse dates
df["date_recorded"] = pd.to_datetime(df["date_recorded"], errors="coerce")
# Handle missing values
df["patient_gender"] = df["patient_gender"].fillna("")
df["diagnosis"] = df["diagnosis"].fillna("")
df["summary_text"] = df["summary_text"].fillna("")
# Normalize gender and diagnosis
df["patient_gender_norm"] = df["patient_gender"].apply(lambda v: normalize_gender(str(v)))
df["diagnosis_norm"] = df["diagnosis"].apply(lambda v: clean_diagnosis(str(v)))
# Coerce numeric vitals
for col in ["patient_age","body_temp_c","blood_pressure_systolic","heart_rate"]:
df[col] = pd.to_numeric(df[col], errors="coerce")
# Drop rows with no summary text and no diagnosis
df = df[~((df["summary_text"].str.len() == 0) & (df["diagnosis_norm"].isna()))]
self.df = df.reset_index(drop=True)
self._cleaned = True
# ----------------------------- Statistics -----------------------------
def stats_overview(self) -> Dict[str, Any]:
if not self._cleaned:
self.clean()
if self.df.empty:
return {"total_rows": 0}
df = self.df
top_diagnoses = (
df["diagnosis_norm"].value_counts(dropna=True).head(20).dropna().to_dict()
)
age_desc = df["patient_age"].describe().fillna(0).to_dict()
return {
"total_rows": int(len(df)),
"top_diagnoses": top_diagnoses,
"age_stats": age_desc,
"gender_distribution": df["patient_gender_norm"].value_counts(dropna=True).to_dict(),
}
def stats_disease(self, disease_name: str) -> Dict[str, Any]:
if not self._cleaned:
self.clean()
if self.df.empty:
return {"disease": disease_name, "total_cases": 0}
df = self.df
mask = df["diagnosis_norm"] == disease_name.lower()
subset = df[mask]
total = int(len(subset))
# Age buckets
bins = [-1, 18, 35, 60, 200]
labels = ["0-18", "19-35", "36-60", "60+"]
ages = pd.cut(subset["patient_age"], bins=bins, labels=labels)
age_dist = ages.value_counts().reindex(labels, fill_value=0).to_dict()
gender_dist = subset["patient_gender_norm"].value_counts().to_dict()
# Common symptom terms (very simple proxy: frequent tokens in summary_text)
common_symptoms = self._extract_common_terms(subset["summary_text"].tolist(), top_k=15)
return {
"disease": disease_name,
"total_cases": total,
"age_distribution": age_dist,
"gender_distribution": gender_dist,
"common_symptoms": common_symptoms,
}
def seasonal_patterns(self) -> Dict[str, int]:
if not self._cleaned:
self.clean()
if self.df.empty:
return {}
df = self.df.dropna(subset=["date_recorded"]).copy()
df["month"] = df["date_recorded"].dt.month
counts = df["month"].value_counts().sort_index()
# map month numbers to english lowercase names for consistency
months = ["january","february","march","april","may","june","july","august","september","october","november","december"]
return {months[i-1]: int(counts.get(i, 0)) for i in range(1, 13)}
def age_gender_distribution(self) -> Dict[str, Any]:
if not self._cleaned:
self.clean()
if self.df.empty:
return {"age_buckets": {}, "gender_distribution": {}}
df = self.df
bins = [-1, 18, 35, 60, 200]
labels = ["0-18", "19-35", "36-60", "60+"]
ages = pd.cut(df["patient_age"], bins=bins, labels=labels)
age_dist = ages.value_counts().reindex(labels, fill_value=0).to_dict()
gender_dist = df["patient_gender_norm"].value_counts().to_dict()
return {"age_buckets": age_dist, "gender_distribution": gender_dist}
# --------------------------- Semantic Similarity ---------------------------
def _ensure_embeddings(self) -> None:
if self._embeddings is not None and self._nn is not None:
return
if not self._cleaned:
self.clean()
if self.df.empty:
self._embeddings = np.zeros((0, 384), dtype=np.float32)
self._nn = None
return
# Load model lazily
if self._model is None:
model_name = settings.CAMEROON_EMBEDDINGS_MODEL
logger.info(f"Loading sentence-transformers model: {model_name}")
self._model = SentenceTransformer(model_name)
cache_file = os.path.join(self._cache_dir, "embeddings.npy")
if os.path.exists(cache_file):
try:
self._embeddings = np.load(cache_file)
except Exception:
self._embeddings = None
if self._embeddings is None or len(self._embeddings) != len(self.df):
texts = self.df["summary_text"].astype(str).tolist()
self._embeddings = self._model.encode(texts, batch_size=64, show_progress_bar=False, normalize_embeddings=True)
np.save(cache_file, self._embeddings)
# Build NN index
self._nn = NearestNeighbors(n_neighbors=10, metric="cosine")
self._nn.fit(self._embeddings)
def search_similar_cases(self, query_text: str, top_k: int = 10) -> List[SimilarCase]:
if not query_text or query_text.strip() == "":
return []
self._ensure_embeddings()
if self._model is None or self._nn is None or self._embeddings is None or self.df.empty:
return []
q = self._model.encode([query_text], normalize_embeddings=True)
distances, indices = self._nn.kneighbors(q, n_neighbors=min(top_k, len(self.df)))
distances = distances[0]
indices = indices[0]
results: List[SimilarCase] = []
for dist, idx in zip(distances, indices):
row = self.df.iloc[int(idx)]
# similarity = 1 - cosine distance
sim = float(1.0 - dist)
snippet = str(row.get("summary_text", ""))[:140] + ("..." if len(str(row.get("summary_text", ""))) > 140 else "")
results.append(SimilarCase(
summary_id=str(row.get("summary_id", "")),
diagnosis=row.get("diagnosis_norm"),
age=float(row.get("patient_age")) if pd.notna(row.get("patient_age")) else None,
gender=row.get("patient_gender_norm"),
summary_snippet=snippet,
similarity_score=sim,
))
return results
# ----------------------------- Utils -----------------------------
def _extract_common_terms(self, texts: List[str], top_k: int = 20) -> List[str]:
# Very naive bag-of-words; in production consider medical entity extraction.
from collections import Counter
tokens: List[str] = []
for t in texts:
for w in str(t).lower().replace(",", " ").replace(".", " ").split():
if len(w) >= 3 and w.isalpha():
tokens.append(w)
return [w for w, _ in Counter(tokens).most_common(top_k)]
# Singleton accessor
_singleton: Optional[CameroonMedicalData] = None
def get_cameroon_data() -> CameroonMedicalData:
global _singleton
if _singleton is None:
_singleton = CameroonMedicalData()
return _singleton
|