backend_chatbot / app /services /parts_combination_service.py
helal94hb1's picture
back end fix for parts combination script
4502223
# app/services/parts_combination_service.py
import logging
import re
from typing import List, Dict, Tuple
from collections import defaultdict
# --- Core App Imports ---
from app.core import state
from app.utils import neo4j_utils
from app.core.config import settings
from neo4j.exceptions import Neo4jError
logger = logging.getLogger(__name__)
async def load_chunk_type_map():
"""
Connects to Neo4j and builds maps for chunk types and sequences.
"""
if state.sequence_map_loaded and state.chunk_type_map_loaded:
return True
logger.info("Loading all chunk types and building sequence map from Neo4j...")
cypher_query = """
MATCH (c:Chunk)
WHERE c.chunk_type IS NOT NULL
RETURN c.id AS chunk_id, c.chunk_type AS chunk_type, c.content AS content
"""
chunk_type_map = {}
temp_groups = defaultdict(list)
sequence_pattern = re.compile(r"^(.*)\s+Part\s+(\d+)$")
try:
driver = await neo4j_utils.get_driver()
if not driver:
logger.error("Cannot load chunk type data: Neo4j driver not available.")
return False
records, _, _ = await driver.execute_query(
cypher_query, database_=settings.NEO4J_DATABASE or "neo4j"
)
for record in records:
data = record.data()
chunk_id, chunk_type, content = data.get("chunk_id"), data.get("chunk_type"), data.get("content")
if chunk_id and chunk_type:
chunk_type_map[chunk_id] = chunk_type
if " Part " in chunk_type:
match = sequence_pattern.match(chunk_type)
if match:
base_name, part_number = match.group(1).strip(), int(match.group(2))
temp_groups[base_name].append({"id": chunk_id, "part": part_number, "type": chunk_type, "text": content})
final_sequence_map = {base: sorted(parts, key=lambda x: x['part']) for base, parts in temp_groups.items()}
state.chunk_type_map = chunk_type_map
state.sequence_base_to_parts_map = final_sequence_map
state.chunk_type_map_loaded = True
state.sequence_map_loaded = True
logger.info(f"Successfully built chunk type map for {len(chunk_type_map)} chunks and {len(final_sequence_map)} sequence groups.")
return True
except Exception as e:
logger.exception(f"An unexpected error occurred building chunk maps from Neo4j: {e}")
state.chunk_type_map, state.sequence_base_to_parts_map = {}, {}
state.chunk_type_map_loaded, state.sequence_map_loaded = False, False
return False
def _get_content_key(chunk_id: str, chunk_text: str) -> Tuple:
""" Creates a robust, unique key for a chunk based on its type and content. """
chunk_type = state.chunk_type_map.get(chunk_id)
normalized_type = chunk_type.lower() if chunk_type else None
normalized_text = chunk_text.strip() if chunk_text else None
return (normalized_type, normalized_text)
def organize_chunks_by_sequence(chunks: List[Dict]) -> List[Dict]:
"""
Ensures re-ranker chunks are prioritized, then expands context by filling
in missing sequential parts.
"""
if not state.sequence_map_loaded or not chunks:
return chunks
# --- Step 1: Deduplicate re-ranker output to get the highest-scoring unique chunks ---
deduplicated_chunks = []
processed_content = set()
for chunk in chunks:
content_key = _get_content_key(chunk['id'], chunk.get('text'))
if content_key not in processed_content:
deduplicated_chunks.append(chunk)
processed_content.add(content_key)
# This map is now the "golden source" of high-priority chunks and their scores.
final_chunks_map = {chunk['id']: chunk for chunk in deduplicated_chunks}
# --- Step 2: Identify which sequences need expansion based on the golden source ---
sequences_to_expand = {}
for chunk in deduplicated_chunks:
chunk_type = state.chunk_type_map.get(chunk.get("id"))
if chunk_type and " Part " in chunk_type:
for base_name, parts in state.sequence_base_to_parts_map.items():
if any(part['id'] == chunk["id"] for part in parts):
# Store the highest score found for any part of this sequence
current_max_score = sequences_to_expand.get(base_name, -1.0)
sequences_to_expand[base_name] = max(current_max_score, chunk.get('rerank_score', 0.0))
break
# --- Step 3: Expand sequences by filling in missing parts ---
for base_name, trigger_score in sequences_to_expand.items():
if trigger_score > settings.SEQUENCE_EXPANSION_THRESHOLD:
logger.info(f"Expanding sequence '{base_name}' triggered by a chunk with score {trigger_score:.4f}")
full_sequence_parts = state.sequence_base_to_parts_map.get(base_name, [])
for part_info in full_sequence_parts:
part_id = part_info['id']
# Add the part ONLY if it's not already in our golden source map
if part_id not in final_chunks_map:
final_chunks_map[part_id] = {
"id": part_id,
"text": part_info['text'],
"rerank_score": -1.0, # Mark as contextually added
}
# --- Step 4: Convert map to list and perform the final sort ---
final_chunks_list = list(final_chunks_map.values())
# Create helper maps for an efficient final sort
chunk_to_sequence_info = {}
for base_name, parts in state.sequence_base_to_parts_map.items():
for part_info in parts:
chunk_to_sequence_info[part_info['id']] = {"base": base_name, "part": part_info['part']}
# Get the max score for each sequence group
sequence_group_scores = {}
for chunk in final_chunks_list:
seq_info = chunk_to_sequence_info.get(chunk['id'])
if seq_info:
base_name = seq_info['base']
current_max = sequence_group_scores.get(base_name, -1.0)
sequence_group_scores[base_name] = max(current_max, chunk.get('rerank_score', 0.0))
def sort_key(chunk):
seq_info = chunk_to_sequence_info.get(chunk['id'])
if seq_info:
# If sequential, sort by the group's max score, then by part number
return (-sequence_group_scores.get(seq_info['base'], -1.0), seq_info['part'])
else:
# If not sequential, sort by its own score, with a tie-breaker
return (-chunk.get('rerank_score', -1.0), 0)
final_chunks_list.sort(key=sort_key)
logger.info(f"Context expansion and deduplication complete. Final chunk count: {len(final_chunks_list)}.")
return final_chunks_list