Spaces:
Sleeping
Sleeping
| # app/services/parts_combination_service.py | |
| import logging | |
| import re | |
| from typing import List, Dict, Tuple | |
| from collections import defaultdict | |
| # --- Core App Imports --- | |
| from app.core import state | |
| from app.utils import neo4j_utils | |
| from app.core.config import settings | |
| from neo4j.exceptions import Neo4jError | |
| logger = logging.getLogger(__name__) | |
| async def load_chunk_type_map(): | |
| """ | |
| Connects to Neo4j and builds maps for chunk types and sequences. | |
| """ | |
| if state.sequence_map_loaded and state.chunk_type_map_loaded: | |
| return True | |
| logger.info("Loading all chunk types and building sequence map from Neo4j...") | |
| cypher_query = """ | |
| MATCH (c:Chunk) | |
| WHERE c.chunk_type IS NOT NULL | |
| RETURN c.id AS chunk_id, c.chunk_type AS chunk_type, c.content AS content | |
| """ | |
| chunk_type_map = {} | |
| temp_groups = defaultdict(list) | |
| sequence_pattern = re.compile(r"^(.*)\s+Part\s+(\d+)$") | |
| try: | |
| driver = await neo4j_utils.get_driver() | |
| if not driver: | |
| logger.error("Cannot load chunk type data: Neo4j driver not available.") | |
| return False | |
| records, _, _ = await driver.execute_query( | |
| cypher_query, database_=settings.NEO4J_DATABASE or "neo4j" | |
| ) | |
| for record in records: | |
| data = record.data() | |
| chunk_id, chunk_type, content = data.get("chunk_id"), data.get("chunk_type"), data.get("content") | |
| if chunk_id and chunk_type: | |
| chunk_type_map[chunk_id] = chunk_type | |
| if " Part " in chunk_type: | |
| match = sequence_pattern.match(chunk_type) | |
| if match: | |
| base_name, part_number = match.group(1).strip(), int(match.group(2)) | |
| temp_groups[base_name].append({"id": chunk_id, "part": part_number, "type": chunk_type, "text": content}) | |
| final_sequence_map = {base: sorted(parts, key=lambda x: x['part']) for base, parts in temp_groups.items()} | |
| state.chunk_type_map = chunk_type_map | |
| state.sequence_base_to_parts_map = final_sequence_map | |
| state.chunk_type_map_loaded = True | |
| state.sequence_map_loaded = True | |
| logger.info(f"Successfully built chunk type map for {len(chunk_type_map)} chunks and {len(final_sequence_map)} sequence groups.") | |
| return True | |
| except Exception as e: | |
| logger.exception(f"An unexpected error occurred building chunk maps from Neo4j: {e}") | |
| state.chunk_type_map, state.sequence_base_to_parts_map = {}, {} | |
| state.chunk_type_map_loaded, state.sequence_map_loaded = False, False | |
| return False | |
| def _get_content_key(chunk_id: str, chunk_text: str) -> Tuple: | |
| """ Creates a robust, unique key for a chunk based on its type and content. """ | |
| chunk_type = state.chunk_type_map.get(chunk_id) | |
| normalized_type = chunk_type.lower() if chunk_type else None | |
| normalized_text = chunk_text.strip() if chunk_text else None | |
| return (normalized_type, normalized_text) | |
| def organize_chunks_by_sequence(chunks: List[Dict]) -> List[Dict]: | |
| """ | |
| Ensures re-ranker chunks are prioritized, then expands context by filling | |
| in missing sequential parts. | |
| """ | |
| if not state.sequence_map_loaded or not chunks: | |
| return chunks | |
| # --- Step 1: Deduplicate re-ranker output to get the highest-scoring unique chunks --- | |
| deduplicated_chunks = [] | |
| processed_content = set() | |
| for chunk in chunks: | |
| content_key = _get_content_key(chunk['id'], chunk.get('text')) | |
| if content_key not in processed_content: | |
| deduplicated_chunks.append(chunk) | |
| processed_content.add(content_key) | |
| # This map is now the "golden source" of high-priority chunks and their scores. | |
| final_chunks_map = {chunk['id']: chunk for chunk in deduplicated_chunks} | |
| # --- Step 2: Identify which sequences need expansion based on the golden source --- | |
| sequences_to_expand = {} | |
| for chunk in deduplicated_chunks: | |
| chunk_type = state.chunk_type_map.get(chunk.get("id")) | |
| if chunk_type and " Part " in chunk_type: | |
| for base_name, parts in state.sequence_base_to_parts_map.items(): | |
| if any(part['id'] == chunk["id"] for part in parts): | |
| # Store the highest score found for any part of this sequence | |
| current_max_score = sequences_to_expand.get(base_name, -1.0) | |
| sequences_to_expand[base_name] = max(current_max_score, chunk.get('rerank_score', 0.0)) | |
| break | |
| # --- Step 3: Expand sequences by filling in missing parts --- | |
| for base_name, trigger_score in sequences_to_expand.items(): | |
| if trigger_score > settings.SEQUENCE_EXPANSION_THRESHOLD: | |
| logger.info(f"Expanding sequence '{base_name}' triggered by a chunk with score {trigger_score:.4f}") | |
| full_sequence_parts = state.sequence_base_to_parts_map.get(base_name, []) | |
| for part_info in full_sequence_parts: | |
| part_id = part_info['id'] | |
| # Add the part ONLY if it's not already in our golden source map | |
| if part_id not in final_chunks_map: | |
| final_chunks_map[part_id] = { | |
| "id": part_id, | |
| "text": part_info['text'], | |
| "rerank_score": -1.0, # Mark as contextually added | |
| } | |
| # --- Step 4: Convert map to list and perform the final sort --- | |
| final_chunks_list = list(final_chunks_map.values()) | |
| # Create helper maps for an efficient final sort | |
| chunk_to_sequence_info = {} | |
| for base_name, parts in state.sequence_base_to_parts_map.items(): | |
| for part_info in parts: | |
| chunk_to_sequence_info[part_info['id']] = {"base": base_name, "part": part_info['part']} | |
| # Get the max score for each sequence group | |
| sequence_group_scores = {} | |
| for chunk in final_chunks_list: | |
| seq_info = chunk_to_sequence_info.get(chunk['id']) | |
| if seq_info: | |
| base_name = seq_info['base'] | |
| current_max = sequence_group_scores.get(base_name, -1.0) | |
| sequence_group_scores[base_name] = max(current_max, chunk.get('rerank_score', 0.0)) | |
| def sort_key(chunk): | |
| seq_info = chunk_to_sequence_info.get(chunk['id']) | |
| if seq_info: | |
| # If sequential, sort by the group's max score, then by part number | |
| return (-sequence_group_scores.get(seq_info['base'], -1.0), seq_info['part']) | |
| else: | |
| # If not sequential, sort by its own score, with a tie-breaker | |
| return (-chunk.get('rerank_score', -1.0), 0) | |
| final_chunks_list.sort(key=sort_key) | |
| logger.info(f"Context expansion and deduplication complete. Final chunk count: {len(final_chunks_list)}.") | |
| return final_chunks_list | |