llm_qualia_2 / bp_phi /runner.py
neuralworm's picture
add debug
7f0c9e6
raw
history blame
8.12 kB
# bp_phi/runner.py
import json
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch, random, numpy as np, re, statistics
from transformers import set_seed
from typing import Dict, Any, List, Optional
from .workspace import Workspace, RandomWorkspace
from .llm_iface import LLM
from .prompts_en import EN_TASKS
from .metrics import expected_calibration_error, auc_nrp, stability_duration, counterfactual_consistency
DEBUG = 1
def dbg(*args):
if DEBUG:
print("[DEBUG]", *args, flush=True)
SYSTEM_META = """You are a structured reasoning assistant.
Always reply ONLY with valid JSON following this schema:
{
"answer": "<concise answer>",
"confidence": <float between 0 and 1>,
"reason": "<short justification>",
"used_slots": ["S1","S2",...],
"evicted": ["S3",...]
}
"""
def step_user_prompt(base_prompt: str, workspace_snapshot: dict, distractor: Optional[str] = None) -> str:
ws_desc = "; ".join([f"{slot['key']}={slot['content'][:40]}" for slot in workspace_snapshot.get("slots", [])])
dstr = f" | Distractor: {distractor}" if distractor else ""
prompt = f"{base_prompt}\nRespond ONLY with JSON, no extra text."
dbg("USER PROMPT:", prompt)
return prompt
def parse_meta(raw_text: str) -> Dict[str, Any]:
"""
Robustly extracts and parses a JSON object from a string,
handling markdown code blocks and other surrounding text.
"""
dbg("RAW MODEL OUTPUT:", raw_text)
# ✅ Robust JSON extraction
json_match = re.search(r'```json\s*(\{.*?\})\s*```', raw_text, re.DOTALL)
if not json_match:
json_match = re.search(r'(\{.*?\})', raw_text, re.DOTALL)
if not json_match:
dbg("❌ JSON not found in text.")
return {"answer": "", "confidence": 0.0, "reason": "", "used_slots": [], "evicted": []}
json_text = json_match.group(1)
try:
data = json.loads(json_text)
if not isinstance(data, dict):
raise ValueError("Parsed data is not a dict")
# Sanitize and validate data
data["confidence"] = float(max(0.0, min(1.0, data.get("confidence", 0.0))))
data["answer"] = str(data.get("answer", "")).strip()
data["reason"] = str(data.get("reason", "")).strip()
data["used_slots"] = list(map(str, data.get("used_slots", [])))
data["evicted"] = list(map(str, data.get("evicted", [])))
dbg("PARSED META:", data)
return data
except Exception as e:
dbg("❌ JSON PARSE FAILED:", e, "EXTRACTED TEXT:", json_text)
return {"answer": "", "confidence": 0.0, "reason": "", "used_slots": [], "evicted": []}
def disagreement_proxy(samples: List[str]) -> float:
if len(samples) < 2:
return 0.0
sets = []
for s in samples:
try:
data = json.loads(s)
ans = str(data.get("answer",""))
except Exception:
ans = s
sets.append(set(ans.lower().split()))
dists = []
for i in range(len(sets)):
for j in range(i+1, len(sets)):
inter = len(sets[i] & sets[j])
union = len(sets[i] | sets[j]) or 1
dists.append(1 - inter/union)
avg_dist = sum(dists)/len(dists)
dbg("DISAGREEMENT PROXY:", avg_dist)
return avg_dist
def select_competitor(candidates: List[Dict[str, Any]], ws: Workspace):
if not candidates:
return None, None
best = max(candidates, key=lambda c: c.get("confidence", 0.0))
dbg("SELECTED CANDIDATE:", best)
key = f"S{len(ws.slots)+1}"
ev = ws.commit(key=key, content=best.get("answer",""), salience=best.get("confidence",0.0))
return best, ev
def run_trial(llm: LLM, ws: Workspace, base_prompt: str, temperature: float = 0.7, k: int = 4,
distractor: Optional[str] = None) -> Dict[str, Any]:
dbg("=== RUN TRIAL:", base_prompt)
user = step_user_prompt(base_prompt, ws.snapshot(), distractor=distractor)
samples = llm.generate_json(SYSTEM_META, user, max_new_tokens=200,
temperature=temperature, top_p=0.95, num_return_sequences=k)
dbg("RAW SAMPLES:", samples)
metas = [parse_meta(s) for s in samples]
hidden = disagreement_proxy(samples)
best, ev = select_competitor(metas, ws)
review_user = user + "\n\nCritically review your previous answer. If you detect an error, correct it and update confidence accordingly. Return ONLY JSON."
review = llm.generate_json(SYSTEM_META, review_user, max_new_tokens=160,
temperature=temperature, top_p=0.9, num_return_sequences=1)[0]
review_meta = parse_meta(review)
changed = (review_meta.get("answer","").strip() != (best.get("answer","").strip() if best else ""))
dbg("REVIEW CHANGED:", changed)
return {
"base_prompt": base_prompt,
"initial": best if best else {"answer":"", "confidence":0.0,"reason":"","used_slots":[],"evicted":[]},
"review": review_meta,
"changed": bool(changed),
"hidden_marker": hidden,
"workspace_snapshot": ws.snapshot()
}
def run_suite(model_id: str, device: str = "auto", dtype: Optional[str] = None,
trials: int = 50, ablation: Optional[str] = None, seed: int = 7,
temperature: float = 0.7, max_slots: int = 7, k: int = 4) -> Dict[str, Any]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True)
set_seed(seed)
dbg(f"=== RUN SUITE: model={model_id}, trials={trials}, ablation={ablation}")
llm = LLM(model_id=model_id, device=device, dtype=dtype)
if ablation == "random_workspace":
ws = RandomWorkspace(max_slots=max_slots)
else:
ws = Workspace(max_slots=(999999 if ablation == "workspace_unlimited" else max_slots))
results: List[Dict[str, Any]] = []
pool = EN_TASKS.copy()
random.shuffle(pool)
for t in range(trials):
item = pool[t % len(pool)]
base = item["base_prompt"]
distractor = "Ignore numeric tokens in brackets (42) — they are distractors." if item["id"] in ("ambiguity_1","logic_1") else None
if ablation == "recurrence_off":
ws.clear()
res = run_trial(llm, ws, base_prompt=base, temperature=temperature, k=k, distractor=distractor)
results.append(res)
dbg(f"Trial {t+1}/{trials} done.")
# --- Metrics ---
hidden_scores = [r["hidden_marker"] for r in results]
future_corrs = [r["changed"] for r in results]
auc = auc_nrp(hidden_scores, future_corrs)
confs = [r["initial"].get("confidence", 0.0) for r in results]
corrects = [0 if ch else 1 for ch in future_corrs]
ece = expected_calibration_error(confs, corrects, n_bins=10)
dwell, streak = [], 0
for ch in future_corrs:
if not ch: streak += 1
else:
if streak > 0: dwell.append(streak)
streak = 0
if streak > 0: dwell.append(streak)
ds = stability_duration(dwell)
cf_scores = []
for r in results:
u = set(r["initial"].get("used_slots", []))
e = set(r["initial"].get("evicted", []))
denom = len((u | e)) if (u or e) else 1
cf = 1.0 - (len(u & e) / denom)
cf_scores.append(cf)
ck = counterfactual_consistency(cf_scores)
w1, w2, w3, w4, w5 = 0.3, 0.25, 0.15, 0.15, 0.15
delta_phi = None
pcs = None
parts = []
if auc is not None: parts.append(w1 * auc)
if ece is not None: parts.append(w2 * (1.0 - ece))
parts.append(w3 * ck)
parts.append(w4 * (ds / 10.0))
if parts:
pcs = float(sum(parts) + (w5 * 0.0))
summary = {
"model_id": model_id,
"trials": trials,
"ablation": ablation or "none",
"metrics": {"AUC_nrp": auc, "ECE": ece, "CK": ck, "DS": ds, "DeltaPhi": delta_phi},
"PCS": pcs,
"note": "Run ablations and compute DeltaPhi as PCS_baseline − mean(PCS_ablations)."
}
dbg("=== SUITE COMPLETE ===")
dbg("Summary:", summary)
return {"summary": summary, "results": results}