Spaces:
Running
Running
Commit
Β·
0c2088f
1
Parent(s):
a21bd5b
correction
Browse files
app.py
CHANGED
|
@@ -470,6 +470,248 @@
|
|
| 470 |
# if __name__ == "__main__":
|
| 471 |
# demo.launch()
|
| 472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
import gradio as gr
|
| 475 |
import subprocess
|
|
@@ -500,13 +742,13 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
|
|
| 500 |
|
| 501 |
# 2. File Handling: Use the temporary path of the uploaded file
|
| 502 |
if dataset_file is None:
|
| 503 |
-
return "β ERROR: Please upload a file.", None
|
| 504 |
|
| 505 |
# Using .name (Corrected in previous steps)
|
| 506 |
input_path = dataset_file.name
|
| 507 |
|
| 508 |
if not input_path.lower().endswith(".json"):
|
| 509 |
-
return "β ERROR: Please upload a valid Label Studio JSON file (.json).", None
|
| 510 |
|
| 511 |
progress(0.1, desc="Starting LayoutLMv3 Training...")
|
| 512 |
|
|
@@ -555,25 +797,26 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
|
|
| 555 |
file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
|
| 556 |
log_output += f"\nπ¦ Model file: {MODEL_FILE_PATH}"
|
| 557 |
log_output += f"\nπ Model size: {file_size:.2f} MB"
|
| 558 |
-
log_output += f"\nβ¬οΈ Click the download button below to save your model!"
|
| 559 |
|
| 560 |
print(f"\nβ
Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
|
| 561 |
|
| 562 |
-
# Create a copy in the root directory for
|
| 563 |
-
|
|
|
|
|
|
|
| 564 |
try:
|
| 565 |
-
shutil.copy2(MODEL_FILE_PATH,
|
| 566 |
-
log_output += f"\nπ
|
| 567 |
-
print(f"β
Created
|
| 568 |
except Exception as e:
|
| 569 |
-
log_output += f"\nβ οΈ Could not create
|
| 570 |
-
|
| 571 |
|
| 572 |
-
# Return the
|
| 573 |
-
|
| 574 |
-
log_output += f"\n
|
| 575 |
|
| 576 |
-
return log_output,
|
| 577 |
else:
|
| 578 |
log_output += f"\nβ οΈ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
|
| 579 |
log_output += f"\nπ Checking directory contents..."
|
|
@@ -585,21 +828,42 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
|
|
| 585 |
else:
|
| 586 |
log_output += f"\nβ Directory {MODEL_OUTPUT_DIR} does not exist!"
|
| 587 |
|
| 588 |
-
return log_output, None
|
| 589 |
else:
|
| 590 |
log_output += f"\n\nβ TRAINING FAILED with return code {return_code}. Check logs above."
|
| 591 |
-
return log_output, None
|
| 592 |
|
| 593 |
except FileNotFoundError:
|
| 594 |
error_msg = f"β ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
|
| 595 |
print(error_msg)
|
| 596 |
-
return error_msg, None
|
| 597 |
except Exception as e:
|
| 598 |
error_msg = f"β An unexpected error occurred: {e}"
|
| 599 |
print(error_msg)
|
| 600 |
import traceback
|
| 601 |
print(traceback.format_exc())
|
| 602 |
-
return error_msg, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
|
| 605 |
# --- Gradio Interface Setup (using Blocks for a nicer layout) ---
|
|
@@ -612,8 +876,7 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
|
|
| 612 |
**β οΈ IMPORTANT - Free Tier Users:**
|
| 613 |
- **Download your model IMMEDIATELY** after training completes!
|
| 614 |
- The model file is **temporary** and will be deleted when the Space restarts.
|
| 615 |
-
-
|
| 616 |
-
- Model is saved as: **`layoutlmv3_crf_passage.pth`**
|
| 617 |
|
| 618 |
**β±οΈ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
|
| 619 |
"""
|
|
@@ -666,8 +929,20 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
|
|
| 666 |
|
| 667 |
gr.Markdown("### β¬οΈ Download Trained Model")
|
| 668 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
model_download = gr.File(
|
| 670 |
-
label="
|
| 671 |
interactive=False,
|
| 672 |
visible=True
|
| 673 |
)
|
|
@@ -676,7 +951,7 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
|
|
| 676 |
"""
|
| 677 |
**π₯ Download Instructions:**
|
| 678 |
1. Wait for training to complete (β
appears in logs)
|
| 679 |
-
2. Click the
|
| 680 |
3. Save the `.pth` file to your local machine
|
| 681 |
4. **Do this immediately** - file is temporary!
|
| 682 |
|
|
@@ -687,14 +962,21 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
|
|
| 687 |
"""
|
| 688 |
)
|
| 689 |
|
| 690 |
-
# Define the action
|
| 691 |
train_button.click(
|
| 692 |
fn=train_model,
|
| 693 |
inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
|
| 694 |
-
outputs=[log_output,
|
| 695 |
api_name="train"
|
| 696 |
)
|
| 697 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
# Add example info
|
| 699 |
gr.Markdown(
|
| 700 |
"""
|
|
|
|
| 470 |
# if __name__ == "__main__":
|
| 471 |
# demo.launch()
|
| 472 |
|
| 473 |
+
#
|
| 474 |
+
# import gradio as gr
|
| 475 |
+
# import subprocess
|
| 476 |
+
# import os
|
| 477 |
+
# import sys
|
| 478 |
+
# from datetime import datetime
|
| 479 |
+
# import shutil
|
| 480 |
+
#
|
| 481 |
+
# # FIX: Update the script name to the correct one you uploaded
|
| 482 |
+
# TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"
|
| 483 |
+
#
|
| 484 |
+
# # --- CORRECTED MODEL PATH BASED ON YOUR SCRIPT ---
|
| 485 |
+
# MODEL_OUTPUT_DIR = "checkpoints"
|
| 486 |
+
# MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
|
| 487 |
+
# MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
|
| 488 |
+
#
|
| 489 |
+
#
|
| 490 |
+
# # ----------------------------------------------------------------
|
| 491 |
+
#
|
| 492 |
+
#
|
| 493 |
+
# def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
|
| 494 |
+
# """
|
| 495 |
+
# Handles the Gradio submission and executes the training script using subprocess.
|
| 496 |
+
# """
|
| 497 |
+
#
|
| 498 |
+
# # 1. Setup: Create output directory if it doesn't exist
|
| 499 |
+
# os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
|
| 500 |
+
#
|
| 501 |
+
# # 2. File Handling: Use the temporary path of the uploaded file
|
| 502 |
+
# if dataset_file is None:
|
| 503 |
+
# return "β ERROR: Please upload a file.", None
|
| 504 |
+
#
|
| 505 |
+
# # Using .name (Corrected in previous steps)
|
| 506 |
+
# input_path = dataset_file.name
|
| 507 |
+
#
|
| 508 |
+
# if not input_path.lower().endswith(".json"):
|
| 509 |
+
# return "β ERROR: Please upload a valid Label Studio JSON file (.json).", None
|
| 510 |
+
#
|
| 511 |
+
# progress(0.1, desc="Starting LayoutLMv3 Training...")
|
| 512 |
+
#
|
| 513 |
+
# log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
|
| 514 |
+
#
|
| 515 |
+
# # 3. Construct the subprocess command
|
| 516 |
+
# command = [
|
| 517 |
+
# sys.executable,
|
| 518 |
+
# TRAINING_SCRIPT,
|
| 519 |
+
# "--mode", "train",
|
| 520 |
+
# "--input", input_path,
|
| 521 |
+
# "--batch_size", str(batch_size),
|
| 522 |
+
# "--epochs", str(epochs),
|
| 523 |
+
# "--lr", str(lr),
|
| 524 |
+
# "--max_len", str(max_len)
|
| 525 |
+
# ]
|
| 526 |
+
#
|
| 527 |
+
# log_output += f"Executing command: {' '.join(command)}\n\n"
|
| 528 |
+
#
|
| 529 |
+
# try:
|
| 530 |
+
# # 4. Run the training script and capture output
|
| 531 |
+
# process = subprocess.Popen(
|
| 532 |
+
# command,
|
| 533 |
+
# stdout=subprocess.PIPE,
|
| 534 |
+
# stderr=subprocess.STDOUT,
|
| 535 |
+
# text=True,
|
| 536 |
+
# bufsize=1
|
| 537 |
+
# )
|
| 538 |
+
#
|
| 539 |
+
# # Stream logs in real-time
|
| 540 |
+
# for line in iter(process.stdout.readline, ""):
|
| 541 |
+
# log_output += line
|
| 542 |
+
# # Print to console as well for debugging
|
| 543 |
+
# print(line, end='')
|
| 544 |
+
#
|
| 545 |
+
# process.stdout.close()
|
| 546 |
+
# return_code = process.wait()
|
| 547 |
+
#
|
| 548 |
+
# # 5. Check for successful completion
|
| 549 |
+
# if return_code == 0:
|
| 550 |
+
# log_output += "\nβ
TRAINING COMPLETE! Model saved."
|
| 551 |
+
# print("\nβ
TRAINING COMPLETE! Model saved.")
|
| 552 |
+
#
|
| 553 |
+
# # 6. Verify model file exists
|
| 554 |
+
# if os.path.exists(MODEL_FILE_PATH):
|
| 555 |
+
# file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
|
| 556 |
+
# log_output += f"\nπ¦ Model file: {MODEL_FILE_PATH}"
|
| 557 |
+
# log_output += f"\nπ Model size: {file_size:.2f} MB"
|
| 558 |
+
# log_output += f"\nβ¬οΈ Click the download button below to save your model!"
|
| 559 |
+
#
|
| 560 |
+
# print(f"\nβ
Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
|
| 561 |
+
#
|
| 562 |
+
# # Create a copy in the root directory for easier access
|
| 563 |
+
# root_copy = MODEL_FILE_NAME
|
| 564 |
+
# try:
|
| 565 |
+
# shutil.copy2(MODEL_FILE_PATH, root_copy)
|
| 566 |
+
# log_output += f"\nπ Copy created: {root_copy}"
|
| 567 |
+
# print(f"β
Created copy at: {root_copy}")
|
| 568 |
+
# except Exception as e:
|
| 569 |
+
# log_output += f"\nβ οΈ Could not create root copy: {e}"
|
| 570 |
+
# root_copy = MODEL_FILE_PATH
|
| 571 |
+
#
|
| 572 |
+
# # Return the full absolute path to ensure Gradio can find it
|
| 573 |
+
# absolute_path = os.path.abspath(root_copy)
|
| 574 |
+
# log_output += f"\nπ Download path: {absolute_path}"
|
| 575 |
+
#
|
| 576 |
+
# return log_output, absolute_path
|
| 577 |
+
# else:
|
| 578 |
+
# log_output += f"\nβ οΈ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
|
| 579 |
+
# log_output += f"\nπ Checking directory contents..."
|
| 580 |
+
#
|
| 581 |
+
# # List files in checkpoints directory for debugging
|
| 582 |
+
# if os.path.exists(MODEL_OUTPUT_DIR):
|
| 583 |
+
# files = os.listdir(MODEL_OUTPUT_DIR)
|
| 584 |
+
# log_output += f"\nπ Files in {MODEL_OUTPUT_DIR}: {files}"
|
| 585 |
+
# else:
|
| 586 |
+
# log_output += f"\nβ Directory {MODEL_OUTPUT_DIR} does not exist!"
|
| 587 |
+
#
|
| 588 |
+
# return log_output, None
|
| 589 |
+
# else:
|
| 590 |
+
# log_output += f"\n\nβ TRAINING FAILED with return code {return_code}. Check logs above."
|
| 591 |
+
# return log_output, None
|
| 592 |
+
#
|
| 593 |
+
# except FileNotFoundError:
|
| 594 |
+
# error_msg = f"β ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
|
| 595 |
+
# print(error_msg)
|
| 596 |
+
# return error_msg, None
|
| 597 |
+
# except Exception as e:
|
| 598 |
+
# error_msg = f"β An unexpected error occurred: {e}"
|
| 599 |
+
# print(error_msg)
|
| 600 |
+
# import traceback
|
| 601 |
+
# print(traceback.format_exc())
|
| 602 |
+
# return error_msg, None
|
| 603 |
+
#
|
| 604 |
+
#
|
| 605 |
+
# # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
|
| 606 |
+
# with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as demo:
|
| 607 |
+
# gr.Markdown("# π LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
|
| 608 |
+
# gr.Markdown(
|
| 609 |
+
# """
|
| 610 |
+
# Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model.
|
| 611 |
+
#
|
| 612 |
+
# **β οΈ IMPORTANT - Free Tier Users:**
|
| 613 |
+
# - **Download your model IMMEDIATELY** after training completes!
|
| 614 |
+
# - The model file is **temporary** and will be deleted when the Space restarts.
|
| 615 |
+
# - The download button will appear below once training is complete.
|
| 616 |
+
# - Model is saved as: **`layoutlmv3_crf_passage.pth`**
|
| 617 |
+
#
|
| 618 |
+
# **β±οΈ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
|
| 619 |
+
# """
|
| 620 |
+
# )
|
| 621 |
+
#
|
| 622 |
+
# with gr.Row():
|
| 623 |
+
# with gr.Column(scale=1):
|
| 624 |
+
# gr.Markdown("### π Dataset Upload")
|
| 625 |
+
# file_input = gr.File(
|
| 626 |
+
# label="Upload Label Studio JSON Dataset",
|
| 627 |
+
# file_types=[".json"]
|
| 628 |
+
# )
|
| 629 |
+
#
|
| 630 |
+
# gr.Markdown("---")
|
| 631 |
+
# gr.Markdown("### βοΈ Training Parameters")
|
| 632 |
+
#
|
| 633 |
+
# batch_size_input = gr.Slider(
|
| 634 |
+
# minimum=1, maximum=16, step=1, value=4,
|
| 635 |
+
# label="Batch Size",
|
| 636 |
+
# info="Smaller = less memory, slower training"
|
| 637 |
+
# )
|
| 638 |
+
# epochs_input = gr.Slider(
|
| 639 |
+
# minimum=1, maximum=10, step=1, value=3,
|
| 640 |
+
# label="Epochs",
|
| 641 |
+
# info="Fewer epochs = faster training (recommended: 3-5)"
|
| 642 |
+
# )
|
| 643 |
+
# lr_input = gr.Number(
|
| 644 |
+
# value=5e-5, label="Learning Rate",
|
| 645 |
+
# info="Default: 5e-5"
|
| 646 |
+
# )
|
| 647 |
+
# max_len_input = gr.Slider(
|
| 648 |
+
# minimum=128, maximum=512, step=128, value=512,
|
| 649 |
+
# label="Max Sequence Length",
|
| 650 |
+
# info="Shorter = faster training, less memory"
|
| 651 |
+
# )
|
| 652 |
+
#
|
| 653 |
+
# train_button = gr.Button("π₯ Start Training", variant="primary", size="lg")
|
| 654 |
+
#
|
| 655 |
+
# with gr.Column(scale=2):
|
| 656 |
+
# gr.Markdown("### π Training Progress")
|
| 657 |
+
#
|
| 658 |
+
# log_output = gr.Textbox(
|
| 659 |
+
# label="Training Logs",
|
| 660 |
+
# lines=25,
|
| 661 |
+
# max_lines=30,
|
| 662 |
+
# autoscroll=True,
|
| 663 |
+
# show_copy_button=True,
|
| 664 |
+
# placeholder="Click 'Start Training' to begin...\n\nLogs will appear here in real-time."
|
| 665 |
+
# )
|
| 666 |
+
#
|
| 667 |
+
# gr.Markdown("### β¬οΈ Download Trained Model")
|
| 668 |
+
#
|
| 669 |
+
# model_download = gr.File(
|
| 670 |
+
# label="Trained Model File (layoutlmv3_crf_passage.pth)",
|
| 671 |
+
# interactive=False,
|
| 672 |
+
# visible=True
|
| 673 |
+
# )
|
| 674 |
+
#
|
| 675 |
+
# gr.Markdown(
|
| 676 |
+
# """
|
| 677 |
+
# **π₯ Download Instructions:**
|
| 678 |
+
# 1. Wait for training to complete (β
appears in logs)
|
| 679 |
+
# 2. Click the download button/icon that appears above
|
| 680 |
+
# 3. Save the `.pth` file to your local machine
|
| 681 |
+
# 4. **Do this immediately** - file is temporary!
|
| 682 |
+
#
|
| 683 |
+
# **π§ Troubleshooting:**
|
| 684 |
+
# - If download button doesn't appear, check the logs for errors
|
| 685 |
+
# - Try reducing epochs or batch size if timeout occurs
|
| 686 |
+
# - Ensure your JSON file is properly formatted
|
| 687 |
+
# """
|
| 688 |
+
# )
|
| 689 |
+
#
|
| 690 |
+
# # Define the action when the button is clicked
|
| 691 |
+
# train_button.click(
|
| 692 |
+
# fn=train_model,
|
| 693 |
+
# inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
|
| 694 |
+
# outputs=[log_output, model_download],
|
| 695 |
+
# api_name="train"
|
| 696 |
+
# )
|
| 697 |
+
#
|
| 698 |
+
# # Add example info
|
| 699 |
+
# gr.Markdown(
|
| 700 |
+
# """
|
| 701 |
+
# ---
|
| 702 |
+
# ### π About
|
| 703 |
+
# This Space fine-tunes LayoutLMv3 with CRF for document understanding tasks including:
|
| 704 |
+
# - Questions, Options, Answers
|
| 705 |
+
# - Section Headings
|
| 706 |
+
# - Passages
|
| 707 |
+
#
|
| 708 |
+
# **Model Details:** LayoutLMv3-base + CRF layer for sequence labeling
|
| 709 |
+
# """
|
| 710 |
+
# )
|
| 711 |
+
#
|
| 712 |
+
# if __name__ == "__main__":
|
| 713 |
+
# demo.launch()
|
| 714 |
+
|
| 715 |
|
| 716 |
import gradio as gr
|
| 717 |
import subprocess
|
|
|
|
| 742 |
|
| 743 |
# 2. File Handling: Use the temporary path of the uploaded file
|
| 744 |
if dataset_file is None:
|
| 745 |
+
return "β ERROR: Please upload a file.", None, gr.Button(visible=False)
|
| 746 |
|
| 747 |
# Using .name (Corrected in previous steps)
|
| 748 |
input_path = dataset_file.name
|
| 749 |
|
| 750 |
if not input_path.lower().endswith(".json"):
|
| 751 |
+
return "β ERROR: Please upload a valid Label Studio JSON file (.json).", None, gr.Button(visible=False)
|
| 752 |
|
| 753 |
progress(0.1, desc="Starting LayoutLMv3 Training...")
|
| 754 |
|
|
|
|
| 797 |
file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
|
| 798 |
log_output += f"\nπ¦ Model file: {MODEL_FILE_PATH}"
|
| 799 |
log_output += f"\nπ Model size: {file_size:.2f} MB"
|
|
|
|
| 800 |
|
| 801 |
print(f"\nβ
Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
|
| 802 |
|
| 803 |
+
# Create a copy in the root directory with timestamp for uniqueness
|
| 804 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 805 |
+
download_filename = f"layoutlmv3_trained_{timestamp}.pth"
|
| 806 |
+
|
| 807 |
try:
|
| 808 |
+
shutil.copy2(MODEL_FILE_PATH, download_filename)
|
| 809 |
+
log_output += f"\nπ Download file created: {download_filename}"
|
| 810 |
+
print(f"β
Created download file: {download_filename}")
|
| 811 |
except Exception as e:
|
| 812 |
+
log_output += f"\nβ οΈ Could not create download file: {e}"
|
| 813 |
+
download_filename = MODEL_FILE_PATH
|
| 814 |
|
| 815 |
+
# Return the path and make download button visible
|
| 816 |
+
log_output += f"\n\nπ SUCCESS! Click the 'Download Model' button below to save your model."
|
| 817 |
+
log_output += f"\nβ οΈ IMPORTANT: Download NOW - file will be deleted when Space restarts!"
|
| 818 |
|
| 819 |
+
return log_output, download_filename, gr.Button(visible=True)
|
| 820 |
else:
|
| 821 |
log_output += f"\nβ οΈ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
|
| 822 |
log_output += f"\nπ Checking directory contents..."
|
|
|
|
| 828 |
else:
|
| 829 |
log_output += f"\nβ Directory {MODEL_OUTPUT_DIR} does not exist!"
|
| 830 |
|
| 831 |
+
return log_output, None, gr.Button(visible=False)
|
| 832 |
else:
|
| 833 |
log_output += f"\n\nβ TRAINING FAILED with return code {return_code}. Check logs above."
|
| 834 |
+
return log_output, None, gr.Button(visible=False)
|
| 835 |
|
| 836 |
except FileNotFoundError:
|
| 837 |
error_msg = f"β ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
|
| 838 |
print(error_msg)
|
| 839 |
+
return error_msg, None, gr.Button(visible=False)
|
| 840 |
except Exception as e:
|
| 841 |
error_msg = f"β An unexpected error occurred: {e}"
|
| 842 |
print(error_msg)
|
| 843 |
import traceback
|
| 844 |
print(traceback.format_exc())
|
| 845 |
+
return error_msg, None, gr.Button(visible=False)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def download_model():
|
| 849 |
+
"""
|
| 850 |
+
Returns the model file for download.
|
| 851 |
+
"""
|
| 852 |
+
if os.path.exists(MODEL_FILE_PATH):
|
| 853 |
+
return MODEL_FILE_PATH
|
| 854 |
+
else:
|
| 855 |
+
# Check for any .pth files in current directory
|
| 856 |
+
pth_files = [f for f in os.listdir('.') if f.endswith('.pth')]
|
| 857 |
+
if pth_files:
|
| 858 |
+
return pth_files[0]
|
| 859 |
+
|
| 860 |
+
# Check checkpoints directory
|
| 861 |
+
if os.path.exists(MODEL_OUTPUT_DIR):
|
| 862 |
+
pth_files = [os.path.join(MODEL_OUTPUT_DIR, f) for f in os.listdir(MODEL_OUTPUT_DIR) if f.endswith('.pth')]
|
| 863 |
+
if pth_files:
|
| 864 |
+
return pth_files[0]
|
| 865 |
+
|
| 866 |
+
return None
|
| 867 |
|
| 868 |
|
| 869 |
# --- Gradio Interface Setup (using Blocks for a nicer layout) ---
|
|
|
|
| 876 |
**β οΈ IMPORTANT - Free Tier Users:**
|
| 877 |
- **Download your model IMMEDIATELY** after training completes!
|
| 878 |
- The model file is **temporary** and will be deleted when the Space restarts.
|
| 879 |
+
- A download button will appear below once training is complete.
|
|
|
|
| 880 |
|
| 881 |
**β±οΈ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
|
| 882 |
"""
|
|
|
|
| 929 |
|
| 930 |
gr.Markdown("### β¬οΈ Download Trained Model")
|
| 931 |
|
| 932 |
+
# Hidden state to store the file path
|
| 933 |
+
model_path_state = gr.State(value=None)
|
| 934 |
+
|
| 935 |
+
# Download button (initially hidden)
|
| 936 |
+
download_btn = gr.Button(
|
| 937 |
+
"π₯ Download Model (.pth file)",
|
| 938 |
+
variant="primary",
|
| 939 |
+
size="lg",
|
| 940 |
+
visible=False
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
# File output for download
|
| 944 |
model_download = gr.File(
|
| 945 |
+
label="Your trained model will appear here",
|
| 946 |
interactive=False,
|
| 947 |
visible=True
|
| 948 |
)
|
|
|
|
| 951 |
"""
|
| 952 |
**π₯ Download Instructions:**
|
| 953 |
1. Wait for training to complete (β
appears in logs)
|
| 954 |
+
2. Click the **"Download Model"** button above
|
| 955 |
3. Save the `.pth` file to your local machine
|
| 956 |
4. **Do this immediately** - file is temporary!
|
| 957 |
|
|
|
|
| 962 |
"""
|
| 963 |
)
|
| 964 |
|
| 965 |
+
# Define the training action
|
| 966 |
train_button.click(
|
| 967 |
fn=train_model,
|
| 968 |
inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
|
| 969 |
+
outputs=[log_output, model_path_state, download_btn],
|
| 970 |
api_name="train"
|
| 971 |
)
|
| 972 |
|
| 973 |
+
# Define the download action
|
| 974 |
+
download_btn.click(
|
| 975 |
+
fn=lambda path: path,
|
| 976 |
+
inputs=[model_path_state],
|
| 977 |
+
outputs=[model_download]
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
# Add example info
|
| 981 |
gr.Markdown(
|
| 982 |
"""
|