|
|
import torch |
|
|
from .llm_iface import LLM |
|
|
from .utils import dbg |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_spontaneous_text( |
|
|
llm: LLM, |
|
|
final_hidden_state: torch.Tensor, |
|
|
final_kv_cache: tuple, |
|
|
max_new_tokens: int = 50, |
|
|
temperature: float = 0.8 |
|
|
) -> str: |
|
|
""" |
|
|
FIXED: Generates text using a manual, token-by-token forward loop. |
|
|
This avoids the high-level `model.generate()` function, which is incompatible |
|
|
with manually constructed states, thus ensuring an unbroken causal chain from |
|
|
the final cognitive state to the generated text. |
|
|
""" |
|
|
dbg("Attempting to generate spontaneous text from converged state (manual loop)...") |
|
|
|
|
|
generated_token_ids = [] |
|
|
hidden_state = final_hidden_state |
|
|
kv_cache = final_kv_cache |
|
|
|
|
|
for i in range(max_new_tokens): |
|
|
|
|
|
llm.set_all_seeds(llm.seed + i) |
|
|
|
|
|
|
|
|
next_token_logits = llm.model.lm_head(hidden_state) |
|
|
|
|
|
|
|
|
if temperature > 0.01: |
|
|
probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1) |
|
|
next_token_id = torch.multinomial(probabilities, num_samples=1) |
|
|
else: |
|
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
|
|
|
|
|
|
|
|
if next_token_id.item() == llm.tokenizer.eos_token_id: |
|
|
dbg("EOS token generated. Halting generation.") |
|
|
break |
|
|
|
|
|
generated_token_ids.append(next_token_id.item()) |
|
|
|
|
|
|
|
|
outputs = llm.model( |
|
|
input_ids=next_token_id, |
|
|
past_key_values=kv_cache, |
|
|
output_hidden_states=True, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
hidden_state = outputs.hidden_states[-1][:, -1, :] |
|
|
kv_cache = outputs.past_key_values |
|
|
|
|
|
|
|
|
final_text = llm.tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip() |
|
|
dbg(f"Spontaneous text generated: '{final_text}'") |
|
|
assert isinstance(final_text, str), "Generated text must be a string." |
|
|
return final_text |
|
|
|