julian-schelb commited on
Commit
3acfa02
·
verified ·
1 Parent(s): fe7d068

Create config_stage.py

Browse files
Files changed (1) hide show
  1. config_stage.py +254 -0
config_stage.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration stage for the Loci Similes GUI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+
7
+ try:
8
+ import gradio as gr
9
+ except ImportError as exc:
10
+ missing = getattr(exc, "name", None)
11
+ base_msg = (
12
+ "Optional GUI dependencies are missing. Install them via "
13
+ "'pip install locisimiles[gui]' (Python 3.13+ also requires the "
14
+ "audioop-lts backport) to use the Gradio interface."
15
+ )
16
+ if missing and missing != "gradio":
17
+ raise ImportError(f"{base_msg} (missing package: {missing})") from exc
18
+ raise ImportError(base_msg) from exc
19
+
20
+ from .utils import validate_csv
21
+ from locisimiles.pipeline import ClassificationPipelineWithCandidategeneration
22
+ from locisimiles.document import Document
23
+
24
+
25
+ def _show_processing_status() -> dict:
26
+ """Show the processing spinner."""
27
+ spinner_html = """
28
+ <div style="display: flex; align-items: center; justify-content: center; padding: 20px; background-color: #e3f2fd; border-radius: 8px; margin: 20px 0;">
29
+ <div style="display: flex; flex-direction: column; align-items: center; gap: 15px;">
30
+ <div style="border: 4px solid #f3f3f3; border-top: 4px solid #2196F3; border-radius: 50%; width: 40px; height: 40px; animation: spin 1s linear infinite;"></div>
31
+ <div style="font-size: 16px; color: #1976D2; font-weight: 500;">Processing documents... This may take several minutes on first run.</div>
32
+ <div style="font-size: 13px; color: #666;">Downloading models, generating embeddings, and classifying candidates...</div>
33
+ </div>
34
+ </div>
35
+ <style>
36
+ @keyframes spin {
37
+ 0% { transform: rotate(0deg); }
38
+ 100% { transform: rotate(360deg); }
39
+ }
40
+ </style>
41
+ """
42
+ return gr.update(value=spinner_html, visible=True)
43
+
44
+
45
+ def _process_documents(
46
+ query_file: str | None,
47
+ source_file: str | None,
48
+ classification_model: str,
49
+ embedding_model: str,
50
+ top_k: int,
51
+ threshold: float,
52
+ ) -> tuple:
53
+ """Process the documents using the Loci Similes pipeline and navigate to results step.
54
+
55
+ Args:
56
+ query_file: Path to query CSV file
57
+ source_file: Path to source CSV file
58
+ classification_model: Name of the classification model
59
+ embedding_model: Name of the embedding model
60
+ top_k: Number of top candidates to retrieve
61
+ threshold: Similarity threshold (not used in pipeline, for future filtering)
62
+
63
+ Returns:
64
+ Tuple of (processing_status_update, walkthrough_update, results_state, query_doc_state)
65
+ """
66
+ if not query_file or not source_file:
67
+ gr.Warning("Both query and source documents must be uploaded before processing.")
68
+ return gr.update(visible=False), gr.Walkthrough(selected=1), None, None
69
+
70
+ # Validate both files
71
+ query_valid, query_msg = validate_csv(query_file)
72
+ source_valid, source_msg = validate_csv(source_file)
73
+
74
+ if not query_valid or not source_valid:
75
+ gr.Warning("Please ensure both documents are valid before processing.")
76
+ return gr.update(visible=False), gr.Walkthrough(selected=1), None, None
77
+
78
+ try:
79
+ # Detect device (prefer GPU if available)
80
+ import torch
81
+ if torch.cuda.is_available():
82
+ device = "cuda"
83
+ elif torch.backends.mps.is_available():
84
+ device = "mps"
85
+ else:
86
+ device = "cpu"
87
+
88
+ # Initialize pipeline
89
+ # Note: First run will download models (~500MB each), subsequent runs use cached models
90
+ pipeline = ClassificationPipelineWithCandidategeneration(
91
+ classification_name=classification_model,
92
+ embedding_model_name=embedding_model,
93
+ device=device,
94
+ )
95
+
96
+ # Load documents
97
+ query_doc = Document(query_file)
98
+ source_doc = Document(source_file)
99
+
100
+ # Run pipeline
101
+ results = pipeline.run(
102
+ query=query_doc,
103
+ source=source_doc,
104
+ top_k=top_k,
105
+ )
106
+
107
+ # Store results
108
+ num_queries = len(results)
109
+ total_matches = sum(len(matches) for matches in results.values())
110
+
111
+ print(f"Processing complete! Found matches for {num_queries} query segments ({total_matches} total matches).")
112
+
113
+ # Return results and navigate to results step (Step 3, id=2)
114
+ return (
115
+ gr.update(visible=False), # Hide processing status
116
+ gr.Walkthrough(selected=2), # Navigate to Results step
117
+ results, # Store results in state
118
+ query_doc, # Store query doc in state
119
+ )
120
+
121
+ except Exception as e:
122
+ print(f"Processing error: {e}", file=sys.stderr)
123
+ import traceback
124
+ traceback.print_exc()
125
+ gr.Error(f"Processing failed: {str(e)}")
126
+ return (
127
+ gr.update(visible=False), # Hide processing status
128
+ gr.Walkthrough(selected=1), # Stay on Configuration step
129
+ None, # No results
130
+ None, # No query doc
131
+ )
132
+
133
+
134
+ def build_config_stage() -> tuple[gr.Step, dict]:
135
+ """Build the configuration stage UI.
136
+
137
+ Returns:
138
+ Tuple of (Step component, dict of components for external access)
139
+ """
140
+ components = {}
141
+
142
+ with gr.Step("Pipeline Configuration", id=1) as step:
143
+ gr.Markdown("### ⚙️ Step 2: Pipeline Configuration")
144
+ gr.Markdown(
145
+ "Configure the two-stage pipeline. Stage 1 (Embedding): Quickly ranks all source segments by similarity to each query segment. "
146
+ "Stage 2 (Classification): Examines the top-K candidates more carefully to identify true intertextual references. "
147
+ "Higher K values catch more potential citations but increase computation time. The threshold filters results by classification confidence."
148
+ )
149
+
150
+ with gr.Row():
151
+ # Left column: Model Selection
152
+ with gr.Column():
153
+ gr.Markdown("**🤖 Model Selection**")
154
+ components["classification_model"] = gr.Dropdown(
155
+ label="Classification Model",
156
+ choices=["julian-schelb/PhilBerta-class-latin-intertext-v1"],
157
+ value="julian-schelb/PhilBerta-class-latin-intertext-v1",
158
+ interactive=True,
159
+ info="Model used to classify candidate pairs as intertextual or not",
160
+ )
161
+ components["embedding_model"] = gr.Dropdown(
162
+ label="Embedding Model",
163
+ choices=["julian-schelb/SPhilBerta-emb-lat-intertext-v1"],
164
+ value="julian-schelb/SPhilBerta-emb-lat-intertext-v1",
165
+ interactive=True,
166
+ info="Model used to generate embeddings for candidate retrieval",
167
+ )
168
+
169
+ # Right column: Retrieval Parameters
170
+ with gr.Column():
171
+ gr.Markdown("**🛠️ Retrieval Parameters**")
172
+ components["top_k"] = gr.Slider(
173
+ minimum=1,
174
+ maximum=50,
175
+ value=10,
176
+ step=1,
177
+ label="Top K Candidates",
178
+ info="How many candidates to examine per query. Higher values find more references but take longer to process.",
179
+ )
180
+ components["threshold"] = gr.Slider(
181
+ minimum=0.0,
182
+ maximum=1.0,
183
+ value=0.5,
184
+ step=0.05,
185
+ label="Classification Threshold",
186
+ info="Minimum confidence to count as a 'find'. Lower = more results but more false positives; Higher = fewer but more certain results.",
187
+ )
188
+
189
+ components["processing_status"] = gr.HTML(visible=False)
190
+
191
+ with gr.Row():
192
+ components["back_btn"] = gr.Button("← Back to Upload", size="lg")
193
+ components["process_btn"] = gr.Button("Process Documents →", variant="primary", size="lg")
194
+
195
+ return step, components
196
+
197
+
198
+ def setup_config_handlers(
199
+ components: dict,
200
+ file_states: dict,
201
+ pipeline_states: dict,
202
+ walkthrough: gr.Walkthrough,
203
+ results_components: dict,
204
+ ) -> None:
205
+ """Set up event handlers for the configuration stage.
206
+
207
+ Args:
208
+ components: Dictionary of UI components from build_config_stage
209
+ file_states: Dictionary with query_file_state and source_file_state
210
+ pipeline_states: Dictionary with results_state and query_doc_state
211
+ walkthrough: The Walkthrough component for navigation
212
+ results_components: Components from results stage for updating
213
+ """
214
+ from .results_stage import update_results_display
215
+
216
+ # Back button: Step 2 → Step 1
217
+ components["back_btn"].click(
218
+ fn=lambda: gr.Walkthrough(selected=0),
219
+ outputs=walkthrough,
220
+ )
221
+
222
+ # Process button: Step 2 → Step 3
223
+ components["process_btn"].click(
224
+ fn=_show_processing_status,
225
+ outputs=components["processing_status"],
226
+ ).then(
227
+ fn=_process_documents,
228
+ inputs=[
229
+ file_states["query_file_state"],
230
+ file_states["source_file_state"],
231
+ components["classification_model"],
232
+ components["embedding_model"],
233
+ components["top_k"],
234
+ components["threshold"],
235
+ ],
236
+ outputs=[
237
+ components["processing_status"],
238
+ walkthrough,
239
+ pipeline_states["results_state"],
240
+ pipeline_states["query_doc_state"],
241
+ ],
242
+ ).then(
243
+ fn=update_results_display,
244
+ inputs=[
245
+ pipeline_states["results_state"],
246
+ pipeline_states["query_doc_state"],
247
+ components["threshold"],
248
+ ],
249
+ outputs=[
250
+ results_components["query_segments"],
251
+ results_components["query_segments_state"],
252
+ results_components["matches_dict_state"],
253
+ ],
254
+ )