aagamjtdev commited on
Commit
0c2088f
Β·
1 Parent(s): a21bd5b

correction

Browse files
Files changed (1) hide show
  1. app.py +306 -24
app.py CHANGED
@@ -470,6 +470,248 @@
470
  # if __name__ == "__main__":
471
  # demo.launch()
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
  import gradio as gr
475
  import subprocess
@@ -500,13 +742,13 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
500
 
501
  # 2. File Handling: Use the temporary path of the uploaded file
502
  if dataset_file is None:
503
- return "❌ ERROR: Please upload a file.", None
504
 
505
  # Using .name (Corrected in previous steps)
506
  input_path = dataset_file.name
507
 
508
  if not input_path.lower().endswith(".json"):
509
- return "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None
510
 
511
  progress(0.1, desc="Starting LayoutLMv3 Training...")
512
 
@@ -555,25 +797,26 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
555
  file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
556
  log_output += f"\nπŸ“¦ Model file: {MODEL_FILE_PATH}"
557
  log_output += f"\nπŸ“Š Model size: {file_size:.2f} MB"
558
- log_output += f"\n⬇️ Click the download button below to save your model!"
559
 
560
  print(f"\nβœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
561
 
562
- # Create a copy in the root directory for easier access
563
- root_copy = MODEL_FILE_NAME
 
 
564
  try:
565
- shutil.copy2(MODEL_FILE_PATH, root_copy)
566
- log_output += f"\nπŸ“‹ Copy created: {root_copy}"
567
- print(f"βœ… Created copy at: {root_copy}")
568
  except Exception as e:
569
- log_output += f"\n⚠️ Could not create root copy: {e}"
570
- root_copy = MODEL_FILE_PATH
571
 
572
- # Return the full absolute path to ensure Gradio can find it
573
- absolute_path = os.path.abspath(root_copy)
574
- log_output += f"\nπŸ”— Download path: {absolute_path}"
575
 
576
- return log_output, absolute_path
577
  else:
578
  log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
579
  log_output += f"\nπŸ” Checking directory contents..."
@@ -585,21 +828,42 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
585
  else:
586
  log_output += f"\n❌ Directory {MODEL_OUTPUT_DIR} does not exist!"
587
 
588
- return log_output, None
589
  else:
590
  log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
591
- return log_output, None
592
 
593
  except FileNotFoundError:
594
  error_msg = f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
595
  print(error_msg)
596
- return error_msg, None
597
  except Exception as e:
598
  error_msg = f"❌ An unexpected error occurred: {e}"
599
  print(error_msg)
600
  import traceback
601
  print(traceback.format_exc())
602
- return error_msg, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
 
604
 
605
  # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
@@ -612,8 +876,7 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
612
  **⚠️ IMPORTANT - Free Tier Users:**
613
  - **Download your model IMMEDIATELY** after training completes!
614
  - The model file is **temporary** and will be deleted when the Space restarts.
615
- - The download button will appear below once training is complete.
616
- - Model is saved as: **`layoutlmv3_crf_passage.pth`**
617
 
618
  **⏱️ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
619
  """
@@ -666,8 +929,20 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
666
 
667
  gr.Markdown("### ⬇️ Download Trained Model")
668
 
 
 
 
 
 
 
 
 
 
 
 
 
669
  model_download = gr.File(
670
- label="Trained Model File (layoutlmv3_crf_passage.pth)",
671
  interactive=False,
672
  visible=True
673
  )
@@ -676,7 +951,7 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
676
  """
677
  **πŸ“₯ Download Instructions:**
678
  1. Wait for training to complete (βœ… appears in logs)
679
- 2. Click the download button/icon that appears above
680
  3. Save the `.pth` file to your local machine
681
  4. **Do this immediately** - file is temporary!
682
 
@@ -687,14 +962,21 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
687
  """
688
  )
689
 
690
- # Define the action when the button is clicked
691
  train_button.click(
692
  fn=train_model,
693
  inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
694
- outputs=[log_output, model_download],
695
  api_name="train"
696
  )
697
 
 
 
 
 
 
 
 
698
  # Add example info
699
  gr.Markdown(
700
  """
 
470
  # if __name__ == "__main__":
471
  # demo.launch()
472
 
473
+ #
474
+ # import gradio as gr
475
+ # import subprocess
476
+ # import os
477
+ # import sys
478
+ # from datetime import datetime
479
+ # import shutil
480
+ #
481
+ # # FIX: Update the script name to the correct one you uploaded
482
+ # TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"
483
+ #
484
+ # # --- CORRECTED MODEL PATH BASED ON YOUR SCRIPT ---
485
+ # MODEL_OUTPUT_DIR = "checkpoints"
486
+ # MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
487
+ # MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
488
+ #
489
+ #
490
+ # # ----------------------------------------------------------------
491
+ #
492
+ #
493
+ # def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
494
+ # """
495
+ # Handles the Gradio submission and executes the training script using subprocess.
496
+ # """
497
+ #
498
+ # # 1. Setup: Create output directory if it doesn't exist
499
+ # os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
500
+ #
501
+ # # 2. File Handling: Use the temporary path of the uploaded file
502
+ # if dataset_file is None:
503
+ # return "❌ ERROR: Please upload a file.", None
504
+ #
505
+ # # Using .name (Corrected in previous steps)
506
+ # input_path = dataset_file.name
507
+ #
508
+ # if not input_path.lower().endswith(".json"):
509
+ # return "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None
510
+ #
511
+ # progress(0.1, desc="Starting LayoutLMv3 Training...")
512
+ #
513
+ # log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
514
+ #
515
+ # # 3. Construct the subprocess command
516
+ # command = [
517
+ # sys.executable,
518
+ # TRAINING_SCRIPT,
519
+ # "--mode", "train",
520
+ # "--input", input_path,
521
+ # "--batch_size", str(batch_size),
522
+ # "--epochs", str(epochs),
523
+ # "--lr", str(lr),
524
+ # "--max_len", str(max_len)
525
+ # ]
526
+ #
527
+ # log_output += f"Executing command: {' '.join(command)}\n\n"
528
+ #
529
+ # try:
530
+ # # 4. Run the training script and capture output
531
+ # process = subprocess.Popen(
532
+ # command,
533
+ # stdout=subprocess.PIPE,
534
+ # stderr=subprocess.STDOUT,
535
+ # text=True,
536
+ # bufsize=1
537
+ # )
538
+ #
539
+ # # Stream logs in real-time
540
+ # for line in iter(process.stdout.readline, ""):
541
+ # log_output += line
542
+ # # Print to console as well for debugging
543
+ # print(line, end='')
544
+ #
545
+ # process.stdout.close()
546
+ # return_code = process.wait()
547
+ #
548
+ # # 5. Check for successful completion
549
+ # if return_code == 0:
550
+ # log_output += "\nβœ… TRAINING COMPLETE! Model saved."
551
+ # print("\nβœ… TRAINING COMPLETE! Model saved.")
552
+ #
553
+ # # 6. Verify model file exists
554
+ # if os.path.exists(MODEL_FILE_PATH):
555
+ # file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
556
+ # log_output += f"\nπŸ“¦ Model file: {MODEL_FILE_PATH}"
557
+ # log_output += f"\nπŸ“Š Model size: {file_size:.2f} MB"
558
+ # log_output += f"\n⬇️ Click the download button below to save your model!"
559
+ #
560
+ # print(f"\nβœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
561
+ #
562
+ # # Create a copy in the root directory for easier access
563
+ # root_copy = MODEL_FILE_NAME
564
+ # try:
565
+ # shutil.copy2(MODEL_FILE_PATH, root_copy)
566
+ # log_output += f"\nπŸ“‹ Copy created: {root_copy}"
567
+ # print(f"βœ… Created copy at: {root_copy}")
568
+ # except Exception as e:
569
+ # log_output += f"\n⚠️ Could not create root copy: {e}"
570
+ # root_copy = MODEL_FILE_PATH
571
+ #
572
+ # # Return the full absolute path to ensure Gradio can find it
573
+ # absolute_path = os.path.abspath(root_copy)
574
+ # log_output += f"\nπŸ”— Download path: {absolute_path}"
575
+ #
576
+ # return log_output, absolute_path
577
+ # else:
578
+ # log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
579
+ # log_output += f"\nπŸ” Checking directory contents..."
580
+ #
581
+ # # List files in checkpoints directory for debugging
582
+ # if os.path.exists(MODEL_OUTPUT_DIR):
583
+ # files = os.listdir(MODEL_OUTPUT_DIR)
584
+ # log_output += f"\nπŸ“ Files in {MODEL_OUTPUT_DIR}: {files}"
585
+ # else:
586
+ # log_output += f"\n❌ Directory {MODEL_OUTPUT_DIR} does not exist!"
587
+ #
588
+ # return log_output, None
589
+ # else:
590
+ # log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
591
+ # return log_output, None
592
+ #
593
+ # except FileNotFoundError:
594
+ # error_msg = f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
595
+ # print(error_msg)
596
+ # return error_msg, None
597
+ # except Exception as e:
598
+ # error_msg = f"❌ An unexpected error occurred: {e}"
599
+ # print(error_msg)
600
+ # import traceback
601
+ # print(traceback.format_exc())
602
+ # return error_msg, None
603
+ #
604
+ #
605
+ # # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
606
+ # with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as demo:
607
+ # gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
608
+ # gr.Markdown(
609
+ # """
610
+ # Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model.
611
+ #
612
+ # **⚠️ IMPORTANT - Free Tier Users:**
613
+ # - **Download your model IMMEDIATELY** after training completes!
614
+ # - The model file is **temporary** and will be deleted when the Space restarts.
615
+ # - The download button will appear below once training is complete.
616
+ # - Model is saved as: **`layoutlmv3_crf_passage.pth`**
617
+ #
618
+ # **⏱️ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
619
+ # """
620
+ # )
621
+ #
622
+ # with gr.Row():
623
+ # with gr.Column(scale=1):
624
+ # gr.Markdown("### πŸ“ Dataset Upload")
625
+ # file_input = gr.File(
626
+ # label="Upload Label Studio JSON Dataset",
627
+ # file_types=[".json"]
628
+ # )
629
+ #
630
+ # gr.Markdown("---")
631
+ # gr.Markdown("### βš™οΈ Training Parameters")
632
+ #
633
+ # batch_size_input = gr.Slider(
634
+ # minimum=1, maximum=16, step=1, value=4,
635
+ # label="Batch Size",
636
+ # info="Smaller = less memory, slower training"
637
+ # )
638
+ # epochs_input = gr.Slider(
639
+ # minimum=1, maximum=10, step=1, value=3,
640
+ # label="Epochs",
641
+ # info="Fewer epochs = faster training (recommended: 3-5)"
642
+ # )
643
+ # lr_input = gr.Number(
644
+ # value=5e-5, label="Learning Rate",
645
+ # info="Default: 5e-5"
646
+ # )
647
+ # max_len_input = gr.Slider(
648
+ # minimum=128, maximum=512, step=128, value=512,
649
+ # label="Max Sequence Length",
650
+ # info="Shorter = faster training, less memory"
651
+ # )
652
+ #
653
+ # train_button = gr.Button("πŸ”₯ Start Training", variant="primary", size="lg")
654
+ #
655
+ # with gr.Column(scale=2):
656
+ # gr.Markdown("### πŸ“Š Training Progress")
657
+ #
658
+ # log_output = gr.Textbox(
659
+ # label="Training Logs",
660
+ # lines=25,
661
+ # max_lines=30,
662
+ # autoscroll=True,
663
+ # show_copy_button=True,
664
+ # placeholder="Click 'Start Training' to begin...\n\nLogs will appear here in real-time."
665
+ # )
666
+ #
667
+ # gr.Markdown("### ⬇️ Download Trained Model")
668
+ #
669
+ # model_download = gr.File(
670
+ # label="Trained Model File (layoutlmv3_crf_passage.pth)",
671
+ # interactive=False,
672
+ # visible=True
673
+ # )
674
+ #
675
+ # gr.Markdown(
676
+ # """
677
+ # **πŸ“₯ Download Instructions:**
678
+ # 1. Wait for training to complete (βœ… appears in logs)
679
+ # 2. Click the download button/icon that appears above
680
+ # 3. Save the `.pth` file to your local machine
681
+ # 4. **Do this immediately** - file is temporary!
682
+ #
683
+ # **πŸ”§ Troubleshooting:**
684
+ # - If download button doesn't appear, check the logs for errors
685
+ # - Try reducing epochs or batch size if timeout occurs
686
+ # - Ensure your JSON file is properly formatted
687
+ # """
688
+ # )
689
+ #
690
+ # # Define the action when the button is clicked
691
+ # train_button.click(
692
+ # fn=train_model,
693
+ # inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
694
+ # outputs=[log_output, model_download],
695
+ # api_name="train"
696
+ # )
697
+ #
698
+ # # Add example info
699
+ # gr.Markdown(
700
+ # """
701
+ # ---
702
+ # ### πŸ“– About
703
+ # This Space fine-tunes LayoutLMv3 with CRF for document understanding tasks including:
704
+ # - Questions, Options, Answers
705
+ # - Section Headings
706
+ # - Passages
707
+ #
708
+ # **Model Details:** LayoutLMv3-base + CRF layer for sequence labeling
709
+ # """
710
+ # )
711
+ #
712
+ # if __name__ == "__main__":
713
+ # demo.launch()
714
+
715
 
716
  import gradio as gr
717
  import subprocess
 
742
 
743
  # 2. File Handling: Use the temporary path of the uploaded file
744
  if dataset_file is None:
745
+ return "❌ ERROR: Please upload a file.", None, gr.Button(visible=False)
746
 
747
  # Using .name (Corrected in previous steps)
748
  input_path = dataset_file.name
749
 
750
  if not input_path.lower().endswith(".json"):
751
+ return "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None, gr.Button(visible=False)
752
 
753
  progress(0.1, desc="Starting LayoutLMv3 Training...")
754
 
 
797
  file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
798
  log_output += f"\nπŸ“¦ Model file: {MODEL_FILE_PATH}"
799
  log_output += f"\nπŸ“Š Model size: {file_size:.2f} MB"
 
800
 
801
  print(f"\nβœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
802
 
803
+ # Create a copy in the root directory with timestamp for uniqueness
804
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
805
+ download_filename = f"layoutlmv3_trained_{timestamp}.pth"
806
+
807
  try:
808
+ shutil.copy2(MODEL_FILE_PATH, download_filename)
809
+ log_output += f"\nπŸ“‹ Download file created: {download_filename}"
810
+ print(f"βœ… Created download file: {download_filename}")
811
  except Exception as e:
812
+ log_output += f"\n⚠️ Could not create download file: {e}"
813
+ download_filename = MODEL_FILE_PATH
814
 
815
+ # Return the path and make download button visible
816
+ log_output += f"\n\nπŸŽ‰ SUCCESS! Click the 'Download Model' button below to save your model."
817
+ log_output += f"\n⚠️ IMPORTANT: Download NOW - file will be deleted when Space restarts!"
818
 
819
+ return log_output, download_filename, gr.Button(visible=True)
820
  else:
821
  log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
822
  log_output += f"\nπŸ” Checking directory contents..."
 
828
  else:
829
  log_output += f"\n❌ Directory {MODEL_OUTPUT_DIR} does not exist!"
830
 
831
+ return log_output, None, gr.Button(visible=False)
832
  else:
833
  log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
834
+ return log_output, None, gr.Button(visible=False)
835
 
836
  except FileNotFoundError:
837
  error_msg = f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
838
  print(error_msg)
839
+ return error_msg, None, gr.Button(visible=False)
840
  except Exception as e:
841
  error_msg = f"❌ An unexpected error occurred: {e}"
842
  print(error_msg)
843
  import traceback
844
  print(traceback.format_exc())
845
+ return error_msg, None, gr.Button(visible=False)
846
+
847
+
848
+ def download_model():
849
+ """
850
+ Returns the model file for download.
851
+ """
852
+ if os.path.exists(MODEL_FILE_PATH):
853
+ return MODEL_FILE_PATH
854
+ else:
855
+ # Check for any .pth files in current directory
856
+ pth_files = [f for f in os.listdir('.') if f.endswith('.pth')]
857
+ if pth_files:
858
+ return pth_files[0]
859
+
860
+ # Check checkpoints directory
861
+ if os.path.exists(MODEL_OUTPUT_DIR):
862
+ pth_files = [os.path.join(MODEL_OUTPUT_DIR, f) for f in os.listdir(MODEL_OUTPUT_DIR) if f.endswith('.pth')]
863
+ if pth_files:
864
+ return pth_files[0]
865
+
866
+ return None
867
 
868
 
869
  # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
 
876
  **⚠️ IMPORTANT - Free Tier Users:**
877
  - **Download your model IMMEDIATELY** after training completes!
878
  - The model file is **temporary** and will be deleted when the Space restarts.
879
+ - A download button will appear below once training is complete.
 
880
 
881
  **⏱️ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
882
  """
 
929
 
930
  gr.Markdown("### ⬇️ Download Trained Model")
931
 
932
+ # Hidden state to store the file path
933
+ model_path_state = gr.State(value=None)
934
+
935
+ # Download button (initially hidden)
936
+ download_btn = gr.Button(
937
+ "πŸ“₯ Download Model (.pth file)",
938
+ variant="primary",
939
+ size="lg",
940
+ visible=False
941
+ )
942
+
943
+ # File output for download
944
  model_download = gr.File(
945
+ label="Your trained model will appear here",
946
  interactive=False,
947
  visible=True
948
  )
 
951
  """
952
  **πŸ“₯ Download Instructions:**
953
  1. Wait for training to complete (βœ… appears in logs)
954
+ 2. Click the **"Download Model"** button above
955
  3. Save the `.pth` file to your local machine
956
  4. **Do this immediately** - file is temporary!
957
 
 
962
  """
963
  )
964
 
965
+ # Define the training action
966
  train_button.click(
967
  fn=train_model,
968
  inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
969
+ outputs=[log_output, model_path_state, download_btn],
970
  api_name="train"
971
  )
972
 
973
+ # Define the download action
974
+ download_btn.click(
975
+ fn=lambda path: path,
976
+ inputs=[model_path_state],
977
+ outputs=[model_download]
978
+ )
979
+
980
  # Add example info
981
  gr.Markdown(
982
  """