|
|
import torch |
|
|
from typing import List |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .llm_iface import LLM |
|
|
from .utils import dbg |
|
|
|
|
|
|
|
|
|
|
|
BASELINE_WORDS = [ |
|
|
"thing", "place", "idea", "person", "object", "time", "way", "day", "man", "world", |
|
|
"life", "hand", "part", "child", "eye", "woman", "fact", "group", "case", "point" |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_concept_vector(llm: LLM, concept: str, baseline_words: List[str] = BASELINE_WORDS) -> torch.Tensor: |
|
|
""" |
|
|
Extracts a concept vector using the contrastive method, inspired by Anthropic's research. |
|
|
It computes the activation for the target concept and subtracts the mean activation |
|
|
of several neutral baseline words to distill a more pure representation. |
|
|
""" |
|
|
dbg(f"Extracting contrastive concept vector for '{concept}'...") |
|
|
|
|
|
def get_last_token_hidden_state(prompt: str) -> torch.Tensor: |
|
|
"""Helper function to get the hidden state of the final token of a prompt.""" |
|
|
inputs = llm.tokenizer(prompt, return_tensors="pt").to(llm.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
outputs = llm.model(**inputs, output_hidden_states=True) |
|
|
|
|
|
last_hidden_state = outputs.hidden_states[-1][0, -1, :].cpu() |
|
|
assert last_hidden_state.shape == (llm.config.hidden_size,), \ |
|
|
f"Hidden state shape mismatch. Expected {(llm.config.hidden_size,)}, got {last_hidden_state.shape}" |
|
|
return last_hidden_state |
|
|
|
|
|
|
|
|
prompt_template = "Here is a sentence about the concept of {}." |
|
|
|
|
|
|
|
|
dbg(f" - Getting activation for '{concept}'") |
|
|
target_hs = get_last_token_hidden_state(prompt_template.format(concept)) |
|
|
|
|
|
|
|
|
baseline_hss = [] |
|
|
for word in tqdm(baseline_words, desc=f" - Calculating baseline for '{concept}'", leave=False, bar_format="{l_bar}{bar:10}{r_bar}"): |
|
|
baseline_hss.append(get_last_token_hidden_state(prompt_template.format(word))) |
|
|
|
|
|
assert all(hs.shape == target_hs.shape for hs in baseline_hss), "Shape mismatch in baseline hidden states." |
|
|
|
|
|
mean_baseline_hs = torch.stack(baseline_hss).mean(dim=0) |
|
|
dbg(f" - Mean baseline vector computed with norm {torch.norm(mean_baseline_hs).item():.2f}") |
|
|
|
|
|
|
|
|
concept_vector = target_hs - mean_baseline_hs |
|
|
norm = torch.norm(concept_vector).item() |
|
|
dbg(f"Concept vector for '{concept}' extracted with norm {norm:.2f}.") |
|
|
|
|
|
assert torch.isfinite(concept_vector).all(), "Concept vector contains NaN or Inf values." |
|
|
return concept_vector |
|
|
|