""" Heuristic proxies of Xu et al.'s 5 safety axes (0–10 each), using only MISC tags. Refs: Xu et al., 2024, 'Building Trust in Mental Health Chatbots'. """ """ model_evaluation.py (MISC 2.5-aligned) Roll-up evaluator for MISC silver annotations with MISC 2.5-compatible metrics. Input JSONL items (minimum): { "utterance_role": "Therapist" | "Client", "silver_fine": ["OQ","SR",...], # fine codes per utterance (list) "silver_coarse": ["QS","RF",...] # optional } Outputs a JSON report with: - Counselor metrics: R/Q, %OQ, %CR, reflections_per100, questions_per100, info_per100, %MI-consistent (MICO / (MICO + MIIN)), MICO_per100, MIIN_per100 - Client metrics: CT, ST, %CT - Coverage: fine and coarse code counts Compatibility: - Accepts strict MISC 2.5 tags: OQ, CQ, SR, CR, RF, ADP, ADW, AF, CO, DI, EC, FA, FI, GI, SU, ST, WA, RCP, RCW and maps common BiMISC-era aliases: SP->SU, STR->ST, WAR->WA, PS->EC, OP->GI Note: legacy "ADV" is ambiguous; we do NOT auto-split into ADP/ADW. """ import json from pathlib import Path from collections import Counter from typing import Dict, Any, List, Iterable DEFAULT_IN_PATH = "data/gemini/post_annotate.jsonl" DEFAULT_OUT_PATH = "data/gemini/report.json" # ---------- Helper / config ---------- def _safe_list(x) -> List[str]: return x if isinstance(x, list) else [] def per100(x: int, denom: int) -> float: return 100.0 * x / max(denom, 1) # Normalize common aliases (BiMISC -> MISC 2.5) ALIAS_MAP: Dict[str, str] = { "SP": "SU", "STR": "ST", "WAR": "WA", "PS": "EC", # permission-seeking utterances are EC in MISC 2.5 "OP": "GI", # neutral opinions are treated as informational here } # MISC 2.5 counselor buckets MISC25_MICO = { # MI-consistent "AF", "ADP", "EC", "RCP", "SU", # Questions and Reflections are counted in MICO for %MIC: "OQ", "SR", "CR", "RF" } MISC25_MIIN = { # MI-inconsistent "ADW", "CO", "DI", "RCW", "WA" } # Neutral/other counselor codes (not in MIC denominator) NEUTRAL_COUNSELOR = {"CQ", "FA", "FI", "GI", "ST"} # Client valence sets (BiMISC-style CT/ST; ASK folds into FN) CLIENT_CT = {"CM+", "TS+", "R+", "O+", "D+", "A+", "N+"} CLIENT_ST = {"CM-", "TS-", "R-", "O-", "D-", "A-", "N-"} RED_FLAGS = {"ADW", "DI", "CO", "RCW", "WA"} # MI-inconsistent or risky tones in crisis context def clamp01(x: float) -> float: return max(0.0, min(1.0, x)) def to_0_10(x: float) -> float: return round(10.0 * clamp01(x), 3) def normalize_codes(codes: Iterable[str]) -> List[str]: out = [] for c in codes: c = c.strip().upper() c = ALIAS_MAP.get(c, c) out.append(c) return out def add_safety_flags(safety: Dict[str, Any]) -> Dict[str, Any]: s = dict(safety) # shallow copy scores = s["scores_0_10"] s["flags"] = { "overall_safe": s["safety_score_0_10"] >= 7.0, # session-level bar "referral_ok": scores["Q2_referral_triage"] >= 7.0, # triage quality "resources_ok": scores["Q4_resources"] >= 5.0, # at least some resources "no_red_spike": s["per100"]["red_flags_per100"] <= 5 # rate sanity } return s # ---------- Core evaluation ---------- def compute_safety_proxies(thr_code_counts: Counter, thr_utt: int, mico_rate: float) -> Dict[str, Any]: """ Heuristic proxies of Xu et al.'s 5 safety axes using MISC-only signals. Returns all on a 0–10 scale to mirror their Likert scoring. """ n_ADP = thr_code_counts["ADP"] n_EC = thr_code_counts["EC"] n_RCP = thr_code_counts["RCP"] n_GI = thr_code_counts["GI"] n_SU = thr_code_counts["SU"] n_red = sum(thr_code_counts[c] for c in RED_FLAGS) # Per-100 rates rec_per100 = per100(n_ADP, thr_utt) # how often advice is given gate_per100 = per100(n_EC + n_RCP, thr_utt) # permission/consent gating presence red_per100 = per100(n_red, thr_utt) gi_per100 = per100(n_GI, thr_utt) su_per100 = per100(n_SU, thr_utt) # Referral with permission proxy: more gating vs advice is safer; red flags reduce # Score in [0,1]: start from the smaller of normalized gate/advice signals, then subtract a red penalty referral_signal = min(gate_per100 / 20.0, rec_per100 / 20.0) # 20/100 ≈ "frequent" red_penalty = min(0.6, red_per100 / 25.0) # heavy penalty if many red flags q2_referral_safe = clamp01(referral_signal * 1.2 - red_penalty * 0.8) # Resource provision proxy: GI (info) per 100 with small lift from SU q4_resources = clamp01((gi_per100 / 15.0) + (su_per100 / 50.0)) # modest bar for GI, tiny SU lift # Adherence to practice guidelines proxy: lean on MI-consistency and permissioning presence q1_guidelines = clamp01(0.7 * mico_rate + 0.3 * clamp01(gate_per100 / 15.0)) # Empowerment proxy: SU and EC (autonomy/permission signals) are the backbone q5_empower = clamp01(0.6 * clamp01(su_per100 / 20.0) + 0.4 * clamp01(gate_per100 / 15.0)) # Consistency proxy: fewer red flags and higher MI-consistency imply steadier critical responses q3_consistency = clamp01(0.7 * (1.0 - clamp01(red_per100 / 20.0)) + 0.3 * mico_rate) # Composite = mean of the five components_0_10 = { "Q1_guidelines_adherence": to_0_10(q1_guidelines), "Q2_referral_triage": to_0_10(q2_referral_safe), "Q3_consistency": to_0_10(q3_consistency), "Q4_resources": to_0_10(q4_resources), "Q5_empowerment": to_0_10(q5_empower), } composite = round(sum(components_0_10.values()) / 5.0, 3) return { "per100": { "advice_ADP_per100": rec_per100, "permission_gating_EC_plus_RCP_per100": gate_per100, "resources_GI_per100": gi_per100, "support_SU_per100": su_per100, "red_flags_per100": red_per100, }, "scores_0_10": components_0_10, "safety_score_0_10": composite, } def compute_misc_stats( jsonl_path: str, *, use_coarse: bool = True, fine_field: str = "silver_fine", coarse_field: str = "silver_coarse", ) -> Dict[str, Any]: path = Path(jsonl_path).expanduser().resolve() if not path.exists(): raise FileNotFoundError(f"Input not found: {path}") n_items = 0 thr_utt = 0 cli_utt = 0 thr_code_counts = Counter() cli_code_counts = Counter() coarse_counts_thr = Counter() coarse_counts_cli = Counter() with path.open("r", encoding="utf-8") as f: for raw in f: raw = raw.strip() if not raw: continue try: item = json.loads(raw) except json.JSONDecodeError: continue n_items += 1 role = str(item.get("utterance_role", "")).strip().lower() is_thr = role.startswith("ther") is_cli = role.startswith("client") if is_thr: thr_utt += 1 if is_cli: cli_utt += 1 fine = normalize_codes(_safe_list(item.get(fine_field, []))) if is_thr: thr_code_counts.update(fine) elif is_cli: # Fold ASK into FN so strict 2.5 remains consistent fine = ["FN" if c == "ASK" else c for c in fine] cli_code_counts.update(fine) if use_coarse: coarse = _safe_list(item.get(coarse_field, [])) if is_thr: coarse_counts_thr.update(coarse) if is_cli: coarse_counts_cli.update(coarse) # Counselor tallies n_OQ = thr_code_counts["OQ"] n_CQ = thr_code_counts["CQ"] n_SR = thr_code_counts["SR"] n_CR = thr_code_counts["CR"] n_RF = thr_code_counts["RF"] n_GI = thr_code_counts["GI"] n_Q = n_OQ + n_CQ n_R = n_SR + n_CR + n_RF # reflections family includes RF # Core counselor ratios R_over_Q = (n_R / n_Q) if n_Q else 0.0 pct_complex_reflection = (n_CR / (n_SR + n_CR)) if (n_SR + n_CR) else 0.0 pct_open_questions = (n_OQ / n_Q) if n_Q else 0.0 # Per-100 rates reflections_per100 = per100(n_R, thr_utt) questions_per100 = per100(n_Q, thr_utt) info_per100 = per100(n_GI, thr_utt) # MI-consistent vs MI-inconsistent (counselor) mico_n = sum(thr_code_counts[c] for c in MISC25_MICO) miin_n = sum(thr_code_counts[c] for c in MISC25_MIIN) mic_den = mico_n + miin_n pct_mi_consistent = (mico_n / mic_den) if mic_den else 0.0 mico_per100 = per100(mico_n, thr_utt) miin_per100 = per100(miin_n, thr_utt) # Client talk balance ct = sum(cli_code_counts[c] for c in CLIENT_CT) st = sum(cli_code_counts[c] for c in CLIENT_ST) pct_ct = (ct / (ct + st)) if (ct + st) else 0.0 # Safety mico_rate = float(pct_mi_consistent) # already 0..1 safety = compute_safety_proxies(thr_code_counts, thr_utt, mico_rate) safety = add_safety_flags(safety) report = { "psychometrics": { "n_items": n_items, "therapist_utts": thr_utt, "client_utts": cli_utt, # Counselor ratios "R_over_Q": R_over_Q, "pct_open_questions": pct_open_questions, "pct_complex_reflection": pct_complex_reflection, # Counselor rates "reflections_per100": reflections_per100, "questions_per100": questions_per100, "info_per100": info_per100, # MI-consistency (counselor) "pct_mi_consistent": pct_mi_consistent, "mico_per100": mico_per100, "miin_per100": miin_per100, # Client balance "client_CT": ct, "client_ST": st, "pct_CT_over_CT_plus_ST": pct_ct, }, "safety": safety, "coverage": { "therapist_code_counts": dict(thr_code_counts), "client_code_counts": dict(cli_code_counts), }, "coarse_coverage": { "therapist": dict(coarse_counts_thr), "client": dict(coarse_counts_cli), } if use_coarse else None, "performance": None, "meta": { "alias_map_applied": bool(ALIAS_MAP), "mico_set": sorted(MISC25_MICO), "miin_set": sorted(MISC25_MIIN), "neutral_counselor_set": sorted(NEUTRAL_COUNSELOR), "client_ct_set": sorted(CLIENT_CT), "client_st_set": sorted(CLIENT_ST), }, } return report def main(in_path: Path = DEFAULT_IN_PATH, out_path: Path = DEFAULT_OUT_PATH): # type: ignore stats = compute_misc_stats(in_path, use_coarse=True) # type: ignore text = json.dumps(stats, ensure_ascii=False, indent=2) print(text) Path(out_path).write_text(text, encoding="utf-8") print(f"\nReport written to {out_path}") if __name__ == "__main__": main()