Update scripts/gte_embedding.py
Browse files- scripts/gte_embedding.py +32 -68
    	
        scripts/gte_embedding.py
    CHANGED
    
    | @@ -1,42 +1,25 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
            import heapq
         | 
| 6 | 
            -
            import json
         | 
| 7 | 
            -
            import logging
         | 
| 8 | 
            -
            import os
         | 
| 9 | 
            -
            import queue
         | 
| 10 | 
            -
            import sys
         | 
| 11 | 
            -
            import time
         | 
| 12 | 
            -
            from tqdm import tqdm
         | 
| 13 |  | 
| 14 | 
            -
            import torch
         | 
| 15 | 
             
            from collections import defaultdict
         | 
| 16 | 
            -
            from  | 
| 17 | 
            -
            import numpy as np
         | 
| 18 | 
            -
            import torch.distributed as dist
         | 
| 19 | 
            -
            from torch import nn, Tensor
         | 
| 20 | 
            -
            import torch.nn.functional as F
         | 
| 21 | 
            -
            from transformers import AutoModel, AutoTokenizer
         | 
| 22 | 
            -
            from transformers.file_utils import ModelOutput
         | 
| 23 |  | 
| 24 | 
            -
             | 
|  | |
|  | |
|  | |
| 25 |  | 
| 26 |  | 
| 27 | 
            -
            class GTEEmbeddidng(nn.Module):
         | 
| 28 | 
             
                def __init__(self,
         | 
| 29 | 
             
                             model_name: str = None,
         | 
| 30 | 
             
                             normalized: bool = True,
         | 
| 31 | 
            -
                             pooling_method: str = 'cls',
         | 
| 32 | 
             
                             use_fp16: bool = True,
         | 
| 33 | 
             
                             device: str = None
         | 
| 34 | 
             
                            ):
         | 
| 35 | 
             
                    super().__init__()
         | 
| 36 | 
            -
                    self.load_model(model_name)
         | 
| 37 | 
            -
                    self.vocab_size = self.model.config.vocab_size
         | 
| 38 | 
             
                    self.normalized = normalized
         | 
| 39 | 
            -
                    self.pooling_method = pooling_method
         | 
| 40 | 
             
                    if device:
         | 
| 41 | 
             
                        self.device = torch.device(device)
         | 
| 42 | 
             
                    else:
         | 
| @@ -49,40 +32,13 @@ class GTEEmbeddidng(nn.Module): | |
| 49 | 
             
                        else:
         | 
| 50 | 
             
                            self.device = torch.device("cpu")
         | 
| 51 | 
             
                            use_fp16 = False
         | 
| 52 | 
            -
                    self. | 
| 53 | 
            -
                    self.sparse_linear.to(self.device)
         | 
| 54 | 
            -
                    if use_fp16:
         | 
| 55 | 
            -
                        self.model.half()
         | 
| 56 | 
            -
                        self.sparse_linear.half()
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                def load_model(self, model_name):
         | 
| 59 | 
            -
                    if not os.path.exists(model_name):
         | 
| 60 | 
            -
                        cache_folder = os.getenv('HF_HUB_CACHE')
         | 
| 61 | 
            -
                        model_name = snapshot_download(repo_id=model_name,
         | 
| 62 | 
            -
                                                       cache_dir=cache_folder,
         | 
| 63 | 
            -
                                                       ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                    self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
         | 
| 66 | 
            -
                    self.sparse_linear = torch.nn.Linear(in_features=self.model.config.hidden_size, out_features=1)
         | 
| 67 | 
             
                    self.tokenizer = AutoTokenizer.from_pretrained(model_name)
         | 
| 68 | 
            -
                    self.model. | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
                     | 
| 73 | 
            -
                        logger.warring('The parameters of  sparse linear is not found')
         | 
| 74 | 
            -
             | 
| 75 | 
            -
                def dense_embedding(self, hidden_state, mask):
         | 
| 76 | 
            -
                    if self.pooling_method == 'cls':
         | 
| 77 | 
            -
                        return hidden_state[:, 0]
         | 
| 78 | 
            -
                    elif self.pooling_method == 'mean':
         | 
| 79 | 
            -
                        s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
         | 
| 80 | 
            -
                        d = mask.sum(axis=1, keepdim=True).float()
         | 
| 81 | 
            -
                        return s / d
         | 
| 82 | 
            -
             | 
| 83 | 
            -
                def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True):
         | 
