| import torch | |
| from backend.utils import load_embeddings, load_model, load_texts | |
| # Search | |
| def query_search(query: str, n_answers: int, model_name: str): | |
| model = load_model(model_name) | |
| # Creating embeddings | |
| # query_emb = model.encode(query, convert_to_tensor=True)[None, :] | |
| query_emb = model.encode(query, convert_to_tensor=True) | |
| print("loading embedding") | |
| corpus_emb = load_embeddings() | |
| corpus_texts = load_texts() | |
| # Getting hits | |
| hits = torch.nn.functional.cosine_similarity( | |
| query_emb[None, :], corpus_emb, dim=1, eps=1e-8 | |
| ) | |
| corpus_texts["Similarity"] = hits.tolist() | |
| print(corpus_texts) | |
| return corpus_texts.sort_values(by="Similarity", ascending=False).head(n_answers)[ | |
| ["func_documentation_string", "repository_name", "func_code_url"] | |
| ] | |