File size: 3,035 Bytes
c8fa89c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f5b07d
 
c8fa89c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from typing import List
from tqdm import tqdm

from .llm_iface import LLM
from .utils import dbg

# A list of neutral, common words used to calculate a baseline activation.
# This helps to isolate the unique activation pattern of the target concept.
BASELINE_WORDS = [
    "thing", "place", "idea", "person", "object", "time", "way", "day", "man", "world",
    "life", "hand", "part", "child", "eye", "woman", "fact", "group", "case", "point"
]

@torch.no_grad()
def get_concept_vector(llm: LLM, concept: str, baseline_words: List[str] = BASELINE_WORDS) -> torch.Tensor:
    """
    Extracts a concept vector using the contrastive method, inspired by Anthropic's research.
    It computes the activation for the target concept and subtracts the mean activation
    of several neutral baseline words to distill a more pure representation.
    """
    dbg(f"Extracting contrastive concept vector for '{concept}'...")

    def get_last_token_hidden_state(prompt: str) -> torch.Tensor:
        """Helper function to get the hidden state of the final token of a prompt."""
        inputs = llm.tokenizer(prompt, return_tensors="pt").to(llm.model.device)
        # Ensure the operation does not build a computation graph
        with torch.no_grad():
            # KORREKTUR: Hier stand fälschlicherweise 'll.model'. Korrigiert zu 'llm.model'.
            outputs = llm.model(**inputs, output_hidden_states=True)
        # We take the hidden state from the last layer [-1], for the last token [0, -1, :]
        last_hidden_state = outputs.hidden_states[-1][0, -1, :].cpu()
        assert last_hidden_state.shape == (llm.config.hidden_size,), \
            f"Hidden state shape mismatch. Expected {(llm.config.hidden_size,)}, got {last_hidden_state.shape}"
        return last_hidden_state

    # A simple, neutral prompt template to elicit the concept
    prompt_template = "Here is a sentence about the concept of {}."

    # 1. Get activation for the target concept
    dbg(f"  - Getting activation for '{concept}'")
    target_hs = get_last_token_hidden_state(prompt_template.format(concept))

    # 2. Get activations for all baseline words and average them
    baseline_hss = []
    for word in tqdm(baseline_words, desc=f"  - Calculating baseline for '{concept}'", leave=False, bar_format="{l_bar}{bar:10}{r_bar}"):
        baseline_hss.append(get_last_token_hidden_state(prompt_template.format(word)))

    assert all(hs.shape == target_hs.shape for hs in baseline_hss), "Shape mismatch in baseline hidden states."

    mean_baseline_hs = torch.stack(baseline_hss).mean(dim=0)
    dbg(f"  - Mean baseline vector computed with norm {torch.norm(mean_baseline_hs).item():.2f}")

    # 3. The final concept vector is the difference
    concept_vector = target_hs - mean_baseline_hs
    norm = torch.norm(concept_vector).item()
    dbg(f"Concept vector for '{concept}' extracted with norm {norm:.2f}.")

    assert torch.isfinite(concept_vector).all(), "Concept vector contains NaN or Inf values."
    return concept_vector