Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from huggingface_hub import HfApi, HfFolder | |
| import datasets | |
| import pandas as pd | |
| import logging | |
| import os | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import torch.nn.functional as F | |
| def login(): | |
| if 'logged' not in st.session_state: | |
| logging.info("Trying to log in to HF") | |
| st.session_state['logged'] = True | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| api = HfApi() | |
| api.set_access_token(HF_TOKEN) | |
| folder = HfFolder() | |
| folder.save_token(HF_TOKEN) | |
| logging.info("Succesfully logged") | |
| return True | |
| else: | |
| logging.info("Already logged in") | |
| return False | |
| def load_model(): | |
| logging.info("Trying to load model") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| 'sentence-transformers/all-MiniLM-L6-v2') | |
| model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
| logging.info("Model loaded") | |
| return model, tokenizer | |
| def load_index(): | |
| logging.info("Trying to load index") | |
| index = datasets.Dataset.load_from_disk("Data/articles.hf") | |
| logging.info("Articles dataset loaded") | |
| index.load_faiss_index("embedding", "Data/articles.index") | |
| logging.info("FAISS index loaded") | |
| return index | |
| def mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1) \ | |
| .expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) \ | |
| / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| def get_embedding(query, model, tokenizer): | |
| encoded_input = tokenizer( | |
| query, padding=True, truncation=True, return_tensors='pt') | |
| with torch.no_grad(): | |
| embeds = model(**encoded_input) | |
| embeds = mean_pooling(embeds, encoded_input['attention_mask']) | |
| embeds = F.normalize(embeds, p=2, dim=1) | |
| return embeds.numpy() | |
| def get_answers(query, num_answers): | |
| logging.info("Getting answers for {}".format(query)) | |
| model, tokenizer = load_model() | |
| index = load_index() | |
| query_embedding = get_embedding(query, model, tokenizer).reshape(-1) | |
| _, answers = index.get_nearest_examples('embedding', query_embedding, num_answers) | |
| answers = pd.DataFrame.from_dict(answers) | |
| logging.info("Succesfully got answers for {}".format(query)) | |
| return answers.to_dict('records') | |
| def display_article(article): | |
| with st.container(): | |
| href = "https://arxiv.org/abs/{}".format(article['id']) | |
| title = "<h3><a href=\"{}\">{}</a></h3>".format( | |
| href, article['title']) | |
| st.write(title, unsafe_allow_html=True) | |
| st.markdown(article['abstract']) | |
| st.write("---") | |
| def display_answers(query, max_answers=100): | |
| st.write("---") | |
| articles = get_answers(query, max_answers) | |
| for article in articles[:st.session_state['num_articles_to_show']]: | |
| display_article(article) | |