Spaces:
Running
on
Zero
Running
on
Zero
| import os, warnings | |
| from operator import attrgetter | |
| from typing import List, Dict, Callable, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from torchtyping import TensorType | |
| from transformers import TextIteratorStreamer | |
| from transformers import AutoTokenizer, BatchEncoding | |
| import nnsight | |
| from nnsight import LanguageModel | |
| from nnsight.intervention import Envoy | |
| warnings.filterwarnings("ignore") | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # nnsight with multi-threading: https://github.com/ndif-team/nnsight/issues/280 | |
| nnsight.CONFIG.APP.GLOBAL_TRACING = False | |
| config = { | |
| "model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| "steering_vec": "activations/llama3-8b-steering-vec.pt", | |
| "offset": "activations/llama3-8b-offset.pt", | |
| "layer": 20, | |
| "k": (8.5, 6), | |
| } | |
| def detect_module_attrs(model: LanguageModel) -> str: | |
| if "model" in model._modules and "layers" in model.model._modules: | |
| return "model.layers" | |
| elif "transformers" in model._modules and "h" in model.transformers._modules: | |
| return "transformers.h" | |
| else: | |
| raise Exception("Failed to detect module attributes.") | |
| class ModelBase: | |
| def __init__( | |
| self, model_name: str, | |
| steering_vec: TensorType, offset: TensorType, | |
| k: Tuple[float, float], steering_layer: int, | |
| tokenizer: AutoTokenizer = None, block_module_attr=None | |
| ): | |
| if tokenizer is None: | |
| self.tokenizer = self._load_tokenizer(model_name) | |
| else: | |
| self.tokenizer = tokenizer | |
| self.model = self._load_model(model_name, self.tokenizer) | |
| self.device = self.model.device | |
| self.hidden_size = self.model.config.hidden_size | |
| if block_module_attr is None: | |
| self.block_modules = self.get_module(detect_module_attrs(self.model)) | |
| else: | |
| self.block_modules = self.get_module(block_module_attr) | |
| self.steering_layer = steering_layer | |
| self.k = k | |
| self.unit_vec = F.normalize(steering_vec, dim=-1) | |
| self.unit_vec, self.offset = self.set_dtype(self.unit_vec, offset) | |
| def _load_model(self, model_name: str, tokenizer: AutoTokenizer) -> LanguageModel: | |
| return LanguageModel(model_name, tokenizer=tokenizer, dispatch=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16) | |
| def _load_tokenizer(self, model_name) -> AutoTokenizer: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| tokenizer.padding_side = "left" | |
| if not tokenizer.pad_token: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return tokenizer | |
| def tokenize(self, prompt: str) -> BatchEncoding: | |
| return self.tokenizer(prompt, padding=True, truncation=False, return_tensors="pt") | |
| def get_module(self, attr: str) -> Envoy: | |
| return attrgetter(attr)(self.model) | |
| def set_dtype(self, *vars): | |
| if len(vars) == 1: | |
| return vars[0].to(self.model.dtype) | |
| else: | |
| return (var.to(self.model.dtype) for var in vars) | |
| def apply_chat_template(self, instruction: str) -> List[str]: | |
| messages = [{"role": "user", "content": instruction}] | |
| return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| def generate(self, prompt: str, streamer: TextIteratorStreamer, steering: bool, coeff: float, generation_config: Dict): | |
| formatted_prompt = self.apply_chat_template(prompt) | |
| inputs = self.tokenize(formatted_prompt) | |
| if steering: | |
| if coeff < 0: | |
| k = self.k[0] | |
| else: | |
| k = self.k[1] | |
| with self.model.generate(inputs, do_sample=True, streamer=streamer, **generation_config): | |
| self.block_modules.all() | |
| acts = self.block_modules[self.steering_layer].output[0].clone() | |
| proj = (acts - self.offset) @ self.unit_vec.unsqueeze(-1) * self.unit_vec # Orthogonal Projection | |
| self.block_modules[self.steering_layer].output[0][:] = acts - proj + coeff * k * self.unit_vec | |
| else: | |
| inputs = inputs.to(self.device) | |
| _ = self.model._model.generate(**inputs, do_sample=True, streamer=streamer, **generation_config) | |
| def load_model() -> ModelBase: | |
| steering_vec = torch.load(config['steering_vec'], weights_only=True) | |
| offset = torch.load(config['offset'], weights_only=True) | |
| model = ModelBase(config['model_name'], steering_vec=steering_vec, offset=offset, k=config['k'], steering_layer=config['layer']) | |
| return model | |