Federico Galatolo
		
	commited on
		
		
					Commit 
							
							Β·
						
						168a4de
	
1
								Parent(s):
							
							9532cd7
								
first commit
Browse files- .gitignore +4 -0
 - README.md +1 -3
 - app.py +120 -0
 - embedders/__pycache__/labse.cpython-38.pyc +0 -0
 - embedders/labse.py +26 -0
 - requirements.txt +19 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /env
         
     | 
| 2 | 
         
            +
            /__pycache__/
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            .env
         
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,6 +1,6 @@ 
     | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
             
            title: Serica Semantic Search
         
     | 
| 3 | 
         
            -
            emoji:  
     | 
| 4 | 
         
             
            colorFrom: indigo
         
     | 
| 5 | 
         
             
            colorTo: pink
         
     | 
| 6 | 
         
             
            sdk: streamlit
         
     | 
| 
         @@ -9,5 +9,3 @@ app_file: app.py 
     | 
|
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
             
            license: agpl-3.0
         
     | 
| 11 | 
         
             
            ---
         
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         
     | 
| 
         | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
             
            title: Serica Semantic Search
         
     | 
| 3 | 
         
            +
            emoji: π
         
     | 
| 4 | 
         
             
            colorFrom: indigo
         
     | 
| 5 | 
         
             
            colorTo: pink
         
     | 
| 6 | 
         
             
            sdk: streamlit
         
     | 
| 
         | 
|
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
             
            license: agpl-3.0
         
     | 
| 11 | 
         
             
            ---
         
     | 
| 
         | 
|
| 
         | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,120 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import streamlit as st
         
     | 
| 3 | 
         
            +
            from elasticsearch import Elasticsearch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from embedders.labse import LaBSE
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            def search():
         
     | 
| 8 | 
         
            +
                status_indicator.write(f"Loading model {model_name}...")
         
     | 
| 9 | 
         
            +
                model = globals()[model_name]()
         
     | 
| 10 | 
         
            +
                
         
     | 
| 11 | 
         
            +
                status_indicator.write(f"Computing query embeddings...")
         
     | 
| 12 | 
         
            +
                query_vector = model(query)[0, :].tolist()
         
     | 
| 13 | 
         
            +
                
         
     | 
| 14 | 
         
            +
                status_indicator.write(f"Performing query...")
         
     | 
| 15 | 
         
            +
                target_field = f"{model_name}_features"
         
     | 
| 16 | 
         
            +
                results = es.search(
         
     | 
| 17 | 
         
            +
                    index="sentences",
         
     | 
| 18 | 
         
            +
                    query={
         
     | 
| 19 | 
         
            +
                        "script_score": {
         
     | 
| 20 | 
         
            +
                            "query": {"match_all": {}},
         
     | 
| 21 | 
         
            +
                            "script": {
         
     | 
| 22 | 
         
            +
                                "source": f"cosineSimilarity(params.query_vector, '{target_field}') + 1.0",
         
     | 
| 23 | 
         
            +
                                "params": {"query_vector": query_vector}
         
     | 
| 24 | 
         
            +
                            }
         
     | 
| 25 | 
         
            +
                        }
         
     | 
| 26 | 
         
            +
                    },
         
     | 
| 27 | 
         
            +
                    size=limit
         
     | 
| 28 | 
         
            +
                )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                for result in results["hits"]["hits"]:
         
     | 
| 31 | 
         
            +
                    sentence = result['_source']['sentence']
         
     | 
| 32 | 
         
            +
                    score =  result['_score']
         
     | 
| 33 | 
         
            +
                    document = result['_source']['document']
         
     | 
| 34 | 
         
            +
                    number = result['_source']['number']
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    previous = es.search(
         
     | 
| 37 | 
         
            +
                        index="sentences",
         
     | 
| 38 | 
         
            +
                        query={
         
     | 
| 39 | 
         
            +
                            "bool": {
         
     | 
| 40 | 
         
            +
                                "must": [{
         
     | 
| 41 | 
         
            +
                                    "term": {
         
     | 
| 42 | 
         
            +
                                        "document": document
         
     | 
| 43 | 
         
            +
                                        }
         
     | 
| 44 | 
         
            +
                                    },{
         
     | 
| 45 | 
         
            +
                                    "range": {
         
     | 
| 46 | 
         
            +
                                        "number": {
         
     | 
| 47 | 
         
            +
                                        "gte": number-3,
         
     | 
| 48 | 
         
            +
                                        "lt": number,
         
     | 
| 49 | 
         
            +
                                            }
         
     | 
| 50 | 
         
            +
                                        }
         
     | 
| 51 | 
         
            +
                                    }
         
     | 
| 52 | 
         
            +
                                ]
         
     | 
| 53 | 
         
            +
                            }
         
     | 
| 54 | 
         
            +
                        }
         
     | 
| 55 | 
         
            +
                    )
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"])
         
     | 
| 58 | 
         
            +
                    previous_context = "".join([r["_source"]["sentence"] for r in previous_hits])
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    subsequent = es.search(
         
     | 
| 62 | 
         
            +
                        index="sentences",
         
     | 
| 63 | 
         
            +
                        query={
         
     | 
| 64 | 
         
            +
                            "bool": {
         
     | 
| 65 | 
         
            +
                                "must": [{
         
     | 
| 66 | 
         
            +
                                    "term": {
         
     | 
| 67 | 
         
            +
                                        "document": document
         
     | 
| 68 | 
         
            +
                                        }
         
     | 
| 69 | 
         
            +
                                    },{
         
     | 
| 70 | 
         
            +
                                    "range": {
         
     | 
| 71 | 
         
            +
                                        "number": {
         
     | 
| 72 | 
         
            +
                                        "lte": number+3,
         
     | 
| 73 | 
         
            +
                                        "gt": number,
         
     | 
| 74 | 
         
            +
                                            }
         
     | 
| 75 | 
         
            +
                                        }
         
     | 
| 76 | 
         
            +
                                    }
         
     | 
| 77 | 
         
            +
                                ]
         
     | 
| 78 | 
         
            +
                            }
         
     | 
| 79 | 
         
            +
                        }
         
     | 
| 80 | 
         
            +
                    )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    subsequent_hits = sorted(subsequent["hits"]["hits"], key=lambda e: e["_source"]["number"])
         
     | 
