Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from safetensors.torch import load_file | |
| from transformers import AutoModel, AutoTokenizer | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load the model state_dict from safetensors | |
| def load_model_safetensors(model, load_path="model.safetensors"): | |
| # Load the safetensors file | |
| state_dict = load_file(load_path) | |
| # Load the state dict into the model | |
| model.load_state_dict(state_dict) | |
| return model | |
| ########################## | |
| # JINA EMBEDDINGS | |
| ########################## | |
| # Jina Configs | |
| JINA_CONTEXT_LEN = 1024 | |
| # Adapter for embeddings | |
| class Adapter(nn.Module): | |
| def __init__(self, hidden_size): | |
| super(Adapter, self).__init__() | |
| self.down_project = nn.Linear(hidden_size, hidden_size // 2) | |
| self.activation = nn.ReLU() | |
| self.up_project = nn.Linear(hidden_size // 2, hidden_size) | |
| def forward(self, x): | |
| down = self.down_project(x) | |
| activated = self.activation(down) | |
| up = self.up_project(activated) | |
| return up + x # Residual connection | |
| # Pool by attention score | |
| class AttentionPooling(nn.Module): | |
| def __init__(self, hidden_size): | |
| super(AttentionPooling, self).__init__() | |
| self.attention_weights = nn.Parameter(torch.randn(hidden_size)) | |
| def forward(self, hidden_states): | |
| # hidden_states: [seq_len, batch_size, hidden_size] | |
| scores = torch.matmul(hidden_states, self.attention_weights) | |
| attention_weights = torch.softmax(scores, dim=0) | |
| weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0) | |
| return weighted_sum | |
| # Custom bi-encoder model with MLP layers for interaction | |
| class CrossEncoderWithSharedBase(nn.Module): | |
| def __init__(self, base_model, num_labels=2, num_heads=8): | |
| super(CrossEncoderWithSharedBase, self).__init__() | |
| # Shared pre-trained model | |
| self.shared_encoder = base_model | |
| hidden_size = self.shared_encoder.config.hidden_size | |
| # Sentence-specific adapters | |
| self.adapter1 = Adapter(hidden_size) | |
| self.adapter2 = Adapter(hidden_size) | |
| # Cross-attention layers | |
| self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads) | |
| self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads) | |
| # Attention pooling layers | |
| self.attn_pooling_1_to_2 = AttentionPooling(hidden_size) | |
| self.attn_pooling_2_to_1 = AttentionPooling(hidden_size) | |
| # Projection layer with non-linearity | |
| self.projection_layer = nn.Sequential( | |
| nn.Linear(hidden_size * 2, hidden_size), | |
| nn.ReLU() | |
| ) | |
| # Classifier with three hidden layers | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_size // 2, hidden_size // 4), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_size // 4, num_labels) | |
| ) | |
| def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2): | |
| # Encode sentences | |
| outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1) | |
| outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2) | |
| # Apply sentence-specific adapters | |
| embeds1 = self.adapter1(outputs1.last_hidden_state) | |
| embeds2 = self.adapter2(outputs2.last_hidden_state) | |
| # Transpose for attention layers | |
| embeds1 = embeds1.transpose(0, 1) | |
| embeds2 = embeds2.transpose(0, 1) | |
| # Cross-attention | |
| cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2) | |
| cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1) | |
| # Attention pooling | |
| pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2) | |
| pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1) | |
| # Concatenate and project | |
| combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1) | |
| projected = self.projection_layer(combined) | |
| # Classification | |
| logits = self.classifier(projected) | |
| return logits | |
| # Prediction function | |
| def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device): | |
| model.eval() | |
| inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=1024) | |
| inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=1024) | |
| input_ids1 = inputs1['input_ids'].to(device) | |
| attention_mask1 = inputs1['attention_mask'].to(device) | |
| input_ids2 = inputs2['input_ids'].to(device) | |
| attention_mask2 = inputs2['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1, | |
| input_ids2=input_ids2, attention_mask2=attention_mask2) | |
| probabilities = torch.softmax(outputs, dim=1) | |
| predicted_label = torch.argmax(probabilities, dim=1).item() | |
| return predicted_label, probabilities.cpu().numpy() | |
| # Jina model | |
| JINA_MODEL_NAME = "jinaai/jina-embeddings-v2-small-en" | |
| jina_tokenizer = AutoTokenizer.from_pretrained(JINA_MODEL_NAME) | |
| jina_base_model = AutoModel.from_pretrained(JINA_MODEL_NAME) | |
| jina_model = CrossEncoderWithSharedBase(jina_base_model, num_labels=2) | |
| jina_model = load_model_safetensors(jina_model, load_path="models/jinaai-jina-embeddings-v2-small-en-TwinEncoder-OffTopic-Classifier-20240915-151858.safetensors") | |
| ########################## | |
| # CROSS-ENCODER | |
| ########################## | |
| # STSB Configs | |
| STSB_CONTEXT_LEN = 512 | |
| # ms-macro Configs | |
| MS_CONTEXT_LEN = 512 | |
| class CrossEncoderWithMLP(nn.Module): | |
| def __init__(self, base_model, num_labels=2): | |
| super(CrossEncoderWithMLP, self).__init__() | |
| # Existing cross-encoder model | |
| self.base_model = base_model | |
| # Hidden size of the base model | |
| hidden_size = base_model.config.hidden_size | |
| # MLP layers after combining the cross-encoders | |
| self.mlp = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), # Input: a single sentence | |
| nn.ReLU(), | |
| nn.Linear(hidden_size // 2, hidden_size // 4), # Reduce the size of the layer | |
| nn.ReLU() | |
| ) | |
| # Classifier head | |
| self.classifier = nn.Linear(hidden_size // 4, num_labels) | |
| def forward(self, input_ids, attention_mask): | |
| # Encode the pair of sentences in one pass | |
| outputs = self.base_model(input_ids, attention_mask) | |
| pooled_output = outputs.pooler_output | |
| # Pass the pooled output through mlp layers | |
| mlp_output = self.mlp(pooled_output) | |
| # Pass the final MLP output through the classifier | |
| logits = self.classifier(mlp_output) | |
| return logits | |
| def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, device): | |
| model.eval() | |
| # Tokenize the pair of sentences | |
| encoding = tokenizer( | |
| sentence1, sentence2, # Takes in a two sentences as a pair | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=512, | |
| return_token_type_ids=False | |
| ) | |
| # Extract the input_ids and attention mask | |
| input_ids = encoding["input_ids"].to(device) | |
| attention_mask = encoding["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) # Returns logits | |
| # Convert raw logits into probabilities for each class and get the predicted label | |
| probabilities = torch.softmax(outputs, dim=1) | |
| predicted_label = torch.argmax(probabilities, dim=1).item() | |
| return predicted_label, probabilities.cpu().numpy() | |
| # STSB model | |
| STSB_MODEL_NAME = "cross-encoder/stsb-roberta-base" | |
| stsb_tokenizer = AutoTokenizer.from_pretrained(STSB_MODEL_NAME) | |
| stsb_base_model = AutoModel.from_pretrained(STSB_MODEL_NAME) | |
| stsb_model = CrossEncoderWithMLP(stsb_base_model, num_labels=2) | |
| stsb_model = load_model_safetensors(stsb_model, load_path="models/cross-encoder-stsb-roberta-base-CrossEncoder-OffTopic-Classifier-20240920-174009.safetensors") | |
| # MS model | |
| MS_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| ms_tokenizer = AutoTokenizer.from_pretrained(MS_MODEL_NAME) | |
| ms_base_model = AutoModel.from_pretrained(MS_MODEL_NAME) | |
| ms_model = CrossEncoderWithMLP(ms_base_model, num_labels=2) | |
| ms_model = load_model_safetensors(ms_model, load_path="models/cross-encoder-ms-marco-MiniLM-L-6-v2-CrossEncoder-OffTopic-Classifier-20240918-090615.safetensors") | |