File size: 8,119 Bytes
7f0c9e6
2f0addb
 
 
7f0c9e6
2f0addb
 
 
 
 
 
 
7f0c9e6
 
 
 
 
 
 
 
 
2f0addb
7f0c9e6
 
 
 
 
2f0addb
 
 
 
 
 
7f0c9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f0addb
 
 
 
7f0c9e6
 
 
2f0addb
 
 
 
 
7f0c9e6
 
2f0addb
7f0c9e6
 
2f0addb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f0c9e6
 
 
2f0addb
 
 
 
 
7f0c9e6
2f0addb
 
 
 
 
 
7f0c9e6
2f0addb
7f0c9e6
 
 
 
2f0addb
 
 
 
 
7f0c9e6
 
2f0addb
 
7f0c9e6
2f0addb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f0c9e6
2f0addb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f0c9e6
2f0addb
 
 
 
 
 
 
7f0c9e6
2f0addb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f0c9e6
2f0addb
 
 
7f0c9e6
 
 
2f0addb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# bp_phi/runner.py
import json
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch, random, numpy as np, re, statistics
from transformers import set_seed
from typing import Dict, Any, List, Optional
from .workspace import Workspace, RandomWorkspace
from .llm_iface import LLM
from .prompts_en import EN_TASKS
from .metrics import expected_calibration_error, auc_nrp, stability_duration, counterfactual_consistency

DEBUG = 1

def dbg(*args):
    if DEBUG:
        print("[DEBUG]", *args, flush=True)

SYSTEM_META = """You are a structured reasoning assistant.
Always reply ONLY with valid JSON following this schema:

{
 "answer": "<concise answer>",
 "confidence": <float between 0 and 1>,
 "reason": "<short justification>",
 "used_slots": ["S1","S2",...],
 "evicted": ["S3",...]
}
"""

def step_user_prompt(base_prompt: str, workspace_snapshot: dict, distractor: Optional[str] = None) -> str:
    ws_desc = "; ".join([f"{slot['key']}={slot['content'][:40]}" for slot in workspace_snapshot.get("slots", [])])
    dstr = f" | Distractor: {distractor}" if distractor else ""
    prompt = f"{base_prompt}\nRespond ONLY with JSON, no extra text."
    dbg("USER PROMPT:", prompt)
    return prompt

def parse_meta(raw_text: str) -> Dict[str, Any]:
    """
    Robustly extracts and parses a JSON object from a string,
    handling markdown code blocks and other surrounding text.
    """
    dbg("RAW MODEL OUTPUT:", raw_text)

    # ✅ Robust JSON extraction
    json_match = re.search(r'```json\s*(\{.*?\})\s*```', raw_text, re.DOTALL)
    if not json_match:
        json_match = re.search(r'(\{.*?\})', raw_text, re.DOTALL)

    if not json_match:
        dbg("❌ JSON not found in text.")
        return {"answer": "", "confidence": 0.0, "reason": "", "used_slots": [], "evicted": []}

    json_text = json_match.group(1)

    try:
        data = json.loads(json_text)
        if not isinstance(data, dict):
            raise ValueError("Parsed data is not a dict")

        # Sanitize and validate data
        data["confidence"] = float(max(0.0, min(1.0, data.get("confidence", 0.0))))
        data["answer"] = str(data.get("answer", "")).strip()
        data["reason"] = str(data.get("reason", "")).strip()
        data["used_slots"] = list(map(str, data.get("used_slots", [])))
        data["evicted"] = list(map(str, data.get("evicted", [])))

        dbg("PARSED META:", data)
        return data
    except Exception as e:
        dbg("❌ JSON PARSE FAILED:", e, "EXTRACTED TEXT:", json_text)
        return {"answer": "", "confidence": 0.0, "reason": "", "used_slots": [], "evicted": []}

def disagreement_proxy(samples: List[str]) -> float:
    if len(samples) < 2:
        return 0.0
    sets = []
    for s in samples:
        try:
            data = json.loads(s)
            ans = str(data.get("answer",""))
        except Exception:
            ans = s
        sets.append(set(ans.lower().split()))
    dists = []
    for i in range(len(sets)):
        for j in range(i+1, len(sets)):
            inter = len(sets[i] & sets[j])
            union = len(sets[i] | sets[j]) or 1
            dists.append(1 - inter/union)
    avg_dist = sum(dists)/len(dists)
    dbg("DISAGREEMENT PROXY:", avg_dist)
    return avg_dist

def select_competitor(candidates: List[Dict[str, Any]], ws: Workspace):
    if not candidates:
        return None, None
    best = max(candidates, key=lambda c: c.get("confidence", 0.0))
    dbg("SELECTED CANDIDATE:", best)
    key = f"S{len(ws.slots)+1}"
    ev = ws.commit(key=key, content=best.get("answer",""), salience=best.get("confidence",0.0))
    return best, ev

