File size: 10,043 Bytes
3acfa02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db36d77
3acfa02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625cdf2
3acfa02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""Configuration stage for the Loci Similes GUI."""

from __future__ import annotations

import sys

try:
    import gradio as gr
except ImportError as exc:
    missing = getattr(exc, "name", None)
    base_msg = (
        "Optional GUI dependencies are missing. Install them via "
        "'pip install locisimiles[gui]' (Python 3.13+ also requires the "
        "audioop-lts backport) to use the Gradio interface."
    )
    if missing and missing != "gradio":
        raise ImportError(f"{base_msg} (missing package: {missing})") from exc
    raise ImportError(base_msg) from exc

from utils import validate_csv
from locisimiles.pipeline import ClassificationPipelineWithCandidategeneration
from locisimiles.document import Document


def _show_processing_status() -> dict:
    """Show the processing spinner."""
    spinner_html = """
    <div style="display: flex; align-items: center; justify-content: center; padding: 20px; background-color: #e3f2fd; border-radius: 8px; margin: 20px 0;">
        <div style="display: flex; flex-direction: column; align-items: center; gap: 15px;">
            <div style="border: 4px solid #f3f3f3; border-top: 4px solid #2196F3; border-radius: 50%; width: 40px; height: 40px; animation: spin 1s linear infinite;"></div>
            <div style="font-size: 16px; color: #1976D2; font-weight: 500;">Processing documents... This may take several minutes on first run.</div>
            <div style="font-size: 13px; color: #666;">Downloading models, generating embeddings, and classifying candidates...</div>
        </div>
    </div>
    <style>
        @keyframes spin {
            0% { transform: rotate(0deg); }
            100% { transform: rotate(360deg); }
        }
    </style>
    """
    return gr.update(value=spinner_html, visible=True)


def _process_documents(
    query_file: str | None,
    source_file: str | None,
    classification_model: str,
    embedding_model: str,
    top_k: int,
    threshold: float,
) -> tuple:
    """Process the documents using the Loci Similes pipeline and navigate to results step.
    
    Args:
        query_file: Path to query CSV file
        source_file: Path to source CSV file
        classification_model: Name of the classification model
        embedding_model: Name of the embedding model
        top_k: Number of top candidates to retrieve
        threshold: Similarity threshold (not used in pipeline, for future filtering)
    
    Returns:
        Tuple of (processing_status_update, walkthrough_update, results_state, query_doc_state)
    """
    if not query_file or not source_file:
        gr.Warning("Both query and source documents must be uploaded before processing.")
        return gr.update(visible=False), gr.Walkthrough(selected=1), None, None
    
    # Validate both files
    query_valid, query_msg = validate_csv(query_file)
    source_valid, source_msg = validate_csv(source_file)
    
    if not query_valid or not source_valid:
        gr.Warning("Please ensure both documents are valid before processing.")
        return gr.update(visible=False), gr.Walkthrough(selected=1), None, None
    
    try:
        # Detect device (prefer GPU if available)
        import torch
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"
        
        # Initialize pipeline
        # Note: First run will download models (~500MB each), subsequent runs use cached models
        pipeline = ClassificationPipelineWithCandidategeneration(
            classification_name=classification_model,
            embedding_model_name=embedding_model,
            device=device,
        )
        
        # Load documents
        query_doc = Document(query_file)
        source_doc = Document(source_file)
        
        # Run pipeline
        results = pipeline.run(
            query=query_doc,
            source=source_doc,
            top_k=top_k,
        )
        
        # Store results
        num_queries = len(results)
        total_matches = sum(len(matches) for matches in results.values())
        
        print(f"Processing complete! Found matches for {num_queries} query segments ({total_matches} total matches).")
        
        # Return results and navigate to results step (Step 3, id=2)
        return (
            gr.update(visible=False),   # Hide processing status
            gr.Walkthrough(selected=2), # Navigate to Results step
            results,                    # Store results in state
            query_doc,                  # Store query doc in state
        )
        
    except Exception as e:
        print(f"Processing error: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        gr.Error(f"Processing failed: {str(e)}")
        return (
            gr.update(visible=False),   # Hide processing status
            gr.Walkthrough(selected=1), # Stay on Configuration step
            None,                       # No results
            None,                       # No query doc
        )


def build_config_stage() -> tuple[gr.Step, dict]:
    """Build the configuration stage UI.
    
    Returns:
        Tuple of (Step component, dict of components for external access)
    """
    components = {}
    
    with gr.Step("Pipeline Configuration", id=1) as step:
        gr.Markdown("### βš™οΈ Step 2: Pipeline Configuration")
        gr.Markdown(
            "Configure the two-stage pipeline. Stage 1 (Embedding): Quickly ranks all source segments by similarity to each query segment. "
            "Stage 2 (Classification): Examines the top-K candidates more carefully to identify true intertextual references. "
            "Higher K values catch more potential citations but increase computation time. The threshold filters results by classification confidence."
        )
        
        with gr.Row():
            # Left column: Model Selection
            with gr.Column():
                gr.Markdown("**πŸ€– Model Selection**")
                components["classification_model"] = gr.Dropdown(
                    label="Classification Model",
                    choices=["julian-schelb/PhilBerta-class-latin-intertext-v1"],
                    value="julian-schelb/PhilBerta-class-latin-intertext-v1",
                    interactive=True,
                    info="Model used to classify candidate pairs as intertextual or not",
                )
                components["embedding_model"] = gr.Dropdown(
                    label="Embedding Model",
                    choices=["julian-schelb/SPhilBerta-emb-lat-intertext-v1"],
                    value="julian-schelb/SPhilBerta-emb-lat-intertext-v1",
                    interactive=True,
                    info="Model used to generate embeddings for candidate retrieval",
                )
            
            # Right column: Retrieval Parameters
            with gr.Column():
                gr.Markdown("**πŸ› οΈ Retrieval Parameters**")
                components["top_k"] = gr.Slider(
                    minimum=1,
                    maximum=50,
                    value=10,
                    step=1,
                    label="Top K Candidates",
                    info="How many candidates to examine per query. Higher values find more references but take longer to process.",
                )
                components["threshold"] = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.5,
                    step=0.05,
                    label="Classification Threshold",
                    info="Minimum confidence to count as a 'find'. Lower = more results but more false positives; Higher = fewer but more certain results.",
                )
        
        components["processing_status"] = gr.HTML(visible=False)
        
        with gr.Row():
            components["back_btn"] = gr.Button("← Back to Upload", size="lg")
            components["process_btn"] = gr.Button("Process Documents β†’", variant="primary", size="lg")
    
    return step, components


def setup_config_handlers(
    components: dict,
    file_states: dict,
    pipeline_states: dict,
    walkthrough: gr.Walkthrough,
    results_components: dict,
) -> None:
    """Set up event handlers for the configuration stage.
    
    Args:
        components: Dictionary of UI components from build_config_stage
        file_states: Dictionary with query_file_state and source_file_state
        pipeline_states: Dictionary with results_state and query_doc_state
        walkthrough: The Walkthrough component for navigation
        results_components: Components from results stage for updating
    """
    from results_stage import update_results_display
    
    # Back button: Step 2 β†’ Step 1
    components["back_btn"].click(
        fn=lambda: gr.Walkthrough(selected=0),
        outputs=walkthrough,
    )
    
    # Process button: Step 2 β†’ Step 3
    components["process_btn"].click(
        fn=_show_processing_status,
        outputs=components["processing_status"],
    ).then(
        fn=_process_documents,
        inputs=[
            file_states["query_file_state"],
            file_states["source_file_state"],
            components["classification_model"],
            components["embedding_model"],
            components["top_k"],
            components["threshold"],
        ],
        outputs=[
            components["processing_status"],
            walkthrough,
            pipeline_states["results_state"],
            pipeline_states["query_doc_state"],
        ],
    ).then(
        fn=update_results_display,
        inputs=[
            pipeline_states["results_state"],
            pipeline_states["query_doc_state"],
            components["threshold"],
        ],
        outputs=[
            results_components["query_segments"],
            results_components["query_segments_state"],
            results_components["matches_dict_state"],
        ],
    )