|  | from __future__ import annotations | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  | from transformers import LlamaModel, LlamaConfig, PreTrainedTokenizer | 
					
						
						|  | from transformers.modeling_attn_mask_utils import AttentionMaskConverter | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DramaModel(LlamaModel): | 
					
						
						|  | """ | 
					
						
						|  | DramaModel is a modified version of the LlamaModel that supports bi-directional attention | 
					
						
						|  | and provides query and document encoding functionalities. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: LlamaConfig): | 
					
						
						|  | """ | 
					
						
						|  | Initializes the DramaModel by disabling causal masking in self-attention layers. | 
					
						
						|  | """ | 
					
						
						|  | super().__init__(config) | 
					
						
						|  | for layer in self.layers: | 
					
						
						|  | layer.self_attn.is_causal = False | 
					
						
						|  |  | 
					
						
						|  | self.query_prefix = "Query: " | 
					
						
						|  | self.max_seq_len = 8192 | 
					
						
						|  | self.hidden_size = config.hidden_size | 
					
						
						|  |  | 
					
						
						|  | def _update_causal_mask( | 
					
						
						|  | self, | 
					
						
						|  | attention_mask: torch.Tensor, | 
					
						
						|  | input_tensor: torch.Tensor, | 
					
						
						|  | cache_position: torch.Tensor, | 
					
						
						|  | past_seen_tokens=None, | 
					
						
						|  | output_attentions=False, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Updates the causal mask for attention computations. | 
					
						
						|  | """ | 
					
						
						|  | if self.config._attn_implementation == "flash_attention_2": | 
					
						
						|  | if attention_mask is not None and (attention_mask == 0.0).any(): | 
					
						
						|  | return attention_mask | 
					
						
						|  | return None | 
					
						
						|  | if attention_mask is None or attention_mask.dim() == 4: | 
					
						
						|  | return attention_mask | 
					
						
						|  |  | 
					
						
						|  | return AttentionMaskConverter._expand_mask( | 
					
						
						|  | mask=attention_mask, | 
					
						
						|  | dtype=input_tensor.dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def _average_pool( | 
					
						
						|  | self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Computes the average pooled representation of the last hidden states. | 
					
						
						|  | """ | 
					
						
						|  | last_hidden = last_hidden_states.masked_fill( | 
					
						
						|  | ~attention_mask[..., None].bool(), 0.0 | 
					
						
						|  | ) | 
					
						
						|  | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | 
					
						
						|  |  | 
					
						
						|  | def _tokenize( | 
					
						
						|  | self, | 
					
						
						|  | tokenizer: PreTrainedTokenizer, | 
					
						
						|  | texts: list[str], | 
					
						
						|  | max_seq_len: int = None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Tokenizes input text sequences with optional sequence length restriction. | 
					
						
						|  | """ | 
					
						
						|  | if max_seq_len is None: | 
					
						
						|  | max_seq_len = self.max_seq_len | 
					
						
						|  | tokenized = tokenizer( | 
					
						
						|  | texts, | 
					
						
						|  | padding=True, | 
					
						
						|  | truncation=True, | 
					
						
						|  | max_length=max_seq_len, | 
					
						
						|  | return_tensors='pt', | 
					
						
						|  | ).to(self.device) | 
					
						
						|  | return tokenized | 
					
						
						|  |  | 
					
						
						|  | def encode(self, input_ids, attention_mask, dim, *args, **kwargs): | 
					
						
						|  | """ | 
					
						
						|  | Pass through the model and compute normalized embeddings. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_ids (torch.Tensor): Input token IDs. | 
					
						
						|  | attention_mask (torch.Tensor): Attention mask tensor. | 
					
						
						|  | dim (int): Dimensionality for output embeddings. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | torch.Tensor: Normalized output embeddings. | 
					
						
						|  | """ | 
					
						
						|  | outputs = self.forward( | 
					
						
						|  | input_ids, attention_mask, *args, **kwargs | 
					
						
						|  | ) | 
					
						
						|  | embeddings = self._average_pool( | 
					
						
						|  | outputs.last_hidden_state[:, :, :dim], attention_mask | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | embeddings = F.normalize(embeddings, p=2, dim=1) | 
					
						
						|  | return embeddings | 
					
						
						|  |  | 
					
						
						|  | def encode_queries( | 
					
						
						|  | self, | 
					
						
						|  | tokenizer: PreTrainedTokenizer, | 
					
						
						|  | queries: list[str], | 
					
						
						|  | max_seq_len: int = None, | 
					
						
						|  | dim: int = None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Encodes a list of queries into embeddings. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | tokenizer (PreTrainedTokenizer): Tokenizer for text processing. | 
					
						
						|  | queries (list[str]): List of query texts. | 
					
						
						|  | max_seq_len (int, optional): Maximum sequence length. | 
					
						
						|  | dim (int, optional): Dimensionality for output embeddings. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | torch.Tensor: Encoded query embeddings in shape (num_queries, dim). | 
					
						
						|  | """ | 
					
						
						|  | if not queries: | 
					
						
						|  | raise ValueError("queries must not be empty.") | 
					
						
						|  | if not isinstance(queries, list) or not all(isinstance(q, str) for q in queries): | 
					
						
						|  | raise ValueError("queries must be a list of strings.") | 
					
						
						|  | if tokenizer is None: | 
					
						
						|  | raise ValueError("tokenizer must not be None.") | 
					
						
						|  | if dim is not None and (dim < 1 or dim > self.hidden_size): | 
					
						
						|  | raise ValueError(f"dim must be in range [1, {self.hidden_size}].") | 
					
						
						|  | queries = [self.query_prefix + query for query in queries] | 
					
						
						|  | tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len) | 
					
						
						|  | embeddings = self.encode(**tokenized_queries, dim=dim) | 
					
						
						|  | return embeddings | 
					
						
						|  |  | 
					
						
						|  | def encode_documents( | 
					
						
						|  | self, | 
					
						
						|  | tokenizer: PreTrainedTokenizer, | 
					
						
						|  | documents: list[str], | 
					
						
						|  | max_seq_len: int = None, | 
					
						
						|  | dim: int = None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Encodes a list of documents into embeddings. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | tokenizer (PreTrainedTokenizer): Tokenizer for text processing. | 
					
						
						|  | documents (list[str]): List of document texts. | 
					
						
						|  | max_seq_len (int, optional): Maximum sequence length. | 
					
						
						|  | dim (int, optional): Dimensionality for output embeddings. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | torch.Tensor: Encoded document embeddings in shape (num_documents, dim). | 
					
						
						|  | """ | 
					
						
						|  | if not documents: | 
					
						
						|  | raise ValueError("documents must not be empty.") | 
					
						
						|  | if not isinstance(documents, list) or not all(isinstance(d, str) for d in documents): | 
					
						
						|  | raise ValueError("documents must be a list of strings.") | 
					
						
						|  | if tokenizer is None: | 
					
						
						|  | raise ValueError("tokenizer must not be None.") | 
					
						
						|  | if dim is not None and (dim < 1 or dim > self.hidden_size): | 
					
						
						|  | raise ValueError(f"dim must be in range [1, {self.hidden_size}].") | 
					
						
						|  | tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len) | 
					
						
						|  | embeddings = self.encode(**tokenized_documents, dim=dim) | 
					
						
						|  | return embeddings | 
					
						
						|  |  | 
					
						
						|  |  |