|
|
import torch |
|
|
import numpy as np |
|
|
from typing import Optional, List, Dict, Any, Tuple |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .llm_iface import LLM |
|
|
from .prompts import RESONANCE_PROMPTS |
|
|
from .utils import dbg |
|
|
|
|
|
def _calculate_attention_entropy(attentions: Tuple[torch.Tensor, ...]) -> float: |
|
|
""" |
|
|
Berechnet die mittlere Entropie der Attention-Verteilungen. |
|
|
Ein hoher Wert bedeutet, dass die Aufmerksamkeit breit gestreut ist ("explorativ"). |
|
|
Ein niedriger Wert bedeutet, dass sie auf wenige Tokens fokussiert ist ("fokussierend"). |
|
|
""" |
|
|
total_entropy = 0.0 |
|
|
num_heads = 0 |
|
|
|
|
|
|
|
|
for layer_attention in attentions: |
|
|
|
|
|
|
|
|
|
|
|
attention_probs = layer_attention[:, :, -1, :] |
|
|
|
|
|
|
|
|
attention_probs = attention_probs + 1e-9 |
|
|
|
|
|
|
|
|
log_probs = torch.log2(attention_probs) |
|
|
entropy_per_head = -torch.sum(attention_probs * log_probs, dim=-1) |
|
|
|
|
|
total_entropy += torch.sum(entropy_per_head).item() |
|
|
num_heads += attention_probs.shape[1] |
|
|
|
|
|
return total_entropy / num_heads if num_heads > 0 else 0.0 |
|
|
|
|
|
@torch.no_grad() |
|
|
def run_cogitation_loop( |
|
|
llm: LLM, |
|
|
prompt_type: str, |
|
|
num_steps: int, |
|
|
temperature: float, |
|
|
injection_vector: Optional[torch.Tensor] = None, |
|
|
injection_strength: float = 0.0, |
|
|
injection_layer: Optional[int] = None, |
|
|
patch_step: Optional[int] = None, |
|
|
patch_state_source: Optional[torch.Tensor] = None, |
|
|
reset_kv_cache_on_patch: bool = False, |
|
|
record_states: bool = False, |
|
|
record_attentions: bool = False, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Eine verallgemeinerte Version, die nun auch die Aufzeichnung von Attention-Mustern |
|
|
und die Berechnung der Entropie unterstützt. |
|
|
""" |
|
|
prompt = RESONANCE_PROMPTS[prompt_type] |
|
|
inputs = llm.tokenizer(prompt, return_tensors="pt").to(llm.model.device) |
|
|
|
|
|
outputs = llm.model(**inputs, output_hidden_states=True, use_cache=True, output_attentions=record_attentions) |
|
|
hidden_state_2d = outputs.hidden_states[-1][:, -1, :] |
|
|
kv_cache = outputs.past_key_values |
|
|
|
|
|
state_deltas: List[float] = [] |
|
|
state_history: List[torch.Tensor] = [] |
|
|
attention_entropies: List[float] = [] |
|
|
|
|
|
if record_attentions and outputs.attentions: |
|
|
attention_entropies.append(_calculate_attention_entropy(outputs.attentions)) |
|
|
|
|
|
for i in tqdm(range(num_steps), desc=f"Cognitive Loop ({prompt_type})", leave=False, bar_format="{l_bar}{bar:10}{r_bar}"): |
|
|
if i == patch_step and patch_state_source is not None: |
|
|
dbg(f"--- Applying Causal Surgery at step {i}: Patching state. ---") |
|
|
hidden_state_2d = patch_state_source.clone().to(device=llm.model.device, dtype=llm.model.dtype) |
|
|
if reset_kv_cache_on_patch: |
|
|
dbg("--- KV-Cache has been RESET as part of the intervention. ---") |
|
|
kv_cache = None |
|
|
|
|
|
if record_states: |
|
|
state_history.append(hidden_state_2d.cpu()) |
|
|
|
|
|
next_token_logits = llm.model.lm_head(hidden_state_2d) |
|
|
|
|
|
temp_to_use = temperature if temperature > 0.0 else 1.0 |
|
|
probabilities = torch.nn.functional.softmax(next_token_logits / temp_to_use, dim=-1) |
|
|
if temperature > 0.0: |
|
|
next_token_id = torch.multinomial(probabilities, num_samples=1) |
|
|
else: |
|
|
next_token_id = torch.argmax(probabilities, dim=-1).unsqueeze(-1) |
|
|
|
|
|
hook_handle = None |
|
|
if injection_vector is not None and injection_strength > 0: |
|
|
injection_vector = injection_vector.to(device=llm.model.device, dtype=llm.model.dtype) |
|
|
if injection_layer is None: |
|
|
injection_layer = llm.stable_config.num_layers // 2 |
|
|
|
|
|
def injection_hook(module: Any, layer_input: Any) -> Any: |
|
|
seq_len = layer_input[0].shape[1] |
|
|
injection_3d = injection_vector.unsqueeze(0).expand(1, seq_len, -1) |
|
|
modified_hidden_states = layer_input[0] + (injection_3d * injection_strength) |
|
|
return (modified_hidden_states,) + layer_input[1:] |
|
|
|
|
|
try: |
|
|
if injection_vector is not None and injection_strength > 0 and injection_layer is not None: |
|
|
assert 0 <= injection_layer < llm.stable_config.num_layers, f"Injection layer {injection_layer} is out of bounds." |
|
|
target_layer = llm.stable_config.layer_list[injection_layer] |
|
|
hook_handle = target_layer.register_forward_pre_hook(injection_hook) |
|
|
|
|
|
outputs = llm.model( |
|
|
input_ids=next_token_id, past_key_values=kv_cache, |
|
|
output_hidden_states=True, use_cache=True, |
|
|
output_attentions=record_attentions |
|
|
) |
|
|
finally: |
|
|
if hook_handle: |
|
|
hook_handle.remove() |
|
|
hook_handle = None |
|
|
|
|
|
new_hidden_state = outputs.hidden_states[-1][:, -1, :] |
|
|
kv_cache = outputs.past_key_values |
|
|
|
|
|
if record_attentions and outputs.attentions: |
|
|
attention_entropies.append(_calculate_attention_entropy(outputs.attentions)) |
|
|
|
|
|
delta = torch.norm(new_hidden_state - hidden_state_2d).item() |
|
|
state_deltas.append(delta) |
|
|
|
|
|
hidden_state_2d = new_hidden_state.clone() |
|
|
|
|
|
dbg(f"Cognitive loop finished after {num_steps} steps.") |
|
|
|
|
|
return { |
|
|
"state_deltas": state_deltas, |
|
|
"state_history": state_history, |
|
|
"attention_entropies": attention_entropies, |
|
|
"final_hidden_state": hidden_state_2d, |
|
|
"final_kv_cache": kv_cache, |
|
|
} |
|
|
|
|
|
def run_silent_cogitation_seismic( |
|
|
llm: LLM, |
|
|
prompt_type: str, |
|
|
num_steps: int, |
|
|
temperature: float, |
|
|
injection_vector: Optional[torch.Tensor] = None, |
|
|
injection_strength: float = 0.0, |
|
|
injection_layer: Optional[int] = None |
|
|
) -> List[float]: |
|
|
""" |
|
|
Ein abwärtskompatibler Wrapper, der die alte, einfachere Schnittstelle beibehält. |
|
|
Ruft den neuen, verallgemeinerten Loop auf und gibt nur die Deltas zurück. |
|
|
""" |
|
|
results = run_cogitation_loop( |
|
|
llm=llm, prompt_type=prompt_type, num_steps=num_steps, temperature=temperature, |
|
|
injection_vector=injection_vector, injection_strength=injection_strength, |
|
|
injection_layer=injection_layer |
|
|
) |
|
|
return results["state_deltas"] |