| 84 | 
            -
                    token_weights = torch.relu(self.sparse_linear(hidden_state))
         | 
| 85 | 
            -
                    return token_weights
         | 
| 86 |  | 
| 87 | 
             
                def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
         | 
| 88 | 
             
                    # conver to dict
         | 
| @@ -127,7 +83,7 @@ class GTEEmbeddidng(nn.Module): | |
| 127 |  | 
| 128 | 
             
                @torch.no_grad()
         | 
| 129 | 
             
                def _encode(self,
         | 
| 130 | 
            -
                            texts: Dict[str, Tensor] = None,
         | 
| 131 | 
             
                            dimension: int = None,
         | 
| 132 | 
             
                            max_length: int = 1024,
         | 
| 133 | 
             
                            batch_size: int = 16,
         | 
| @@ -136,27 +92,22 @@ class GTEEmbeddidng(nn.Module): | |
| 136 |  | 
| 137 | 
             
                    text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
         | 
| 138 | 
             
                    text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
         | 
| 139 | 
            -
                     | 
| 140 |  | 
| 141 | 
             
                    output = {}
         | 
| 142 | 
             
                    if return_dense:
         | 
| 143 | 
            -
                        dense_vecs =  | 
| 144 | 
            -
                        dense_vecs = dense_vecs[:, :dimension]
         | 
| 145 | 
             
                        if self.normalized:
         | 
| 146 | 
             
                            dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
         | 
| 147 | 
             
                        output['dense_embeddings'] = dense_vecs
         | 
| 148 | 
             
                    if return_sparse:
         | 
| 149 | 
            -
                        token_weights =  | 
| 150 | 
             
                        token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
         | 
| 151 | 
             
                                                                text_input['input_ids'].cpu().numpy().tolist()))
         | 
| 152 | 
             
                        output['token_weights'] = token_weights
         | 
| 153 |  | 
| 154 | 
             
                    return output
         | 
| 155 |  | 
| 156 | 
            -
                def load_pooler(self, model_dir):
         | 
| 157 | 
            -
                    sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')
         | 
| 158 | 
            -
                    self.sparse_linear.load_state_dict(sparse_state_dict)
         | 
| 159 | 
            -
                
         | 
| 160 | 
             
                def _compute_sparse_scores(self, embs1, embs2):
         | 
| 161 | 
             
                    scores = 0
         | 
| 162 | 
             
                    for token, weight in embs1.items():
         | 
| @@ -188,3 +139,16 @@ class GTEEmbeddidng(nn.Module): | |
| 188 | 
             
                        self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
         | 
| 189 | 
             
                    scores = scores.tolist()
         | 
| 190 | 
             
                    return scores
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2024 The GTE Team Authors and Alibaba Group.
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 4 |  | 
|  | |
| 5 | 
             
            from collections import defaultdict
         | 
| 6 | 
            +
            from typing import Dict, List, Tuple
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 7 |  | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            from transformers import AutoModelForTokenClassification, AutoTokenizer
         | 
| 11 | 
            +
            from transformers.utils import is_torch_npu_available
         | 
| 12 |  | 
| 13 |  | 
| 14 | 
            +
            class GTEEmbeddidng(torch.nn.Module):
         | 
| 15 | 
             
                def __init__(self,
         | 
| 16 | 
             
                             model_name: str = None,
         | 
| 17 | 
             
                             normalized: bool = True,
         | 
|  | |
| 18 | 
             
                             use_fp16: bool = True,
         | 
| 19 | 
             
                             device: str = None
         | 
| 20 | 
             
                            ):
         | 
| 21 | 
             
                    super().__init__()
         | 
