|
|
import torch |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
|
|
from transformers.models.llama.configuration_llama import LlamaConfig |
|
|
from transformers.models.llama.modeling_llama import LlamaModel |
|
|
|
|
|
|
|
|
class LlamaBidirectionalConfig(LlamaConfig): |
|
|
model_type = "llama_bidirec" |
|
|
|
|
|
def __init__(self, pooling="avg", temperature=1.0, **kwargs): |
|
|
self.pooling = pooling |
|
|
self.temperature = temperature |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
class LlamaBidirectionalModel(LlamaModel): |
|
|
config_class = LlamaBidirectionalConfig |
|
|
|
|
|
def __init__(self, config: LlamaConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
for layer in self.layers: |
|
|
layer.self_attn.is_causal = False |
|
|
|
|
|
def _update_causal_mask( |
|
|
self, |
|
|
attention_mask: torch.Tensor, |
|
|
input_tensor: torch.Tensor, |
|
|
cache_position: torch.Tensor, |
|
|
past_key_values: Cache, |
|
|
output_attentions: bool, |
|
|
): |
|
|
assert self.config._attn_implementation in [ |
|
|
"flash_attention_2", |
|
|
"eager", |
|
|
], ( |
|
|
f"Unsupported attention implementation: " |
|
|
f"{self.config._attn_implementation}, " |
|
|
f"only support flash_attention_2 or eager" |
|
|
) |
|
|
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
|
if attention_mask is not None and (attention_mask == 0.0).any(): |
|
|
return attention_mask |
|
|
return None |
|
|
elif self.config._attn_implementation == "eager": |
|
|
|
|
|
causal_mask = _prepare_4d_attention_mask( |
|
|
attention_mask, |
|
|
dtype=input_tensor.dtype, |
|
|
) |
|
|
return causal_mask |
|
|
|