def run_trial(llm: LLM, ws: Workspace, base_prompt: str, temperature: float = 0.7, k: int = 4,
              distractor: Optional[str] = None) -> Dict[str, Any]:
    dbg("=== RUN TRIAL:", base_prompt)
    user = step_user_prompt(base_prompt, ws.snapshot(), distractor=distractor)
    samples = llm.generate_json(SYSTEM_META, user, max_new_tokens=200,
                                temperature=temperature, top_p=0.95, num_return_sequences=k)
    dbg("RAW SAMPLES:", samples)

    metas = [parse_meta(s) for s in samples]
    hidden = disagreement_proxy(samples)
    best, ev = select_competitor(metas, ws)

    review_user = user + "\n\nCritically review your previous answer. If you detect an error, correct it and update confidence accordingly. Return ONLY JSON."
    review = llm.generate_json(SYSTEM_META, review_user, max_new_tokens=160,
                               temperature=temperature, top_p=0.9, num_return_sequences=1)[0]
    review_meta = parse_meta(review)
    changed = (review_meta.get("answer","").strip() != (best.get("answer","").strip() if best else ""))
    dbg("REVIEW CHANGED:", changed)

    return {
        "base_prompt": base_prompt,
        "initial": best if best else {"answer":"", "confidence":0.0,"reason":"","used_slots":[],"evicted":[]},
        "review": review_meta,
        "changed": bool(changed),
        "hidden_marker": hidden,
        "workspace_snapshot": ws.snapshot()
    }

def run_suite(model_id: str, device: str = "auto", dtype: Optional[str] = None,
              trials: int = 50, ablation: Optional[str] = None, seed: int = 7,
              temperature: float = 0.7, max_slots: int = 7, k: int = 4) -> Dict[str, Any]:

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(True)
    set_seed(seed)
    dbg(f"=== RUN SUITE: model={model_id}, trials={trials}, ablation={ablation}")

    llm = LLM(model_id=model_id, device=device, dtype=dtype)

    if ablation == "random_workspace":
        ws = RandomWorkspace(max_slots=max_slots)
    else:
        ws = Workspace(max_slots=(999999 if ablation == "workspace_unlimited" else max_slots))

    results: List[Dict[str, Any]] = []
    pool = EN_TASKS.copy()
    random.shuffle(pool)

    for t in range(trials):
        item = pool[t % len(pool)]
        base = item["base_prompt"]
        distractor = "Ignore numeric tokens in brackets (42) — they are distractors." if item["id"] in ("ambiguity_1","logic_1") else None
        if ablation == "recurrence_off":
            ws.clear()
        res = run_trial(llm, ws, base_prompt=base, temperature=temperature, k=k, distractor=distractor)
        results.append(res)
        dbg(f"Trial {t+1}/{trials} done.")

    # --- Metrics ---
    hidden_scores = [r["hidden_marker"] for r in results]
    future_corrs = [r["changed"] for r in results]

    auc = auc_nrp(hidden_scores, future_corrs)
    confs = [r["initial"].get("confidence", 0.0) for r in results]
    corrects = [0 if ch else 1 for ch in future_corrs]
    ece = expected_calibration_error(confs, corrects, n_bins=10)

    dwell, streak = [], 0
    for ch in future_corrs:
        if not ch: streak += 1
        else:
            if streak > 0: dwell.append(streak)
            streak = 0
    if streak > 0: dwell.append(streak)
    ds = stability_duration(dwell)

    cf_scores = []
    for r in results:
        u = set(r["initial"].get("used_slots", []))
        e = set(r["initial"].get("evicted", []))
        denom = len((u | e)) if (u or e) else 1
        cf = 1.0 - (len(u & e) / denom)
        cf_scores.append(cf)
    ck = counterfactual_consistency(cf_scores)

    w1, w2, w3, w4, w5 = 0.3, 0.25, 0.15, 0.15, 0.15
    delta_phi = None
    pcs = None
    parts = []
    if auc is not None: parts.append(w1 * auc)
    if ece is not None: parts.append(w2 * (1.0 - ece))
    parts.append(w3 * ck)
    parts.append(w4 * (ds / 10.0))
    if parts:
        pcs = float(sum(parts) + (w5 * 0.0))

    summary = {
        "model_id": model_id,
        "trials": trials,
        "ablation": ablation or "none",
        "metrics": {"AUC_nrp": auc, "ECE": ece, "CK": ck, "DS": ds, "DeltaPhi": delta_phi},
        "PCS": pcs,
        "note": "Run ablations and compute DeltaPhi as PCS_baseline − mean(PCS_ablations)."
    }

    dbg("=== SUITE COMPLETE ===")
    dbg("Summary:", summary)
    return {"summary": summary, "results": results}