julian-schelb commited on
Commit
dc017a9
·
verified ·
1 Parent(s): 4b23b4b

Create results_page.py

Browse files
Files changed (1) hide show
  1. results_page.py +362 -0
results_page.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Results stage for the Loci Similes GUI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ import io
7
+ import re
8
+ from typing import TYPE_CHECKING
9
+
10
+ try:
11
+ import gradio as gr
12
+ except ImportError as exc:
13
+ missing = getattr(exc, "name", None)
14
+ base_msg = (
15
+ "Optional GUI dependencies are missing. Install them via "
16
+ "'pip install locisimiles[gui]' (Python 3.13+ also requires the "
17
+ "audioop-lts backport) to use the Gradio interface."
18
+ )
19
+ if missing and missing != "gradio":
20
+ raise ImportError(f"{base_msg} (missing package: {missing})") from exc
21
+ raise ImportError(base_msg) from exc
22
+
23
+ if TYPE_CHECKING:
24
+ from locisimiles.document import Document, TextSegment
25
+
26
+ import tempfile
27
+ from typing import Any, Dict, List, Tuple
28
+
29
+ try:
30
+ import gradio as gr
31
+ except ImportError as exc:
32
+ raise ImportError("Gradio is required for results page") from exc
33
+
34
+ from locisimiles.document import Document, TextSegment
35
+
36
+ # Type aliases from pipeline
37
+ FullDict = Dict[str, List[Tuple[TextSegment, float, float]]]
38
+
39
+
40
+ def update_results_display(results: FullDict | None, query_doc: Document | None, threshold: float = 0.5) -> tuple[dict, dict, dict]:
41
+ """Update the results display with new data.
42
+
43
+ Args:
44
+ results: Pipeline results
45
+ query_doc: Query document
46
+ threshold: Classification probability threshold for counting finds
47
+
48
+ Returns:
49
+ Tuple of (query_segments_update, query_segments_state, matches_dict_state)
50
+ """
51
+ query_segments, matches_dict = _convert_results_to_display(results, query_doc, threshold)
52
+
53
+ return (
54
+ gr.update(value=query_segments), # Update query segments dataframe
55
+ query_segments, # Update query segments state
56
+ matches_dict, # Update matches dict state
57
+ )
58
+
59
+
60
+ def _format_metric_with_bar(value: float, is_above_threshold: bool = False) -> str:
61
+ """Format a metric value with a visual progress bar.
62
+
63
+ Args:
64
+ value: Metric value between 0 and 1
65
+ is_above_threshold: Whether to highlight this value
66
+
67
+ Returns:
68
+ HTML string with progress bar
69
+ """
70
+ percentage = int(value * 100)
71
+
72
+ # Choose color based on threshold
73
+ if is_above_threshold:
74
+ bar_color = "#6B9BD1" # Blue accent for findings
75
+ bg_color = "#E3F2FD" # Light blue background
76
+ else:
77
+ bar_color = "#B0B0B0" # Gray for below threshold
78
+ bg_color = "#F5F5F5" # Light gray background
79
+
80
+ html = f'''
81
+ <div style="display: flex; align-items: center; gap: 8px; width: 100%;">
82
+ <div style="flex: 1; background-color: {bg_color}; border-radius: 4px; overflow: hidden; height: 20px; position: relative;">
83
+ <div style="background-color: {bar_color}; width: {percentage}%; height: 100%; transition: width 0.3s;"></div>
84
+ </div>
85
+ <span style="min-width: 45px; text-align: right; font-weight: {'bold' if is_above_threshold else 'normal'};">{value:.3f}</span>
86
+ </div>
87
+ '''
88
+ return html
89
+
90
+
91
+ def _convert_results_to_display(results: FullDict | None, query_doc: Document | None, threshold: float = 0.5) -> tuple[list[list], dict]:
92
+ """Convert pipeline results to display format.
93
+
94
+ Args:
95
+ results: Pipeline results (FullDict format)
96
+ query_doc: Query document
97
+ threshold: Classification probability threshold for counting finds
98
+
99
+ Returns:
100
+ Tuple of (query_segments_list, matches_dict)
101
+ """
102
+ if results is None or query_doc is None:
103
+ # Return empty data if no results
104
+ return [], {}
105
+
106
+ # First pass: Create raw matches dictionary and count finds
107
+ raw_matches = {}
108
+ find_counts = {}
109
+
110
+ for query_id, match_list in results.items():
111
+ # Sort by probability (descending) to show most likely matches first
112
+ sorted_matches = sorted(match_list, key=lambda x: x[2], reverse=True) # x[2] is probability
113
+
114
+ # Store raw numeric values
115
+ raw_matches[query_id] = sorted_matches
116
+
117
+ # Count finds above threshold
118
+ find_counts[query_id] = sum(1 for _, _, prob in sorted_matches if prob >= threshold)
119
+
120
+ # Convert query document to list format with find counts
121
+ # Document is iterable and returns TextSegments in order
122
+ query_segments = []
123
+ for segment in query_doc:
124
+ find_count = find_counts.get(segment.id, 0)
125
+ query_segments.append([segment.id, segment.text, find_count])
126
+
127
+ # Second pass: Format matches with HTML progress bars
128
+ matches_dict = {}
129
+ for query_id, match_list in raw_matches.items():
130
+ matches_dict[query_id] = [
131
+ [
132
+ source_seg.id,
133
+ source_seg.text,
134
+ _format_metric_with_bar(round(similarity, 3), probability >= threshold),
135
+ _format_metric_with_bar(round(probability, 3), probability >= threshold)
136
+ ]
137
+ for source_seg, similarity, probability in match_list
138
+ ]
139
+
140
+ return query_segments, matches_dict
141
+
142
+
143
+ def _on_query_select(evt: gr.SelectData, query_segments: list, matches_dict: dict) -> tuple[dict, dict]:
144
+ """Handle query segment selection and return matching source segments.
145
+
146
+ Note: evt.index[0] gives the row number when clicking anywhere in that row.
147
+
148
+ Args:
149
+ evt: Selection event data
150
+ query_segments: List of query segments
151
+ matches_dict: Dictionary mapping query IDs to matches
152
+
153
+ Returns:
154
+ A tuple of (prompt_visibility_update, dataframe_update_with_data)
155
+ """
156
+ if evt.index is None or len(evt.index) < 1:
157
+ return gr.update(visible=True), gr.update(visible=False)
158
+
159
+ row_index = evt.index[0]
160
+ if row_index >= len(query_segments):
161
+ return gr.update(visible=True), gr.update(visible=False)
162
+
163
+ segment_id = query_segments[row_index][0]
164
+ matches = matches_dict.get(segment_id, [])
165
+
166
+ # Hide prompt, show dataframe with results
167
+ return gr.update(visible=False), gr.update(value=matches, visible=True)
168
+
169
+
170
+ def _extract_numeric_from_html(html_str: str) -> float:
171
+ """Extract numeric value from HTML formatted metric string.
172
+
173
+ Args:
174
+ html_str: HTML string with embedded numeric value
175
+
176
+ Returns:
177
+ Extracted numeric value
178
+ """
179
+ import re
180
+ # Extract the number from the span at the end: <span ...>0.XXX</span>
181
+ match = re.search(r'<span[^>]*>([\d.]+)</span>', html_str)
182
+ if match:
183
+ return float(match.group(1))
184
+ # Fallback: if it's already a number
185
+ try:
186
+ return float(html_str)
187
+ except (ValueError, TypeError):
188
+ return 0.0
189
+
190
+
191
+ def _export_results_to_csv(query_segments: list, matches_dict: dict, threshold: float) -> str:
192
+ """Export results to a CSV file.
193
+
194
+ Args:
195
+ query_segments: List of query segments with find counts
196
+ matches_dict: Dictionary mapping query IDs to matches
197
+ threshold: Classification probability threshold
198
+
199
+ Returns:
200
+ Path to the temporary CSV file
201
+ """
202
+ # Create a temporary file
203
+ temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv', newline='', encoding='utf-8')
204
+
205
+ with temp_file as f:
206
+ writer = csv.writer(f)
207
+
208
+ # Write header
209
+ writer.writerow([
210
+ "Query_Segment_ID",
211
+ "Query_Text",
212
+ "Source_Segment_ID",
213
+ "Source_Text",
214
+ "Similarity",
215
+ "Probability",
216
+ "Above_Threshold"
217
+ ])
218
+
219
+ # Write data for each query segment
220
+ for query_row in query_segments:
221
+ query_id = query_row[0]
222
+ query_text = query_row[1]
223
+
224
+ # Get matches for this query segment
225
+ matches = matches_dict.get(query_id, [])
226
+
227
+ if matches:
228
+ for match in matches:
229
+ source_id = match[0]
230
+ source_text = match[1]
231
+ # Extract numeric values from HTML formatted strings
232
+ similarity = _extract_numeric_from_html(match[2]) if isinstance(match[2], str) else match[2]
233
+ probability = _extract_numeric_from_html(match[3]) if isinstance(match[3], str) else match[3]
234
+ above_threshold = "Yes" if probability >= threshold else "No"
235
+
236
+ writer.writerow([
237
+ query_id,
238
+ query_text,
239
+ source_id,
240
+ source_text,
241
+ similarity,
242
+ probability,
243
+ above_threshold
244
+ ])
245
+ else:
246
+ # Write row even if no matches
247
+ writer.writerow([
248
+ query_id,
249
+ query_text,
250
+ "",
251
+ "",
252
+ "",
253
+ "",
254
+ ""
255
+ ])
256
+
257
+ return temp_file.name
258
+
259
+
260
+ def build_results_stage() -> tuple[gr.Step, dict[str, Any]]:
261
+ """Build the results stage UI.
262
+
263
+ Returns:
264
+ A tuple of (Step component, components_dict) where components_dict contains
265
+ references to all interactive components that need to be accessed later.
266
+ """
267
+ with gr.Step("Results", id=2) as step:
268
+ # State to hold current query segments and matches
269
+ query_segments_state = gr.State(value=[])
270
+ matches_dict_state = gr.State(value={})
271
+ gr.Markdown("### 📊 Step 3: View Results")
272
+ gr.Markdown(
273
+ "Select a query segment on the left to view potential intertextual references from the source document. "
274
+ "Similarity measures the cosine similarity between embeddings (0-1, higher = more similar). "
275
+ "Probability is the classifier's confidence that the pair represents an intertextual reference (0-1, higher = more likely)."
276
+ )
277
+
278
+ # Download button
279
+ with gr.Row():
280
+ download_btn = gr.DownloadButton("Download Results as CSV", variant="primary")
281
+
282
+ with gr.Row():
283
+ # Left column: Query segments
284
+ with gr.Column(scale=1):
285
+ gr.Markdown("### Query Document Segments")
286
+ query_segments = gr.Dataframe(
287
+ value=[],
288
+ headers=["Segment ID", "Text", "Finds"],
289
+ interactive=False,
290
+ show_label=False,
291
+ label="Query Document Segments",
292
+ wrap=True,
293
+ max_height=600,
294
+ col_count=(3, "fixed"),
295
+ )
296
+
297
+ # Right column: Matching source segments
298
+ with gr.Column(scale=1):
299
+ gr.Markdown("### Potential Intertextual References")
300
+
301
+ # Prompt shown initially
302
+ selection_prompt = gr.Markdown(
303
+ """
304
+ <div style="display: flex; align-items: center; justify-content: center; height: 400px; font-size: 18px; color: #666;">
305
+ <div style="text-align: center;">
306
+ <div style="font-size: 48px; margin-bottom: 20px;">←</div>
307
+ <div>Select a query segment to view</div>
308
+ <div>potential intertextual references</div>
309
+ </div>
310
+ </div>
311
+ """,
312
+ visible=True
313
+ )
314
+
315
+ # Dataframe hidden initially
316
+ source_matches = gr.Dataframe(
317
+ headers=["Source ID", "Source Text", "Similarity", "Probability"],
318
+ interactive=False,
319
+ show_label=False,
320
+ label="Potential Intertextual References from Source Document",
321
+ wrap=True,
322
+ max_height=600,
323
+ visible=False,
324
+ datatype=["str", "str", "html", "html"], # Enable HTML rendering for metric columns
325
+ )
326
+
327
+ with gr.Row():
328
+ restart_btn = gr.Button("← Start Over", size="lg")
329
+
330
+ # Return the step and all components that need to be accessed
331
+ components = {
332
+ "query_segments": query_segments,
333
+ "query_segments_state": query_segments_state,
334
+ "matches_dict_state": matches_dict_state,
335
+ "source_matches": source_matches,
336
+ "selection_prompt": selection_prompt,
337
+ "download_btn": download_btn,
338
+ "restart_btn": restart_btn,
339
+ }
340
+
341
+ return step, components
342
+
343
+
344
+ def setup_results_handlers(components: dict, walkthrough: gr.Walkthrough) -> None:
345
+ """Set up event handlers for the results stage.
346
+
347
+ Args:
348
+ components: Dictionary of UI components from build_results_stage
349
+ walkthrough: The Walkthrough component for navigation
350
+ """
351
+ # Selection handler for query segments
352
+ components["query_segments"].select(
353
+ fn=_on_query_select,
354
+ inputs=[components["query_segments_state"], components["matches_dict_state"]],
355
+ outputs=[components["selection_prompt"], components["source_matches"]],
356
+ )
357
+
358
+ # Restart button: Step 3 → Step 1
359
+ components["restart_btn"].click(
360
+ fn=lambda: gr.Walkthrough(selected=0),
361
+ outputs=walkthrough,
362
+ )