File size: 17,691 Bytes
1636abd
dc56cce
 
 
 
 
a21bd5b
dc56cce
77ead37
28f8ac4
dc56cce
 
 
 
 
 
 
 
c822327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c163f90
c822327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc56cce
bdfc013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc56cce
 
 
1636abd
dc56cce
 
 
 
 
 
d988980
1636abd
 
d988980
1636abd
 
 
 
 
 
 
 
 
dc56cce
d988980
1636abd
 
dc56cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1636abd
dc56cce
 
 
 
 
 
 
 
 
 
 
 
 
 
a21bd5b
 
1636abd
 
dc56cce
 
 
 
 
 
1636abd
 
 
a21bd5b
 
 
 
 
1636abd
a21bd5b
 
 
 
1636abd
 
 
 
0c2088f
1636abd
 
 
0c2088f
a21bd5b
1636abd
 
 
 
 
 
 
 
 
 
 
 
 
 
a21bd5b
1636abd
 
 
 
a21bd5b
1636abd
 
 
 
 
 
 
 
 
 
 
a21bd5b
1636abd
 
 
 
dc56cce
 
a21bd5b
 
 
 
 
 
 
 
 
1636abd
 
dc56cce
1636abd
 
 
 
 
 
dc56cce
 
a21bd5b
 
1636abd
 
dc56cce
a21bd5b
 
 
 
1636abd
 
dc56cce
 
 
bb64eee
dc56cce
 
 
a21bd5b
 
 
 
 
0c2088f
1636abd
dc56cce
a21bd5b
dc56cce
 
 
 
 
a21bd5b
dc56cce
a21bd5b
 
dc56cce
 
 
 
 
 
a21bd5b
 
 
dc56cce
 
a21bd5b
 
 
dc56cce
 
a21bd5b
 
dc56cce
a21bd5b
 
 
 
dc56cce
 
a21bd5b
f3ce0ca
bdfc013
a21bd5b
dc56cce
1636abd
dc56cce
 
1636abd
a21bd5b
 
dc56cce
a21bd5b
1636abd
dc56cce
 
a21bd5b
dc56cce
0c2088f
 
 
 
 
 
 
 
 
 
 
c822327
 
 
 
 
 
bdfc013
 
0c2088f
a21bd5b
1636abd
a21bd5b
 
 
 
bdfc013
 
 
 
 
 
a21bd5b
 
 
1636abd
 
 
 
 
a21bd5b
 
 
 
 
1636abd
a21bd5b
 
dc56cce
1636abd
dc56cce
 
 
0c2088f
a21bd5b
 
 
0c2088f
 
 
 
 
 
 
a21bd5b
 
 
 
 
 
 
 
 
 
1636abd
 
 
 
 
 
a21bd5b
dc56cce
 
 
d988980
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

import gradio as gr
import subprocess
import os
import sys
from datetime import datetime
import shutil


TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"

MODEL_OUTPUT_DIR = "checkpoints"
MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)


# ----------------------------------------------------------------

def retrieve_model():
    """
    Checks for the final model file and prepares it for download.
    Useful for when the training job finishes server-side but the
    client connection has timed out.
    """
    MODEL_OUTPUT_DIR = "checkpoints"
    MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
    MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)

    if os.path.exists(MODEL_FILE_PATH):
        file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024)  # Size in MB

        # CRITICAL: Copy to a simple location that Gradio can reliably serve
        import tempfile
        temp_dir = tempfile.gettempdir()
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        temp_model_path = os.path.join(temp_dir, f"layoutlmv3_trained_{timestamp}_recovered.pth")

        try:
            shutil.copy2(MODEL_FILE_PATH, temp_model_path)
            download_path = temp_model_path

            log_output = (
                f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
                f"πŸŽ‰ SUCCESS! A trained model was found and recovered. Boobs! AASTIK MERA NAAM\n"
                f"πŸ“¦ Model file: {MODEL_FILE_PATH}\n"
                f"πŸ“Š Model size: {file_size:.2f} MB\n"
                f"πŸ”— Download path prepared: {download_path}\n\n"
                f"⬇️ Click the 'πŸ“₯ Download Model' button below to save your model."
            )
            return log_output, download_path, gr.Button(visible=True)

        except Exception as e:
            log_output = (
                f"--- Model Status Check FAILED ---\n"
                f"⚠️ Trained model found, but could not prepare for download: {e}\n"
                f"πŸ“ Original Path: {MODEL_FILE_PATH}. Try again or check Space logs."
            )
            return log_output, None, gr.Button(visible=False)

    else:
        log_output = (
            f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
            f"❌ Model file not found at {MODEL_FILE_PATH}.\n"
            f"Training may still be running or it failed. Check back later."
        )
        return log_output, None, gr.Button(visible=False)




