File size: 8,664 Bytes
1af334e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import torch
from torch import nn
import torch.nn.functional as F
from transformers import LlamaForCausalLM, PreTrainedTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass

@dataclass
class SelfCorrectiveLlamaOutput(CausalLMOutputWithPast):
    hallucination_logits: torch.FloatTensor = None

class SelfCorrectiveLlama(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
        self.correction_cooldown = getattr(config, "correction_cooldown", 30)
        self.num_new_tokens = 2
        self.deletion_threshold = config.deletion_threshold if "deletion_threshold" in config else 0.7

        intermediate_size = config.intermediate_size
        self.hallucination_gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.hallucination_up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.hallucination_down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
        self.hallucination_detector = nn.Linear(config.hidden_size, self.num_new_tokens + 1)
        self.hallucination_norm = nn.LayerNorm(config.hidden_size)
    
    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
        # Get the full sequence of input IDs from the past, if available
        past_input_ids = kwargs.get("past_input_ids", None)

        # If past_input_ids exists, concatenate it with the new input_ids
        if past_input_ids is not None:
            input_ids = torch.cat([past_input_ids, input_ids], dim=-1)
        
        # Call the original prepare_inputs_for_generation method
        model_inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, **kwargs)

        # Update model_kwargs to include the full input_ids sequence for the next step
        model_inputs["past_input_ids"] = input_ids
        
        return model_inputs
    
    def forward(
        self, 
        input_ids, 
        attention_mask=None, 
        labels=None, 
        hallucination_labels=None,
        past_input_ids=None,
        **kwargs
    ):
        # Pass inputs through the base LLaMA model.
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        last_hidden = outputs.last_hidden_state

        # Calculate main token logits from the original lm_head.
        logits = self.lm_head(last_hidden)

        # Detach the hidden state to prevent gradients from the hallucination loss
        # from flowing back into the base model's LoRA adapters.
        detector_input = last_hidden.detach()

        # Pass the detached hidden state through the SwiGLU-based detector.
        gate_output = self.hallucination_gate_proj(detector_input)
        up_output = self.hallucination_up_proj(detector_input)
        gated_hidden = F.silu(gate_output) * up_output
        detector_hidden = self.hallucination_down_proj(gated_hidden)

        # Apply a residual connection and LayerNorm using the detached input.
        normalized_hidden = self.hallucination_norm(detector_hidden + detector_input)

        # Get the final hallucination logits from the detector head.
        hallucination_logits = self.hallucination_detector(normalized_hidden)

        # Return a custom output object with both sets of logits.
        return SelfCorrectiveLlamaOutput(
            loss=None, # Loss calculation is handled by the Trainer
            logits=logits,
            hallucination_logits=hallucination_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=None,
            attentions=outputs.attentions
        )

    def _initialize_instruction_tokens(self, tokenizer: PreTrainedTokenizer):
        """A helper to tokenize and cache instruction phrases."""
        if not hasattr(self, "rewrite_sentence_ids"):
            self.rewrite_sentence_ids = tokenizer(
                "[rewrite sentence]",
                return_tensors="pt",
                add_special_tokens=False,
            ).input_ids.to(self.device)
        if not hasattr(self, "rewrite_response_ids"):
            self.rewrite_response_ids = tokenizer(
                "[rewrite response]",
                return_tensors="pt",
                add_special_tokens=False,
            ).input_ids.to(self.device)

    @torch.no_grad()
    def generate(
        self,
        input_ids,
        tokenizer: PreTrainedTokenizer,
        max_new_tokens=512,
        temperature=0.3,
        **kwargs,
    ):
        """
        Custom generate method to orchestrate self-correction.

        NOTE: This implementation currently only supports a batch size of 1.
        """
        # Set the model to evaluation mode and cache instruction tokens.
        self.eval() 
        self._initialize_instruction_tokens(tokenizer)

        # Initialize the sequence with the prompt and its attention mask.
        generated_ids = input_ids
        attention_mask = torch.ones_like(input_ids)
        
        # Initialize a counter to track tokens since the last correction.
        # Start it at the cooldown value to allow immediate correction if needed.
        tokens_since_correction = self.correction_cooldown
        
        # The first forward pass processes the prompt and gets the initial KV cache.
        outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=None,
            return_dict=True,
            use_cache=True,
        )
        past_key_values = outputs.past_key_values
        
        # Start the generation loop with the logits for the token after the prompt.
        next_token_logits = outputs.logits[:, -1, :]
        hallucination_logits = outputs.hallucination_logits[:, -1, :]

        # Autoregressively generate tokens one by one.
        for _ in range(max_new_tokens):
            # Apply softmax to get hallucination probabilities.
            hallucination_probs = F.softmax(hallucination_logits, dim=-1)

            # Check if the cooldown period has passed.
            can_correct = tokens_since_correction >= self.correction_cooldown

            # Conditionally choose the next tokens based on the detector's output and the cooldown.
            if can_correct and hallucination_probs[0, 1] > self.deletion_threshold:
                current_tokens = self.rewrite_sentence_ids
                tokens_since_correction = 0 # Reset the counter
            elif can_correct and hallucination_probs[0, 2] > self.deletion_threshold:
                current_tokens = self.rewrite_response_ids
                tokens_since_correction = 0 # Reset the counter
            else:
                if temperature > 0.0:
                    scaled_logits = next_token_logits / temperature
                    probs = F.softmax(scaled_logits, dim=-1)
                    current_tokens = torch.multinomial(probs, num_samples=1)
                else:
                    current_tokens = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
                
                # Increment the counter by the number of tokens just generated.
                tokens_since_correction += current_tokens.shape[1]
            
            generated_ids = torch.cat([generated_ids, current_tokens], dim=-1)
            
            # Stop generating if an EOS token is produced.
            if torch.any(current_tokens == tokenizer.eos_token_id):
                break

            # Prepare for the next iteration.
            cache_position = torch.arange(attention_mask.shape[1], attention_mask.shape[1] + current_tokens.shape[1], device=self.device)
            attention_mask = torch.cat(
                [attention_mask, torch.ones((1, current_tokens.shape[1]), device=self.device, dtype=torch.long)],
                dim=1
            )
            
            # Perform a forward pass with only the new tokens, the KV cache, and the correct positions.
            outputs = self(
                input_ids=current_tokens,
                attention_mask=attention_mask,
                cache_position=cache_position,
                past_key_values=past_key_values,
                return_dict=True,
                use_cache=True,
            )

            # Update the state for the next loop.
            past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            hallucination_logits = outputs.hallucination_logits[:, -1, :]
            
        
        # Return the final generated sequence.
        return generated_ids