Spaces:
Sleeping
Sleeping
| import torch | |
| from .llm_iface import get_or_load_model | |
| from .utils import dbg | |
| def run_diagnostic_suite(model_id: str, seed: int) -> str: | |
| """ | |
| Führt eine Reihe von Selbsttests durch, um die mechanische Integrität des Experiments zu überprüfen. | |
| Löst bei einem kritischen Fehler eine Exception aus, um die Ausführung zu stoppen. | |
| """ | |
| dbg("--- STARTING DIAGNOSTIC SUITE ---") | |
| results = [] | |
| try: | |
| # --- Setup --- | |
| dbg("Loading model for diagnostics...") | |
| llm = get_or_load_model(model_id, seed) | |
| test_prompt = "Hello world" | |
| inputs = llm.tokenizer(test_prompt, return_tensors="pt").to(llm.model.device) | |
| # --- Test 1: Attention Output Verification --- | |
| dbg("Running Test 1: Attention Output Verification...") | |
| # This test ensures that 'eager' attention implementation is active, which is | |
| # necessary for reliable hook functionality in many transformers versions. | |
| outputs = llm.model(**inputs, output_attentions=True) | |
| assert outputs.attentions is not None, "FAIL: `outputs.attentions` is None. 'eager' implementation is likely not active." | |
| assert isinstance(outputs.attentions, tuple), "FAIL: `outputs.attentions` is not a tuple." | |
| assert len(outputs.attentions) == llm.config.num_hidden_layers, "FAIL: Number of attention tuples does not match number of layers." | |
| results.append("✅ Test 1: Attention Output PASSED") | |
| dbg("Test 1 PASSED.") | |
| # --- Test 2: Hook Causal Efficacy --- | |
| dbg("Running Test 2: Hook Causal Efficacy Verification...") | |
| # This is the most critical test. It verifies that our injection mechanism (via hooks) | |
| # has a real, causal effect on the model's computation. | |
| # Run 1: Get the baseline hidden state without any intervention | |
| outputs_no_hook = llm.model(**inputs, output_hidden_states=True) | |
| target_layer_idx = llm.config.num_hidden_layers // 2 | |
| state_no_hook = outputs_no_hook.hidden_states[target_layer_idx + 1].clone() | |
| # Define a simple hook that adds a large, constant value | |
| injection_value = 42.0 | |
| def test_hook_fn(module, layer_input): | |
| modified_input = layer_input[0] + injection_value | |
| return (modified_input,) + layer_input[1:] | |
| target_layer = llm.model.model.layers[target_layer_idx] | |
| handle = target_layer.register_forward_pre_hook(test_hook_fn) | |
| # Run 2: Get the hidden state with the hook active | |
| outputs_with_hook = llm.model(**inputs, output_hidden_states=True) | |
| state_with_hook = outputs_with_hook.hidden_states[target_layer_idx + 1].clone() | |
| handle.remove() # Clean up the hook immediately | |
| # The core assertion: the hook MUST change the subsequent hidden state. | |
| assert not torch.allclose(state_no_hook, state_with_hook), \ | |
| "FAIL: Hook had no measurable effect on the subsequent layer's hidden state. Injections are not working." | |
| results.append("✅ Test 2: Hook Causal Efficacy PASSED") | |
| dbg("Test 2 PASSED.") | |
| # --- Test 3: KV-Cache Integrity --- | |
| dbg("Running Test 3: KV-Cache Integrity Verification...") | |
| # This test ensures that the `past_key_values` are being passed and updated correctly, | |
| # which is the core mechanic of the silent cogitation loop. | |
| # Step 1: Initial pass with `use_cache=True` | |
| outputs1 = llm.model(**inputs, use_cache=True) | |
| kv_cache1 = outputs1.past_key_values | |
| assert kv_cache1 is not None, "FAIL: KV-Cache was not generated in the first pass." | |
| # Step 2: Second pass using the cache from step 1 | |
| next_token = torch.tensor([[123]], device=llm.model.device) # Arbitrary next token ID | |
| outputs2 = llm.model(input_ids=next_token, past_key_values=kv_cache1, use_cache=True) | |
| kv_cache2 = outputs2.past_key_values | |
| original_seq_len = inputs.input_ids.shape[-1] | |
| # The sequence length of the keys/values in the cache should have grown by 1 | |
| assert kv_cache2[0][0].shape[-2] == original_seq_len + 1, \ | |
| f"FAIL: KV-Cache sequence length did not update correctly. Expected {original_seq_len + 1}, got {kv_cache2[0][0].shape[-2]}." | |
| results.append("✅ Test 3: KV-Cache Integrity PASSED") | |
| dbg("Test 3 PASSED.") | |
| # Clean up memory | |
| del llm | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return "\n".join(results) | |
| except Exception as e: | |
| dbg(f"--- DIAGNOSTIC SUITE FAILED --- \n{traceback.format_exc()}") | |
| # Re-raise the exception to be caught by the Gradio UI | |
| raise e | |