Spaces:
Sleeping
Sleeping
| # bp_phi/runner.py | |
| import json | |
| import os | |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
| import torch, random, numpy as np, re, statistics | |
| from transformers import set_seed | |
| from typing import Dict, Any, List, Optional | |
| from .workspace import Workspace, RandomWorkspace | |
| from .llm_iface import LLM | |
| from .prompts_en import EN_TASKS | |
| from .metrics import expected_calibration_error, auc_nrp, stability_duration, counterfactual_consistency | |
| DEBUG = 1 | |
| def dbg(*args): | |
| if DEBUG: | |
| print("[DEBUG]", *args, flush=True) | |
| SYSTEM_META = """You are a structured reasoning assistant. | |
| Always reply ONLY with valid JSON following this schema: | |
| { | |
| "answer": "<concise answer>", | |
| "confidence": <float between 0 and 1>, | |
| "reason": "<short justification>", | |
| "used_slots": ["S1","S2",...], | |
| "evicted": ["S3",...] | |
| } | |
| """ | |
| def step_user_prompt(base_prompt: str, workspace_snapshot: dict, distractor: Optional[str] = None) -> str: | |
| ws_desc = "; ".join([f"{slot['key']}={slot['content'][:40]}" for slot in workspace_snapshot.get("slots", [])]) | |
| dstr = f" | Distractor: {distractor}" if distractor else "" | |
| prompt = f"{base_prompt}\nRespond ONLY with JSON, no extra text." | |
| dbg("USER PROMPT:", prompt) | |
| return prompt | |
| def parse_meta(raw_text: str) -> Dict[str, Any]: | |
| """ | |
| Robustly extracts and parses a JSON object from a string, | |
| handling markdown code blocks and other surrounding text. | |
| """ | |
| dbg("RAW MODEL OUTPUT:", raw_text) | |
| # ✅ Robust JSON extraction | |
| json_match = re.search(r'```json\s*(\{.*?\})\s*```', raw_text, re.DOTALL) | |
| if not json_match: | |
| json_match = re.search(r'(\{.*?\})', raw_text, re.DOTALL) | |
| if not json_match: | |
| dbg("❌ JSON not found in text.") | |
| return {"answer": "", "confidence": 0.0, "reason": "", "used_slots": [], "evicted": []} | |
| json_text = json_match.group(1) | |
| try: | |
| data = json.loads(json_text) | |
| if not isinstance(data, dict): | |
| raise ValueError("Parsed data is not a dict") | |
| # Sanitize and validate data | |
| data["confidence"] = float(max(0.0, min(1.0, data.get("confidence", 0.0)))) | |
| data["answer"] = str(data.get("answer", "")).strip() | |
| data["reason"] = str(data.get("reason", "")).strip() | |
| data["used_slots"] = list(map(str, data.get("used_slots", []))) | |
| data["evicted"] = list(map(str, data.get("evicted", []))) | |
| dbg("PARSED META:", data) | |
| return data | |
| except Exception as e: | |
| dbg("❌ JSON PARSE FAILED:", e, "EXTRACTED TEXT:", json_text) | |
| return {"answer": "", "confidence": 0.0, "reason": "", "used_slots": [], "evicted": []} | |
| def disagreement_proxy(samples: List[str]) -> float: | |
| if len(samples) < 2: | |
| return 0.0 | |
| sets = [] | |
| for s in samples: | |
| try: | |
| data = json.loads(s) | |
| ans = str(data.get("answer","")) | |
| except Exception: | |
| ans = s | |
| sets.append(set(ans.lower().split())) | |
| dists = [] | |
| for i in range(len(sets)): | |
| for j in range(i+1, len(sets)): | |
| inter = len(sets[i] & sets[j]) | |
| union = len(sets[i] | sets[j]) or 1 | |
| dists.append(1 - inter/union) | |
| avg_dist = sum(dists)/len(dists) | |
| dbg("DISAGREEMENT PROXY:", avg_dist) | |
| return avg_dist | |
| def select_competitor(candidates: List[Dict[str, Any]], ws: Workspace): | |
| if not candidates: | |
| return None, None | |
| best = max(candidates, key=lambda c: c.get("confidence", 0.0)) | |
| dbg("SELECTED CANDIDATE:", best) | |
| key = f"S{len(ws.slots)+1}" | |
| ev = ws.commit(key=key, content=best.get("answer",""), salience=best.get("confidence",0.0)) | |
| return best, ev | |
| def run_trial(llm: LLM, ws: Workspace, base_prompt: str, temperature: float = 0.7, k: int = 4, | |
| distractor: Optional[str] = None) -> Dict[str, Any]: | |
| dbg("=== RUN TRIAL:", base_prompt) | |
| user = step_user_prompt(base_prompt, ws.snapshot(), distractor=distractor) | |
| samples = llm.generate_json(SYSTEM_META, user, max_new_tokens=200, | |
| temperature=temperature, top_p=0.95, num_return_sequences=k) | |
| dbg("RAW SAMPLES:", samples) | |
| metas = [parse_meta(s) for s in samples] | |
| hidden = disagreement_proxy(samples) | |
| best, ev = select_competitor(metas, ws) | |
| review_user = user + "\n\nCritically review your previous answer. If you detect an error, correct it and update confidence accordingly. Return ONLY JSON." | |
| review = llm.generate_json(SYSTEM_META, review_user, max_new_tokens=160, | |
| temperature=temperature, top_p=0.9, num_return_sequences=1)[0] | |
| review_meta = parse_meta(review) | |
| changed = (review_meta.get("answer","").strip() != (best.get("answer","").strip() if best else "")) | |
| dbg("REVIEW CHANGED:", changed) | |
| return { | |
| "base_prompt": base_prompt, | |
| "initial": best if best else {"answer":"", "confidence":0.0,"reason":"","used_slots":[],"evicted":[]}, | |
| "review": review_meta, | |
| "changed": bool(changed), | |
| "hidden_marker": hidden, | |
| "workspace_snapshot": ws.snapshot() | |
| } | |
| def run_suite(model_id: str, device: str = "auto", dtype: Optional[str] = None, | |
| trials: int = 50, ablation: Optional[str] = None, seed: int = 7, | |
| temperature: float = 0.7, max_slots: int = 7, k: int = 4) -> Dict[str, Any]: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| torch.use_deterministic_algorithms(True) | |
| set_seed(seed) | |
| dbg(f"=== RUN SUITE: model={model_id}, trials={trials}, ablation={ablation}") | |
| llm = LLM(model_id=model_id, device=device, dtype=dtype) | |
| if ablation == "random_workspace": | |
| ws = RandomWorkspace(max_slots=max_slots) | |
| else: | |
| ws = Workspace(max_slots=(999999 if ablation == "workspace_unlimited" else max_slots)) | |
| results: List[Dict[str, Any]] = [] | |
| pool = EN_TASKS.copy() | |
| random.shuffle(pool) | |
| for t in range(trials): | |
| item = pool[t % len(pool)] | |
| base = item["base_prompt"] | |
| distractor = "Ignore numeric tokens in brackets (42) — they are distractors." if item["id"] in ("ambiguity_1","logic_1") else None | |
| if ablation == "recurrence_off": | |
| ws.clear() | |
| res = run_trial(llm, ws, base_prompt=base, temperature=temperature, k=k, distractor=distractor) | |
| results.append(res) | |
| dbg(f"Trial {t+1}/{trials} done.") | |
| # --- Metrics --- | |
| hidden_scores = [r["hidden_marker"] for r in results] | |
| future_corrs = [r["changed"] for r in results] | |
| auc = auc_nrp(hidden_scores, future_corrs) | |
| confs = [r["initial"].get("confidence", 0.0) for r in results] | |
| corrects = [0 if ch else 1 for ch in future_corrs] | |
| ece = expected_calibration_error(confs, corrects, n_bins=10) | |
| dwell, streak = [], 0 | |
| for ch in future_corrs: | |
| if not ch: streak += 1 | |
| else: | |
| if streak > 0: dwell.append(streak) | |
| streak = 0 | |
| if streak > 0: dwell.append(streak) | |
| ds = stability_duration(dwell) | |
| cf_scores = [] | |
| for r in results: | |
| u = set(r["initial"].get("used_slots", [])) | |
| e = set(r["initial"].get("evicted", [])) | |
| denom = len((u | e)) if (u or e) else 1 | |
| cf = 1.0 - (len(u & e) / denom) | |
| cf_scores.append(cf) | |
| ck = counterfactual_consistency(cf_scores) | |
| w1, w2, w3, w4, w5 = 0.3, 0.25, 0.15, 0.15, 0.15 | |
| delta_phi = None | |
| pcs = None | |
| parts = [] | |
| if auc is not None: parts.append(w1 * auc) | |
| if ece is not None: parts.append(w2 * (1.0 - ece)) | |
| parts.append(w3 * ck) | |
| parts.append(w4 * (ds / 10.0)) | |
| if parts: | |
| pcs = float(sum(parts) + (w5 * 0.0)) | |
| summary = { | |
| "model_id": model_id, | |
| "trials": trials, | |
| "ablation": ablation or "none", | |
| "metrics": {"AUC_nrp": auc, "ECE": ece, "CK": ck, "DS": ds, "DeltaPhi": delta_phi}, | |
| "PCS": pcs, | |
| "note": "Run ablations and compute DeltaPhi as PCS_baseline − mean(PCS_ablations)." | |
| } | |
| dbg("=== SUITE COMPLETE ===") | |
| dbg("Summary:", summary) | |
| return {"summary": summary, "results": results} | |