Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Callable | |
| import numpy as np | |
| from components.embedding_extraction import EmbeddingExtractor | |
| from components.parser.abbreviations.abbreviation import Abbreviation | |
| from components.parser.abbreviations.structures import AbbreviationsCollection | |
| from components.parser.features.dataset_creator import DatasetCreator | |
| from components.parser.features.documents_dataset import DatasetRow, DocumentsDataset | |
| from components.parser.features.hierarchy_parser import Hierarchy, HierarchyParser | |
| from components.parser.paths import DatasetPaths | |
| from components.parser.xml import ParsedXMLs, XMLParser | |
| from components.parser.xml.constants import ACTUAL_STATUSES | |
| logger = logging.getLogger(__name__) | |
| class DatasetCreationPipeline: | |
| """ | |
| Пайплайн для обработки XML файлов со следующими шагами: | |
| 1. Парсинг XML файлов из директории | |
| 2. Извлечение аббревиатур из распаршенного контента | |
| 3. Применение аббревиатур к текстовому и табличному контенту | |
| 4. Обработка контента с помощью HierarchyParser | |
| 5. Создание и сохранение финального датасета | |
| """ | |
| def __init__( | |
| self, | |
| dataset_id: int, | |
| prepared_abbreviations: list[Abbreviation], | |
| document_ids: list[int], | |
| document_formats: list[str], | |
| datasets_path: Path, | |
| documents_path: Path, | |
| vectorizer: EmbeddingExtractor | None = None, | |
| save_intermediate_files: bool = False, | |
| old_dataset_id: int | None = None, | |
| ) -> None: | |
| """ | |
| Инициализация пайплайна. | |
| Args: | |
| dataset_id: Идентификатор датасета | |
| vectorizer: Векторизатор для создания эмбеддингов | |
| prepared_abbreviations: Датафрейм с аббревиатурами, извлечёнными ранее | |
| xml_ids: Список идентификаторов XML файлов | |
| save_intermediate_files: Флаг, указывающий, нужно ли сохранять промежуточные файлы | |
| old_dataset: Старый датасет, если он есть | |
| """ | |
| self.datasets_path = datasets_path | |
| self.documents_path = documents_path | |
| self.dataset_id = dataset_id | |
| self.paths = DatasetPaths( | |
| self.datasets_path / str(dataset_id), save_intermediate_files | |
| ) | |
| self.document_ids = document_ids | |
| self.document_formats = document_formats | |
| self.prepared_abbreviations = self._group_abbreviations(prepared_abbreviations) | |
| self.dataset_creator = DatasetCreator() | |
| self.vectorizer = vectorizer | |
| self.xml_parser = XMLParser() | |
| self.hierarchy_parser = HierarchyParser() | |
| self.abbreviations: AbbreviationsCollection | None = None | |
| self.info: ParsedXMLs | None = None | |
| self.dataset: DocumentsDataset | None = None | |
| self.old_paths = ( | |
| DatasetPaths( | |
| self.datasets_path / str(old_dataset_id), | |
| save_intermediate_files, | |
| ) | |
| if old_dataset_id | |
| else None | |
| ) | |
| logger.info(f'DatasetCreationPipeline initialized for {dataset_id}') | |
| def run( | |
| self, | |
| progress_callback: Callable[[int, int], None] | None = None, | |
| ) -> DocumentsDataset: | |
| """ | |
| Выполнение полного пайплайна обработки. | |
| Args: | |
| progress_callback: Функция, которая будет вызываться при каждом шаге векторизации. | |
| Принимает два аргумента: current и total. | |
| current - текущий шаг векторизации. | |
| total - общее количество шагов векторизации. | |
| Returns: | |
| DocumentsDataset: Векторизованный датасет. | |
| """ | |
| logger.info(f'Running pipeline for {self.dataset_id}') | |
| # Создание выходной директории | |
| Path(self.paths.root_path).mkdir(parents=True, exist_ok=True) | |
| logger.info('Folder created') | |
| logger.info('Processing XML files...') | |
| # Парсинг XML файлов | |
| parsed_xmls = self.process_xml_files() | |
| logger.info('XML files processed') | |
| logger.info('Saving XML info...') | |
| self.info = [xml.only_info() for xml in parsed_xmls.xmls] | |
| parsed_xmls.to_pandas().to_csv( | |
| self.paths.xml_info, | |
| index=False, | |
| ) | |
| logger.info('XML info saved') | |
| logger.info('Saving txt files...') | |
| # Сохранение промежуточных txt файлов | |
| if self.paths.save_intermediate_files: | |
| self._save_txt_files(parsed_xmls) | |
| logger.info('Txt files saved') | |
| logger.info('Processing abbreviations...') | |
| # Обработка аббревиатур | |
| self.abbreviations = self.process_abbreviations(parsed_xmls) | |
| logger.info('Abbreviations processed') | |
| logger.info('Saving abbreviations...') | |
| AbbreviationsCollection(self.abbreviations).to_pandas().to_csv( | |
| self.paths.abbreviations, | |
| index=False, | |
| ) | |
| logger.info('Abbreviations saved') | |
| logger.info('Saving txt files with abbreviations...') | |
| # Сохранение промежуточных txt файлов с применением аббревиатур | |
| if self.paths.save_intermediate_files: | |
| self._save_txt_files(parsed_xmls) | |
| logger.info('Txt files with abbreviations saved') | |
| logger.info('Extracting hierarchies...') | |
| hierarchies = self._extract_hierarchies(parsed_xmls) | |
| logger.info('Hierarchies extracted') | |
| logger.info('Saving hierarchies...') | |
| if self.paths.save_intermediate_files: | |
| self._save_hierarchies(hierarchies) | |
| logger.info('Hierarchies saved') | |
| logger.info('Creating dataset...') | |
| dataset = self.create_dataset(parsed_xmls, hierarchies) | |
| if self.vectorizer: | |
| logger.info('Vectorizing dataset...') | |
| dataset.vectorize_with( | |
| self.vectorizer, | |
| progress_callback=progress_callback, | |
| ) | |
| logger.info('Dataset vectorized') | |
| logger.info('Saving dataset...') | |
| dataset.to_pickle(self.paths.dataset) | |
| logger.info('Dataset saved') | |
| return dataset | |
| def process_xml_files(self) -> ParsedXMLs: | |
| """ | |
| Парсинг XML файлов из указанной директории. | |
| Возвращает: | |
| ParsedXMLs: Структура с данными из всех XML файлов | |
| """ | |
| parsed_xmls = [] | |
| for document_id, document_format in zip( | |
| self.document_ids, self.document_formats | |
| ): | |
| parsed_xml = XMLParser.parse( | |
| self.documents_path / f'{document_id}.{document_format}', | |
| include_content=True, | |
| ) | |
| if ('состав' in parsed_xml.name.lower()) or ( | |
| 'составы' in parsed_xml.name.lower() | |
| ): | |
| continue | |
| if parsed_xml.status not in ACTUAL_STATUSES: | |
| continue | |
| parsed_xml.id = document_id | |
| parsed_xmls.append(parsed_xml) | |
| return ParsedXMLs(parsed_xmls) | |
| def process_abbreviations( | |
| self, | |
| parsed_xmls: ParsedXMLs, | |
| ) -> list[Abbreviation]: | |
| """ | |
| Обработка и применение аббревиатур к контенту документов. | |
| Теперь аббревиатуры уже извлечены во время парсинга, этот метод: | |
| 1. Устанавливает document_id для извлеченных аббревиатур | |
| 2. Применяет только документно-специфичные аббревиатуры к соответствующим документам | |
| 3. Объединяет все аббревиатуры (извлеченные и предварительно подготовленные) для возврата | |
| Args: | |
| parsed_xmls: Структура с данными из всех XML файлов | |
| Returns: | |
| list[Abbreviation]: Список всех аббревиатур для датасета | |
| """ | |
| all_abbreviations = {} | |
| # Итерируем по документам | |
| for xml in parsed_xmls.xmls: | |
| # Устанавливаем document_id для извлеченных аббревиатур, если они есть | |
| doc_specific_abbreviations = [] | |
| if xml.abbreviations: | |
| for abbreviation in xml.abbreviations: | |
| abbreviation.document_id = xml.id | |
| doc_specific_abbreviations = xml.abbreviations | |
| # Применяем только аббревиатуры, извлеченные из этого документа | |
| if doc_specific_abbreviations: | |
| # Если есть аббревиатуры из документа, применяем их | |
| xml.apply_abbreviations(doc_specific_abbreviations) | |
| # Получаем подготовленные аббревиатуры для текущего документа | |
| prepared_abbr = self.prepared_abbreviations.get(xml.id, []) | |
| # Объединяем все аббревиатуры для возврата (не для применения) | |
| combined_abbr = (doc_specific_abbreviations or []) + prepared_abbr | |
| # Сохраняем объединенный список в document.abbreviations и в общем словаре | |
| if combined_abbr: | |
| xml.abbreviations = combined_abbr | |
| all_abbreviations[xml.id] = combined_abbr | |
| return self._ungroup_abbreviations(all_abbreviations) | |
| def _get_already_parsed_xmls( | |
| self, | |
| ) -> tuple[list[int], list[DatasetRow], list[np.ndarray]]: | |
| if self.old_paths: | |
| self.old_dataset = DocumentsDataset.from_pickle(self.old_paths.dataset) | |
| ids = set([int(row.DocNumber) for row in self.old_dataset.rows]) | |
| ids = ids.intersection(self.xml_ids) | |
| rows = [row for row in self.old_dataset.rows if row.DocNumber in ids] | |
| embs = [ | |
| emb | |
| for row, emb in zip(rows, self.old_dataset.vectors) | |
| if row.DocNumber in ids | |
| ] | |
| return ids, rows, embs | |
| return [], [], [] | |
| def _extract_hierarchies( | |
| self, | |
| parsed_xmls: ParsedXMLs, | |
| ) -> dict[int, tuple[Hierarchy, Hierarchy]]: | |
| """ | |
| Извлечение иерархических структур из текстового и табличного контента. | |
| Args: | |
| parsed_xmls: Структура с данными из всех XML файлов | |
| Returns: | |
| dict[int, tuple[Hierarchy, Hierarchy]]: Словарь иерархических структур для каждого документа | |
| """ | |
| hierarchies = {} | |
| for xml in parsed_xmls.xmls: | |
| doc_id = xml.id | |
| # Обработка текстового контента | |
| if xml.text: | |
| text_lines = xml.text.to_text().split('\n') | |
| self.hierarchy_parser.parse(text_lines, doc_id, '') | |
| text_hierarchy = self.hierarchy_parser.hierarchy() | |
| else: | |
| text_hierarchy = {} | |
| # Обработка табличного контента | |
| if xml.tables: | |
| table_lines = xml.tables.to_text().split('\n') | |
| self.hierarchy_parser.parse_table(table_lines, doc_id) | |
| table_hierarchy = self.hierarchy_parser.hierarchy() | |
| else: | |
| table_hierarchy = {} | |
| hierarchies[doc_id] = (text_hierarchy, table_hierarchy) | |
| return hierarchies | |
| def create_dataset( | |
| self, | |
| parsed_xmls: ParsedXMLs, | |
| hierarchies: dict[int, tuple[Hierarchy, Hierarchy]], | |
| ) -> DocumentsDataset: | |
| """ | |
| Создание финального датасета с векторизацией. | |
| Args: | |
| parsed_xmls: Структура с данными из всех XML файлов | |
| hierarchies: Словарь с иерархической структурой документов | |
| Returns: | |
| DocumentsDataset: Датасет с векторизованными текстами | |
| """ | |
| xmls = {xml.id: xml for xml in parsed_xmls.xmls} | |
| self.dataset = self.dataset_creator.create_dataset(xmls, hierarchies) | |
| return self.dataset | |
| def _group_abbreviations( | |
| self, | |
| abbreviations: list[Abbreviation], | |
| ) -> dict[int, list[Abbreviation]]: | |
| """ | |
| Преобразует список аббревиатур в словарь, где ключи - идентификаторы документов, а значения - списки аббревиатур. | |
| """ | |
| doc_to_abbreviations = defaultdict(list) | |
| for abbreviation in abbreviations: | |
| doc_to_abbreviations[abbreviation.document_id].append(abbreviation) | |
| return doc_to_abbreviations | |
| def _ungroup_abbreviations( | |
| self, abbreviations: dict[int, list[Abbreviation]] | |
| ) -> list[Abbreviation]: | |
| """ | |
| Преобразует словарь аббревиатур в список аббревиатур. | |
| """ | |
| return sum(abbreviations.values(), []) | |
| def _save_txt_files(self, parsed_xmls: ParsedXMLs) -> None: | |
| """ | |
| Сохранение текстового и табличного контента в текстовые файлы. | |
| """ | |
| self.paths.txt_path.mkdir(parents=True, exist_ok=True) | |
| for xml in parsed_xmls.xmls: | |
| with open(self.paths.txt_path / f'{xml.id}.txt', 'w', encoding='utf-8') as f: | |
| f.write(xml.text.to_text()) | |
| if xml.tables: | |
| with open(self.paths.txt_path / f'{xml.id}_table.txt', 'w', encoding='utf-8') as f: | |
| f.write(xml.tables.to_text()) | |
| def _save_hierarchies( | |
| self, | |
| hierarchies: dict[int, tuple[Hierarchy, Hierarchy]], | |
| ) -> None: | |
| """ | |
| Сохранение иерархий в JSON файлы. | |
| """ | |
| self.paths.jsons_path.mkdir(parents=True, exist_ok=True) | |
| for doc_id, (text_hierarchy, table_hierarchy) in hierarchies.items(): | |
| if text_hierarchy: | |
| with open(self.paths.jsons_path / f'{doc_id}.json', 'w', encoding='utf-8') as f: | |
| json.dump(text_hierarchy, f) | |
| if table_hierarchy: | |
| with open(self.paths.jsons_path / f'{doc_id}_table.json', 'w', encoding='utf-8') as f: | |
| json.dump(table_hierarchy, f) | |