def clear_memory(dataset_file: gr.File):
    """
    Deletes the model output directory and the uploaded dataset file.
    """
    MODEL_OUTPUT_DIR = "checkpoints"

    log_output = f"--- Memory Clear Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"

    # 1. Clear Model Checkpoints Directory
    if os.path.exists(MODEL_OUTPUT_DIR):
        try:
            shutil.rmtree(MODEL_OUTPUT_DIR)
            log_output += f"βœ… Successfully deleted model directory: {MODEL_OUTPUT_DIR}\n"
        except Exception as e:
            log_output += f"❌ ERROR deleting model directory {MODEL_OUTPUT_DIR}: {e}\n"
    else:
        log_output += f"ℹ️ Model directory not found: {MODEL_OUTPUT_DIR} (Nothing to delete)\n"

    # 2. Clear Uploaded Dataset File (Temporary file cleanup)
    if dataset_file is not None:
        input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)
        if os.path.exists(input_path):
            try:
                os.remove(input_path)
                log_output += f"βœ… Successfully deleted uploaded dataset file: {input_path}\n"
            except Exception as e:
                log_output += f"❌ ERROR deleting dataset file {input_path}: {e}\n"
        else:
            log_output += f"ℹ️ Uploaded dataset file not found at {input_path}.\n"
    else:
        log_output += f"ℹ️ No dataset file currently tracked for deletion.\n"

    # 3. Final message and state reset
    log_output += f"--- Memory Clear Complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
    log_output += "✨ Files and checkpoints have been removed. You can now start a fresh training run."

    # Reset log_output, model_path_state, download_btn visibility, and model_download component
    return log_output, None, gr.Button(visible=False), None



