Spaces:
Runtime error
Runtime error
| import asyncio | |
| import logging | |
| from uuid import UUID | |
| import numpy as np | |
| from ntr_text_fragmentation import LinkerEntity | |
| from ntr_text_fragmentation.integrations.sqlalchemy import \ | |
| SQLAlchemyEntityRepository | |
| from sqlalchemy import func, select | |
| from sqlalchemy.orm import Session, sessionmaker | |
| from components.dbo.models.entity import EntityModel | |
| logger = logging.getLogger(__name__) | |
| class ChunkRepository(SQLAlchemyEntityRepository): | |
| """ | |
| Репозиторий для работы с сущностями (чанками, документами, связями), | |
| хранящимися в базе данных с использованием SQL Alchemy. | |
| Наследуется от SQLAlchemyEntityRepository, предоставляя конкретную реализацию | |
| для модели EntityModel. | |
| """ | |
| def __init__(self, db_session_factory: sessionmaker[Session]): | |
| """ | |
| Инициализация репозитория. | |
| Args: | |
| db_session_factory: Фабрика сессий SQLAlchemy. | |
| """ | |
| super().__init__(db_session_factory) | |
| def _entity_model_class(self): | |
| """Возвращает класс модели SQLAlchemy.""" | |
| return EntityModel | |
| def _map_db_entity_to_linker_entity(self, db_entity: EntityModel) -> LinkerEntity: | |
| """ | |
| Преобразует объект EntityModel из базы данных в объект LinkerEntity | |
| или его соответствующий подкласс. | |
| Args: | |
| db_entity: Сущность EntityModel из базы данных. | |
| Returns: | |
| Объект LinkerEntity или его подкласс. | |
| """ | |
| # Создаем базовый LinkerEntity со всеми данными из БД | |
| # Преобразуем строковые UUID обратно в объекты UUID | |
| base_data = LinkerEntity( | |
| id=UUID(db_entity.uuid), | |
| name=db_entity.name, | |
| text=db_entity.text, | |
| in_search_text=db_entity.in_search_text, | |
| metadata=db_entity.metadata_json or {}, | |
| source_id=UUID(db_entity.source_id) if db_entity.source_id else None, | |
| target_id=UUID(db_entity.target_id) if db_entity.target_id else None, | |
| number_in_relation=db_entity.number_in_relation, | |
| type=db_entity.entity_type, | |
| groupper=db_entity.entity_type, | |
| ) | |
| # Используем LinkerEntity._deserialize для получения объекта нужного типа | |
| # на основе поля 'type', взятого из db_entity.entity_type | |
| try: | |
| deserialized_entity = base_data.deserialize() | |
| return deserialized_entity | |
| except Exception as e: | |
| logger.error( | |
| f"Error deserializing entity {base_data.id} of type {base_data.type}: {e}" | |
| ) | |
| return base_data | |
| def add_entities( | |
| self, | |
| entities: list[LinkerEntity], | |
| dataset_id: int, | |
| embeddings: dict[str, np.ndarray] | None = None, | |
| ): | |
| """ | |
| Добавляет список сущностей LinkerEntity в базу данных. | |
| Args: | |
| entities: Список сущностей LinkerEntity для добавления. | |
| dataset_id: ID датасета, к которому принадлежат сущности. | |
| embeddings: Словарь эмбеддингов {entity_id_str: embedding}, где entity_id_str - строка UUID. | |
| """ | |
| embeddings = embeddings or {} | |
| with self.db() as session: | |
| db_entities_to_add = [] | |
| for entity in entities: | |
| # Преобразуем UUID в строку для хранения в базе | |
| entity_id_str = str(entity.id) | |
| embedding = embeddings.get(entity_id_str) | |
| db_entity = EntityModel( | |
| uuid=entity_id_str, | |
| name=entity.name, | |
| text=entity.text, | |
| entity_type=entity.type, | |
| in_search_text=entity.in_search_text, | |
| metadata_json=( | |
| entity.metadata if isinstance(entity.metadata, dict) else {} | |
| ), | |
| source_id=str(entity.source_id) if entity.source_id else None, | |
| target_id=str(entity.target_id) if entity.target_id else None, | |
| number_in_relation=entity.number_in_relation, | |
| dataset_id=dataset_id, | |
| embedding=embedding, | |
| ) | |
| db_entities_to_add.append(db_entity) | |
| session.add_all(db_entities_to_add) | |
| session.commit() | |
| async def add_entities_async( | |
| self, | |
| entities: list[LinkerEntity], | |
| dataset_id: int, | |
| embeddings: dict[str, np.ndarray] | None = None, | |
| ): | |
| """Асинхронно добавляет список сущностей LinkerEntity в базу данных.""" | |
| # TODO: Реализовать с использованием async-сессии | |
| await asyncio.to_thread(self.add_entities, entities, dataset_id, embeddings) | |
| def get_searching_entities( | |
| self, | |
| dataset_id: int, | |
| ) -> tuple[list[LinkerEntity], list[np.ndarray]]: | |
| """ | |
| Получает сущности из указанного датасета, которые имеют текст для поиска | |
| (in_search_text не None), вместе с их эмбеддингами. | |
| Args: | |
| dataset_id: ID датасета. | |
| Returns: | |
| Кортеж из двух списков: список LinkerEntity и список их эмбеддингов (numpy array). | |
| Порядок эмбеддингов соответствует порядку сущностей. | |
| """ | |
| entity_model = self._entity_model_class | |
| linker_entities = [] | |
| embeddings_list = [] | |
| with self.db() as session: | |
| stmt = select(entity_model).where( | |
| entity_model.in_search_text.isnot(None), | |
| entity_model.dataset_id == dataset_id, | |
| entity_model.embedding.isnot(None) | |
| ) | |
| db_models = session.execute(stmt).scalars().all() | |
| # Переносим цикл внутрь сессии | |
| for model in db_models: | |
| # Теперь маппинг происходит при активной сессии | |
| linker_entity = self._map_db_entity_to_linker_entity(model) | |
| linker_entities.append(linker_entity) | |
| # Извлекаем эмбеддинг. | |
| # _map_db_entity_to_linker_entity может поместить его в метаданные. | |
| embedding = linker_entity.metadata.get('_embedding') | |
| if embedding is None and hasattr(model, 'embedding'): # Fallback | |
| embedding = model.embedding # Доступ к model.embedding тоже должен быть внутри сессии | |
| if embedding is not None: | |
| embeddings_list.append(embedding) | |
| else: | |
| # Обработка случая отсутствия эмбеддинга | |
| print(f"Warning: Entity {model.uuid} has in_search_text but no embedding.") | |
| linker_entities.pop() | |
| # Возвращаем результаты после закрытия сессии | |
| return linker_entities, embeddings_list | |
| async def get_searching_entities_async( | |
| self, | |
| dataset_id: int, | |
| ) -> tuple[list[LinkerEntity], list[np.ndarray]]: | |
| """Асинхронно получает сущности для поиска вместе с эмбеддингами.""" | |
| # TODO: Реализовать с использованием async-сессии | |
| return await asyncio.to_thread(self.get_searching_entities, dataset_id) | |
| def get_all_entities_for_dataset(self, dataset_id: int) -> list[LinkerEntity]: | |
| """ | |
| Получает все сущности для указанного датасета. | |
| Args: | |
| dataset_id: ID датасета. | |
| Returns: | |
| Список всех LinkerEntity для данного датасета. | |
| """ | |
| entity_model = self._entity_model_class | |
| linker_entities = [] | |
| with self.db() as session: | |
| stmt = select(entity_model).where( | |
| entity_model.dataset_id == dataset_id | |
| ) | |
| db_models = session.execute(stmt).scalars().all() | |
| # Переносим цикл внутрь сессии для маппинга | |
| for model in db_models: | |
| try: | |
| linker_entity = self._map_db_entity_to_linker_entity(model) | |
| linker_entities.append(linker_entity) | |
| except Exception as e: | |
| logger.error(f"Error mapping entity {getattr(model, 'uuid', 'N/A')} in dataset {dataset_id}: {e}") | |
| logger.info(f"Loaded {len(linker_entities)} entities for dataset {dataset_id}") | |
| return linker_entities | |
| async def get_all_entities_for_dataset_async(self, dataset_id: int) -> list[LinkerEntity]: | |
| """Асинхронно получает все сущности для указанного датасета.""" | |
| # TODO: Реализовать с использованием async-сессии | |
| return await asyncio.to_thread(self.get_all_entities_for_dataset, dataset_id) | |
| def count_entities_by_dataset_id(self, dataset_id: int) -> int: | |
| """ | |
| Подсчитывает общее количество сущностей для указанного датасета. | |
| Args: | |
| dataset_id: ID датасета. | |
| Returns: | |
| Общее количество сущностей в датасете. | |
| """ | |
| entity_model = self._entity_model_class | |
| id_column = self._get_id_column() # Получаем колонку ID (uuid или id) | |
| with self.db() as session: | |
| stmt = select(func.count(id_column)).where( | |
| entity_model.dataset_id == dataset_id | |
| ) | |
| count = session.execute(stmt).scalar_one() | |
| return count | |
| async def count_entities_by_dataset_id_async(self, dataset_id: int) -> int: | |
| """Асинхронно подсчитывает общее количество сущностей для датасета.""" | |
| # TODO: Реализовать с использованием async-сессии | |
| return await asyncio.to_thread(self.count_entities_by_dataset_id, dataset_id) | |
| async def get_entities_by_ids_async(self, entity_ids: list[UUID]) -> list[LinkerEntity]: | |
| """Асинхронно получить сущности по списку ID.""" | |
| # TODO: Реализовать с использованием async-сессии | |
| return await asyncio.to_thread(self.get_entities_by_ids, entity_ids) | |