| class Edit: | |
| def __init__(self, ablator, vanilla_pre_forward_dict: Callable[[str, int], dict], | |
| vanilla_forward_dict: Callable[[str, int], dict], | |
| ablated_pre_forward_dict: Callable[[str, int], dict], | |
| ablated_forward_dict: Callable[[str, int], dict],): | |
| self.ablator=ablator | |
| self.vanilla_seed = 42 | |
| self.vanilla_pre_forward_dict = vanilla_pre_forward_dict | |
| self.vanilla_forward_dict = vanilla_forward_dict | |
| self.ablated_seed = 42 | |
| self.ablated_pre_forward_dict = ablated_pre_forward_dict | |
| self.ablated_forward_dict = ablated_forward_dict | |
| def get_edit(name: str): | |
| if name == "edit_streams": | |
| ablator = TransformerActivationCache() | |
| stream: str = kwargs["stream"] | |
| layers = kwargs["layers"] | |
| edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"] | |
| interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19} | |
| interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19}) | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_forward_dict=lambda block_type, layer_num: interventions, | |
| ) | |
| """ | |
| def get_ablation(name: str, **kwargs): | |
| if name == "intermediate_text_stream_to_input": | |
| ablator = TransformerActivationCache() | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream="text")}, | |
| ablated_forward_dict=lambda block_type, layer_num: {}) | |
| elif name == "input_to_intermediate_text_stream": | |
| ablator = TransformerActivationCache() | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.cache_attention_activation(*args, full_output=True)}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.replace_stream_input(*args, stream="text")}, | |
| ablated_forward_dict=lambda block_type, layer_num: {}) | |
| elif name == "set_input_text": | |
| tensor: torch.Tensor = kwargs["tensor"] | |
| ablator = TransformerActivationCache() | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream="text")}, | |
| ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)}) | |
| elif name == "replace_text_stream_activation": | |
| ablator = AttentionAblationCacheHook() | |
| weight = kwargs["weight"] if "weight" in kwargs else 1.0 | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward}, | |
| ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.set_ablated_attention(*args, weight=weight)}) | |
| elif name == "replace_text_stream": | |
| ablator = TransformerActivationCache() | |
| weight = kwargs["weight"] if "weight" in kwargs else 1.0 | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward}, | |
| ablated_forward_dict=lambda block_type, layer_num: {}) | |
| elif name == "input=output": | |
| return Ablation(None, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablate_block(*args)}) | |
| elif name == "reweight_text_stream": | |
| ablator = TransformerActivationCache() | |
| residual_w=kwargs["residual_w"] | |
| activation_w=kwargs["activation_w"] | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_text_stream(*args, residual_w=residual_w, activation_w=activation_w)}) | |
| elif name == "add_input_text": | |
| tensor: torch.Tensor = kwargs["tensor"] | |
| ablator = TransformerActivationCache() | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.add_text_stream_input(*args, use_tensor=tensor)}, | |
| ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)}) | |
| elif name == "nothing": | |
| ablator = TransformerActivationCache() | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_forward_dict=lambda block_type, layer_num: {}) | |
| elif name == "reweight_image_stream": | |
| ablator = TransformerActivationCache() | |
| residual_w=kwargs["residual_w"] | |
| activation_w=kwargs["activation_w"] | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_image_stream(*args, residual_w=residual_w, activation_w=activation_w)}) | |
| if name == "intermediate_image_stream_to_input": | |
| ablator = TransformerActivationCache() | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream='image')}, | |
| ablated_forward_dict=lambda block_type, layer_num: {}) | |
| elif name == "replace_text_stream_one_layer": | |
| ablator = AttentionAblationCacheHook() | |
| weight = kwargs["weight"] if "weight" in kwargs else 1.0 | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward}, | |
| ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.restore_text_stream}) | |
| elif name == "replace_intermediate_representation": | |
| ablator = TransformerActivationCache() | |
| tensor: torch.Tensor = kwargs["tensor"] | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream='text_image')}, | |
| ablated_forward_dict=lambda block_type, layer_num: {}) | |
| elif name == "destroy_registers": | |
| ablator = TransformerActivationCache() | |
| layers: List[int] = kwargs['layers'] | |
| k: float = kwargs["k"] | |
| stream: str = kwargs['stream'] | |
| random: bool = kwargs["random"] if "random" in kwargs else False | |
| lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers}, | |
| ablated_forward_dict=lambda block_type, layer_num: {}) | |
| elif name == "patch_registers": | |
| ablator = TransformerActivationCache() | |
| layers: List[int] = kwargs['layers'] | |
| k: float = kwargs["k"] | |
| stream: str = kwargs['stream'] | |
| random: bool = kwargs["random"] if "random" in kwargs else False | |
| lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.set_cached_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers}, | |
| ablated_forward_dict=lambda block_type, layer_num: {}) | |
| elif name == "add_registers": | |
| ablator = TransformerActivationCache() | |
| num_registers: int = kwargs["num_registers"] | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: insert_extra_registers(*args, num_registers=num_registers)}, | |
| ablated_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: discard_extra_registers(*args, num_registers=num_registers)},) | |
| elif name == "edit_streams": | |
| ablator = TransformerActivationCache() | |
| stream: str = kwargs["stream"] | |
| layers = kwargs["layers"] | |
| edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"] | |
| interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19} | |
| interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19}) | |
| return Ablation(ablator, | |
| vanilla_pre_forward_dict=lambda block_type, layer_num: {}, | |
| vanilla_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_pre_forward_dict=lambda block_type, layer_num: {}, | |
| ablated_forward_dict=lambda block_type, layer_num: interventions, | |
| ) | |
| """ |