neuralworm's picture
initial commit
c8fa89c
raw
history blame
4.64 kB
import torch
from typing import Optional, Tuple
from tqdm import tqdm
from .llm_iface import LLM
from .prompts import RESONANCE_PROMPTS
from .utils import dbg
@torch.no_grad()
def run_silent_cogitation(
llm: LLM,
prompt_type: str,
num_steps: int,
temperature: float,
injection_vector: Optional[torch.Tensor] = None,
injection_strength: float = 0.0,
injection_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, tuple, torch.Tensor, str]:
"""
Simulates the "silent thought" process and returns the final cognitive state
along with the reason for termination ('converged' or 'max_steps_reached').
Returns:
- final_hidden_state: The hidden state of the last generated token.
- final_kv_cache: The past_key_values cache after the final step.
- final_token_id: The ID of the last generated token.
- termination_reason: A string indicating why the loop ended.
"""
prompt = RESONANCE_PROMPTS[prompt_type]
inputs = llm.tokenizer(prompt, return_tensors="pt").to(llm.model.device)
# Initial forward pass to establish the starting state
outputs = llm.model(**inputs, output_hidden_states=True, use_cache=True)
hidden_state = outputs.hidden_states[-1][:, -1, :]
kv_cache = outputs.past_key_values
last_token_id = inputs.input_ids[:, -1].unsqueeze(-1)
previous_hidden_state = hidden_state.clone()
termination_reason = "max_steps_reached" # Default assumption
# Prepare injection if provided
hook_handle = None
if injection_vector is not None and injection_strength > 0:
# Move vector to the correct device and dtype once
injection_vector = injection_vector.to(device=llm.model.device, dtype=llm.model.dtype)
# Default to a middle layer if not specified
if injection_layer is None:
injection_layer = llm.config.num_hidden_layers // 2
dbg(f"Injection enabled: Layer {injection_layer}, Strength {injection_strength:.2f}, Vector Norm {torch.norm(injection_vector).item():.2f}")
# Define the hook function that performs the activation addition
def injection_hook(module, layer_input):
# layer_input is a tuple, the first element is the hidden state tensor
original_hidden_states = layer_input[0]
# Add the scaled vector to the hidden states
modified_hidden_states = original_hidden_states + (injection_vector * injection_strength)
return (modified_hidden_states,) + layer_input[1:]
# Main cognitive loop
for i in tqdm(range(num_steps), desc=f"Simulating Thought (Strength {injection_strength:.2f})", leave=False, bar_format="{l_bar}{bar:10}{r_bar}"):
# Predict the next token from the current hidden state
next_token_logits = llm.model.lm_head(hidden_state)
# Apply temperature and sample the next token ID
if temperature > 0.01:
probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1)
next_token_id = torch.multinomial(probabilities, num_samples=1)
else: # Use argmax for deterministic behavior at low temperatures
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
last_token_id = next_token_id
# --- Activation Injection via Hook ---
try:
if injection_vector is not None and injection_strength > 0:
target_layer = llm.model.model.layers[injection_layer]
hook_handle = target_layer.register_forward_pre_hook(injection_hook)
# Perform the next forward pass
outputs = llm.model(
input_ids=next_token_id,
past_key_values=kv_cache,
output_hidden_states=True,
use_cache=True,
)
finally:
# IMPORTANT: Always remove the hook after the forward pass
if hook_handle:
hook_handle.remove()
hook_handle = None
hidden_state = outputs.hidden_states[-1][:, -1, :]
kv_cache = outputs.past_key_values
# Check for convergence
delta = torch.norm(hidden_state - previous_hidden_state).item()
if delta < 1e-4 and i > 10: # Check for stability after a few initial steps
termination_reason = "converged"
dbg(f"State converged after {i+1} steps (delta={delta:.6f}).")
break
previous_hidden_state = hidden_state.clone()
dbg(f"Silent cogitation finished. Reason: {termination_reason}")
return hidden_state, kv_cache, last_token_id, termination_reason