File size: 7,140 Bytes
d60bab3
 
 
 
b6ee133
d60bab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efd3c5f
 
d60bab3
efd3c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d60bab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import re
from typing import List, Dict, Any, Union
import ast
from langchain_core.messages import SystemMessage, HumanMessage


# ---------------------------------------------------------------------
# Core Processing Functions
# ---------------------------------------------------------------------
def _parse_citations(response: str) -> List[int]:
    """Parse citation numbers from response text"""
    citation_pattern = r'\[(\d+)\]'
    matches = re.findall(citation_pattern, response)
    citation_numbers = sorted(list(set(int(match) for match in matches)))
    
    return citation_numbers

def _extract_sources(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
    """Extract sources that were cited in the response"""
    if not cited_numbers:
        return []
    
    cited_sources = []
    for citation_num in cited_numbers:
        source_index = citation_num - 1
        
        if 0 <= source_index < len(processed_results):
            source = processed_results[source_index].copy()  # Make copy to avoid modifying original
            source['_citation_number'] = citation_num  # Preserve original citation number
            cited_sources.append(source)
    
    return cited_sources

def clean_citations(response: str) -> str:
    """Normalize all citation formats to [x] and remove unwanted sections"""
    
    # Remove References/Sources/Bibliography sections
    ref_patterns = [
        r'\n\s*#+\s*References?\s*:?.*$',
        r'\n\s*#+\s*Sources?\s*:?.*$',
        r'\n\s*#+\s*Bibliography\s*:?.*$',
        r'\n\s*References?\s*:.*$',
        r'\n\s*Sources?\s*:.*$',
        r'\n\s*Bibliography\s*:.*$',
    ]
    for pattern in ref_patterns:
        response = re.sub(pattern, '', response, flags=re.IGNORECASE | re.DOTALL)
    
    # Fix (Document X, Page Y, Year Z) -> [X]
    response = re.sub(
        r'\(Document\s+(\d+)(?:,\s*Page\s+\d+)?(?:,\s*(?:Year\s+)?\d+)?\)',
        r'[\1]',
        response,
        flags=re.IGNORECASE
    )
    
    # Fix [Document X, Page Y, Year Z] -> [X]
    response = re.sub(
        r'\[Document\s+(\d+)(?:[^\]]*)\]', 
        r'[\1]', 
        response, 
        flags=re.IGNORECASE
    )
    
    # Fix [Document X: filename, Page Y, Year Z] -> [X]
    response = re.sub(
        r'\[Document\s+(\d+):[^\]]+\]',
        r'[\1]',
        response,
        flags=re.IGNORECASE
    )
    
    # Fix [X.Y.Z] style (section numbers) -> [X]
    response = re.sub(
        r'\[(\d+)\.[\d\.]+\]', 
        r'[\1]', 
        response
    )
    
    # Fix (Document X) -> [X]
    response = re.sub(
        r'\(Document\s+(\d+)\)', 
        r'[\1]', 
        response, 
        flags=re.IGNORECASE
    )
    
    # Fix "Document X, Page Y, Year Z" (no brackets) -> [X]
    response = re.sub(
        r'Document\s+(\d+)(?:,\s*Page\s+\d+)?(?:,\s*(?:Year\s+)?\d+)?(?=\s|[,.])',
        r'[\1]',
        response,
        flags=re.IGNORECASE
    )
    
    # Fix "Document X states/says/mentions" -> [X]
    response = re.sub(
        r'Document\s+(\d+)\s+(?:states|says|mentions|reports|indicates|notes|shows)',
        r'[\1]',
        response,
        flags=re.IGNORECASE
    )
    
    # Clean up any double citations [[1]] -> [1]
    response = re.sub(r'\[\[(\d+)\]\]', r'[\1]', response)
    
    # Clean up multiple spaces
    response = re.sub(r'\s+', ' ', response)
    
    return response.strip()

def _process_context(context: Union[str, List[Dict[str, Any]]]) -> tuple[str, List[Dict[str, Any]]]:
    """Process context and return formatted context string and processed results"""
    processed_results = []
    
    if isinstance(context, list):
        if not context:
            raise ValueError("No retrieval results provided")
        
        # Extract relevant fields from retrieval results
        for result in context:
            if isinstance(result, str):
                result = ast.literal_eval(result)
            
            # Handle both ingested files (metadata at top level) and retrieved documents (metadata in answer_metadata)
            # Check if metadata is nested in 'answer_metadata' (retrieved documents)
            metadata = result.get('answer_metadata', {})
            
            # If answer_metadata is empty or missing, check top level (ingested files)
            if not metadata or all(v is None or v == 'Unknown' for v in metadata.values()):
                # For ingested files, metadata is at the top level
                doc_info = {
                    'answer': result.get('answer', result.get('content', '')),
                    'filename': result.get('filename', 'Unknown'),
                    'page': result.get('page', 'Unknown'),
                    'year': result.get('year', 'Unknown'),
                    'source': result.get('source', 'Unknown'),
                    'document_id': result.get('_id', result.get('document_id', 'Unknown'))
                }
            else:
                # For retrieved documents, use nested metadata
                doc_info = {
                    'answer': result.get('answer', ''),
                    'filename': metadata.get('filename', 'Unknown'),
                    'page': metadata.get('page', 'Unknown'),
                    'year': metadata.get('year', 'Unknown'),
                    'source': metadata.get('source', 'Unknown'),
                    'document_id': metadata.get('_id', 'Unknown')
                }
            
            processed_results.append(doc_info)
        
        # Format context string - SIMPLIFIED TO ONLY USE [1], [2], [3]
        context_parts = []
        for i, result in enumerate(processed_results, 1):
            # Simple format: [1], [2], etc.
            context_parts.append(f"[{i}]\n{result['answer']}\n")
        
        formatted_context = "\n".join(context_parts)
        
    elif isinstance(context, str):
        if not context.strip():
            raise ValueError("Context cannot be empty")
        formatted_context = context
    else:
        raise ValueError("Context must be either a string or list of retrieval results")
    
    return formatted_context, processed_results

def _build_messages(system_prompt: str, question: str, context: str) -> list:
    """Build messages for LLM call"""
    system_content = system_prompt
    user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
    return [SystemMessage(content=system_content), HumanMessage(content=user_content)]

def _create_sources_list(cited_sources: List[Dict[str, Any]]) -> List[Dict[str, str]]:
    """Create sources list for ChatUI format"""
    sources = []
    for result in cited_sources:
        filename = result.get('filename', 'Unknown')
        page = result.get('page', 'Unknown')
        year = result.get('year', 'Unknown')
        
        link = f"doc://{filename}"
        title_parts = [filename]
        if page != 'Unknown':
            title_parts.append(f"Page {page}")
        if year != 'Unknown':
            title_parts.append(f"({year})")
        
        sources.append({"link": link, "title": " - ".join(title_parts)})
    
    return sources