Spaces:
Sleeping
Sleeping
File size: 6,905 Bytes
a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 a9465d3 4502223 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# app/services/parts_combination_service.py
import logging
import re
from typing import List, Dict, Tuple
from collections import defaultdict
# --- Core App Imports ---
from app.core import state
from app.utils import neo4j_utils
from app.core.config import settings
from neo4j.exceptions import Neo4jError
logger = logging.getLogger(__name__)
async def load_chunk_type_map():
"""
Connects to Neo4j and builds maps for chunk types and sequences.
"""
if state.sequence_map_loaded and state.chunk_type_map_loaded:
return True
logger.info("Loading all chunk types and building sequence map from Neo4j...")
cypher_query = """
MATCH (c:Chunk)
WHERE c.chunk_type IS NOT NULL
RETURN c.id AS chunk_id, c.chunk_type AS chunk_type, c.content AS content
"""
chunk_type_map = {}
temp_groups = defaultdict(list)
sequence_pattern = re.compile(r"^(.*)\s+Part\s+(\d+)$")
try:
driver = await neo4j_utils.get_driver()
if not driver:
logger.error("Cannot load chunk type data: Neo4j driver not available.")
return False
records, _, _ = await driver.execute_query(
cypher_query, database_=settings.NEO4J_DATABASE or "neo4j"
)
for record in records:
data = record.data()
chunk_id, chunk_type, content = data.get("chunk_id"), data.get("chunk_type"), data.get("content")
if chunk_id and chunk_type:
chunk_type_map[chunk_id] = chunk_type
if " Part " in chunk_type:
match = sequence_pattern.match(chunk_type)
if match:
base_name, part_number = match.group(1).strip(), int(match.group(2))
temp_groups[base_name].append({"id": chunk_id, "part": part_number, "type": chunk_type, "text": content})
final_sequence_map = {base: sorted(parts, key=lambda x: x['part']) for base, parts in temp_groups.items()}
state.chunk_type_map = chunk_type_map
state.sequence_base_to_parts_map = final_sequence_map
state.chunk_type_map_loaded = True
state.sequence_map_loaded = True
logger.info(f"Successfully built chunk type map for {len(chunk_type_map)} chunks and {len(final_sequence_map)} sequence groups.")
return True
except Exception as e:
logger.exception(f"An unexpected error occurred building chunk maps from Neo4j: {e}")
state.chunk_type_map, state.sequence_base_to_parts_map = {}, {}
state.chunk_type_map_loaded, state.sequence_map_loaded = False, False
return False
def _get_content_key(chunk_id: str, chunk_text: str) -> Tuple:
""" Creates a robust, unique key for a chunk based on its type and content. """
chunk_type = state.chunk_type_map.get(chunk_id)
normalized_type = chunk_type.lower() if chunk_type else None
normalized_text = chunk_text.strip() if chunk_text else None
return (normalized_type, normalized_text)
def organize_chunks_by_sequence(chunks: List[Dict]) -> List[Dict]:
"""
Ensures re-ranker chunks are prioritized, then expands context by filling
in missing sequential parts.
"""
if not state.sequence_map_loaded or not chunks:
return chunks
# --- Step 1: Deduplicate re-ranker output to get the highest-scoring unique chunks ---
deduplicated_chunks = []
processed_content = set()
for chunk in chunks:
content_key = _get_content_key(chunk['id'], chunk.get('text'))
if content_key not in processed_content:
deduplicated_chunks.append(chunk)
processed_content.add(content_key)
# This map is now the "golden source" of high-priority chunks and their scores.
final_chunks_map = {chunk['id']: chunk for chunk in deduplicated_chunks}
# --- Step 2: Identify which sequences need expansion based on the golden source ---
sequences_to_expand = {}
for chunk in deduplicated_chunks:
chunk_type = state.chunk_type_map.get(chunk.get("id"))
if chunk_type and " Part " in chunk_type:
for base_name, parts in state.sequence_base_to_parts_map.items():
if any(part['id'] == chunk["id"] for part in parts):
# Store the highest score found for any part of this sequence
current_max_score = sequences_to_expand.get(base_name, -1.0)
sequences_to_expand[base_name] = max(current_max_score, chunk.get('rerank_score', 0.0))
break
# --- Step 3: Expand sequences by filling in missing parts ---
for base_name, trigger_score in sequences_to_expand.items():
if trigger_score > settings.SEQUENCE_EXPANSION_THRESHOLD:
logger.info(f"Expanding sequence '{base_name}' triggered by a chunk with score {trigger_score:.4f}")
full_sequence_parts = state.sequence_base_to_parts_map.get(base_name, [])
for part_info in full_sequence_parts:
part_id = part_info['id']
# Add the part ONLY if it's not already in our golden source map
if part_id not in final_chunks_map:
final_chunks_map[part_id] = {
"id": part_id,
"text": part_info['text'],
"rerank_score": -1.0, # Mark as contextually added
}
# --- Step 4: Convert map to list and perform the final sort ---
final_chunks_list = list(final_chunks_map.values())
# Create helper maps for an efficient final sort
chunk_to_sequence_info = {}
for base_name, parts in state.sequence_base_to_parts_map.items():
for part_info in parts:
chunk_to_sequence_info[part_info['id']] = {"base": base_name, "part": part_info['part']}
# Get the max score for each sequence group
sequence_group_scores = {}
for chunk in final_chunks_list:
seq_info = chunk_to_sequence_info.get(chunk['id'])
if seq_info:
base_name = seq_info['base']
current_max = sequence_group_scores.get(base_name, -1.0)
sequence_group_scores[base_name] = max(current_max, chunk.get('rerank_score', 0.0))
def sort_key(chunk):
seq_info = chunk_to_sequence_info.get(chunk['id'])
if seq_info:
# If sequential, sort by the group's max score, then by part number
return (-sequence_group_scores.get(seq_info['base'], -1.0), seq_info['part'])
else:
# If not sequential, sort by its own score, with a tie-breaker
return (-chunk.get('rerank_score', -1.0), 0)
final_chunks_list.sort(key=sort_key)
logger.info(f"Context expansion and deduplication complete. Final chunk count: {len(final_chunks_list)}.")
return final_chunks_list
|