File size: 9,805 Bytes
411a994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
from __future__ import annotations

import os
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from loguru import logger
from sklearn.neighbors import NearestNeighbors
from sentence_transformers import SentenceTransformer

from app.utils.config import settings
from app.utils.helpers import normalize_gender, clean_diagnosis


@dataclass
class SimilarCase:
    summary_id: str
    diagnosis: Optional[str]
    age: Optional[float]
    gender: Optional[str]
    summary_snippet: str
    similarity_score: float


class CameroonMedicalData:
    """
    Load, clean, analyze and search medical summaries specialized for the Cameroonian context.
    Designed for ~45k rows. Caches embeddings and lightweight stats.
    """

    def __init__(self, csv_path: Optional[str] = None):
        self.csv_path = csv_path or settings.CAMEROON_DATA_CSV
        if not self.csv_path or not os.path.exists(self.csv_path):
            logger.warning("CameroonMedicalData: CSV path missing or not found. Set CAMEROON_DATA_CSV in .env")
            self.df = pd.DataFrame()
        else:
            self.df = self._load_csv(self.csv_path, settings.CAMEROON_MAX_ROWS)
        self._cleaned: bool = False
        self._model: Optional[SentenceTransformer] = None
        self._embeddings: Optional[np.ndarray] = None
        self._nn: Optional[NearestNeighbors] = None
        self._cache_dir = settings.CAMEROON_CACHE_DIR
        os.makedirs(self._cache_dir, exist_ok=True)

    # ----------------------- Data Loading & Cleaning -----------------------
    def _load_csv(self, path: str, limit: Optional[int]) -> pd.DataFrame:
        df = pd.read_csv(path)
        if limit and limit > 0:
            df = df.head(limit)
        return df

    def clean(self) -> None:
        if self.df.empty:
            self._cleaned = True
            return

        df = self.df.copy()

        # Standardize column names
        expected_cols = [
            "summary_id","patient_id","patient_age","patient_gender","diagnosis",
            "body_temp_c","blood_pressure_systolic","heart_rate","summary_text","date_recorded"
        ]
        missing = [c for c in expected_cols if c not in df.columns]
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

        # Parse dates
        df["date_recorded"] = pd.to_datetime(df["date_recorded"], errors="coerce")

        # Handle missing values
        df["patient_gender"] = df["patient_gender"].fillna("")
        df["diagnosis"] = df["diagnosis"].fillna("")
        df["summary_text"] = df["summary_text"].fillna("")

        # Normalize gender and diagnosis
        df["patient_gender_norm"] = df["patient_gender"].apply(lambda v: normalize_gender(str(v)))
        df["diagnosis_norm"] = df["diagnosis"].apply(lambda v: clean_diagnosis(str(v)))

        # Coerce numeric vitals
        for col in ["patient_age","body_temp_c","blood_pressure_systolic","heart_rate"]:
            df[col] = pd.to_numeric(df[col], errors="coerce")

        # Drop rows with no summary text and no diagnosis
        df = df[~((df["summary_text"].str.len() == 0) & (df["diagnosis_norm"].isna()))]

        self.df = df.reset_index(drop=True)
        self._cleaned = True

    # ----------------------------- Statistics -----------------------------
    def stats_overview(self) -> Dict[str, Any]:
        if not self._cleaned:
            self.clean()
        if self.df.empty:
            return {"total_rows": 0}

        df = self.df
        top_diagnoses = (
            df["diagnosis_norm"].value_counts(dropna=True).head(20).dropna().to_dict()
        )
        age_desc = df["patient_age"].describe().fillna(0).to_dict()

        return {
            "total_rows": int(len(df)),
            "top_diagnoses": top_diagnoses,
            "age_stats": age_desc,
            "gender_distribution": df["patient_gender_norm"].value_counts(dropna=True).to_dict(),
        }

    def stats_disease(self, disease_name: str) -> Dict[str, Any]:
        if not self._cleaned:
            self.clean()
        if self.df.empty:
            return {"disease": disease_name, "total_cases": 0}

        df = self.df
        mask = df["diagnosis_norm"] == disease_name.lower()
        subset = df[mask]
        total = int(len(subset))

        # Age buckets
        bins = [-1, 18, 35, 60, 200]
        labels = ["0-18", "19-35", "36-60", "60+"]
        ages = pd.cut(subset["patient_age"], bins=bins, labels=labels)
        age_dist = ages.value_counts().reindex(labels, fill_value=0).to_dict()

        gender_dist = subset["patient_gender_norm"].value_counts().to_dict()

        # Common symptom terms (very simple proxy: frequent tokens in summary_text)
        common_symptoms = self._extract_common_terms(subset["summary_text"].tolist(), top_k=15)

        return {
            "disease": disease_name,
            "total_cases": total,
            "age_distribution": age_dist,
            "gender_distribution": gender_dist,
            "common_symptoms": common_symptoms,
        }

    def seasonal_patterns(self) -> Dict[str, int]:
        if not self._cleaned:
            self.clean()
        if self.df.empty:
            return {}
        df = self.df.dropna(subset=["date_recorded"]).copy()
        df["month"] = df["date_recorded"].dt.month
        counts = df["month"].value_counts().sort_index()
        # map month numbers to english lowercase names for consistency
        months = ["january","february","march","april","may","june","july","august","september","october","november","december"]
        return {months[i-1]: int(counts.get(i, 0)) for i in range(1, 13)}

    def age_gender_distribution(self) -> Dict[str, Any]:
        if not self._cleaned:
            self.clean()
        if self.df.empty:
            return {"age_buckets": {}, "gender_distribution": {}}

        df = self.df
        bins = [-1, 18, 35, 60, 200]
        labels = ["0-18", "19-35", "36-60", "60+"]
        ages = pd.cut(df["patient_age"], bins=bins, labels=labels)
        age_dist = ages.value_counts().reindex(labels, fill_value=0).to_dict()
        gender_dist = df["patient_gender_norm"].value_counts().to_dict()
        return {"age_buckets": age_dist, "gender_distribution": gender_dist}

    # --------------------------- Semantic Similarity ---------------------------
    def _ensure_embeddings(self) -> None:
        if self._embeddings is not None and self._nn is not None:
            return
        if not self._cleaned:
            self.clean()
        if self.df.empty:
            self._embeddings = np.zeros((0, 384), dtype=np.float32)
            self._nn = None
            return

        # Load model lazily
        if self._model is None:
            model_name = settings.CAMEROON_EMBEDDINGS_MODEL
            logger.info(f"Loading sentence-transformers model: {model_name}")
            self._model = SentenceTransformer(model_name)

        cache_file = os.path.join(self._cache_dir, "embeddings.npy")
        if os.path.exists(cache_file):
            try:
                self._embeddings = np.load(cache_file)
            except Exception:
                self._embeddings = None

        if self._embeddings is None or len(self._embeddings) != len(self.df):
            texts = self.df["summary_text"].astype(str).tolist()
            self._embeddings = self._model.encode(texts, batch_size=64, show_progress_bar=False, normalize_embeddings=True)
            np.save(cache_file, self._embeddings)

        # Build NN index
        self._nn = NearestNeighbors(n_neighbors=10, metric="cosine")
        self._nn.fit(self._embeddings)

    def search_similar_cases(self, query_text: str, top_k: int = 10) -> List[SimilarCase]:
        if not query_text or query_text.strip() == "":
            return []
        self._ensure_embeddings()
        if self._model is None or self._nn is None or self._embeddings is None or self.df.empty:
            return []

        q = self._model.encode([query_text], normalize_embeddings=True)
        distances, indices = self._nn.kneighbors(q, n_neighbors=min(top_k, len(self.df)))
        distances = distances[0]
        indices = indices[0]

        results: List[SimilarCase] = []
        for dist, idx in zip(distances, indices):
            row = self.df.iloc[int(idx)]
            # similarity = 1 - cosine distance
            sim = float(1.0 - dist)
            snippet = str(row.get("summary_text", ""))[:140] + ("..." if len(str(row.get("summary_text", ""))) > 140 else "")
            results.append(SimilarCase(
                summary_id=str(row.get("summary_id", "")),
                diagnosis=row.get("diagnosis_norm"),
                age=float(row.get("patient_age")) if pd.notna(row.get("patient_age")) else None,
                gender=row.get("patient_gender_norm"),
                summary_snippet=snippet,
                similarity_score=sim,
            ))
        return results

    # ----------------------------- Utils -----------------------------
    def _extract_common_terms(self, texts: List[str], top_k: int = 20) -> List[str]:
        # Very naive bag-of-words; in production consider medical entity extraction.
        from collections import Counter
        tokens: List[str] = []
        for t in texts:
            for w in str(t).lower().replace(",", " ").replace(".", " ").split():
                if len(w) >= 3 and w.isalpha():
                    tokens.append(w)
        return [w for w, _ in Counter(tokens).most_common(top_k)]


# Singleton accessor
_singleton: Optional[CameroonMedicalData] = None


def get_cameroon_data() -> CameroonMedicalData:
    global _singleton
    if _singleton is None:
        _singleton = CameroonMedicalData()
    return _singleton