Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| """ | |
| Скрипт для оценки качества различных стратегий чанкинга. | |
| Сравнивает стратегии на основе релевантности чанков к вопросам. | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from fuzzywuzzy import fuzz | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from tqdm import tqdm | |
| from transformers import AutoModel, AutoTokenizer | |
| # Константы для настройки | |
| DATA_FOLDER = "data/docs" # Путь к папке с документами | |
| MODEL_NAME = "intfloat/e5-base" # Название модели для векторизации | |
| DATASET_PATH = "data/dataset.xlsx" # Путь к Excel-датасету с вопросами | |
| BATCH_SIZE = 8 # Размер батча для векторизации | |
| DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu" # Устройство для вычислений | |
| SIMILARITY_THRESHOLD = 0.7 # Порог для нечеткого сравнения | |
| OUTPUT_DIR = "data" # Директория для сохранения результатов | |
| TOP_CHUNKS_DIR = "data/top_chunks" # Директория для сохранения топ-чанков | |
| TOP_N_VALUES = [5, 10, 20, 30, 50, 70, 100] # Значения N для оценки | |
| # Параметры стратегий чанкинга | |
| FIXED_SIZE_CONFIG = { | |
| "words_per_chunk": 50, # Количество слов в чанке | |
| "overlap_words": 25 # Количество слов перекрытия | |
| } | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from ntr_fileparser import UniversalParser | |
| from ntr_text_fragmentation import Destructurer | |
| def _average_pool( | |
| last_hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Расчёт усредненного эмбеддинга по всем токенам | |
| Args: | |
| last_hidden_states: Матрица эмбеддингов отдельных токенов размерности (batch_size, seq_len, embedding_size) - последний скрытый слой | |
| attention_mask: Маска, чтобы не учитывать при усреднении пустые токены | |
| Returns: | |
| torch.Tensor - Усредненный эмбеддинг размерности (batch_size, embedding_size) | |
| """ | |
| last_hidden = last_hidden_states.masked_fill( | |
| ~attention_mask[..., None].bool(), 0.0 | |
| ) | |
| return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
| def parse_args(): | |
| """ | |
| Парсит аргументы командной строки. | |
| Returns: | |
| Аргументы командной строки | |
| """ | |
| parser = argparse.ArgumentParser(description="Скрипт для оценки качества чанкинга") | |
| parser.add_argument("--data-folder", type=str, default=DATA_FOLDER, | |
| help=f"Путь к папке с документами (по умолчанию: {DATA_FOLDER})") | |
| parser.add_argument("--model-name", type=str, default=MODEL_NAME, | |
| help=f"Название модели для векторизации (по умолчанию: {MODEL_NAME})") | |
| parser.add_argument("--dataset-path", type=str, default=DATASET_PATH, | |
| help=f"Путь к Excel-датасету с вопросами (по умолчанию: {DATASET_PATH})") | |
| parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, | |
| help=f"Размер батча для векторизации (по умолчанию: {BATCH_SIZE})") | |
| parser.add_argument("--similarity-threshold", type=float, default=SIMILARITY_THRESHOLD, | |
| help=f"Порог для нечеткого сравнения (по умолчанию: {SIMILARITY_THRESHOLD})") | |
| parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR, | |
| help=f"Директория для сохранения результатов (по умолчанию: {OUTPUT_DIR})") | |
| parser.add_argument("--force-recompute", action="store_true", | |
| help="Принудительно пересчитать эмбеддинги, игнорируя сохраненные") | |
| parser.add_argument("--use-sentence-transformers", action="store_true", | |
| help="Использовать библиотеку sentence_transformers для извлечения эмбеддингов (для FRIDA и других моделей)") | |
| parser.add_argument("--device", type=str, default=DEVICE, | |
| help=f"Устройство для вычислений (по умолчанию: {DEVICE})") | |
| # Параметры для fixed_size стратегии | |
| parser.add_argument("--words-per-chunk", type=int, default=FIXED_SIZE_CONFIG["words_per_chunk"], | |
| help=f"Количество слов в чанке для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['words_per_chunk']})") | |
| parser.add_argument("--overlap-words", type=int, default=FIXED_SIZE_CONFIG["overlap_words"], | |
| help=f"Количество слов перекрытия для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['overlap_words']})") | |
| return parser.parse_args() | |
| def read_documents(folder_path: str) -> dict: | |
| """ | |
| Читает все документы из указанной папки. | |
| Args: | |
| folder_path: Путь к папке с документами | |
| Returns: | |
| Словарь {имя_файла: parsed_document} | |
| """ | |
| print(f"Чтение документов из {folder_path}...") | |
| parser = UniversalParser() | |
| documents = {} | |
| for file_path in tqdm(list(Path(folder_path).glob("*.docx")), desc="Чтение документов"): | |
| try: | |
| doc_name = file_path.stem | |
| documents[doc_name] = parser.parse_by_path(str(file_path)) | |
| except Exception as e: | |
| print(f"Ошибка при чтении файла {file_path}: {e}") | |
| return documents | |
| def process_documents(documents: dict, fixed_size_config: dict) -> pd.DataFrame: | |
| """ | |
| Обрабатывает документы со стратегией fixed_size для чанкинга. | |
| Args: | |
| documents: Словарь с распарсенными документами | |
| fixed_size_config: Конфигурация для fixed_size стратегии | |
| Returns: | |
| DataFrame с чанками | |
| """ | |
| print("Обработка документов стратегией fixed_size...") | |
| all_data = [] | |
| for doc_name, document in tqdm(documents.items(), desc="Применение стратегии fixed_size"): | |
| # Стратегия fixed_size для чанкинга | |
| destructurer = Destructurer(document) | |
| destructurer.configure('fixed_size', | |
| words_per_chunk=fixed_size_config["words_per_chunk"], | |
| overlap_words=fixed_size_config["overlap_words"]) | |
| fixed_size_entities, _ = destructurer.destructure() | |
| # Обрабатываем только сущности для поиска | |
| for entity in fixed_size_entities: | |
| if hasattr(entity, 'use_in_search') and entity.use_in_search: | |
| entity_data = { | |
| 'id': str(entity.id), | |
| 'doc_name': doc_name, | |
| 'name': entity.name, | |
| 'text': entity.text, | |
| 'type': entity.type, | |
| 'strategy': 'fixed_size', | |
| 'metadata': json.dumps(entity.metadata, ensure_ascii=False) | |
| } | |
| all_data.append(entity_data) | |
| # Создаем DataFrame | |
| df = pd.DataFrame(all_data) | |
| # Фильтруем по типу, исключая Document | |
| df = df[df['type'] != 'Document'] | |
| return df | |
| def load_questions_dataset(file_path: str) -> pd.DataFrame: | |
| """ | |
| Загружает датасет с вопросами из Excel-файла. | |
| Args: | |
| file_path: Путь к Excel-файлу | |
| Returns: | |
| DataFrame с вопросами и пунктами | |
| """ | |
| print(f"Загрузка датасета из {file_path}...") | |
| df = pd.read_excel(file_path) | |
| print(f"Загружен датасет со столбцами: {df.columns.tolist()}") | |
| # Преобразуем NaN в пустые строки для текстовых полей | |
| text_columns = ['question', 'text', 'item_type'] | |
| for col in text_columns: | |
| if col in df.columns: | |
| df[col] = df[col].fillna('') | |
| return df | |
| def setup_model_and_tokenizer(model_name: str, use_sentence_transformers: bool = False, device: str = DEVICE): | |
| """ | |
| Инициализирует модель и токенизатор. | |
| Args: | |
| model_name: Название предобученной модели | |
| use_sentence_transformers: Использовать ли библиотеку sentence_transformers | |
| device: Устройство для вычислений | |
| Returns: | |
| Кортеж (модель, токенизатор) или объект SentenceTransformer | |
| """ | |
| print(f"Загрузка модели {model_name} на устройство {device}...") | |
| if use_sentence_transformers: | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| model = SentenceTransformer(model_name, device=device) | |
| return model, None | |
| except ImportError: | |
| print("Библиотека sentence_transformers не установлена. Установите её с помощью pip install sentence-transformers") | |
| raise | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name).to(device) | |
| model.eval() | |
| return model, tokenizer | |
| def get_embeddings(texts: list[str], model, tokenizer=None, batch_size: int = BATCH_SIZE, use_sentence_transformers: bool = False, device: str = DEVICE) -> np.ndarray: | |
| """ | |
| Получает эмбеддинги для списка текстов с использованием average pooling или sentence_transformers. | |
| Args: | |
| texts: Список текстов | |
| model: Модель для векторизации или SentenceTransformer | |
| tokenizer: Токенизатор (None для sentence_transformers) | |
| batch_size: Размер батча | |
| use_sentence_transformers: Использовать ли библиотеку sentence_transformers | |
| device: Устройство для вычислений | |
| Returns: | |
| Массив эмбеддингов | |
| """ | |
| if use_sentence_transformers: | |
| # Используем sentence_transformers для получения эмбеддингов | |
| all_embeddings = [] | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов (sentence_transformers)"): | |
| batch_texts = texts[i:i+batch_size] | |
| # Получаем эмбеддинги с помощью sentence_transformers | |
| embeddings = model.encode(batch_texts, batch_size=batch_size, show_progress_bar=False) | |
| all_embeddings.append(embeddings) | |
| return np.vstack(all_embeddings) | |
| else: | |
| # Используем стандартный подход с average pooling | |
| all_embeddings = [] | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов"): | |
| batch_texts = texts[i:i+batch_size] | |
| # Токенизация с обрезкой и padding | |
| encoding = tokenizer( | |
| batch_texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(device) | |
| # Получаем эмбеддинги с average pooling | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"]) | |
| all_embeddings.append(embeddings.cpu().numpy()) | |
| return np.vstack(all_embeddings) | |
| def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float: | |
| """ | |
| Рассчитывает степень перекрытия между чанком и пунктом с использованием partial_ratio. | |
| Args: | |
| chunk_text: Текст чанка | |
| punct_text: Текст пункта | |
| Returns: | |
| Коэффициент перекрытия от 0 до 1 | |
| """ | |
| # Если чанк входит в пункт, возвращаем 1.0 (полное вхождение) | |
| if chunk_text in punct_text: | |
| return 1.0 | |
| # Если пункт входит в чанк, возвращаем соотношение длин | |
| if punct_text in chunk_text: | |
| return len(punct_text) / len(chunk_text) | |
| # Используем partial_ratio из fuzzywuzzy, который лучше обрабатывает | |
| # случаи, когда один текст является подстрокой другого, даже с небольшими различиями | |
| partial_ratio_score = fuzz.partial_ratio(chunk_text, punct_text) / 100.0 | |
| return partial_ratio_score | |
| def save_embeddings_and_data(embeddings: np.ndarray, data: pd.DataFrame, filename: str, output_dir: str): | |
| """ | |
| Сохраняет эмбеддинги и соответствующие данные в файлы. | |
| Args: | |
| embeddings: Массив эмбеддингов | |
| data: DataFrame с данными | |
| filename: Базовое имя файла | |
| output_dir: Директория для сохранения | |
| """ | |
| embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy") | |
| data_path = os.path.join(output_dir, f"{filename}_data.csv") | |
| # Сохраняем эмбеддинги | |
| np.save(embeddings_path, embeddings) | |
| print(f"Эмбеддинги сохранены в {embeddings_path}") | |
| # Сохраняем данные | |
| data.to_csv(data_path, index=False) | |
| print(f"Данные сохранены в {data_path}") | |
| def load_embeddings_and_data(filename: str, output_dir: str) -> tuple[np.ndarray | None, pd.DataFrame | None]: | |
| """ | |
| Загружает эмбеддинги и соответствующие данные из файлов. | |
| Args: | |
| filename: Базовое имя файла | |
| output_dir: Директория, где хранятся файлы | |
| Returns: | |
| Кортеж (эмбеддинги, данные) или (None, None), если файлы не найдены | |
| """ | |
| embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy") | |
| data_path = os.path.join(output_dir, f"{filename}_data.csv") | |
| if os.path.exists(embeddings_path) and os.path.exists(data_path): | |
| print(f"Загрузка данных из {embeddings_path} и {data_path}...") | |
| embeddings = np.load(embeddings_path) | |
| data = pd.read_csv(data_path) | |
| return embeddings, data | |
| return None, None | |
| def save_top_chunks_for_question( | |
| question_id: int, | |
| question_text: str, | |
| question_puncts: list[str], | |
| top_chunks: pd.DataFrame, | |
| similarities: dict, | |
| overlap_data: list, | |
| output_dir: str | |
| ): | |
| """ | |
| Сохраняет топ-чанки для конкретного вопроса в JSON-файл. | |
| Args: | |
| question_id: ID вопроса | |
| question_text: Текст вопроса | |
| question_puncts: Список пунктов, относящихся к вопросу | |
| top_chunks: DataFrame с топ-чанками | |
| similarities: Словарь с косинусными схожестями для чанков | |
| overlap_data: Данные о перекрытии чанков с пунктами | |
| output_dir: Директория для сохранения | |
| """ | |
| # Подготавливаем результаты для сохранения | |
| chunks_data = [] | |
| for i, (idx, chunk) in enumerate(top_chunks.iterrows()): | |
| # Получаем данные о перекрытии для текущего чанка | |
| chunk_overlaps = overlap_data[i] if i < len(overlap_data) else [] | |
| # Преобразуем numpy типы в стандартные типы Python | |
| similarity = float(similarities.get(idx, 0.0)) | |
| # Формируем данные чанка | |
| chunk_data = { | |
| 'chunk_id': chunk['id'], | |
| 'doc_name': chunk['doc_name'], | |
| 'text': chunk['text'], | |
| 'similarity': similarity, | |
| 'overlaps': chunk_overlaps | |
| } | |
| chunks_data.append(chunk_data) | |
| # Преобразуем numpy.int64 в int для question_id | |
| question_id = int(question_id) | |
| # Формируем общий результат | |
| result = { | |
| 'question_id': question_id, | |
| 'question_text': question_text, | |
| 'puncts': question_puncts, | |
| 'chunks': chunks_data | |
| } | |
| # Создаем имя файла | |
| filename = f"question_{question_id}_top_chunks.json" | |
| filepath = os.path.join(output_dir, filename) | |
| # Класс для сериализации numpy типов | |
| class NumpyEncoder(json.JSONEncoder): | |
| def default(self, obj): | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| if isinstance(obj, np.floating): | |
| return float(obj) | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| return super().default(obj) | |
| # Сохраняем в JSON с кастомным энкодером | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(result, f, ensure_ascii=False, indent=2, cls=NumpyEncoder) | |
| print(f"Топ-чанки для вопроса {question_id} сохранены в {filepath}") | |
| def evaluate_for_top_n_with_mapping( | |
| questions_df: pd.DataFrame, | |
| chunks_df: pd.DataFrame, | |
| question_embeddings: np.ndarray, | |
| chunk_embeddings: np.ndarray, | |
| question_id_to_idx: dict, | |
| top_n: int, | |
| similarity_threshold: float, | |
| top_chunks_dir: str = None | |
| ) -> tuple[dict[str, float], pd.DataFrame]: | |
| """ | |
| Оценивает качество чанкинга для заданного значения top_n с использованием маппинга id -> индекс. | |
| Args: | |
| questions_df: DataFrame с вопросами и релевантными пунктами (исходный датасет) | |
| chunks_df: DataFrame с чанками | |
| question_embeddings: Эмбеддинги вопросов | |
| chunk_embeddings: Эмбеддинги чанков | |
| question_id_to_idx: Словарь соответствия id вопроса и его индекса в массиве эмбеддингов | |
| top_n: Количество чанков в топе для каждого вопроса | |
| similarity_threshold: Порог для нечеткого сравнения | |
| top_chunks_dir: Директория для сохранения топ-чанков (если None, то не сохраняем) | |
| Returns: | |
| Кортеж (словарь с усредненными метриками, DataFrame с метриками по отдельным вопросам) | |
| """ | |
| print(f"Оценка для top-{top_n}...") | |
| # Вычисляем косинусную близость между вопросами и чанками | |
| similarity_matrix = cosine_similarity(question_embeddings, chunk_embeddings) | |
| # Счетчики для метрик на основе текста | |
| total_puncts = 0 | |
| found_puncts = 0 | |
| total_chunks = 0 | |
| relevant_chunks = 0 | |
| # Счетчики для метрик на основе документов | |
| total_docs_required = 0 | |
| found_relevant_docs = 0 | |
| total_docs_found = 0 | |
| # Для сохранения метрик по отдельным вопросам | |
| question_metrics = [] | |
| # Выводим информацию о столбцах для отладки | |
| print(f"Столбцы в исходном датасете: {questions_df.columns.tolist()}") | |
| # Группируем вопросы по id (у нас 20 уникальных вопросов) | |
| for question_id in tqdm(questions_df['id'].unique(), desc=f"Оценка top-{top_n}"): | |
| # Получаем строки для текущего вопроса из исходного датасета | |
| question_rows = questions_df[questions_df['id'] == question_id] | |
| # Проверяем, есть ли вопрос с таким id в нашем маппинге | |
| if question_id not in question_id_to_idx: | |
| print(f"Предупреждение: вопрос с id {question_id} отсутствует в маппинге") | |
| continue | |
| # Если нет строк с таким id, пропускаем | |
| if len(question_rows) == 0: | |
| continue | |
| # Получаем индекс вопроса в массиве эмбеддингов | |
| question_idx = question_id_to_idx[question_id] | |
| # Получаем текст вопроса | |
| question_text = question_rows['question'].iloc[0] | |
| # Получаем все пункты для этого вопроса | |
| puncts = question_rows['text'].tolist() | |
| question_total_puncts = len(puncts) | |
| total_puncts += question_total_puncts | |
| # Получаем связанные документы | |
| relevant_docs = [] | |
| if 'filename' in question_rows.columns: | |
| relevant_docs = [f for f in question_rows['filename'].unique() if f and not pd.isna(f)] | |
| question_total_docs_required = len(relevant_docs) | |
| total_docs_required += question_total_docs_required | |
| print(f"Найдено {question_total_docs_required} документов для вопроса {question_id}") | |
| else: | |
| print(f"Столбец 'filename' отсутствует. Используем все документы.") | |
| relevant_docs = chunks_df['doc_name'].unique().tolist() | |
| question_total_docs_required = len(relevant_docs) | |
| total_docs_required += question_total_docs_required | |
| # Если для вопроса нет релевантных документов, пропускаем | |
| if not relevant_docs: | |
| print(f"Для вопроса {question_id} нет связанных документов") | |
| continue | |
| # Флаги для отслеживания найденных пунктов | |
| punct_found = [False] * question_total_puncts | |
| # Для отслеживания найденных документов | |
| docs_found_for_question = set() | |
| # Для хранения всех чанков вопроса для ограничения top_n | |
| all_question_chunks = [] | |
| all_question_similarities = [] | |
| # Собираем чанки для всех документов по этому вопросу | |
| for filename in relevant_docs: | |
| if not filename or pd.isna(filename): | |
| continue | |
| # Фильтруем чанки по имени файла | |
| doc_chunks = chunks_df[chunks_df['doc_name'] == filename] | |
| if doc_chunks.empty: | |
| print(f"Предупреждение: документ {filename} не содержит чанков") | |
| continue | |
| # Индексы чанков для текущего файла | |
| doc_chunk_indices = doc_chunks.index.tolist() | |
| # Получаем значения близости для чанков текущего файла | |
| doc_similarities = [ | |
| similarity_matrix[question_idx, chunks_df.index.get_loc(idx)] | |
| for idx in doc_chunk_indices | |
| ] | |
| # Добавляем чанки и их схожести к общему списку для вопроса | |
| for i, idx in enumerate(doc_chunk_indices): | |
| all_question_chunks.append((idx, doc_chunks.iloc[doc_chunks.index.get_indexer([idx])[0]])) | |
| all_question_similarities.append(doc_similarities[i]) | |
| # Сортируем все чанки по убыванию схожести и берем top_n | |
| sorted_indices = np.argsort(all_question_similarities)[-min(top_n, len(all_question_similarities)):][::-1] | |
| top_chunks_indices = [all_question_chunks[i][0] for i in sorted_indices] | |
| top_chunks = [all_question_chunks[i][1] for i in sorted_indices] | |
| # Увеличиваем счетчик общего числа чанков | |
| question_total_chunks = len(top_chunks) | |
| total_chunks += question_total_chunks | |
| # Для сохранения данных топ-чанков | |
| all_top_chunks = pd.DataFrame([chunk for chunk in top_chunks]) | |
| all_chunk_similarities = {idx: all_question_similarities[i] for i, idx in enumerate([all_question_chunks[j][0] for j in sorted_indices])} | |
| all_chunk_overlaps = [] | |
| # Для каждого чанка проверяем его релевантность к пунктам | |
| question_relevant_chunks = 0 | |
| for i, chunk in enumerate(top_chunks): | |
| is_relevant = False | |
| chunk_overlaps = [] | |
| # Добавляем документ в найденные | |
| docs_found_for_question.add(chunk['doc_name']) | |
| # Проверяем перекрытие с каждым пунктом | |
| for j, punct in enumerate(puncts): | |
| overlap = calculate_chunk_overlap(chunk['text'], punct) | |
| # Если нужно сохранить топ-чанки и top_n == 20 | |
| if top_chunks_dir and top_n == 20: | |
| chunk_overlaps.append({ | |
| 'punct_index': j, | |
| 'punct_text': punct[:100] + '...' if len(punct) > 100 else punct, | |
| 'overlap': overlap | |
| }) | |
| # Если перекрытие больше порога, чанк релевантен | |
| if overlap >= similarity_threshold: | |
| is_relevant = True | |
| punct_found[j] = True | |
| if is_relevant: | |
| question_relevant_chunks += 1 | |
| # Если нужно сохранить топ-чанки и top_n == 20 | |
| if top_chunks_dir and top_n == 20: | |
| all_chunk_overlaps.append(chunk_overlaps) | |
| # Если нужно сохранить топ-чанки и top_n == 20 | |
| if top_chunks_dir and top_n == 20 and not all_top_chunks.empty: | |
| save_top_chunks_for_question( | |
| question_id, | |
| question_text, | |
| puncts, | |
| all_top_chunks, | |
| all_chunk_similarities, | |
| all_chunk_overlaps, | |
| top_chunks_dir | |
| ) | |
| # Подсчитываем метрики для текущего вопроса | |
| question_found_puncts = sum(punct_found) | |
| found_puncts += question_found_puncts | |
| relevant_chunks += question_relevant_chunks | |
| # Обновляем метрики для документов | |
| question_found_relevant_docs = sum(1 for doc in docs_found_for_question if doc in relevant_docs) | |
| found_relevant_docs += question_found_relevant_docs | |
| question_total_docs_found = len(docs_found_for_question) | |
| total_docs_found += question_total_docs_found | |
| # Вычисляем метрики для текущего вопроса | |
| question_text_precision = question_relevant_chunks / question_total_chunks if question_total_chunks > 0 else 0 | |
| question_text_recall = question_found_puncts / question_total_puncts if question_total_puncts > 0 else 0 | |
| question_text_f1 = 2 * question_text_precision * question_text_recall / (question_text_precision + question_text_recall) if question_text_precision + question_text_recall > 0 else 0 | |
| question_doc_precision = question_found_relevant_docs / question_total_docs_found if question_total_docs_found > 0 else 0 | |
| question_doc_recall = question_found_relevant_docs / question_total_docs_required if question_total_docs_required > 0 else 0 | |
| question_doc_f1 = 2 * question_doc_precision * question_doc_recall / (question_doc_precision + question_doc_recall) if question_doc_precision + question_doc_recall > 0 else 0 | |
| # Сохраняем метрики вопроса | |
| question_metrics.append({ | |
| 'question_id': question_id, | |
| 'question_text': question_text, | |
| 'top_n': top_n, | |
| 'text_precision': question_text_precision, | |
| 'text_recall': question_text_recall, | |
| 'text_f1': question_text_f1, | |
| 'doc_precision': question_doc_precision, | |
| 'doc_recall': question_doc_recall, | |
| 'doc_f1': question_doc_f1, | |
| 'found_puncts': question_found_puncts, | |
| 'total_puncts': question_total_puncts, | |
| 'relevant_chunks': question_relevant_chunks, | |
| 'total_chunks': question_total_chunks, | |
| 'found_relevant_docs': question_found_relevant_docs, | |
| 'total_docs_required': question_total_docs_required, | |
| 'total_docs_found': question_total_docs_found | |
| }) | |
| # Вычисляем метрики для текста | |
| text_precision = relevant_chunks / total_chunks if total_chunks > 0 else 0 | |
| text_recall = found_puncts / total_puncts if total_puncts > 0 else 0 | |
| text_f1 = 2 * text_precision * text_recall / (text_precision + text_recall) if text_precision + text_recall > 0 else 0 | |
| # Вычисляем метрики для документов | |
| doc_precision = found_relevant_docs / total_docs_found if total_docs_found > 0 else 0 | |
| doc_recall = found_relevant_docs / total_docs_required if total_docs_required > 0 else 0 | |
| doc_f1 = 2 * doc_precision * doc_recall / (doc_precision + doc_recall) if doc_precision + doc_recall > 0 else 0 | |
| aggregated_metrics = { | |
| 'top_n': top_n, | |
| 'text_precision': text_precision, | |
| 'text_recall': text_recall, | |
| 'text_f1': text_f1, | |
| 'doc_precision': doc_precision, | |
| 'doc_recall': doc_recall, | |
| 'doc_f1': doc_f1, | |
| 'found_puncts': found_puncts, | |
| 'total_puncts': total_puncts, | |
| 'relevant_chunks': relevant_chunks, | |
| 'total_chunks': total_chunks, | |
| 'found_relevant_docs': found_relevant_docs, | |
| 'total_docs_required': total_docs_required, | |
| 'total_docs_found': total_docs_found | |
| } | |
| return aggregated_metrics, pd.DataFrame(question_metrics) | |
| def main(): | |
| """ | |
| Основная функция скрипта. | |
| """ | |
| args = parse_args() | |
| # Устанавливаем устройство из аргументов | |
| device = args.device | |
| # Создаем выходной каталог, если его нет | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Создаем директорию для топ-чанков | |
| top_chunks_dir = os.path.join(args.output_dir, "top_chunks") | |
| os.makedirs(top_chunks_dir, exist_ok=True) | |
| # Загружаем датасет с вопросами | |
| questions_df = load_questions_dataset(args.dataset_path) | |
| # Формируем уникальное имя для сохраняемых файлов на основе параметров стратегии и модели | |
| strategy_config_str = f"fixed_size_w{args.words_per_chunk}_o{args.overlap_words}" | |
| chunks_filename = f"chunks_{strategy_config_str}_{args.model_name.replace('/', '_')}" | |
| questions_filename = f"questions_{args.model_name.replace('/', '_')}" | |
| # Пытаемся загрузить сохраненные эмбеддинги и данные | |
| chunk_embeddings, chunks_df = None, None | |
| question_embeddings, questions_df_with_embeddings = None, None | |
| if not args.force_recompute: | |
| chunk_embeddings, chunks_df = load_embeddings_and_data(chunks_filename, args.output_dir) | |
| question_embeddings, questions_df_with_embeddings = load_embeddings_and_data(questions_filename, args.output_dir) | |
| # Если не удалось загрузить данные или включен режим принудительного пересчета | |
| if chunk_embeddings is None or chunks_df is None: | |
| # Читаем и обрабатываем документы | |
| documents = read_documents(args.data_folder) | |
| # Формируем конфигурацию для стратегии fixed_size | |
| fixed_size_config = { | |
| "words_per_chunk": args.words_per_chunk, | |
| "overlap_words": args.overlap_words | |
| } | |
| # Получаем DataFrame с чанками | |
| chunks_df = process_documents(documents, fixed_size_config) | |
| # Настраиваем модель и токенизатор | |
| model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device) | |
| # Получаем эмбеддинги для чанков | |
| chunk_embeddings = get_embeddings(chunks_df['text'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device) | |
| # Сохраняем эмбеддинги и данные | |
| save_embeddings_and_data(chunk_embeddings, chunks_df, chunks_filename, args.output_dir) | |
| # Если не удалось загрузить эмбеддинги вопросов или включен режим принудительного пересчета | |
| if question_embeddings is None or questions_df_with_embeddings is None: | |
| # Получаем уникальные вопросы (по id) | |
| unique_questions = questions_df.drop_duplicates(subset=['id'])[['id', 'question']] | |
| # Настраиваем модель и токенизатор (если еще не настроены) | |
| if 'model' not in locals() or 'tokenizer' not in locals(): | |
| model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device) | |
| # Получаем эмбеддинги для вопросов | |
| question_embeddings = get_embeddings(unique_questions['question'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device) | |
| # Сохраняем эмбеддинги и данные | |
| save_embeddings_and_data(question_embeddings, unique_questions, questions_filename, args.output_dir) | |
| # Устанавливаем questions_df_with_embeddings для дальнейшего использования | |
| questions_df_with_embeddings = unique_questions | |
| # Создаем словарь соответствия id вопроса и его индекса в эмбеддингах | |
| question_id_to_idx = { | |
| row['id']: i | |
| for i, (_, row) in enumerate(questions_df_with_embeddings.iterrows()) | |
| } | |
| # Оцениваем стратегию чанкинга для разных значений top_n | |
| aggregated_results = [] | |
| all_question_metrics = [] | |
| for top_n in TOP_N_VALUES: | |
| metrics, question_metrics = evaluate_for_top_n_with_mapping( | |
| questions_df, # Исходный датасет с связью между вопросами и документами | |
| chunks_df, # Датасет с чанками | |
| question_embeddings, # Эмбеддинги вопросов | |
| chunk_embeddings, # Эмбеддинги чанков | |
| question_id_to_idx, # Маппинг id вопроса к индексу в эмбеддингах | |
| top_n, # Количество чанков в топе | |
| args.similarity_threshold, # Порог для определения перекрытия | |
| top_chunks_dir if top_n == 20 else None # Сохраняем топ-чанки только для top_n=20 | |
| ) | |
| aggregated_results.append(metrics) | |
| all_question_metrics.append(question_metrics) | |
| # Объединяем все метрики по вопросам | |
| all_question_metrics_df = pd.concat(all_question_metrics) | |
| # Создаем DataFrame с агрегированными результатами | |
| aggregated_results_df = pd.DataFrame(aggregated_results) | |
| # Сохраняем результаты | |
| results_filename = f"results_{strategy_config_str}_{args.model_name.replace('/', '_')}.csv" | |
| results_path = os.path.join(args.output_dir, results_filename) | |
| aggregated_results_df.to_csv(results_path, index=False) | |
| # Сохраняем метрики по вопросам | |
| question_metrics_filename = f"question_metrics_{strategy_config_str}_{args.model_name.replace('/', '_')}.xlsx" | |
| question_metrics_path = os.path.join(args.output_dir, question_metrics_filename) | |
| all_question_metrics_df.to_excel(question_metrics_path, index=False) | |
| print(f"\nРезультаты сохранены в {results_path}") | |
| print(f"Метрики по вопросам сохранены в {question_metrics_path}") | |
| print(f"Топ-20 чанков для каждого вопроса сохранены в {top_chunks_dir}") | |
| print("\nМетрики для различных значений top_n:") | |
| print(aggregated_results_df[['top_n', 'text_precision', 'text_recall', 'text_f1', 'doc_precision', 'doc_recall', 'doc_f1']]) | |
| if __name__ == "__main__": | |
| main() |