neuralworm's picture
initial commit
c8fa89c
raw
history blame
4.67 kB
import torch
from .llm_iface import get_or_load_model
from .utils import dbg
def run_diagnostic_suite(model_id: str, seed: int) -> str:
"""
Führt eine Reihe von Selbsttests durch, um die mechanische Integrität des Experiments zu überprüfen.
Löst bei einem kritischen Fehler eine Exception aus, um die Ausführung zu stoppen.
"""
dbg("--- STARTING DIAGNOSTIC SUITE ---")
results = []
try:
# --- Setup ---
dbg("Loading model for diagnostics...")
llm = get_or_load_model(model_id, seed)
test_prompt = "Hello world"
inputs = llm.tokenizer(test_prompt, return_tensors="pt").to(llm.model.device)
# --- Test 1: Attention Output Verification ---
dbg("Running Test 1: Attention Output Verification...")
# This test ensures that 'eager' attention implementation is active, which is
# necessary for reliable hook functionality in many transformers versions.
outputs = llm.model(**inputs, output_attentions=True)
assert outputs.attentions is not None, "FAIL: `outputs.attentions` is None. 'eager' implementation is likely not active."
assert isinstance(outputs.attentions, tuple), "FAIL: `outputs.attentions` is not a tuple."
assert len(outputs.attentions) == llm.config.num_hidden_layers, "FAIL: Number of attention tuples does not match number of layers."
results.append("✅ Test 1: Attention Output PASSED")
dbg("Test 1 PASSED.")
# --- Test 2: Hook Causal Efficacy ---
dbg("Running Test 2: Hook Causal Efficacy Verification...")
# This is the most critical test. It verifies that our injection mechanism (via hooks)
# has a real, causal effect on the model's computation.
# Run 1: Get the baseline hidden state without any intervention
outputs_no_hook = llm.model(**inputs, output_hidden_states=True)
target_layer_idx = llm.config.num_hidden_layers // 2
state_no_hook = outputs_no_hook.hidden_states[target_layer_idx + 1].clone()
# Define a simple hook that adds a large, constant value
injection_value = 42.0
def test_hook_fn(module, layer_input):
modified_input = layer_input[0] + injection_value
return (modified_input,) + layer_input[1:]
target_layer = llm.model.model.layers[target_layer_idx]
handle = target_layer.register_forward_pre_hook(test_hook_fn)
# Run 2: Get the hidden state with the hook active
outputs_with_hook = llm.model(**inputs, output_hidden_states=True)
state_with_hook = outputs_with_hook.hidden_states[target_layer_idx + 1].clone()
handle.remove() # Clean up the hook immediately
# The core assertion: the hook MUST change the subsequent hidden state.
assert not torch.allclose(state_no_hook, state_with_hook), \
"FAIL: Hook had no measurable effect on the subsequent layer's hidden state. Injections are not working."
results.append("✅ Test 2: Hook Causal Efficacy PASSED")
dbg("Test 2 PASSED.")
# --- Test 3: KV-Cache Integrity ---
dbg("Running Test 3: KV-Cache Integrity Verification...")
# This test ensures that the `past_key_values` are being passed and updated correctly,
# which is the core mechanic of the silent cogitation loop.
# Step 1: Initial pass with `use_cache=True`
outputs1 = llm.model(**inputs, use_cache=True)
kv_cache1 = outputs1.past_key_values
assert kv_cache1 is not None, "FAIL: KV-Cache was not generated in the first pass."
# Step 2: Second pass using the cache from step 1
next_token = torch.tensor([[123]], device=llm.model.device) # Arbitrary next token ID
outputs2 = llm.model(input_ids=next_token, past_key_values=kv_cache1, use_cache=True)
kv_cache2 = outputs2.past_key_values
original_seq_len = inputs.input_ids.shape[-1]
# The sequence length of the keys/values in the cache should have grown by 1
assert kv_cache2[0][0].shape[-2] == original_seq_len + 1, \
f"FAIL: KV-Cache sequence length did not update correctly. Expected {original_seq_len + 1}, got {kv_cache2[0][0].shape[-2]}."
results.append("✅ Test 3: KV-Cache Integrity PASSED")
dbg("Test 3 PASSED.")
# Clean up memory
del llm
if torch.cuda.is_available():
torch.cuda.empty_cache()
return "\n".join(results)
except Exception as e:
dbg(f"--- DIAGNOSTIC SUITE FAILED --- \n{traceback.format_exc()}")
# Re-raise the exception to be caught by the Gradio UI
raise e