|
|
import torch |
|
|
from typing import Optional, Tuple |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .llm_iface import LLM |
|
|
from .prompts import RESONANCE_PROMPTS |
|
|
from .utils import dbg |
|
|
|
|
|
@torch.no_grad() |
|
|
def run_silent_cogitation( |
|
|
llm: LLM, |
|
|
prompt_type: str, |
|
|
num_steps: int, |
|
|
temperature: float, |
|
|
injection_vector: Optional[torch.Tensor] = None, |
|
|
injection_strength: float = 0.0, |
|
|
injection_layer: Optional[int] = None, |
|
|
) -> Tuple[torch.Tensor, tuple, torch.Tensor, str]: |
|
|
""" |
|
|
Simulates the "silent thought" process and returns the final cognitive state |
|
|
along with the reason for termination ('converged' or 'max_steps_reached'). |
|
|
|
|
|
Returns: |
|
|
- final_hidden_state: The hidden state of the last generated token. |
|
|
- final_kv_cache: The past_key_values cache after the final step. |
|
|
- final_token_id: The ID of the last generated token. |
|
|
- termination_reason: A string indicating why the loop ended. |
|
|
""" |
|
|
prompt = RESONANCE_PROMPTS[prompt_type] |
|
|
inputs = llm.tokenizer(prompt, return_tensors="pt").to(llm.model.device) |
|
|
|
|
|
|
|
|
outputs = llm.model(**inputs, output_hidden_states=True, use_cache=True) |
|
|
|
|
|
hidden_state = outputs.hidden_states[-1][:, -1, :] |
|
|
kv_cache = outputs.past_key_values |
|
|
last_token_id = inputs.input_ids[:, -1].unsqueeze(-1) |
|
|
|
|
|
previous_hidden_state = hidden_state.clone() |
|
|
termination_reason = "max_steps_reached" |
|
|
|
|
|
|
|
|
hook_handle = None |
|
|
if injection_vector is not None and injection_strength > 0: |
|
|
|
|
|
injection_vector = injection_vector.to(device=llm.model.device, dtype=llm.model.dtype) |
|
|
|
|
|
|
|
|
if injection_layer is None: |
|
|
injection_layer = llm.config.num_hidden_layers // 2 |
|
|
|
|
|
dbg(f"Injection enabled: Layer {injection_layer}, Strength {injection_strength:.2f}, Vector Norm {torch.norm(injection_vector).item():.2f}") |
|
|
|
|
|
|
|
|
def injection_hook(module, layer_input): |
|
|
|
|
|
original_hidden_states = layer_input[0] |
|
|
|
|
|
modified_hidden_states = original_hidden_states + (injection_vector * injection_strength) |
|
|
return (modified_hidden_states,) + layer_input[1:] |
|
|
|
|
|
|
|
|
for i in tqdm(range(num_steps), desc=f"Simulating Thought (Strength {injection_strength:.2f})", leave=False, bar_format="{l_bar}{bar:10}{r_bar}"): |
|
|
|
|
|
next_token_logits = llm.model.lm_head(hidden_state) |
|
|
|
|
|
|
|
|
if temperature > 0.01: |
|
|
probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1) |
|
|
next_token_id = torch.multinomial(probabilities, num_samples=1) |
|
|
else: |
|
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
|
|
|
|
|
last_token_id = next_token_id |
|
|
|
|
|
|
|
|
try: |
|
|
if injection_vector is not None and injection_strength > 0: |
|
|
target_layer = llm.model.model.layers[injection_layer] |
|
|
hook_handle = target_layer.register_forward_pre_hook(injection_hook) |
|
|
|
|
|
|
|
|
outputs = llm.model( |
|
|
input_ids=next_token_id, |
|
|
past_key_values=kv_cache, |
|
|
output_hidden_states=True, |
|
|
use_cache=True, |
|
|
) |
|
|
finally: |
|
|
|
|
|
if hook_handle: |
|
|
hook_handle.remove() |
|
|
hook_handle = None |
|
|
|
|
|
hidden_state = outputs.hidden_states[-1][:, -1, :] |
|
|
kv_cache = outputs.past_key_values |
|
|
|
|
|
|
|
|
delta = torch.norm(hidden_state - previous_hidden_state).item() |
|
|
if delta < 1e-4 and i > 10: |
|
|
termination_reason = "converged" |
|
|
dbg(f"State converged after {i+1} steps (delta={delta:.6f}).") |
|
|
break |
|
|
|
|
|
previous_hidden_state = hidden_state.clone() |
|
|
|
|
|
dbg(f"Silent cogitation finished. Reason: {termination_reason}") |
|
|
return hidden_state, kv_cache, last_token_id, termination_reason |
|
|
|