neuralworm's picture
v2.3
8489475
raw
history blame
2.47 kB
import torch
from typing import List
from tqdm import tqdm
from .llm_iface import LLM
from .utils import dbg
# Eine Liste neutraler Wörter zur Berechnung der Baseline-Aktivierung.
BASELINE_WORDS = [
"thing", "place", "idea", "person", "object", "time", "way", "day", "man", "world",
"life", "hand", "part", "child", "eye", "woman", "fact", "group", "case", "point"
]
# REFAKTORISIERUNG: Diese Funktion wird auf Modulebene verschoben, um sie testbar zu machen.
# Sie ist nun keine lokale Funktion innerhalb von `get_concept_vector` mehr.
@torch.no_grad()
def _get_last_token_hidden_state(llm: LLM, prompt: str) -> torch.Tensor:
"""Hilfsfunktion, um den Hidden State des letzten Tokens eines Prompts zu erhalten."""
inputs = llm.tokenizer(prompt, return_tensors="pt").to(llm.model.device)
with torch.no_grad():
outputs = llm.model(**inputs, output_hidden_states=True)
last_hidden_state = outputs.hidden_states[-1][0, -1, :].cpu()
assert last_hidden_state.shape == (llm.config.hidden_size,), \
f"Hidden state shape mismatch. Expected {(llm.config.hidden_size,)}, got {last_hidden_state.shape}"
return last_hidden_state
@torch.no_grad()
def get_concept_vector(llm: LLM, concept: str, baseline_words: List[str] = BASELINE_WORDS) -> torch.Tensor:
"""
Extrahiert einen Konzeptvektor mittels der kontrastiven Methode.
"""
dbg(f"Extracting contrastive concept vector for '{concept}'...")
prompt_template = "Here is a sentence about the concept of {}."
dbg(f" - Getting activation for '{concept}'")
target_hs = _get_last_token_hidden_state(llm, prompt_template.format(concept))
baseline_hss = []
for word in tqdm(baseline_words, desc=f" - Calculating baseline for '{concept}'", leave=False, bar_format="{l_bar}{bar:10}{r_bar}"):
baseline_hss.append(_get_last_token_hidden_state(llm, prompt_template.format(word)))
assert all(hs.shape == target_hs.shape for hs in baseline_hss), "Shape mismatch in baseline hidden states."
mean_baseline_hs = torch.stack(baseline_hss).mean(dim=0)
dbg(f" - Mean baseline vector computed with norm {torch.norm(mean_baseline_hs).item():.2f}")
concept_vector = target_hs - mean_baseline_hs
norm = torch.norm(concept_vector).item()
dbg(f"Concept vector for '{concept}' extracted with norm {norm:.2f}.")
assert torch.isfinite(concept_vector).all(), "Concept vector contains NaN or Inf values."
return concept_vector