File size: 6,905 Bytes
a9465d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4502223
a9465d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4502223
a9465d3
 
 
 
 
4502223
 
 
 
 
a9465d3
 
 
 
 
4502223
a9465d3
 
 
 
4502223
 
a9465d3
 
 
 
4502223
 
 
 
 
a9465d3
 
 
 
4502223
 
a9465d3
 
 
 
4502223
a9465d3
 
 
 
 
 
 
 
4502223
 
 
 
 
a9465d3
 
 
 
4502223
 
 
 
a9465d3
4502223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9465d3
4502223
 
 
 
a9465d3
4502223
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# app/services/parts_combination_service.py

import logging
import re
from typing import List, Dict, Tuple
from collections import defaultdict

# --- Core App Imports ---
from app.core import state
from app.utils import neo4j_utils
from app.core.config import settings
from neo4j.exceptions import Neo4jError

logger = logging.getLogger(__name__)


async def load_chunk_type_map():
    """
    Connects to Neo4j and builds maps for chunk types and sequences.
    """
    if state.sequence_map_loaded and state.chunk_type_map_loaded:
        return True

    logger.info("Loading all chunk types and building sequence map from Neo4j...")
    
    cypher_query = """
    MATCH (c:Chunk)
    WHERE c.chunk_type IS NOT NULL
    RETURN c.id AS chunk_id, c.chunk_type AS chunk_type, c.content AS content
    """
    
    chunk_type_map = {}
    temp_groups = defaultdict(list)
    sequence_pattern = re.compile(r"^(.*)\s+Part\s+(\d+)$")
    
    try:
        driver = await neo4j_utils.get_driver()
        if not driver:
            logger.error("Cannot load chunk type data: Neo4j driver not available.")
            return False

        records, _, _ = await driver.execute_query(
            cypher_query, database_=settings.NEO4J_DATABASE or "neo4j"
        )
        
        for record in records:
            data = record.data()
            chunk_id, chunk_type, content = data.get("chunk_id"), data.get("chunk_type"), data.get("content")
            if chunk_id and chunk_type:
                chunk_type_map[chunk_id] = chunk_type
                if " Part " in chunk_type:
                    match = sequence_pattern.match(chunk_type)
                    if match:
                        base_name, part_number = match.group(1).strip(), int(match.group(2))
                        temp_groups[base_name].append({"id": chunk_id, "part": part_number, "type": chunk_type, "text": content})

        final_sequence_map = {base: sorted(parts, key=lambda x: x['part']) for base, parts in temp_groups.items()}

        state.chunk_type_map = chunk_type_map
        state.sequence_base_to_parts_map = final_sequence_map
        state.chunk_type_map_loaded = True
        state.sequence_map_loaded = True
        
        logger.info(f"Successfully built chunk type map for {len(chunk_type_map)} chunks and {len(final_sequence_map)} sequence groups.")
        return True
        
    except Exception as e:
        logger.exception(f"An unexpected error occurred building chunk maps from Neo4j: {e}")
        state.chunk_type_map, state.sequence_base_to_parts_map = {}, {}
        state.chunk_type_map_loaded, state.sequence_map_loaded = False, False
        return False


def _get_content_key(chunk_id: str, chunk_text: str) -> Tuple:
    """ Creates a robust, unique key for a chunk based on its type and content. """
    chunk_type = state.chunk_type_map.get(chunk_id)
    normalized_type = chunk_type.lower() if chunk_type else None
    normalized_text = chunk_text.strip() if chunk_text else None
    return (normalized_type, normalized_text)


def organize_chunks_by_sequence(chunks: List[Dict]) -> List[Dict]:
    """
    Ensures re-ranker chunks are prioritized, then expands context by filling
    in missing sequential parts.
    """
    if not state.sequence_map_loaded or not chunks:
        return chunks

    # --- Step 1: Deduplicate re-ranker output to get the highest-scoring unique chunks ---
    deduplicated_chunks = []
    processed_content = set()
    for chunk in chunks:
        content_key = _get_content_key(chunk['id'], chunk.get('text'))
        if content_key not in processed_content:
            deduplicated_chunks.append(chunk)
            processed_content.add(content_key)
    
    # This map is now the "golden source" of high-priority chunks and their scores.
    final_chunks_map = {chunk['id']: chunk for chunk in deduplicated_chunks}

    # --- Step 2: Identify which sequences need expansion based on the golden source ---
    sequences_to_expand = {}
    for chunk in deduplicated_chunks:
        chunk_type = state.chunk_type_map.get(chunk.get("id"))
        if chunk_type and " Part " in chunk_type:
            for base_name, parts in state.sequence_base_to_parts_map.items():
                if any(part['id'] == chunk["id"] for part in parts):
                    # Store the highest score found for any part of this sequence
                    current_max_score = sequences_to_expand.get(base_name, -1.0)
                    sequences_to_expand[base_name] = max(current_max_score, chunk.get('rerank_score', 0.0))
                    break
    
    # --- Step 3: Expand sequences by filling in missing parts ---
    for base_name, trigger_score in sequences_to_expand.items():
        if trigger_score > settings.SEQUENCE_EXPANSION_THRESHOLD:
            logger.info(f"Expanding sequence '{base_name}' triggered by a chunk with score {trigger_score:.4f}")
            full_sequence_parts = state.sequence_base_to_parts_map.get(base_name, [])
            
            for part_info in full_sequence_parts:
                part_id = part_info['id']
                # Add the part ONLY if it's not already in our golden source map
                if part_id not in final_chunks_map:
                    final_chunks_map[part_id] = {
                        "id": part_id,
                        "text": part_info['text'],
                        "rerank_score": -1.0, # Mark as contextually added
                    }

    # --- Step 4: Convert map to list and perform the final sort ---
    final_chunks_list = list(final_chunks_map.values())
    
    # Create helper maps for an efficient final sort
    chunk_to_sequence_info = {}
    for base_name, parts in state.sequence_base_to_parts_map.items():
        for part_info in parts:
            chunk_to_sequence_info[part_info['id']] = {"base": base_name, "part": part_info['part']}

    # Get the max score for each sequence group
    sequence_group_scores = {}
    for chunk in final_chunks_list:
        seq_info = chunk_to_sequence_info.get(chunk['id'])
        if seq_info:
            base_name = seq_info['base']
            current_max = sequence_group_scores.get(base_name, -1.0)
            sequence_group_scores[base_name] = max(current_max, chunk.get('rerank_score', 0.0))

    def sort_key(chunk):
        seq_info = chunk_to_sequence_info.get(chunk['id'])
        if seq_info:
            # If sequential, sort by the group's max score, then by part number
            return (-sequence_group_scores.get(seq_info['base'], -1.0), seq_info['part'])
        else:
            # If not sequential, sort by its own score, with a tie-breaker
            return (-chunk.get('rerank_score', -1.0), 0)

    final_chunks_list.sort(key=sort_key)
            
    logger.info(f"Context expansion and deduplication complete. Final chunk count: {len(final_chunks_list)}.")
    return final_chunks_list