def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
    """
    Handles the Gradio submission and executes the training script using subprocess.
    Yields logs in real-time for user feedback.
    """

    # 1. Setup: Create output directory if it doesn't exist
    os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

    # 2. File Handling: Use the temporary path of the uploaded file
    if dataset_file is None:
        yield "❌ ERROR: Please upload a file.", None, gr.Button(visible=False)
        return

    # CRITICAL FIX: dataset_file is a gradio.File object, use .name to get the path
    # This is a temporary file path like /tmp/gradio/.../filename.json
    input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)

    # Verify the file actually exists before proceeding
    if not os.path.exists(input_path):
        error_msg = f"❌ ERROR: Uploaded file not found at {input_path}. Please try uploading again."
        yield error_msg, None, gr.Button(visible=False)
        return

    if not input_path.lower().endswith(".json"):
        yield "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None, gr.Button(visible=False)
        return

    progress(0.1, desc="Starting LayoutLMv3 Training...")

    log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"

    # 3. Construct the subprocess command
    command = [
        sys.executable,
        TRAINING_SCRIPT,
        "--mode", "train",
        "--input", input_path,
        "--batch_size", str(batch_size),
        "--epochs", str(epochs),
        "--lr", str(lr),
        "--max_len", str(max_len)
    ]

    log_output += f"Executing command: {' '.join(command)}\n\n"
    yield log_output, None, gr.Button(visible=False)  # Initial yield

    try:
        # 4. Run the training script and capture output
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1
        )

        # Stream logs in real-time
        for line in iter(process.stdout.readline, ""):
            log_output += line
            # Print to console as well for debugging
            print(line, end='')
            # Yield updated logs in real-time
            yield log_output, None, gr.Button(visible=False)

        process.stdout.close()
        return_code = process.wait()

        # 5. Check for successful completion
        if return_code == 0:
            log_output += "\n" + "=" * 60 + "\n"
            log_output += "βœ… TRAINING COMPLETE! Model saved successfully.\n"
            log_output += "=" * 60 + "\n"
            print("\nβœ… TRAINING COMPLETE! Model saved.")

            # 6. Verify model file exists
            if os.path.exists(MODEL_FILE_PATH):
                file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024)  # Size in MB
                log_output += f"\nπŸ“¦ Model file found: {MODEL_FILE_PATH}"
                log_output += f"\nπŸ“Š Model size: {file_size:.2f} MB"

                print(f"\nβœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")

                # CRITICAL: Copy to a simple location that Gradio can reliably serve
                # Use the same temp directory pattern as the uploaded JSON file
                import tempfile
                temp_dir = tempfile.gettempdir()
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

                # Create filename in temp directory
                temp_model_path = os.path.join(temp_dir, f"layoutlmv3_trained_{timestamp}.pth")

                try:
                    # Copy the model to temp directory
                    shutil.copy2(MODEL_FILE_PATH, temp_model_path)
                    log_output += f"\nπŸ“‹ Model copied to temporary download location"
                    log_output += f"\nπŸ”— Download path: {temp_model_path}"
                    print(f"βœ… Model copied to temp location: {temp_model_path}")

                    # Verify the copy exists
                    if os.path.exists(temp_model_path):
                        log_output += f"\nβœ… Download file verified and ready!"
                        download_path = temp_model_path
                    else:
                        log_output += f"\n⚠️ Warning: Temp copy verification failed, using original path"
                        download_path = MODEL_FILE_PATH

                except Exception as e:
                    log_output += f"\n⚠️ Could not create temp copy: {e}"
                    log_output += f"\nπŸ“ Using original path: {MODEL_FILE_PATH}"
                    print(f"⚠️ Copy failed: {e}, using original path")
                    download_path = MODEL_FILE_PATH

                # Final success message
                log_output += f"\n\n{'=' * 60}"
                log_output += f"\nπŸŽ‰ SUCCESS! Your model is ready for download."
                log_output += f"\n{'=' * 60}"
                log_output += f"\n\n⬇️ Click the 'πŸ“₯ Download Model' button below to save your model."
                log_output += f"\n⚠️ CRITICAL: Download NOW! File will be deleted when:"
                log_output += f"\n   - This tab is closed"
                log_output += f"\n   - Space restarts or goes idle"
                log_output += f"\n   - System clears temp files"
                log_output += f"\n\nπŸ“₯ The file will download as a .pth file to your computer's Downloads folder."
                log_output += f"\n\n{'=' * 60}\n"

                # Return final logs and make download button visible
                # IMPORTANT: Return the path that Gradio can access
                yield log_output, download_path, gr.Button(visible=True)
                return
            else:
                log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
                log_output += f"\nπŸ” Checking directory contents..."

                # List files in checkpoints directory for debugging
                if os.path.exists(MODEL_OUTPUT_DIR):
                    files = os.listdir(MODEL_OUTPUT_DIR)
                    log_output += f"\nπŸ“ Files in {MODEL_OUTPUT_DIR}: {files}"
                else:
                    log_output += f"\n❌ Directory {MODEL_OUTPUT_DIR} does not exist!"

                yield log_output, None, gr.Button(visible=False)
                return
        else:
            log_output += f"\n\n{'=' * 60}\n"
            log_output += f"❌ TRAINING FAILED with return code {return_code}\n"
            log_output += f"{'=' * 60}\n"
            log_output += f"\nPlease check the logs above for error details.\n"
            yield log_output, None, gr.Button(visible=False)
            return

    except FileNotFoundError:
        error_msg = f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
        print(error_msg)
        yield log_output + "\n" + error_msg, None, gr.Button(visible=False)
        return
    except Exception as e:
        error_msg = f"❌ An unexpected error occurred: {e}"
        print(error_msg)
        import traceback
        print(traceback.format_exc())
        yield log_output + "\n" + error_msg, None, gr.Button(visible=False)
        return


