Spaces:
Build error
Build error
| import os | |
| import time | |
| import pdfplumber | |
| import docx | |
| import nltk | |
| import gradio as gr | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.embeddings import ( | |
| #HuggingFaceEmbeddings, | |
| OpenAIEmbeddings, | |
| CohereEmbeddings, | |
| ) | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_community.vectorstores import FAISS, Chroma | |
| from langchain_text_splitters import ( | |
| RecursiveCharacterTextSplitter, | |
| TokenTextSplitter, | |
| ) | |
| #from langchain.retrievers import ( | |
| # VectorStoreRetriever, | |
| # ContextualCompressionRetriever, | |
| #) | |
| from langchain.retrievers.document_compressors import LLMChainExtractor | |
| from langchain_community.llms import OpenAI | |
| from typing import List, Dict, Any | |
| import pandas as pd | |
| # Ensure nltk sentence tokenizer is downloaded | |
| nltk.download('punkt', quiet=True) | |
| FILES_DIR = './files' | |
| # Supported embedding models | |
| MODELS = { | |
| 'HuggingFace': { | |
| 'e5-base': "danielheinz/e5-base-sts-en-de", | |
| 'multilingual-e5-base': "multilingual-e5-base", | |
| 'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2", | |
| 'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2", | |
| 'gte-large': "gte-large", | |
| 'gbert-base': "gbert-base" | |
| }, | |
| 'OpenAI': { | |
| 'text-embedding-ada-002': "text-embedding-ada-002" | |
| }, | |
| 'Cohere': { | |
| 'embed-multilingual-v2.0': "embed-multilingual-v2.0" | |
| } | |
| } | |
| class FileHandler: | |
| def extract_text(file_path): | |
| ext = os.path.splitext(file_path)[-1].lower() | |
| if ext == '.pdf': | |
| return FileHandler._extract_from_pdf(file_path) | |
| elif ext == '.docx': | |
| return FileHandler._extract_from_docx(file_path) | |
| elif ext == '.txt': | |
| return FileHandler._extract_from_txt(file_path) | |
| else: | |
| raise ValueError(f"Unsupported file type: {ext}") | |
| def _extract_from_pdf(file_path): | |
| with pdfplumber.open(file_path) as pdf: | |
| return ' '.join([page.extract_text() for page in pdf.pages]) | |
| def _extract_from_docx(file_path): | |
| doc = docx.Document(file_path) | |
| return ' '.join([para.text for para in doc.paragraphs]) | |
| def _extract_from_txt(file_path): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| def get_embedding_model(model_type, model_name): | |
| if model_type == 'HuggingFace': | |
| return HuggingFaceEmbeddings(model_name=MODELS[model_type][model_name]) | |
| elif model_type == 'OpenAI': | |
| return OpenAIEmbeddings(model=MODELS[model_type][model_name]) | |
| elif model_type == 'Cohere': | |
| return CohereEmbeddings(model=MODELS[model_type][model_name]) | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| def get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separators=None): | |
| if split_strategy == 'token': | |
| return TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size) | |
| elif split_strategy == 'recursive': | |
| return RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=overlap_size, | |
| separators=custom_separators or ["\n\n", "\n", " ", ""] | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported split strategy: {split_strategy}") | |
| def get_vector_store(store_type, texts, embedding_model): | |
| if store_type == 'FAISS': | |
| return FAISS.from_texts(texts, embedding_model) | |
| elif store_type == 'Chroma': | |
| return Chroma.from_texts(texts, embedding_model) | |
| else: | |
| raise ValueError(f"Unsupported vector store type: {store_type}") | |
| def get_retriever(vector_store, search_type, search_kwargs=None): | |
| if search_type == 'similarity': | |
| return vector_store.as_retriever(search_type="similarity", search_kwargs=search_kwargs) | |
| elif search_type == 'mmr': | |
| return vector_store.as_retriever(search_type="mmr", search_kwargs=search_kwargs) | |
| else: | |
| raise ValueError(f"Unsupported search type: {search_type}") | |
| def process_files(file_path, model_type, model_name, split_strategy, chunk_size, overlap_size, custom_separators): | |
| # File processing | |
| if file_path: | |
| text = FileHandler.extract_text(file_path) | |
| else: | |
| text = "" | |
| for file in os.listdir(FILES_DIR): | |
| file_path = os.path.join(FILES_DIR, file) | |
| text += FileHandler.extract_text(file_path) | |
| # Split text into chunks | |
| text_splitter = get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separators) | |
| chunks = text_splitter.split_text(text) | |
| # Get embedding model | |
| embedding_model = get_embedding_model(model_type, model_name) | |
| return chunks, embedding_model | |
| def search_embeddings(chunks, embedding_model, vector_store_type, search_type, query, top_k): | |
| # Create vector store | |
| vector_store = get_vector_store(vector_store_type, chunks, embedding_model) | |
| # Get retriever | |
| retriever = get_retriever(vector_store, search_type, {"k": top_k}) | |
| # Perform search | |
| start_time = time.time() | |
| results = retriever.get_relevant_documents(query) | |
| end_time = time.time() | |
| return results, end_time - start_time | |
| def calculate_statistics(results, search_time): | |
| return { | |
| "num_results": len(results), | |
| "avg_content_length": sum(len(doc.page_content) for doc in results) / len(results), | |
| "search_time": search_time | |
| } | |
| import gradio as gr | |
| import pandas as pd | |
| def compare_embeddings(file, query, model_types, model_names, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k): | |
| all_results = [] | |
| all_stats = [] | |
| for model_type, model_name in zip(model_types, model_names): | |
| chunks, embedding_model = process_files( | |
| file.name if file else None, | |
| model_type, | |
| model_name, | |
| split_strategy, | |
| chunk_size, | |
| overlap_size, | |
| custom_separators.split(',') if custom_separators else None | |
| ) | |
| results, search_time = search_embeddings( | |
| chunks, | |
| embedding_model, | |
| vector_store_type, | |
| search_type, | |
| query, | |
| top_k | |
| ) | |
| stats = calculate_statistics(results, search_time) | |
| stats["model"] = f"{model_type} - {model_name}" | |
| formatted_results, formatted_stats = format_results(results, stats) | |
| all_results.append(formatted_results) | |
| all_stats.append(formatted_stats) | |
| return all_results + all_stats | |
| def format_results(results, stats): | |
| # List to store the processed document data | |
| data = [] | |
| # Extracting content and metadata from each document | |
| for doc in results: | |
| # Ensure metadata is a dictionary (if it's a custom object, convert it) | |
| metadata_dict = dict(doc.metadata) | |
| # Create a combined dictionary with 'Content' and all metadata fields | |
| doc_data = {"Content": doc.page_content} | |
| doc_data.update(metadata_dict) # Add all metadata key-value pairs | |
| # Append the processed document data to the list | |
| data.append(doc_data) | |
| # Convert the list of document data into a DataFrame | |
| df = pd.DataFrame(data) | |
| # Formatting stats as a DataFrame | |
| formatted_stats = pd.DataFrame([stats]) | |
| return df, formatted_stats | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=compare_embeddings, | |
| inputs=[ | |
| gr.File(label="Upload File (Optional)"), | |
| gr.Textbox(label="Search Query"), | |
| gr.CheckboxGroup(choices=list(MODELS.keys()), label="Embedding Model Types", value=["HuggingFace"]), | |
| gr.CheckboxGroup(choices=[model for models in MODELS.values() for model in models], label="Embedding Models", value=["e5-base"]), | |
| gr.Radio(choices=["token", "recursive"], label="Split Strategy", value="recursive"), | |
| gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"), | |
| gr.Slider(0, 100, step=10, value=50, label="Overlap Size"), | |
| gr.Textbox(label="Custom Split Separators (comma-separated, optional)"), | |
| gr.Radio(choices=["FAISS", "Chroma"], label="Vector Store Type", value="FAISS"), | |
| gr.Radio(choices=["similarity", "mmr"], label="Search Type", value="similarity"), | |
| gr.Slider(1, 10, step=1, value=5, label="Top K") | |
| ], | |
| outputs=[ | |
| gr.Dataframe(label="Results"), | |
| gr.Dataframe(label="Statistics") | |
| ], | |
| title="Embedding Comparison Tool", | |
| description="Compare different embedding models and retrieval strategies" | |
| ) | |
| iface.launch() |