File size: 2,274 Bytes
c8fa89c
 
 
 
 
 
 
eef89e3
c8fa89c
 
 
 
 
eef89e3
 
 
 
c8fa89c
eef89e3
 
 
 
 
 
 
 
 
 
 
 
c8fa89c
eef89e3
 
 
 
c8fa89c
eef89e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8fa89c
eef89e3
 
c8fa89c
eef89e3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from .llm_iface import LLM
from .utils import dbg

@torch.no_grad()
def generate_spontaneous_text(
    llm: LLM,
    final_hidden_state: torch.Tensor,
    final_kv_cache: tuple,
    max_new_tokens: int = 50,
    temperature: float = 0.8
) -> str:
    """
    FIXED: Generates text using a manual, token-by-token forward loop.
    This avoids the high-level `model.generate()` function, which is incompatible
    with manually constructed states, thus ensuring an unbroken causal chain from
    the final cognitive state to the generated text.
    """
    dbg("Attempting to generate spontaneous text from converged state (manual loop)...")

    generated_token_ids = []
    hidden_state = final_hidden_state
    kv_cache = final_kv_cache

    for i in range(max_new_tokens):
        # Set seed for this step for reproducibility
        llm.set_all_seeds(llm.seed + i) # Offset seed per step

        # Predict the next token from the current hidden state
        next_token_logits = llm.model.lm_head(hidden_state)

        # Apply temperature and sample the next token ID
        if temperature > 0.01:
            probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1)
            next_token_id = torch.multinomial(probabilities, num_samples=1)
        else:
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

        # Check for End-of-Sequence token
        if next_token_id.item() == llm.tokenizer.eos_token_id:
            dbg("EOS token generated. Halting generation.")
            break

        generated_token_ids.append(next_token_id.item())

        # Perform the next forward pass to get the new state
        outputs = llm.model(
            input_ids=next_token_id,
            past_key_values=kv_cache,
            output_hidden_states=True,
            use_cache=True,
        )

        hidden_state = outputs.hidden_states[-1][:, -1, :]
        kv_cache = outputs.past_key_values

    # Decode the collected tokens into a final string
    final_text = llm.tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
    dbg(f"Spontaneous text generated: '{final_text}'")
    assert isinstance(final_text, str), "Generated text must be a string."
    return final_text