|
|
import gradio as gr |
|
|
from datasets import load_dataset |
|
|
from qdrant_client import QdrantClient, models |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import torch |
|
|
import os |
|
|
import shutil |
|
|
import PyPDF2 |
|
|
from docx import Document |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
QDRANT_PATH = "./qdrant_db" |
|
|
COLLECTION_NAME = "my_text_collection" |
|
|
MODEL_NAME = 'KaLM-Embedding/KaLM-embedding-multilingual-mini-instruct-v2.5' |
|
|
|
|
|
|
|
|
device = "cpu" |
|
|
model = SentenceTransformer(MODEL_NAME, device=device) |
|
|
|
|
|
|
|
|
qdrant_client = QdrantClient(path=QDRANT_PATH) |
|
|
|
|
|
|
|
|
collection_exists = False |
|
|
try: |
|
|
collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME) |
|
|
print("Collection already exists.") |
|
|
collection_exists = True |
|
|
except Exception as e: |
|
|
print(f"Collection not found: {e}, creating a new one...") |
|
|
collection_exists = False |
|
|
|
|
|
|
|
|
if not collection_exists: |
|
|
|
|
|
dataset = load_dataset("ag_news", split="test") |
|
|
|
|
|
df = dataset.to_pandas() |
|
|
data = df['text'].tolist()[:1000] |
|
|
|
|
|
|
|
|
|
|
|
vector_size = model.get_sentence_embedding_dimension() or 768 |
|
|
qdrant_client.create_collection( |
|
|
collection_name=COLLECTION_NAME, |
|
|
vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE), |
|
|
) |
|
|
|
|
|
|
|
|
print("Generating and indexing embeddings...") |
|
|
embeddings = model.encode(data) |
|
|
|
|
|
|
|
|
points = [] |
|
|
for i, (text, embedding) in enumerate(zip(data, embeddings)): |
|
|
point = models.PointStruct( |
|
|
id=i, |
|
|
vector=embedding.tolist(), |
|
|
payload={"document": text} |
|
|
) |
|
|
points.append(point) |
|
|
|
|
|
|
|
|
qdrant_client.upsert( |
|
|
collection_name=COLLECTION_NAME, |
|
|
points=points |
|
|
) |
|
|
print("Embeddings indexed successfully.") |
|
|
|
|
|
|
|
|
|
|
|
def search_in_qdrant(query): |
|
|
if not query: |
|
|
return "Please enter a search query." |
|
|
|
|
|
|
|
|
query_embedding = model.encode([query])[0].tolist() |
|
|
|
|
|
hits = qdrant_client.search( |
|
|
collection_name=COLLECTION_NAME, |
|
|
query_vector=query_embedding, |
|
|
limit=5, |
|
|
) |
|
|
|
|
|
results_text = "" |
|
|
if not hits: |
|
|
return "No results found." |
|
|
|
|
|
for hit in hits: |
|
|
|
|
|
if hit.payload and 'document' in hit.payload: |
|
|
results_text += f"**Score:** {hit.score:.4f}\n" |
|
|
results_text += f"**Text:** {hit.payload['document']}\n\n" |
|
|
else: |
|
|
results_text += f"**Score:** {hit.score:.4f}\n" |
|
|
results_text += f"**Text:** [No document content available]\n\n" |
|
|
|
|
|
return results_text |
|
|
|
|
|
|
|
|
def extract_text_from_file(file_path): |
|
|
"""Extract text from various file types""" |
|
|
file_extension = file_path.lower().split('.')[-1] |
|
|
|
|
|
if file_extension == 'txt': |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
return f.read() |
|
|
elif file_extension == 'pdf': |
|
|
text = "" |
|
|
with open(file_path, 'rb') as f: |
|
|
pdf_reader = PyPDF2.PdfReader(f) |
|
|
for page in pdf_reader.pages: |
|
|
text += page.extract_text() + "\n" |
|
|
return text |
|
|
elif file_extension in ['docx', 'doc']: |
|
|
doc = Document(file_path) |
|
|
text = "" |
|
|
for paragraph in doc.paragraphs: |
|
|
text += paragraph.text + "\n" |
|
|
return text |
|
|
elif file_extension in ['csv', 'xlsx', 'xls']: |
|
|
if file_extension == 'csv': |
|
|
df = pd.read_csv(file_path) |
|
|
else: |
|
|
df = pd.read_excel(file_path) |
|
|
|
|
|
return df.to_string() |
|
|
else: |
|
|
|
|
|
try: |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
return f.read() |
|
|
except UnicodeDecodeError: |
|
|
|
|
|
try: |
|
|
with open(file_path, 'r', encoding='latin-1') as f: |
|
|
return f.read() |
|
|
except: |
|
|
return "Could not read file: unsupported format or encoding issue" |
|
|
|
|
|
def upload_to_qdrant(text_content, file_upload=None): |
|
|
if not text_content and not file_upload: |
|
|
return "Please provide text content or upload a file." |
|
|
|
|
|
documents_to_add = [] |
|
|
|
|
|
|
|
|
if text_content: |
|
|
documents_to_add.append(text_content) |
|
|
|
|
|
|
|
|
if file_upload: |
|
|
try: |
|
|
content = extract_text_from_file(file_upload.name) |
|
|
documents_to_add.append(content) |
|
|
except Exception as e: |
|
|
return f"Error reading file: {str(e)}" |
|
|
|
|
|
if not documents_to_add: |
|
|
return "No content to upload." |
|
|
|
|
|
|
|
|
|
|
|
max_id = 0 |
|
|
try: |
|
|
collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME) |
|
|
if hasattr(collection_info, 'points_count') and collection_info.points_count is not None: |
|
|
current_count = collection_info.points_count |
|
|
max_id = current_count |
|
|
except: |
|
|
max_id = 0 |
|
|
|
|
|
|
|
|
embeddings = model.encode(documents_to_add) |
|
|
|
|
|
|
|
|
points = [] |
|
|
for i, (doc, embedding) in enumerate(zip(documents_to_add, embeddings)): |
|
|
point_id = max_id + i + 1 |
|
|
point = models.PointStruct( |
|
|
id=point_id, |
|
|
vector=embedding.tolist(), |
|
|
payload={"document": doc} |
|
|
) |
|
|
points.append(point) |
|
|
|
|
|
|
|
|
qdrant_client.upsert( |
|
|
collection_name=COLLECTION_NAME, |
|
|
points=points |
|
|
) |
|
|
|
|
|
return f"Successfully added {len(documents_to_add)} document(s) to the collection." |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Semantic Search with Qdrant and Gradio") |
|
|
gr.Markdown("Enter a query to search for similar news articles from the AG News dataset.") |
|
|
|
|
|
with gr.Tab("Search"): |
|
|
with gr.Row(): |
|
|
search_input = gr.Textbox(label="Search Query", placeholder="e.g., 'Latest news on space exploration'") |
|
|
search_button = gr.Button("Search") |
|
|
search_output = gr.Markdown() |
|
|
search_button.click(search_in_qdrant, inputs=search_input, outputs=search_output) |
|
|
|
|
|
with gr.Tab("Upload"): |
|
|
with gr.Row(): |
|
|
text_input = gr.Textbox(label="Text Content", placeholder="Enter text to add to the collection", lines=5) |
|
|
with gr.Row(): |
|
|
file_input = gr.File(label="Or Upload a File", file_types=['.txt', '.pdf', '.docx', '.csv', '.xlsx', '.xls', '.md']) |
|
|
upload_button = gr.Button("Upload to Collection") |
|
|
upload_output = gr.Textbox(label="Upload Status", interactive=False) |
|
|
upload_button.click(upload_to_qdrant, inputs=[text_input, file_input], outputs=upload_output) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |