Spaces:
Runtime error
Runtime error
| import logging | |
| import re | |
| from logging import Logger | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| import pandas as pd | |
| from elasticsearch.exceptions import ConnectionError | |
| from natasha import Doc, MorphVocab, NewsEmbedding, NewsMorphTagger, Segmenter | |
| from common.common import ( | |
| get_elastic_abbreviation_query, | |
| get_elastic_group_query, | |
| get_elastic_people_query, | |
| get_elastic_query, | |
| get_elastic_rocks_nn_query, | |
| get_elastic_segmentation_query, | |
| ) | |
| from common.configuration import Configuration, Query, SummaryChunks | |
| from common.constants import PROMPT, PROMPT_CLASSIFICATION | |
| from components.elastic import create_index_elastic_chunks | |
| from components.elastic.elasticsearch_client import ElasticsearchClient | |
| from components.embedding_extraction import EmbeddingExtractor | |
| from components.nmd.aggregate_answers import aggregate_answers | |
| from components.nmd.faiss_vector_search import FaissVectorSearch | |
| from components.nmd.llm_chunk_search import LLMChunkSearch | |
| from components.nmd.metadata_manager import MetadataManager | |
| from components.nmd.query_classification import QueryClassification | |
| from components.nmd.rancker import DocumentRanking | |
| from components.services.dataset import DatasetService | |
| logger = logging.getLogger(__name__) | |
| class Dispatcher: | |
| def __init__( | |
| self, | |
| embedding_model: EmbeddingExtractor, | |
| config: Configuration, | |
| logger: Logger, | |
| dataset_service: DatasetService | |
| ): | |
| self.dataset_service = dataset_service | |
| self.config = config | |
| self.embedder = embedding_model | |
| self.dataset_id = None | |
| self.try_load_default_dataset() | |
| self.llm_search = LLMChunkSearch(config.llm_config, PROMPT, logger) | |
| if self.config.db_config.elastic.use_elastic: | |
| self.elastic_search = ElasticsearchClient( | |
| host=f'{config.db_config.elastic.es_host}', | |
| port=config.db_config.elastic.es_port, | |
| ) | |
| self.query_classification = QueryClassification( | |
| config.llm_config, PROMPT_CLASSIFICATION, logger | |
| ) | |
| self.segmenter = Segmenter() | |
| self.morph_tagger = NewsMorphTagger(NewsEmbedding()) | |
| self.morph_vocab = MorphVocab() | |
| def try_load_default_dataset(self): | |
| default_dataset = self.dataset_service.get_default_dataset() | |
| if default_dataset is not None and default_dataset.id is not None and default_dataset.id != self.dataset_id: | |
| logger.info(f'Reloading dataset {default_dataset.id}') | |
| self.reset_dataset(default_dataset.id) | |
| else: | |
| self.faiss_search = None | |
| self.meta_database = None | |
| def reset_dataset(self, dataset_id: int): | |
| logger.info(f'Reset dataset to dataset_id: {dataset_id}') | |
| data_path = Path(self.config.db_config.faiss.path_to_metadata) | |
| df = pd.read_pickle(data_path / str(dataset_id) / 'dataset.pkl') | |
| logger.info(f'Dataset loaded from {data_path / str(dataset_id) / "dataset.pkl"}') | |
| logger.info(f'Dataset shape: {df.shape}') | |
| self.faiss_search = FaissVectorSearch(self.embedder, df, self.config.db_config) | |
| logger.info(f'Faiss search initialized') | |
| self.meta_database = MetadataManager(df, logger) | |
| logger.info(f'Meta database initialized') | |
| if self.config.db_config.elastic.use_elastic: | |
| create_index_elastic_chunks(df, logger) | |
| logger.info(f'Elastic index created') | |
| self.document_ranking = DocumentRanking(df, self.config) | |
| logger.info(f'Document ranking initialized') | |
| def __vector_search(self, query: str) -> Dict[int, Dict]: | |
| """ | |
| Метод для поиска ближайших векторов по векторной базе Faiss. | |
| Args: | |
| query: Запрос пользователя. | |
| Returns: | |
| возвращает словарь chunks. | |
| """ | |
| query_embeds, scores, indexes = self.faiss_search.search_vectors(query) | |
| if self.config.db_config.ranker.use_ranging: | |
| indexes = self.document_ranking.doc_ranking(query_embeds, scores, indexes) | |
| return self.meta_database.search(indexes) | |
| def __elastic_search( | |
| self, query: str, index_name: str, search_function, size: int | |
| ) -> Dict: | |
| """ | |
| Метод для полнотекстового поиска. | |
| Args: | |
| query: Запрос пользователя. | |
| index_name: Наименование индекса. | |
| search_function: Функция запроса, зависит от индекса по которому нужно искать. | |
| size: Количество ближайших соседей, или размер выборки. | |
| Returns: | |
| Возвращает словарь c ответами. | |
| """ | |
| self.elastic_search.set_index(index_name) | |
| return self.elastic_search.search(query=search_function(query), size=size) | |
| def _get_indexes_full_text_elastic_search(elastic_answer: Dict) -> List: | |
| """ | |
| Метод позволяет получить индексы чанков, которые нашел elastic. | |
| Args: | |
| elastic_answer: Результаты полнотекстового поиска по чанкам. | |
| Returns: | |
| Возвращает список индексов. | |
| """ | |
| answer = [] | |
| for answer_dict in elastic_answer: | |
| answer.append(answer_dict['_source']['index']) | |
| return answer | |
| def _lemmatization_text(self, text: str): | |
| doc = Doc(text) | |
| doc.segment(self.segmenter) | |
| doc.tag_morph(self.morph_tagger) | |
| for token in doc.tokens: | |
| token.lemmatize(self.morph_vocab) | |
| return ' '.join([token.lemma for token in doc.tokens]) | |
| def _get_abbreviations(self, query: Query): | |
| query_abbreviation = query.query_abbreviation | |
| abbreviations_replaced = query.abbreviations_replaced | |
| try: | |
| if self.config.db_config.elastic.use_elastic: | |
| if ( | |
| self.config.db_config.search.abbreviation_search.use_abbreviation_search | |
| ): | |
| abbreviation_answer = self.__elastic_search( | |
| query=query.query, | |
| index_name=self.config.db_config.search.abbreviation_search.index_name, | |
| search_function=get_elastic_abbreviation_query, | |
| size=self.config.db_config.search.abbreviation_search.k_neighbors, | |
| ) | |
| if len(abbreviation_answer) > 0: | |
| query_lemmatization = self._lemmatization_text(query.query) | |
| for abbreviation in abbreviation_answer: | |
| abbreviation_lemmatization = self._lemmatization_text( | |
| abbreviation['_source']['text'].lower() | |
| ) | |
| if abbreviation_lemmatization in query_lemmatization: | |
| query_abbreviation_lemmatization = ( | |
| self._lemmatization_text(query_abbreviation) | |
| ) | |
| index = re.search( | |
| abbreviation_lemmatization, | |
| query_abbreviation_lemmatization, | |
| ).span()[1] | |
| space_index = query_abbreviation.find(' ', index) | |
| if space_index != -1: | |
| query_abbreviation = '{} ({}) {}'.format( | |
| query_abbreviation[:space_index], | |
| abbreviation["_source"]["abbreviation"], | |
| query_abbreviation[space_index:], | |
| ) | |
| else: | |
| query_abbreviation = '{} ({})'.format( | |
| query_abbreviation, | |
| abbreviation["_source"]["abbreviation"], | |
| ) | |
| except ConnectionError: | |
| logger.info("Connection Error Elasticsearch") | |
| return Query( | |
| query=query.query, | |
| query_abbreviation=query_abbreviation, | |
| abbreviations_replaced=abbreviations_replaced, | |
| ) | |
| def search_answer(self, query: Query) -> SummaryChunks: | |
| """ | |
| Метод для поиска чанков отвечающих на вопрос пользователя в разных типах поиска. | |
| Args: | |
| query: Запрос пользователя. | |
| Returns: | |
| Возвращает чанки найденные на запрос пользователя. | |
| """ | |
| self.try_load_default_dataset() | |
| query = self._get_abbreviations(query) | |
| logger.info(f'Start search for {query.query_abbreviation}') | |
| logger.info(f'Use elastic search: {self.config.db_config.elastic.use_elastic}') | |
| answer = {} | |
| if self.config.db_config.search.vector_search.use_vector_search: | |
| logger.info('Start vector search.') | |
| answer['vector_answer'] = self.__vector_search(query.query_abbreviation) | |
| logger.info(f'Vector search found {len(answer["vector_answer"])} chunks') | |
| try: | |
| if self.config.db_config.elastic.use_elastic: | |
| if self.config.db_config.search.people_elastic_search.use_people_search: | |
| logger.info('Start people search.') | |
| people_answer = self.__elastic_search( | |
| query.query, | |
| index_name=self.config.db_config.search.people_elastic_search.index_name, | |
| search_function=get_elastic_people_query, | |
| size=self.config.db_config.search.people_elastic_search.k_neighbors, | |
| ) | |
| logger.info(f'People search found {len(people_answer)} chunks') | |
| answer['people_answer'] = people_answer | |
| if self.config.db_config.search.chunks_elastic_search.use_chunks_search: | |
| logger.info('Start full text chunks search.') | |
| chunks_answer = self.__elastic_search( | |
| query.query, | |
| index_name=self.config.db_config.search.chunks_elastic_search.index_name, | |
| search_function=get_elastic_query, | |
| size=self.config.db_config.search.chunks_elastic_search.k_neighbors, | |
| ) | |
| indexes = self._get_indexes_full_text_elastic_search(chunks_answer) | |
| chunks_answer = self.meta_database.search(indexes) | |
| logger.info( | |
| f'Full text chunks search found {len(chunks_answer)} chunks' | |
| ) | |
| answer['chunks_answer'] = chunks_answer | |
| if self.config.db_config.search.groups_elastic_search.use_groups_search: | |
| logger.info('Start groups search.') | |
| groups_answer = self.__elastic_search( | |
| query.query, | |
| index_name=self.config.db_config.search.groups_elastic_search.index_name, | |
| search_function=get_elastic_group_query, | |
| size=self.config.db_config.search.groups_elastic_search.k_neighbors, | |
| ) | |
| if len(groups_answer) != 0: | |
| logger.info(f'Groups search found {len(groups_answer)} chunks') | |
| answer['groups_answer'] = groups_answer | |
| if ( | |
| self.config.db_config.search.rocks_nn_elastic_search.use_rocks_nn_search | |
| ): | |
| logger.info('Start Rocks NN search.') | |
| rocks_nn_answer = self.__elastic_search( | |
| query.query, | |
| index_name=self.config.db_config.search.rocks_nn_elastic_search.index_name, | |
| search_function=get_elastic_rocks_nn_query, | |
| size=self.config.db_config.search.rocks_nn_elastic_search.k_neighbors, | |
| ) | |
| if len(rocks_nn_answer) != 0: | |
| logger.info( | |
| f'Rocks NN search found {len(rocks_nn_answer)} chunks' | |
| ) | |
| answer['rocks_nn_answer'] = rocks_nn_answer | |
| if ( | |
| self.config.db_config.search.segmentation_elastic_search.use_segmentation_search | |
| ): | |
| logger.info('Start Segmentation search.') | |
| segmentation_answer = self.__elastic_search( | |
| query.query, | |
| index_name=self.config.db_config.search.segmentation_elastic_search.index_name, | |
| search_function=get_elastic_segmentation_query, | |
| size=self.config.db_config.search.segmentation_elastic_search.k_neighbors, | |
| ) | |
| if len(segmentation_answer) != 0: | |
| logger.info( | |
| f'Segmentation search found {len(segmentation_answer)} chunks' | |
| ) | |
| answer['segmentation_answer'] = segmentation_answer | |
| except ConnectionError: | |
| logger.info("Connection Error Elasticsearch") | |
| final_answer = aggregate_answers(**answer) | |
| logger.info(f'Final answer found {len(final_answer)} chunks') | |
| return SummaryChunks(**final_answer) | |
| def llm_classification(self, query: str) -> str: | |
| type_query = self.query_classification.classification(query) | |
| return type_query | |
| def llm_answer( | |
| self, query: str, answer_chunks: SummaryChunks | |
| ) -> Tuple[str, str, str, int]: | |
| """ | |
| Метод для поиска правильного ответа с помощью LLM. | |
| Args: | |
| query: Запрос. | |
| answer_chunks: Ответы векторного поиска и elastic. | |
| Returns: | |
| Возвращает исходные chunks из поисков, и chunk который выбрала модель. | |
| """ | |
| prompt = PROMPT | |
| return self.llm_search.llm_chunk_search(query, answer_chunks, prompt) | |