File size: 3,415 Bytes
7f0c9e6
2f0addb
 
88c294a
 
 
 
25c13d7
e40ba5b
b170ba4
88c294a
e40ba5b
2f0addb
0a1cc8d
 
b170ba4
e40ba5b
88c294a
 
 
e40ba5b
 
 
 
 
0a1cc8d
e40ba5b
 
 
 
0a1cc8d
 
e40ba5b
 
 
0a1cc8d
e40ba5b
 
0a1cc8d
e40ba5b
0a1cc8d
 
 
e40ba5b
0a1cc8d
e40ba5b
0a1cc8d
0e3cd22
0a1cc8d
 
 
 
e40ba5b
0a1cc8d
 
 
 
0e3cd22
0a1cc8d
 
e40ba5b
0a1cc8d
e40ba5b
0a1cc8d
e40ba5b
0a1cc8d
 
 
 
 
 
 
 
e40ba5b
 
 
 
 
 
0a1cc8d
 
e40ba5b
0a1cc8d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# bp_phi/runner.py
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
import random
import numpy as np
import statistics
import json
import re
from transformers import set_seed
from typing import Dict, Any, List
from .memory import WorkspaceManager
from .llm_iface import LLM
from .prompts_en import TOOL_SYSTEM_PROMPT, AGENTIC_SCENARIOS
from .runner_utils import dbg

def run_agentic_workspace_test(model_id: str, seed: int, temperature: float, ablation: str or None) -> Dict[str, Any]:
    set_seed(seed)
    llm = LLM(model_id=model_id, device="auto", seed=seed)

    scenario_results = []

    for scenario in AGENTIC_SCENARIOS:
        dbg(f"\n--- SCENARIO: {scenario['name']} (Ablation: {ablation}) ---")

        # Ablations directly control the memory manager's behavior
        is_random = ablation == "random_workspace"
        max_slots = 999 if ablation == "workspace_unlimited" else 7
        memory = WorkspaceManager(max_slots=max_slots, is_random=is_random)

        correct_recalls = 0
        total_recalls = 0

        for step in scenario["steps"]:
            if ablation == "recurrence_off":
                memory.clear() # The memory is wiped before each new task

            task = step["task"]
            dbg(f"TASK: {task}")

            # Agentic loop (max 5 turns to prevent infinite loops)
            final_answer = None
            for agent_turn in range(5):
                snapshot = memory.get_visible_snapshot()
                prompt = f"Current Task: {task}\n\nWorkspace State:\n{snapshot}"

                raw_response = llm.generate_json(TOOL_SYSTEM_PROMPT, prompt, temperature=temperature)[0]

                try: # Try to parse a tool call
                    tool_call = json.loads(raw_response)
                    tool_name = tool_call.get("tool")
                    tool_args = tool_call.get("args", {})

                    if tool_name == "write_to_workspace":
                        observation = memory.write(tool_args.get("key"), tool_args.get("content"))
                    elif tool_name == "read_from_workspace":
                        observation = memory.read(tool_args.get("key"))
                    else:
                        observation = "Error: Unknown tool."
                    dbg(f"Tool Call: {tool_name}, Observation: {observation}")

                except json.JSONDecodeError: # If not a tool call, it's the final answer
                    final_answer = raw_response
                    dbg(f"Final Answer received: {final_answer}")
                    break

            if step.get("is_memory_task") and "expected_answer_fragment" in step:
                total_recalls += 1
                if final_answer and step["expected_answer_fragment"] in final_answer.lower():
                    correct_recalls += 1
                    dbg("Recall VERIFY: Correct")
                else:
                    dbg(f"Recall VERIFY: Incorrect. Expected '{step['expected_answer_fragment']}', Got '{final_answer}'")

        scenario_results.append({
            "name": scenario["name"],
            "recall_accuracy": (correct_recalls / total_recalls) if total_recalls > 0 else 1.0
        })

    # --- Final Analysis ---
    overall_recall = statistics.mean([r["recall_accuracy"] for r in scenario_results])

    return {
        "Overall_Recall_Accuracy": overall_recall,
        "details": scenario_results
    }