nazib61 commited on
Commit
11b36a2
·
verified ·
1 Parent(s): b44afbe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ from qdrant_client import QdrantClient, models
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ # --- Configuration ---
7
+ QDRANT_HOST = "localhost" # Or your Hugging Face Space Qdrant URL
8
+ QDRANT_PORT = 6333
9
+ COLLECTION_NAME = "my_text_collection"
10
+ MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
11
+
12
+ # --- Load Dataset and Model ---
13
+ # Using a simple dataset from Hugging Face
14
+ dataset = load_dataset("ag_news", split="test")
15
+ data = [item['text'] for item in dataset]
16
+ # Limiting the dataset for a quicker demo
17
+ data = data[:1000]
18
+
19
+ # Load a pre-trained sentence transformer model
20
+ model = SentenceTransformer(MODEL_NAME)
21
+
22
+ # --- Qdrant Client and Collection Setup ---
23
+ # Initialize Qdrant client
24
+ # In a Hugging Face Space, you might use a local in-memory instance or connect to a running Qdrant container.
25
+ qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
26
+
27
+ # Create a Qdrant collection if it doesn't exist
28
+ try:
29
+ qdrant_client.get_collection(collection_name=COLLECTION_NAME)
30
+ print("Collection already exists.")
31
+ except Exception as e:
32
+ print("Creating collection...")
33
+ qdrant_client.recreate_collection(
34
+ collection_name=COLLECTION_NAME,
35
+ vectors_config=models.VectorParams(size=model.get_sentence_embedding_dimension(), distance=models.Distance.COSINE),
36
+ )
37
+
38
+ # --- Generate and Index Embeddings ---
39
+ print("Generating and indexing embeddings...")
40
+ batch_size = 128
41
+ for i in range(0, len(data), batch_size):
42
+ batch_texts = data[i:i+batch_size]
43
+ embeddings = model.encode(batch_texts, convert_to_tensor=True)
44
+
45
+ qdrant_client.upsert(
46
+ collection_name=COLLECTION_NAME,
47
+ points=models.Batch(
48
+ ids=list(range(i, i + len(batch_texts))),
49
+ vectors=[embedding.tolist() for embedding in embeddings],
50
+ payloads=[{"text": text} for text in batch_texts]
51
+ )
52
+ )
53
+ print("Embeddings indexed successfully.")
54
+
55
+
56
+ # --- Search Function ---
57
+ def search_in_qdrant(query):
58
+ """
59
+ Takes a user query, generates its embedding, and searches in Qdrant.
60
+ """
61
+ if not query:
62
+ return "Please enter a search query."
63
+
64
+ query_embedding = model.encode(query).tolist()
65
+
66
+ search_result = qdrant_client.search(
67
+ collection_name=COLLECTION_NAME,
68
+ query_vector=query_embedding,
69
+ limit=5 # Return the top 5 most similar results
70
+ )
71
+
72
+ results_text = ""
73
+ for hit in search_result:
74
+ results_text += f"**Score:** {hit.score:.4f}\n"
75
+ results_text += f"**Text:** {hit.payload['text']}\n\n"
76
+
77
+ return results_text
78
+
79
+ # --- Gradio Interface ---
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("# Semantic Search with Qdrant and Gradio")
82
+ gr.Markdown("Enter a query to search for similar news articles from the AG News dataset.")
83
+
84
+ with gr.Row():
85
+ search_input = gr.Textbox(label="Search Query")
86
+
87
+ search_button = gr.Button("Search")
88
+ search_output = gr.Markdown()
89
+
90
+ search_button.click(search_in_qdrant, inputs=search_input, outputs=search_output)
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch()