Spaces:
Sleeping
Sleeping
| # bp_phi/runner.py | |
| import os | |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
| import torch | |
| import random | |
| import numpy as np | |
| import statistics | |
| import json | |
| import re | |
| from transformers import set_seed | |
| from typing import Dict, Any, List | |
| from .memory import WorkspaceManager | |
| from .llm_iface import LLM | |
| from .prompts_en import TOOL_SYSTEM_PROMPT, AGENTIC_SCENARIOS | |
| from .runner_utils import dbg | |
| def run_agentic_workspace_test(model_id: str, seed: int, temperature: float, ablation: str or None) -> Dict[str, Any]: | |
| set_seed(seed) | |
| llm = LLM(model_id=model_id, device="auto", seed=seed) | |
| scenario_results = [] | |
| for scenario in AGENTIC_SCENARIOS: | |
| dbg(f"\n--- SCENARIO: {scenario['name']} (Ablation: {ablation}) ---") | |
| # Ablations directly control the memory manager's behavior | |
| is_random = ablation == "random_workspace" | |
| max_slots = 999 if ablation == "workspace_unlimited" else 7 | |
| memory = WorkspaceManager(max_slots=max_slots, is_random=is_random) | |
| correct_recalls = 0 | |
| total_recalls = 0 | |
| for step in scenario["steps"]: | |
| if ablation == "recurrence_off": | |
| memory.clear() # The memory is wiped before each new task | |
| task = step["task"] | |
| dbg(f"TASK: {task}") | |
| # Agentic loop (max 5 turns to prevent infinite loops) | |
| final_answer = None | |
| for agent_turn in range(5): | |
| snapshot = memory.get_visible_snapshot() | |
| prompt = f"Current Task: {task}\n\nWorkspace State:\n{snapshot}" | |
| raw_response = llm.generate_json(TOOL_SYSTEM_PROMPT, prompt, temperature=temperature)[0] | |
| try: # Try to parse a tool call | |
| tool_call = json.loads(raw_response) | |
| tool_name = tool_call.get("tool") | |
| tool_args = tool_call.get("args", {}) | |
| if tool_name == "write_to_workspace": | |
| observation = memory.write(tool_args.get("key"), tool_args.get("content")) | |
| elif tool_name == "read_from_workspace": | |
| observation = memory.read(tool_args.get("key")) | |
| else: | |
| observation = "Error: Unknown tool." | |
| dbg(f"Tool Call: {tool_name}, Observation: {observation}") | |
| except json.JSONDecodeError: # If not a tool call, it's the final answer | |
| final_answer = raw_response | |
| dbg(f"Final Answer received: {final_answer}") | |
| break | |
| if step.get("is_memory_task") and "expected_answer_fragment" in step: | |
| total_recalls += 1 | |
| if final_answer and step["expected_answer_fragment"] in final_answer.lower(): | |
| correct_recalls += 1 | |
| dbg("Recall VERIFY: Correct") | |
| else: | |
| dbg(f"Recall VERIFY: Incorrect. Expected '{step['expected_answer_fragment']}', Got '{final_answer}'") | |
| scenario_results.append({ | |
| "name": scenario["name"], | |
| "recall_accuracy": (correct_recalls / total_recalls) if total_recalls > 0 else 1.0 | |
| }) | |
| # --- Final Analysis --- | |
| overall_recall = statistics.mean([r["recall_accuracy"] for r in scenario_results]) | |
| return { | |
| "Overall_Recall_Accuracy": overall_recall, | |
| "details": scenario_results | |
| } | |