File size: 6,580 Bytes
a345062
16e19a3
 
a345062
 
 
 
 
 
16e19a3
 
 
 
 
 
 
 
d15bd24
16e19a3
 
 
 
 
 
d15bd24
16e19a3
 
d15bd24
 
16e19a3
 
d15bd24
16e19a3
 
d15bd24
16e19a3
 
a345062
2a78f31
a345062
 
 
 
21e8595
 
 
2a78f31
 
3bdc105
2a78f31
16e19a3
2a78f31
a345062
16e19a3
 
a345062
 
 
 
16e19a3
a345062
 
 
2a78f31
 
16e19a3
a345062
16e19a3
 
21e8595
2a78f31
 
 
 
3bdc105
 
 
d15bd24
2a78f31
 
 
a345062
d15bd24
 
7dac8c1
 
 
 
 
 
d15bd24
 
 
 
 
 
 
 
 
 
 
21e8595
16e19a3
d15bd24
 
 
 
 
21e8595
2a78f31
16e19a3
 
21e8595
 
d15bd24
21e8595
 
a345062
2a78f31
a345062
 
16e19a3
 
 
2a78f31
a345062
 
2a78f31
a345062
2a78f31
d15bd24
2a78f31
 
 
d15bd24
2a78f31
 
 
 
d15bd24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import numpy as np
from typing import Optional, List, Dict, Any, Tuple
from tqdm import tqdm

from .llm_iface import LLM
from .prompts import RESONANCE_PROMPTS
from .utils import dbg

def _calculate_attention_entropy(attentions: Tuple[torch.Tensor, ...]) -> float:
    """
    Berechnet die mittlere Entropie der Attention-Verteilungen.
    Ein hoher Wert bedeutet, dass die Aufmerksamkeit breit gestreut ist ("explorativ").
    Ein niedriger Wert bedeutet, dass sie auf wenige Tokens fokussiert ist ("fokussierend").
    """
    total_entropy = 0.0
    num_heads = 0
    
    # Iteriere über alle Layer
    for layer_attention in attentions:
        # layer_attention shape: [batch_size, num_heads, seq_len, seq_len]
        # Für unsere Zwecke ist batch_size=1, seq_len=1 (wir schauen nur auf das letzte Token)
        # Die relevante Verteilung ist die letzte Zeile der Attention-Matrix
        attention_probs = layer_attention[:, :, -1, :]
        
        # Stabilisiere die Logarithmus-Berechnung
        attention_probs = attention_probs + 1e-9
        
        # Entropie-Formel: - sum(p * log2(p))
        log_probs = torch.log2(attention_probs)
        entropy_per_head = -torch.sum(attention_probs * log_probs, dim=-1)
        
        total_entropy += torch.sum(entropy_per_head).item()
        num_heads += attention_probs.shape[1]
        
    return total_entropy / num_heads if num_heads > 0 else 0.0

