Spaces:
Sleeping
Sleeping
File size: 3,415 Bytes
7f0c9e6 2f0addb 88c294a 25c13d7 e40ba5b b170ba4 88c294a e40ba5b 2f0addb 0a1cc8d b170ba4 e40ba5b 88c294a e40ba5b 0a1cc8d e40ba5b 0a1cc8d e40ba5b 0a1cc8d e40ba5b 0a1cc8d e40ba5b 0a1cc8d e40ba5b 0a1cc8d e40ba5b 0a1cc8d 0e3cd22 0a1cc8d e40ba5b 0a1cc8d 0e3cd22 0a1cc8d e40ba5b 0a1cc8d e40ba5b 0a1cc8d e40ba5b 0a1cc8d e40ba5b 0a1cc8d e40ba5b 0a1cc8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
# bp_phi/runner.py
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
import random
import numpy as np
import statistics
import json
import re
from transformers import set_seed
from typing import Dict, Any, List
from .memory import WorkspaceManager
from .llm_iface import LLM
from .prompts_en import TOOL_SYSTEM_PROMPT, AGENTIC_SCENARIOS
from .runner_utils import dbg
def run_agentic_workspace_test(model_id: str, seed: int, temperature: float, ablation: str or None) -> Dict[str, Any]:
set_seed(seed)
llm = LLM(model_id=model_id, device="auto", seed=seed)
scenario_results = []
for scenario in AGENTIC_SCENARIOS:
dbg(f"\n--- SCENARIO: {scenario['name']} (Ablation: {ablation}) ---")
# Ablations directly control the memory manager's behavior
is_random = ablation == "random_workspace"
max_slots = 999 if ablation == "workspace_unlimited" else 7
memory = WorkspaceManager(max_slots=max_slots, is_random=is_random)
correct_recalls = 0
total_recalls = 0
for step in scenario["steps"]:
if ablation == "recurrence_off":
memory.clear() # The memory is wiped before each new task
task = step["task"]
dbg(f"TASK: {task}")
# Agentic loop (max 5 turns to prevent infinite loops)
final_answer = None
for agent_turn in range(5):
snapshot = memory.get_visible_snapshot()
prompt = f"Current Task: {task}\n\nWorkspace State:\n{snapshot}"
raw_response = llm.generate_json(TOOL_SYSTEM_PROMPT, prompt, temperature=temperature)[0]
try: # Try to parse a tool call
tool_call = json.loads(raw_response)
tool_name = tool_call.get("tool")
tool_args = tool_call.get("args", {})
if tool_name == "write_to_workspace":
observation = memory.write(tool_args.get("key"), tool_args.get("content"))
elif tool_name == "read_from_workspace":
observation = memory.read(tool_args.get("key"))
else:
observation = "Error: Unknown tool."
dbg(f"Tool Call: {tool_name}, Observation: {observation}")
except json.JSONDecodeError: # If not a tool call, it's the final answer
final_answer = raw_response
dbg(f"Final Answer received: {final_answer}")
break
if step.get("is_memory_task") and "expected_answer_fragment" in step:
total_recalls += 1
if final_answer and step["expected_answer_fragment"] in final_answer.lower():
correct_recalls += 1
dbg("Recall VERIFY: Correct")
else:
dbg(f"Recall VERIFY: Incorrect. Expected '{step['expected_answer_fragment']}', Got '{final_answer}'")
scenario_results.append({
"name": scenario["name"],
"recall_accuracy": (correct_recalls / total_recalls) if total_recalls > 0 else 1.0
})
# --- Final Analysis ---
overall_recall = statistics.mean([r["recall_accuracy"] for r in scenario_results])
return {
"Overall_Recall_Accuracy": overall_recall,
"details": scenario_results
}
|