| 83 | 
         
            +
                    subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits])
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    document_name_results = es.search(
         
     | 
| 87 | 
         
            +
                        index="documents",
         
     | 
| 88 | 
         
            +
                        query={
         
     | 
| 89 | 
         
            +
                            "bool": {
         
     | 
| 90 | 
         
            +
                                "must": [{
         
     | 
| 91 | 
         
            +
                                    "term": {
         
     | 
| 92 | 
         
            +
                                        "id": document
         
     | 
| 93 | 
         
            +
                                        }
         
     | 
| 94 | 
         
            +
                                    }
         
     | 
| 95 | 
         
            +
                                ]
         
     | 
| 96 | 
         
            +
                            }
         
     | 
| 97 | 
         
            +
                        }
         
     | 
| 98 | 
         
            +
                    )
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    document_name_data = document_name_results["hits"]["hits"][0]["_source"]
         
     | 
| 101 | 
         
            +
                    document_name = f"{document_name_data['title']} - {document_name_data['author']}"
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    results_placeholder.markdown(f"#### {document_name} (score: {score:.2f})\n{previous_context} **{sentence}** {subsequent_context}")
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                status_indicator.write(f"Results ready...")
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            es = Elasticsearch(os.environ["ELASTIC_HOST"], basic_auth=os.environ["ELASTIC_AUTH"].split(":"))
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            st.header("Serica Semantic Search")
         
     | 
| 111 | 
         
            +
            st.write("Perform a semantic search using a Sentence Embedding Transformer model on the SERICA database")
         
     | 
| 112 | 
         
            +
            model_name = st.selectbox("Model", ["LaBSE"])
         
     | 
| 113 | 
         
            +
            limit = st.number_input("Number of results", 10)
         
     | 
| 114 | 
         
            +
            query = st.text_input("Query", value="")
         
     | 
| 115 | 
         
            +
            status_indicator = st.empty()
         
     | 
| 116 | 
         
            +
            do_search = st.button("Search")
         
     | 
| 117 | 
         
            +
            results_placeholder = st.container()
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            if do_search:
         
     | 
| 120 | 
         
            +
                search()
         
     | 
    	
        embedders/__pycache__/labse.cpython-38.pyc
    ADDED
    
    | 
         Binary file (1.27 kB). View file 
     | 
| 
         | 
    	
        embedders/labse.py
    ADDED
    
    | 
         @@ -0,0 +1,26 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from transformers import BertModel, BertTokenizerFast
         
     | 
| 3 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class LaBSE:
         
     | 
| 6 | 
         
            +
                def __init__(self):
         
     | 
| 7 | 
         
            +
                    self.tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE")
         
     | 
| 8 | 
         
            +
                    self.model = BertModel.from_pretrained("setu4993/LaBSE")
         
     | 
| 9 | 
         
            +
                    self.model.eval()
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                @torch.no_grad()
         
     | 
| 12 | 
         
            +
                def __call__(self, sentences):
         
     | 
| 13 | 
         
            +
                    if not isinstance(sentences, list):
         
     | 
| 14 | 
         
            +
                        sentences = [sentences]
         
     | 
| 15 | 
         
            +
                    tokens = self.tokenizer(sentences, return_tensors="pt", padding=True)
         
     | 
| 16 | 
         
            +
                    outputs = self.model(**tokens)
         
     | 
| 17 | 
         
            +
                    embeddings = outputs.pooler_output
         
     | 
| 18 | 
         
            +
                    return F.normalize(embeddings, p=2).cpu().numpy()
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                @property
         
     | 
| 21 | 
         
            +
                def dim(self):
         
     | 
| 22 | 
         
            +
                    return 768
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 25 | 
         
            +
                labse = LaBSE()
         
     | 
| 26 | 
         
            +
                print(labse(["odi et amo", "quare id faciam"]).shape)
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            certifi==2022.6.15
         
     | 
| 2 | 
         
            +
            charset-normalizer==2.1.0
         
     | 
| 3 | 
         
            +
            elastic-transport==8.1.2
         
     | 
| 4 | 
         
            +
            elasticsearch==8.3.3
         
     | 
| 5 | 
         
            +
            filelock==3.7.1
         
     | 
| 6 | 
         
            +
            huggingface-hub==0.8.1
         
     | 
| 7 | 
         
            +
            idna==3.3
         
     | 
| 8 | 
         
            +
            numpy==1.23.1
         
     | 
| 9 | 
         
            +
            packaging==21.3
         
     | 
| 10 | 
         
            +
            pyparsing==3.0.9
         
     | 
| 11 | 
         
            +
            PyYAML==6.0
         
     | 
| 12 | 
         
            +
            regex==2022.7.25
         
     | 
| 13 | 
         
            +
            requests==2.28.1
         
     | 
| 14 | 
         
            +
            tokenizers==0.12.1
         
     | 
| 15 | 
         
            +
            tqdm==4.64.0
         
     | 
| 16 | 
         
            +
            transformers==4.21.0
         
     | 
| 17 | 
         
            +
            typing-extensions==4.3.0
         
     | 
| 18 | 
         
            +
            urllib3==1.26.11
         
     | 
| 19 | 
         
            +
            torch==1.12.0
         
     |