@torch.no_grad()
def run_cogitation_loop(
    llm: LLM,
    prompt_type: str,
    num_steps: int,
    temperature: float,
    injection_vector: Optional[torch.Tensor] = None,
    injection_strength: float = 0.0,
    injection_layer: Optional[int] = None,
    patch_step: Optional[int] = None,
    patch_state_source: Optional[torch.Tensor] = None,
    reset_kv_cache_on_patch: bool = False,
    record_states: bool = False,
    record_attentions: bool = False,
) -> Dict[str, Any]:
    """
    Eine verallgemeinerte Version, die nun auch die Aufzeichnung von Attention-Mustern
    und die Berechnung der Entropie unterstützt.
    """
    prompt = RESONANCE_PROMPTS[prompt_type]
    inputs = llm.tokenizer(prompt, return_tensors="pt").to(llm.model.device)

    outputs = llm.model(**inputs, output_hidden_states=True, use_cache=True, output_attentions=record_attentions)
    hidden_state_2d = outputs.hidden_states[-1][:, -1, :]
    kv_cache = outputs.past_key_values

    state_deltas: List[float] = []
    state_history: List[torch.Tensor] = []
    attention_entropies: List[float] = []

    if record_attentions and outputs.attentions:
        attention_entropies.append(_calculate_attention_entropy(outputs.attentions))

    for i in tqdm(range(num_steps), desc=f"Cognitive Loop ({prompt_type})", leave=False, bar_format="{l_bar}{bar:10}{r_bar}"):
        if i == patch_step and patch_state_source is not None:
            dbg(f"--- Applying Causal Surgery at step {i}: Patching state. ---")
            hidden_state_2d = patch_state_source.clone().to(device=llm.model.device, dtype=llm.model.dtype)
            if reset_kv_cache_on_patch:
                dbg("--- KV-Cache has been RESET as part of the intervention. ---")
                kv_cache = None
        
        if record_states:
            state_history.append(hidden_state_2d.cpu())

        next_token_logits = llm.model.lm_head(hidden_state_2d)
        
        temp_to_use = temperature if temperature > 0.0 else 1.0 
        probabilities = torch.nn.functional.softmax(next_token_logits / temp_to_use, dim=-1)
        if temperature > 0.0:
            next_token_id = torch.multinomial(probabilities, num_samples=1)
        else:
            next_token_id = torch.argmax(probabilities, dim=-1).unsqueeze(-1)

        hook_handle = None
        if injection_vector is not None and injection_strength > 0:
            injection_vector = injection_vector.to(device=llm.model.device, dtype=llm.model.dtype)
            if injection_layer is None:
                injection_layer = llm.stable_config.num_layers // 2

            def injection_hook(module: Any, layer_input: Any) -> Any:
                seq_len = layer_input[0].shape[1]
                injection_3d = injection_vector.unsqueeze(0).expand(1, seq_len, -1)
                modified_hidden_states = layer_input[0] + (injection_3d * injection_strength)
                return (modified_hidden_states,) + layer_input[1:]

        try:
            if injection_vector is not None and injection_strength > 0 and injection_layer is not None:
                assert 0 <= injection_layer < llm.stable_config.num_layers, f"Injection layer {injection_layer} is out of bounds."
                target_layer = llm.stable_config.layer_list[injection_layer]
                hook_handle = target_layer.register_forward_pre_hook(injection_hook)

            outputs = llm.model(
                input_ids=next_token_id, past_key_values=kv_cache,
                output_hidden_states=True, use_cache=True,
                output_attentions=record_attentions
            )
        finally:
            if hook_handle: 
                hook_handle.remove()
                hook_handle = None

        new_hidden_state = outputs.hidden_states[-1][:, -1, :]
        kv_cache = outputs.past_key_values

        if record_attentions and outputs.attentions:
            attention_entropies.append(_calculate_attention_entropy(outputs.attentions))

        delta = torch.norm(new_hidden_state - hidden_state_2d).item()
        state_deltas.append(delta)

        hidden_state_2d = new_hidden_state.clone()

    dbg(f"Cognitive loop finished after {num_steps} steps.")
    
    return {
        "state_deltas": state_deltas,
        "state_history": state_history,
        "attention_entropies": attention_entropies,
        "final_hidden_state": hidden_state_2d,
        "final_kv_cache": kv_cache,
    }

def run_silent_cogitation_seismic(
    llm: LLM,
    prompt_type: str,
    num_steps: int,
    temperature: float,
    injection_vector: Optional[torch.Tensor] = None,
    injection_strength: float = 0.0,
    injection_layer: Optional[int] = None
) -> List[float]:
    """
    Ein abwärtskompatibler Wrapper, der die alte, einfachere Schnittstelle beibehält.
    Ruft den neuen, verallgemeinerten Loop auf und gibt nur die Deltas zurück.
    """
    results = run_cogitation_loop(
        llm=llm, prompt_type=prompt_type, num_steps=num_steps, temperature=temperature,
        injection_vector=injection_vector, injection_strength=injection_strength,
        injection_layer=injection_layer
    )
    return results["state_deltas"]