|  | |
|  | |
| 22 | 
             
                    self.normalized = normalized
         | 
|  | |
| 23 | 
             
                    if device:
         | 
| 24 | 
             
                        self.device = torch.device(device)
         | 
| 25 | 
             
                    else:
         | 
|  | |
| 32 | 
             
                        else:
         | 
| 33 | 
             
                            self.device = torch.device("cpu")
         | 
| 34 | 
             
                            use_fp16 = False
         | 
| 35 | 
            +
                    self.use_fp16 = use_fp16
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 36 | 
             
                    self.tokenizer = AutoTokenizer.from_pretrained(model_name)
         | 
| 37 | 
            +
                    self.model = AutoModelForTokenClassification.from_pretrained(
         | 
| 38 | 
            +
                        model_name, trust_remote_code=True, torch_dtype=torch.float16 if self.use_fp16 else None
         | 
| 39 | 
            +
                    )
         | 
| 40 | 
            +
                    self.vocab_size = self.model.config.vocab_size
         | 
| 41 | 
            +
                    self.model.to(self.device)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 42 |  | 
| 43 | 
             
                def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
         | 
| 44 | 
             
                    # conver to dict
         | 
|  | |
| 83 |  | 
| 84 | 
             
                @torch.no_grad()
         | 
| 85 | 
             
                def _encode(self,
         | 
| 86 | 
            +
                            texts: Dict[str, torch.Tensor] = None,
         | 
| 87 | 
             
                            dimension: int = None,
         | 
| 88 | 
             
                            max_length: int = 1024,
         | 
| 89 | 
             
                            batch_size: int = 16,
         | 
|  | |
| 92 |  | 
| 93 | 
             
                    text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
         | 
| 94 | 
             
                    text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
         | 
| 95 | 
            +
                    model_out = self.model(**text_input, return_dict=True)
         | 
| 96 |  | 
| 97 | 
             
                    output = {}
         | 
| 98 | 
             
                    if return_dense:
         | 
| 99 | 
            +
                        dense_vecs = model_out.last_hidden_state[:, 0, :dimension]
         | 
|  | |
| 100 | 
             
                        if self.normalized:
         | 
| 101 | 
             
                            dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
         | 
| 102 | 
             
                        output['dense_embeddings'] = dense_vecs
         | 
| 103 | 
             
                    if return_sparse:
         | 
| 104 | 
            +
                        token_weights = torch.relu(model_out.logits).squeeze(-1)
         | 
| 105 | 
             
                        token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
         | 
| 106 | 
             
                                                                text_input['input_ids'].cpu().numpy().tolist()))
         | 
| 107 | 
             
                        output['token_weights'] = token_weights
         | 
| 108 |  | 
| 109 | 
             
                    return output
         | 
| 110 |  | 
|  | |
|  | |
|  | |
|  | |
| 111 | 
             
                def _compute_sparse_scores(self, embs1, embs2):
         | 
| 112 | 
             
                    scores = 0
         | 
| 113 | 
             
                    for token, weight in embs1.items():
         | 
|  | |
| 139 | 
             
                        self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
         | 
| 140 | 
             
                    scores = scores.tolist()
         | 
| 141 | 
             
                    return scores
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            if __name__ == '__main__':
         | 
| 145 | 
            +
                gte = GTEEmbeddidng('Alibaba-NLP/gte-multilingual-base')
         | 
| 146 | 
            +
                docs =  [
         | 
| 147 | 
            +
                    "黑龙江离俄罗斯很近",
         | 
| 148 | 
            +
                    "哈尔滨是中国黑龙江省的省会,位于中国东北",
         | 
| 149 | 
            +
                    "you are the hero"
         | 
| 150 | 
            +
                ]
         | 
| 151 | 
            +
                print('docs', docs)
         | 
| 152 | 
            +
                embs = gte.encode(docs, return_dense=True,return_sparse=True)
         | 
| 153 | 
            +
                print('dense vecs', embs['dense_embeddings'])
         | 
| 154 | 
            +
                print('sparse vecs', embs['token_weights'])
         | 

