Spaces:
Runtime error
Runtime error
| from langchain.vectorstores import FAISS | |
| import math | |
| import os | |
| import pickle | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple | |
| import numpy as np | |
| from langchain.docstore.base import AddableMixin, Docstore | |
| from langchain.docstore.document import Document | |
| from langchain.docstore.in_memory import InMemoryDocstore | |
| from langchain.embeddings.base import Embeddings | |
| from langchain.vectorstores.base import VectorStore | |
| from langchain.vectorstores.utils import maximal_marginal_relevance | |
| class MyFAISS(FAISS): | |
| def max_marginal_relevance_search_by_vector( | |
| self, | |
| embedding: List[float], | |
| k: int = 4, | |
| fetch_k: int = 20, | |
| lambda_mult: float = 0.5, | |
| filter: Optional[Dict[str, Any]] = None, | |
| **kwargs: Any, | |
| ) -> List[Document]: | |
| """Return docs selected using the maximal marginal relevance. | |
| Maximal marginal relevance optimizes for similarity to query AND diversity | |
| among selected documents. | |
| Args: | |
| embedding: Embedding to look up documents similar to. | |
| k: Number of Documents to return. Defaults to 4. | |
| fetch_k: Number of Documents to fetch before filtering to | |
| pass to MMR algorithm. | |
| lambda_mult: Number between 0 and 1 that determines the degree | |
| of diversity among the results with 0 corresponding | |
| to maximum diversity and 1 to minimum diversity. | |
| Defaults to 0.5. | |
| Returns: | |
| List of Documents selected by maximal marginal relevance. | |
| """ | |
| _, indices = self.index.search( | |
| np.array([embedding], dtype=np.float32), | |
| fetch_k if filter is None else fetch_k * 2, | |
| ) | |
| if filter is not None: | |
| filtered_indices = [] | |
| for i in indices[0]: | |
| if i == -1: | |
| # This happens when not enough docs are returned. | |
| continue | |
| _id = self.index_to_docstore_id[i] | |
| doc = self.docstore.search(_id) | |
| if not isinstance(doc, Document): | |
| raise ValueError(f"Could not find document for id {_id}, got {doc}") | |
| print("metadata: " + str(doc.metadata)) | |
| print("filter: " + str(filter)) | |
| if any(filter_word in doc.metadata.get(key, '') for key, value in filter.items() for filter_word in | |
| value.split()): | |
| filtered_indices.append(i) | |
| indices = np.array([filtered_indices]) | |
| # -1 happens when not enough docs are returned. | |
| embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] | |
| mmr_selected = maximal_marginal_relevance( | |
| np.array([embedding], dtype=np.float32), | |
| embeddings, | |
| k=k, | |
| lambda_mult=lambda_mult, | |
| ) | |
| selected_indices = [indices[0][i] for i in mmr_selected] | |
| docs = [] | |
| for i in selected_indices: | |
| if i == -1: | |
| # This happens when not enough docs are returned. | |
| continue | |
| _id = self.index_to_docstore_id[i] | |
| doc = self.docstore.search(_id) | |
| if not isinstance(doc, Document): | |
| raise ValueError(f"Could not find document for id {_id}, got {doc}") | |
| docs.append(doc) | |
| return docs | |
| def max_marginal_relevance_search( | |
| self, | |
| query: str, | |
| k: int = 4, | |
| fetch_k: int = 20, | |
| lambda_mult: float = 0.5, | |
| filter: Optional[Dict[str, Any]] = None, | |
| **kwargs: Any, | |
| ) -> List[Document]: | |
| """Return docs selected using the maximal marginal relevance. | |
| Maximal marginal relevance optimizes for similarity to query AND diversity | |
| among selected documents. | |
| Args: | |
| query: Text to look up documents similar to. | |
| k: Number of Documents to return. Defaults to 4. | |
| fetch_k: Number of Documents to fetch before filtering (if needed) to | |
| pass to MMR algorithm. | |
| lambda_mult: Number between 0 and 1 that determines the degree | |
| of diversity among the results with 0 corresponding | |
| to maximum diversity and 1 to minimum diversity. | |
| Defaults to 0.5. | |
| Returns: | |
| List of Documents selected by maximal marginal relevance. | |
| """ | |
| print("MMR search") | |
| embedding = self.embedding_function(query) | |
| docs = self.max_marginal_relevance_search_by_vector( | |
| embedding, | |
| k, | |
| fetch_k, | |
| lambda_mult=lambda_mult, | |
| filter=filter, | |
| **kwargs, | |
| ) | |
| return docs | |