Spaces:
Sleeping
Sleeping
| import time | |
| import os | |
| from typing import Literal, Tuple | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| import meilisearch | |
| tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5") | |
| model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5") | |
| model.eval() | |
| cuda_available = torch.cuda.is_available() | |
| print(f"CUDA available: {cuda_available}") | |
| meilisearch_client = meilisearch.Client( | |
| "https://edge.meilisearch.com", os.environ["MEILISEARCH_KEY"] | |
| ) | |
| meilisearch_index_name = "docs-embed" | |
| meilisearch_index = meilisearch_client.index(meilisearch_index_name) | |
| output_options = ["RAG-friendly", "human-friendly"] | |
| def search_embeddings( | |
| query_text: str, output_option: Literal["RAG-friendly", "human-friendly"] | |
| ) -> Tuple[str, str]: | |
| start_time_embedding = time.time() | |
| query_prefix = "Represent this sentence for searching code documentation: " | |
| query_tokens = tokenizer( | |
| query_prefix + query_text, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512, | |
| ) | |
| # step1: tokenizer the query | |
| with torch.no_grad(): | |
| # Compute token embeddings | |
| model_output = model(**query_tokens) | |
| sentence_embeddings = model_output[0][:, 0] | |
| # normalize embeddings | |
| sentence_embeddings = torch.nn.functional.normalize( | |
| sentence_embeddings, p=2, dim=1 | |
| ) | |
| sentence_embeddings_list = sentence_embeddings[0].tolist() | |
| elapsed_time_embedding = time.time() - start_time_embedding | |
| # step2: search meilisearch | |
| start_time_meilisearch = time.time() | |
| response = meilisearch_index.search( | |
| "", | |
| opt_params={ | |
| "vector": sentence_embeddings_list, | |
| "hybrid": {"semanticRatio": 1.0}, | |
| "limit": 5, | |
| "attributesToRetrieve": [ | |
| "text", | |
| "source_page_url", | |
| "source_page_title", | |
| "library", | |
| ], | |
| }, | |
| ) | |
| elapsed_time_meilisearch = time.time() - start_time_meilisearch | |
| hits = response["hits"] | |
| sources_md = [ | |
| f"[\"{hit['source_page_title']}\"]({hit['source_page_url']})" for hit in hits | |
| ] | |
| sources_md = ", ".join(sources_md) | |
| # step3: present the results in markdown | |
| if output_option == "human-friendly": | |
| md = f"Stats:\n\nembedding time: {elapsed_time_embedding:.2f}s\n\nmeilisearch time: {elapsed_time_meilisearch:.2f}s\n\n---\n\n" | |
| for hit in hits: | |
| text, source_page_url, source_page_title = ( | |
| hit["text"], | |
| hit["source_page_url"], | |
| hit["source_page_title"], | |
| ) | |
| source = f'src: ["{source_page_title}"]({source_page_url})' | |
| md += text + f"\n\n{source}\n\n---\n\n" | |
| return md, sources_md | |
| elif output_option == "RAG-friendly": | |
| hit_texts = [hit["text"] for hit in hits] | |
| hit_text_str = "\n------------\n".join(hit_texts) | |
| return hit_text_str, sources_md | |
| demo = gr.Interface( | |
| fn=search_embeddings, | |
| inputs=[ | |
| gr.Textbox( | |
| label="enter your query", placeholder="Type Markdown here...", lines=10 | |
| ), | |
| gr.Radio( | |
| label="Select an output option", | |
| choices=output_options, | |
| value="RAG-friendly", | |
| ), | |
| ], | |
| outputs=[gr.Markdown(), gr.Markdown()], | |
| title="HF Docs Embeddings Explorer", | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |