qdarnt / app.py
nazib61's picture
Update app.py
af5356d verified
import gradio as gr
from datasets import load_dataset
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
import torch # Ensure torch is imported
import os
import shutil
import PyPDF2
from docx import Document
import pandas as pd
# --- Configuration ---
QDRANT_PATH = "./qdrant_db"
COLLECTION_NAME = "my_text_collection"
MODEL_NAME = 'KaLM-Embedding/KaLM-embedding-multilingual-mini-instruct-v2.5' # Better model for semantic similarity
# --- Load Model ---
device = "cpu"
model = SentenceTransformer(MODEL_NAME, device=device)
# --- Qdrant Client and Collection Setup ---
qdrant_client = QdrantClient(path=QDRANT_PATH)
# Check if the collection already exists
collection_exists = False
try:
collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME)
print("Collection already exists.")
collection_exists = True
except Exception as e:
print(f"Collection not found: {e}, creating a new one...")
collection_exists = False
# If collection doesn't exist, create it and populate with data
if not collection_exists:
# Load dataset and convert to a simple list format
dataset = load_dataset("ag_news", split="test")
# Convert dataset to pandas dataframe to properly access the text column
df = dataset.to_pandas()
data = df['text'].tolist()[:1000] # Get first 1000 text entries
# Create the collection with proper vector configuration
# Use the correct vector size for the selected model
vector_size = model.get_sentence_embedding_dimension() or 768 # Get the actual embedding size of the model, default to 768 for mpnet
qdrant_client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE),
)
# Generate embeddings manually to ensure compatibility
print("Generating and indexing embeddings...")
embeddings = model.encode(data)
# Prepare points for insertion
points = []
for i, (text, embedding) in enumerate(zip(data, embeddings)):
point = models.PointStruct(
id=i,
vector=embedding.tolist(),
payload={"document": text}
)
points.append(point)
# Upload points to the collection
qdrant_client.upsert(
collection_name=COLLECTION_NAME,
points=points
)
print("Embeddings indexed successfully.")
# --- Search Function ---
def search_in_qdrant(query):
if not query:
return "Please enter a search query."
# Generate embedding for the query
query_embedding = model.encode([query])[0].tolist()
hits = qdrant_client.search(
collection_name=COLLECTION_NAME,
query_vector=query_embedding,
limit=5,
)
results_text = ""
if not hits:
return "No results found."
for hit in hits:
# Check if payload exists and has the document key
if hit.payload and 'document' in hit.payload:
results_text += f"**Score:** {hit.score:.4f}\n"
results_text += f"**Text:** {hit.payload['document']}\n\n"
else:
results_text += f"**Score:** {hit.score:.4f}\n"
results_text += f"**Text:** [No document content available]\n\n"
return results_text
# --- Upload Function ---
def extract_text_from_file(file_path):
"""Extract text from various file types"""
file_extension = file_path.lower().split('.')[-1]
if file_extension == 'txt':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif file_extension == 'pdf':
text = ""
with open(file_path, 'rb') as f:
pdf_reader = PyPDF2.PdfReader(f)
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text
elif file_extension in ['docx', 'doc']:
doc = Document(file_path)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
elif file_extension in ['csv', 'xlsx', 'xls']:
if file_extension == 'csv':
df = pd.read_csv(file_path)
else:
df = pd.read_excel(file_path)
# Convert the entire dataframe to text
return df.to_string()
else:
# Try to read as plain text
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except UnicodeDecodeError:
# If UTF-8 fails, try with different encoding
try:
with open(file_path, 'r', encoding='latin-1') as f:
return f.read()
except:
return "Could not read file: unsupported format or encoding issue"
def upload_to_qdrant(text_content, file_upload=None):
if not text_content and not file_upload:
return "Please provide text content or upload a file."
documents_to_add = []
# Add text content if provided
if text_content:
documents_to_add.append(text_content)
# Process uploaded file if provided
if file_upload:
try:
content = extract_text_from_file(file_upload.name)
documents_to_add.append(content)
except Exception as e:
return f"Error reading file: {str(e)}"
if not documents_to_add:
return "No content to upload."
# Get the next available ID by checking the current max ID in the collection
# For simplicity, we'll just get the count of existing records and start from there
max_id = 0 # Default to 0 if we can't get the count
try:
collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME)
if hasattr(collection_info, 'points_count') and collection_info.points_count is not None:
current_count = collection_info.points_count
max_id = current_count # Start from the current count
except:
max_id = 0 # If there's an error, start with 0
# Generate embeddings for the new documents
embeddings = model.encode(documents_to_add)
# Prepare points for insertion
points = []
for i, (doc, embedding) in enumerate(zip(documents_to_add, embeddings)):
point_id = max_id + i + 1 # IDs will be automatically converted as needed by Qdrant
point = models.PointStruct(
id=point_id,
vector=embedding.tolist(),
payload={"document": doc}
)
points.append(point)
# Upload points to the collection
qdrant_client.upsert(
collection_name=COLLECTION_NAME,
points=points
)
return f"Successfully added {len(documents_to_add)} document(s) to the collection."
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# Semantic Search with Qdrant and Gradio")
gr.Markdown("Enter a query to search for similar news articles from the AG News dataset.")
with gr.Tab("Search"):
with gr.Row():
search_input = gr.Textbox(label="Search Query", placeholder="e.g., 'Latest news on space exploration'")
search_button = gr.Button("Search")
search_output = gr.Markdown()
search_button.click(search_in_qdrant, inputs=search_input, outputs=search_output)
with gr.Tab("Upload"):
with gr.Row():
text_input = gr.Textbox(label="Text Content", placeholder="Enter text to add to the collection", lines=5)
with gr.Row():
file_input = gr.File(label="Or Upload a File", file_types=['.txt', '.pdf', '.docx', '.csv', '.xlsx', '.xls', '.md'])
upload_button = gr.Button("Upload to Collection")
upload_output = gr.Textbox(label="Upload Status", interactive=False)
upload_button.click(upload_to_qdrant, inputs=[text_input, file_input], outputs=upload_output)
if __name__ == "__main__":
demo.launch()