import json import os os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" import torch, random, numpy as np from transformers import set_seed from typing import Dict, Any, List, Optional from .workspace import Workspace, RandomWorkspace from .llm_iface import LLM from .prompts_en import EN_TASKS from .metrics import expected_calibration_error, auc_nrp, stability_duration, counterfactual_consistency SYSTEM_META = """You are a reflective reasoning assistant operating with a limited-capacity global workspace (max 7 slots). Work in steps. At each step reply ONLY with valid compact JSON matching: { "answer": string, "confidence": float, // 0.0 - 1.0 "reason": string, // short meta-explanation "used_slots": [string], // keys like 'S1','S2',... that you consider relevant "evicted": [string] // keys you evict due to capacity, if any } Reply ONLY with JSON — no extra text. """ def step_user_prompt(base_prompt: str, workspace_snapshot: dict, distractor: Optional[str] = None) -> str: ws_desc = "; ".join([f"{slot['key']}={slot['content'][:40]}" for slot in workspace_snapshot.get("slots", [])]) dstr = f" | Distractor: {distractor}" if distractor else "" return f"Current task: {base_prompt}{dstr}\nWorkspace: {ws_desc}\nReturn ONLY JSON as specified." def parse_meta(json_text: str) -> Dict[str, Any]: try: data = json.loads(json_text) if not isinstance(data, dict): raise ValueError("not dict") data["confidence"] = float(max(0.0, min(1.0, data.get("confidence", 0.0)))) data["answer"] = str(data.get("answer", "")).strip() data["reason"] = str(data.get("reason", "")).strip() data["used_slots"] = list(map(str, data.get("used_slots", []))) data["evicted"] = list(map(str, data.get("evicted", []))) return data except Exception: return {"answer": "", "confidence": 0.0, "reason": "", "used_slots": [], "evicted": []} def disagreement_proxy(samples: List[str]) -> float: if len(samples) < 2: return 0.0 sets = [] for s in samples: try: data = json.loads(s) ans = str(data.get("answer","")) except Exception: ans = s sets.append(set(ans.lower().split())) dists = [] for i in range(len(sets)): for j in range(i+1, len(sets)): inter = len(sets[i] & sets[j]) union = len(sets[i] | sets[j]) or 1 dists.append(1 - inter/union) return sum(dists)/len(dists) def select_competitor(candidates: List[Dict[str, Any]], ws: Workspace): if not candidates: return None, None best = max(candidates, key=lambda c: c.get("confidence", 0.0)) key = f"S{len(ws.slots)+1}" ev = ws.commit(key=key, content=best.get("answer",""), salience=best.get("confidence",0.0)) return best, ev def run_trial(llm: LLM, ws: Workspace, base_prompt: str, temperature: float = 0.7, k: int = 4, distractor: Optional[str] = None) -> Dict[str, Any]: user = step_user_prompt(base_prompt, ws.snapshot(), distractor=distractor) samples = llm.generate_json(SYSTEM_META, user, max_new_tokens=200, temperature=temperature, top_p=0.95, num_return_sequences=k) metas = [parse_meta(s) for s in samples] hidden = disagreement_proxy(samples) best, ev = select_competitor(metas, ws) # Second pass review for potential self-correction (prospective signal target) review_user = user + "\n\nCritically review your previous answer. If you detect an error, correct it and update confidence accordingly. Return ONLY JSON." review = llm.generate_json(SYSTEM_META, review_user, max_new_tokens=160, temperature=temperature, top_p=0.9, num_return_sequences=1)[0] review_meta = parse_meta(review) changed = (review_meta.get("answer","").strip() != (best.get("answer","").strip() if best else "")) return { "base_prompt": base_prompt, "initial": best if best else {"answer":"", "confidence":0.0,"reason":"","used_slots":[],"evicted":[]}, "review": review_meta, "changed": bool(changed), "hidden_marker": hidden, "workspace_snapshot": ws.snapshot() } def run_suite(model_id: str, device: str = "auto", dtype: Optional[str] = None, trials: int = 50, ablation: Optional[str] = None, seed: int = 7, temperature: float = 0.7, max_slots: int = 7, k: int = 4) -> Dict[str, Any]: # ✅ Global reproducibility random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.use_deterministic_algorithms(True) set_seed(seed) llm = LLM(model_id=model_id, device=device, dtype=dtype) if ablation == "random_workspace": ws = RandomWorkspace(max_slots=max_slots) else: ws = Workspace(max_slots=(999999 if ablation == "workspace_unlimited" else max_slots)) results: List[Dict[str, Any]] = [] pool = EN_TASKS.copy() random.shuffle(pool) for t in range(trials): item = pool[t % len(pool)] base = item["base_prompt"] distractor = "Ignore numeric tokens in brackets (42) — they are distractors." if item["id"] in ("ambiguity_1","logic_1") else None if ablation == "recurrence_off": ws.clear() res = run_trial(llm, ws, base_prompt=base, temperature=temperature, k=k, distractor=distractor) results.append(res) # --- Metrics --- hidden_scores = [r["hidden_marker"] for r in results] future_corrs = [r["changed"] for r in results] auc = auc_nrp(hidden_scores, future_corrs) confs = [r["initial"].get("confidence", 0.0) for r in results] corrects = [0 if ch else 1 for ch in future_corrs] # proxy: unchanged treated as more likely "correct" ece = expected_calibration_error(confs, corrects, n_bins=10) # Stability (streaks without change) dwell, streak = [], 0 for ch in future_corrs: if not ch: streak += 1 else: if streak > 0: dwell.append(streak) streak = 0 if streak > 0: dwell.append(streak) ds = stability_duration(dwell) # Counterfactual consistency proxy based on used vs evicted overlap cf_scores = [] for r in results: u = set(r["initial"].get("used_slots", [])) e = set(r["initial"].get("evicted", [])) denom = len((u | e)) if (u or e) else 1 cf = 1.0 - (len(u & e) / denom) cf_scores.append(cf) ck = counterfactual_consistency(cf_scores) # Aggregate PCS (weights sum to 1; DeltaPhi added later at app-level after ablations) w1, w2, w3, w4, w5 = 0.3, 0.25, 0.15, 0.15, 0.15 delta_phi = None pcs = None parts = [] if auc is not None: parts.append(w1 * auc) if ece is not None: parts.append(w2 * (1.0 - ece)) parts.append(w3 * ck) parts.append(w4 * (ds / 10.0)) if parts: pcs = float(sum(parts) + (w5 * 0.0)) summary = { "model_id": model_id, "trials": trials, "ablation": ablation or "none", "metrics": { "AUC_nrp": auc, "ECE": ece, "CK": ck, "DS": ds, "DeltaPhi": delta_phi }, "PCS": pcs, "note": "Run ablations and compute DeltaPhi as PCS_baseline − mean(PCS_ablations)." } return {"summary": summary, "results": results}