import torch from typing import Optional, Tuple from tqdm import tqdm from .llm_iface import LLM from .prompts import RESONANCE_PROMPTS from .utils import dbg @torch.no_grad() def run_silent_cogitation( llm: LLM, prompt_type: str, num_steps: int, temperature: float, injection_vector: Optional[torch.Tensor] = None, injection_strength: float = 0.0, injection_layer: Optional[int] = None, ) -> Tuple[torch.Tensor, tuple, torch.Tensor, str]: """ Simulates the "silent thought" process and returns the final cognitive state along with the reason for termination ('converged' or 'max_steps_reached'). Returns: - final_hidden_state: The hidden state of the last generated token. - final_kv_cache: The past_key_values cache after the final step. - final_token_id: The ID of the last generated token. - termination_reason: A string indicating why the loop ended. """ prompt = RESONANCE_PROMPTS[prompt_type] inputs = llm.tokenizer(prompt, return_tensors="pt").to(llm.model.device) # Initial forward pass to establish the starting state outputs = llm.model(**inputs, output_hidden_states=True, use_cache=True) hidden_state = outputs.hidden_states[-1][:, -1, :] kv_cache = outputs.past_key_values last_token_id = inputs.input_ids[:, -1].unsqueeze(-1) previous_hidden_state = hidden_state.clone() termination_reason = "max_steps_reached" # Default assumption # Prepare injection if provided hook_handle = None if injection_vector is not None and injection_strength > 0: # Move vector to the correct device and dtype once injection_vector = injection_vector.to(device=llm.model.device, dtype=llm.model.dtype) # Default to a middle layer if not specified if injection_layer is None: injection_layer = llm.config.num_hidden_layers // 2 dbg(f"Injection enabled: Layer {injection_layer}, Strength {injection_strength:.2f}, Vector Norm {torch.norm(injection_vector).item():.2f}") # Define the hook function that performs the activation addition def injection_hook(module, layer_input): # layer_input is a tuple, the first element is the hidden state tensor original_hidden_states = layer_input[0] # Add the scaled vector to the hidden states modified_hidden_states = original_hidden_states + (injection_vector * injection_strength) return (modified_hidden_states,) + layer_input[1:] # Main cognitive loop for i in tqdm(range(num_steps), desc=f"Simulating Thought (Strength {injection_strength:.2f})", leave=False, bar_format="{l_bar}{bar:10}{r_bar}"): # Predict the next token from the current hidden state next_token_logits = llm.model.lm_head(hidden_state) # Apply temperature and sample the next token ID if temperature > 0.01: probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1) next_token_id = torch.multinomial(probabilities, num_samples=1) else: # Use argmax for deterministic behavior at low temperatures next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) last_token_id = next_token_id # --- Activation Injection via Hook --- try: if injection_vector is not None and injection_strength > 0: target_layer = llm.model.model.layers[injection_layer] hook_handle = target_layer.register_forward_pre_hook(injection_hook) # Perform the next forward pass outputs = llm.model( input_ids=next_token_id, past_key_values=kv_cache, output_hidden_states=True, use_cache=True, ) finally: # IMPORTANT: Always remove the hook after the forward pass if hook_handle: hook_handle.remove() hook_handle = None hidden_state = outputs.hidden_states[-1][:, -1, :] kv_cache = outputs.past_key_values # Check for convergence delta = torch.norm(hidden_state - previous_hidden_state).item() if delta < 1e-4 and i > 10: # Check for stability after a few initial steps termination_reason = "converged" dbg(f"State converged after {i+1} steps (delta={delta:.6f}).") break previous_hidden_state = hidden_state.clone() dbg(f"Silent cogitation finished. Reason: {termination_reason}") return hidden_state, kv_cache, last_token_id, termination_reason