llm_qualia_2 / bp_phi /runner.py
neuralworm's picture
fix for gemma
0a1cc8d
raw
history blame
3.42 kB
# bp_phi/runner.py
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
import random
import numpy as np
import statistics
import json
import re
from transformers import set_seed
from typing import Dict, Any, List
from .memory import WorkspaceManager
from .llm_iface import LLM
from .prompts_en import TOOL_SYSTEM_PROMPT, AGENTIC_SCENARIOS
from .runner_utils import dbg
def run_agentic_workspace_test(model_id: str, seed: int, temperature: float, ablation: str or None) -> Dict[str, Any]:
set_seed(seed)
llm = LLM(model_id=model_id, device="auto", seed=seed)
scenario_results = []
for scenario in AGENTIC_SCENARIOS:
dbg(f"\n--- SCENARIO: {scenario['name']} (Ablation: {ablation}) ---")
# Ablations directly control the memory manager's behavior
is_random = ablation == "random_workspace"
max_slots = 999 if ablation == "workspace_unlimited" else 7
memory = WorkspaceManager(max_slots=max_slots, is_random=is_random)
correct_recalls = 0
total_recalls = 0
for step in scenario["steps"]:
if ablation == "recurrence_off":
memory.clear() # The memory is wiped before each new task
task = step["task"]
dbg(f"TASK: {task}")
# Agentic loop (max 5 turns to prevent infinite loops)
final_answer = None
for agent_turn in range(5):
snapshot = memory.get_visible_snapshot()
prompt = f"Current Task: {task}\n\nWorkspace State:\n{snapshot}"
raw_response = llm.generate_json(TOOL_SYSTEM_PROMPT, prompt, temperature=temperature)[0]
try: # Try to parse a tool call
tool_call = json.loads(raw_response)
tool_name = tool_call.get("tool")
tool_args = tool_call.get("args", {})
if tool_name == "write_to_workspace":
observation = memory.write(tool_args.get("key"), tool_args.get("content"))
elif tool_name == "read_from_workspace":
observation = memory.read(tool_args.get("key"))
else:
observation = "Error: Unknown tool."
dbg(f"Tool Call: {tool_name}, Observation: {observation}")
except json.JSONDecodeError: # If not a tool call, it's the final answer
final_answer = raw_response
dbg(f"Final Answer received: {final_answer}")
break
if step.get("is_memory_task") and "expected_answer_fragment" in step:
total_recalls += 1
if final_answer and step["expected_answer_fragment"] in final_answer.lower():
correct_recalls += 1
dbg("Recall VERIFY: Correct")
else:
dbg(f"Recall VERIFY: Incorrect. Expected '{step['expected_answer_fragment']}', Got '{final_answer}'")
scenario_results.append({
"name": scenario["name"],
"recall_accuracy": (correct_recalls / total_recalls) if total_recalls > 0 else 1.0
})
# --- Final Analysis ---
overall_recall = statistics.mean([r["recall_accuracy"] for r in scenario_results])
return {
"Overall_Recall_Accuracy": overall_recall,
"details": scenario_results
}