# --- Gradio Interface Setup (using Blocks for a nicer layout) ---
with gr.Blocks(title="LayoutLMv3 Fine-Tuning App by Aastik", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
    gr.Markdown(
        """
        Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model.

        **⚠️ IMPORTANT - Free Tier Users:**
        - **Download your model IMMEDIATELY** after training completes! 
        - The model file is **temporary** and will be deleted when the Space restarts.
        - A download button will appear below once training is complete.
        - **Real-time logs** will stream during training so you can monitor progress.

        **⏱️ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### πŸ“ Dataset Upload")
            file_input = gr.File(
                label="Upload Label Studio JSON Dataset",
                file_types=[".json"]
            )

            gr.Markdown("---")
            gr.Markdown("### βš™οΈ Training Parameters")

            batch_size_input = gr.Slider(
                minimum=1, maximum=16, step=1, value=4,
                label="Batch Size",
                info="Smaller = less memory, slower training"
            )
            epochs_input = gr.Slider(
                minimum=1, maximum=10, step=1, value=3,
                label="Epochs",
                info="Fewer epochs = faster training (recommended: 3-5)"
            )
            lr_input = gr.Number(
                value=5e-5, label="Learning Rate",
                info="Default: 5e-5"
            )
            max_len_input = gr.Slider(
                minimum=128, maximum=512, step=128, value=512,
                label="Max Sequence Length",
                info="Shorter = faster training, less memory"
            )

            train_button = gr.Button("πŸ”₯ Start Training", variant="primary", size="lg")
            check_button = gr.Button("πŸ” Check Model Status/Download", variant="secondary", size="lg")
            clear_button = gr.Button("🧹 Clear Model/Dataset Files", variant="stop", size="lg")

        with gr.Column(scale=2):
            gr.Markdown("### πŸ“Š Training Progress (Real-Time Logs)")

            log_output = gr.Textbox(
                label="Training Logs - Updates in Real-Time",
                lines=25,
                max_lines=30,
                autoscroll=True,
                show_copy_button=True,
                placeholder="Click 'Start Training' to begin...\n\nLogs will stream here in real-time as training progresses."
            )

            gr.Markdown("### ⬇️ Download Trained Model")

            # Hidden state to store the file path
            model_path_state = gr.State(value=None)

            # Download button (initially hidden)
            download_btn = gr.Button(
                "πŸ“₯ Download Model (.pth file)",
                variant="primary",
                size="lg",
                visible=False
            )

            check_button.click(
                fn=retrieve_model,  # A new function we'll define
                inputs=[],
                outputs=[log_output, model_path_state, download_btn]
            )



            # File output for download
            model_download = gr.File(
                label="Your trained model will appear here after clicking Download",
                interactive=False,
                visible=True
            )

            clear_button.click(
                fn=clear_memory,
                inputs=[file_input],  # Pass the uploaded file object to delete the temp file
                outputs=[log_output, model_path_state, download_btn, model_download]
            )

            gr.Markdown(
                """
                **πŸ“₯ Download Instructions:**
                1. Wait for training to complete - watch the real-time logs above
                2. Look for **"βœ… TRAINING COMPLETE!"** message
                3. Click the **"πŸ“₯ Download Model"** button that appears above
                4. Save the `.pth` file to your local machine
                5. **Do this immediately** - file is temporary and will be deleted on Space restart!

                **πŸ”§ Troubleshooting:**
                - If download button doesn't appear, check the logs for errors
                - Try reducing epochs or batch size if timeout occurs
                - Ensure your JSON file is properly formatted
                - Logs update in real-time - you can monitor training progress
                """
            )

    # Define the training action - now with real-time log streaming via yield
    train_button.click(
        fn=train_model,
        inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
        outputs=[log_output, model_path_state, download_btn],
        api_name="train"
    )

    # Define the download action
    download_btn.click(
        fn=lambda path: path,
        inputs=[model_path_state],
        outputs=[model_download]
    )

    gr.Markdown(
        """
        ---
        ### πŸ“– About
        This Space fine-tunes LayoutLMv3 with CRF for document understanding tasks including:
        - Questions, Options, Answers
        - Section Headings
        - Passages

        **Model Details:** LayoutLMv3-base + CRF layer for sequence labeling

        **Features:**
        - βœ… Real-time log streaming during training
        - βœ… Progress monitoring with epoch/batch updates
        - βœ… Immediate model download after completion
        - βœ… Automatic file preparation for download
        """
    )

if __name__ == "__main__":
    demo.launch()