Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from elasticsearch import Elasticsearch | |
| from embedders.labse import LaBSE | |
| def search(): | |
| status_indicator.write(f"Loading model {model_name} (it can take ~1 minute the first time)...") | |
| model = globals()[model_name]() | |
| status_indicator.write(f"Computing query embeddings...") | |
| query_vector = model(query)[0, :].tolist() | |
| status_indicator.write(f"Performing query...") | |
| target_field = f"{model_name}_features" | |
| results = es.search( | |
| index="sentences", | |
| query={ | |
| "script_score": { | |
| "query": {"match_all": {}}, | |
| "script": { | |
| "source": f"cosineSimilarity(params.query_vector, '{target_field}') + 1.0", | |
| "params": {"query_vector": query_vector} | |
| } | |
| } | |
| }, | |
| size=limit | |
| ) | |
| for result in results["hits"]["hits"]: | |
| sentence = result['_source']['sentence'] | |
| score = result['_score'] | |
| document = result['_source']['document'] | |
| number = result['_source']['number'] | |
| previous = es.search( | |
| index="sentences", | |
| query={ | |
| "bool": { | |
| "must": [{ | |
| "term": { | |
| "document": document | |
| } | |
| },{ | |
| "range": { | |
| "number": { | |
| "gte": number-3, | |
| "lt": number, | |
| } | |
| } | |
| } | |
| ] | |
| } | |
| } | |
| ) | |
| previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"]) | |
| previous_context = "".join([r["_source"]["sentence"] for r in previous_hits]) | |
| subsequent = es.search( | |
| index="sentences", | |
| query={ | |
| "bool": { | |
| "must": [{ | |
| "term": { | |
| "document": document | |
| } | |
| },{ | |
| "range": { | |
| "number": { | |
| "lte": number+3, | |
| "gt": number, | |
| } | |
| } | |
| } | |
| ] | |
| } | |
| } | |
| ) | |
| subsequent_hits = sorted(subsequent["hits"]["hits"], key=lambda e: e["_source"]["number"]) | |
| subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits]) | |
| document_name_results = es.search( | |
| index="documents", | |
| query={ | |
| "bool": { | |
| "must": [{ | |
| "term": { | |
| "id": document | |
| } | |
| } | |
| ] | |
| } | |
| } | |
| ) | |
| document_name_data = document_name_results["hits"]["hits"][0]["_source"] | |
| document_name = f"{document_name_data['title']} - {document_name_data['author']}" | |
| results_placeholder.markdown(f"#### {document_name} (score: {score:.2f})\n{previous_context} **{sentence}** {subsequent_context}") | |
| status_indicator.write(f"Results ready...") | |
| es = Elasticsearch(os.environ["ELASTIC_HOST"], basic_auth=os.environ["ELASTIC_AUTH"].split(":")) | |
| st.header("Serica Intelligent Search") | |
| st.write("Perform an intelligent search using a Sentence Embedding Transformer model on the SERICA database") | |
| model_name = st.selectbox("Model", ["LaBSE"]) | |
| limit = st.number_input("Number of results", 10) | |
| query = st.text_input("Query", value="") | |
| status_indicator = st.empty() | |
| do_search = st.button("Search") | |
| results_placeholder = st.container() | |
| if do_search: | |
| search() |