Spaces:
Runtime error
Runtime error
| import logging | |
| import faiss | |
| import numpy as np | |
| from common.constants import DO_NORMALIZATION | |
| from components.embedding_extraction import EmbeddingExtractor | |
| logger = logging.getLogger(__name__) | |
| class FaissVectorSearch: | |
| def __init__( | |
| self, | |
| model: EmbeddingExtractor, | |
| ids_to_embeddings: dict[str, np.ndarray], | |
| ): | |
| self.model = model | |
| self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())} | |
| self.__create_index(ids_to_embeddings) | |
| def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]): | |
| """Создает индекс для векторного поиска.""" | |
| if len(ids_to_embeddings) == 0: | |
| self.index = None | |
| return | |
| embeddings = np.array(list(ids_to_embeddings.values())) | |
| dim = embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(dim) | |
| self.index.add(embeddings) | |
| def search_vectors( | |
| self, | |
| query: str, | |
| max_entities: int = 100, | |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Поиск векторов в индексе. | |
| Args: | |
| query: Строка, запрос для поиска. | |
| max_entities: Максимальное количество найденных сущностей. | |
| Returns: | |
| tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов: | |
| - np.ndarray: Вектор запроса (1, embedding_size) | |
| - np.ndarray: Оценки косинусного сходства (чем больше, тем лучше) | |
| - np.ndarray: Идентификаторы найденных векторов | |
| """ | |
| logger.info(f"Searching vectors in index for query: {query}") | |
| if self.index is None: | |
| return (np.array([]), np.array([]), np.array([])) | |
| query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION) | |
| similarities, indexes = self.index.search(query_embeds, max_entities) | |
| ids = [self.index_to_id[index] for index in indexes[0]] | |
| return query_embeds, similarities[0], np.array(ids) | |