neuralworm's picture
tests
eef89e3
raw
history blame
2.27 kB
import torch
from .llm_iface import LLM
from .utils import dbg
@torch.no_grad()
def generate_spontaneous_text(
llm: LLM,
final_hidden_state: torch.Tensor,
final_kv_cache: tuple,
max_new_tokens: int = 50,
temperature: float = 0.8
) -> str:
"""
FIXED: Generates text using a manual, token-by-token forward loop.
This avoids the high-level `model.generate()` function, which is incompatible
with manually constructed states, thus ensuring an unbroken causal chain from
the final cognitive state to the generated text.
"""
dbg("Attempting to generate spontaneous text from converged state (manual loop)...")
generated_token_ids = []
hidden_state = final_hidden_state
kv_cache = final_kv_cache
for i in range(max_new_tokens):
# Set seed for this step for reproducibility
llm.set_all_seeds(llm.seed + i) # Offset seed per step
# Predict the next token from the current hidden state
next_token_logits = llm.model.lm_head(hidden_state)
# Apply temperature and sample the next token ID
if temperature > 0.01:
probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1)
next_token_id = torch.multinomial(probabilities, num_samples=1)
else:
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
# Check for End-of-Sequence token
if next_token_id.item() == llm.tokenizer.eos_token_id:
dbg("EOS token generated. Halting generation.")
break
generated_token_ids.append(next_token_id.item())
# Perform the next forward pass to get the new state
outputs = llm.model(
input_ids=next_token_id,
past_key_values=kv_cache,
output_hidden_states=True,
use_cache=True,
)
hidden_state = outputs.hidden_states[-1][:, -1, :]
kv_cache = outputs.past_key_values
# Decode the collected tokens into a final string
final_text = llm.tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
dbg(f"Spontaneous text generated: '{final_text}'")
assert isinstance(final_text, str), "Generated text must be a string."
return final_text