aagamjtdev commited on
Commit
a21bd5b
Β·
1 Parent(s): 525a040

download .pth file

Browse files
Files changed (1) hide show
  1. app.py +277 -35
app.py CHANGED
@@ -310,12 +310,173 @@
310
  # # and sometimes cause issues in managed Space environments.
311
  # demo.launch()
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  import gradio as gr
315
  import subprocess
316
  import os
317
  import sys
318
  from datetime import datetime
 
319
 
320
  # FIX: Update the script name to the correct one you uploaded
321
  TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"
@@ -339,15 +500,13 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
339
 
340
  # 2. File Handling: Use the temporary path of the uploaded file
341
  if dataset_file is None:
342
- yield "❌ ERROR: Please upload a file.", None
343
- return
344
 
345
  # Using .name (Corrected in previous steps)
346
  input_path = dataset_file.name
347
 
348
  if not input_path.lower().endswith(".json"):
349
- yield "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None
350
- return
351
 
352
  progress(0.1, desc="Starting LayoutLMv3 Training...")
353
 
@@ -356,7 +515,6 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
356
  # 3. Construct the subprocess command
357
  command = [
358
  sys.executable,
359
- # Now uses the corrected TRAINING_SCRIPT variable
360
  TRAINING_SCRIPT,
361
  "--mode", "train",
362
  "--input", input_path,
@@ -367,7 +525,6 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
367
  ]
368
 
369
  log_output += f"Executing command: {' '.join(command)}\n\n"
370
- yield log_output, None # Yield the command to the log output
371
 
372
  try:
373
  # 4. Run the training script and capture output
@@ -382,7 +539,8 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
382
  # Stream logs in real-time
383
  for line in iter(process.stdout.readline, ""):
384
  log_output += line
385
- yield log_output, None # Send partial log to Gradio output
 
386
 
387
  process.stdout.close()
388
  return_code = process.wait()
@@ -390,81 +548,165 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
390
  # 5. Check for successful completion
391
  if return_code == 0:
392
  log_output += "\nβœ… TRAINING COMPLETE! Model saved."
393
-
394
- # 6. Prepare download links based on script's saved path
395
- model_exists = os.path.exists(MODEL_FILE_PATH)
396
-
397
- if model_exists:
398
- log_output += f"\nModel path: {MODEL_FILE_PATH}"
399
- # Return final log, and the file path for Gradio's download component
400
- return log_output, MODEL_FILE_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  else:
402
  log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
 
 
 
 
 
 
 
 
 
403
  return log_output, None
404
  else:
405
  log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
406
  return log_output, None
407
 
408
  except FileNotFoundError:
409
- return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
 
 
410
  except Exception as e:
411
- return f"❌ An unexpected error occurred: {e}", None
 
 
 
 
412
 
413
 
414
  # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
415
- with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
416
  gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
417
  gr.Markdown(
418
  """
419
- Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.
 
 
 
 
 
 
420
 
421
- **Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
422
  """
423
  )
424
 
425
  with gr.Row():
426
  with gr.Column(scale=1):
 
427
  file_input = gr.File(
428
- label="1. Upload Label Studio JSON Dataset"
 
429
  )
430
 
431
  gr.Markdown("---")
432
  gr.Markdown("### βš™οΈ Training Parameters")
433
 
434
  batch_size_input = gr.Slider(
435
- minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
 
 
436
  )
437
  epochs_input = gr.Slider(
438
- minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
 
 
439
  )
440
  lr_input = gr.Number(
441
- value=5e-5, label="Learning Rate (--lr)"
 
442
  )
443
- max_len_input = gr.Number(
444
- value=512, label="Max Sequence Length (--max_len)"
 
 
445
  )
446
 
 
 
447
  with gr.Column(scale=2):
448
- train_button = gr.Button("πŸ”₯ Train Model", variant="primary")
449
 
450
  log_output = gr.Textbox(
451
- label="Training Log Output",
452
- lines=20,
 
453
  autoscroll=True,
454
- placeholder="Click 'Train Model' to start and see real-time logs..."
 
455
  )
456
 
457
- gr.Markdown("---")
458
- gr.Markdown(f"### πŸŽ‰ Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")
459
 
460
- # Only providing the download link for the saved .pth model file
461
- model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
463
  # Define the action when the button is clicked
464
  train_button.click(
465
  fn=train_model,
466
  inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
467
- outputs=[log_output, model_download]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  )
469
 
470
  if __name__ == "__main__":
 
310
  # # and sometimes cause issues in managed Space environments.
311
  # demo.launch()
312
 
313
+ #
314
+ # import gradio as gr
315
+ # import subprocess
316
+ # import os
317
+ # import sys
318
+ # from datetime import datetime
319
+ #
320
+ # # FIX: Update the script name to the correct one you uploaded
321
+ # TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"
322
+ #
323
+ # # --- CORRECTED MODEL PATH BASED ON YOUR SCRIPT ---
324
+ # MODEL_OUTPUT_DIR = "checkpoints"
325
+ # MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
326
+ # MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
327
+ #
328
+ #
329
+ # # ----------------------------------------------------------------
330
+ #
331
+ #
332
+ # def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
333
+ # """
334
+ # Handles the Gradio submission and executes the training script using subprocess.
335
+ # """
336
+ #
337
+ # # 1. Setup: Create output directory if it doesn't exist
338
+ # os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
339
+ #
340
+ # # 2. File Handling: Use the temporary path of the uploaded file
341
+ # if dataset_file is None:
342
+ # yield "❌ ERROR: Please upload a file.", None
343
+ # return
344
+ #
345
+ # # Using .name (Corrected in previous steps)
346
+ # input_path = dataset_file.name
347
+ #
348
+ # if not input_path.lower().endswith(".json"):
349
+ # yield "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None
350
+ # return
351
+ #
352
+ # progress(0.1, desc="Starting LayoutLMv3 Training...")
353
+ #
354
+ # log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
355
+ #
356
+ # # 3. Construct the subprocess command
357
+ # command = [
358
+ # sys.executable,
359
+ # # Now uses the corrected TRAINING_SCRIPT variable
360
+ # TRAINING_SCRIPT,
361
+ # "--mode", "train",
362
+ # "--input", input_path,
363
+ # "--batch_size", str(batch_size),
364
+ # "--epochs", str(epochs),
365
+ # "--lr", str(lr),
366
+ # "--max_len", str(max_len)
367
+ # ]
368
+ #
369
+ # log_output += f"Executing command: {' '.join(command)}\n\n"
370
+ # yield log_output, None # Yield the command to the log output
371
+ #
372
+ # try:
373
+ # # 4. Run the training script and capture output
374
+ # process = subprocess.Popen(
375
+ # command,
376
+ # stdout=subprocess.PIPE,
377
+ # stderr=subprocess.STDOUT,
378
+ # text=True,
379
+ # bufsize=1
380
+ # )
381
+ #
382
+ # # Stream logs in real-time
383
+ # for line in iter(process.stdout.readline, ""):
384
+ # log_output += line
385
+ # yield log_output, None # Send partial log to Gradio output
386
+ #
387
+ # process.stdout.close()
388
+ # return_code = process.wait()
389
+ #
390
+ # # 5. Check for successful completion
391
+ # if return_code == 0:
392
+ # log_output += "\nβœ… TRAINING COMPLETE! Model saved."
393
+ #
394
+ # # 6. Prepare download links based on script's saved path
395
+ # model_exists = os.path.exists(MODEL_FILE_PATH)
396
+ #
397
+ # if model_exists:
398
+ # log_output += f"\nModel path: {MODEL_FILE_PATH}"
399
+ # # Return final log, and the file path for Gradio's download component
400
+ # return log_output, MODEL_FILE_PATH
401
+ # else:
402
+ # log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
403
+ # return log_output, None
404
+ # else:
405
+ # log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
406
+ # return log_output, None
407
+ #
408
+ # except FileNotFoundError:
409
+ # return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
410
+ # except Exception as e:
411
+ # return f"❌ An unexpected error occurred: {e}", None
412
+ #
413
+ #
414
+ # # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
415
+ # with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
416
+ # gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
417
+ # gr.Markdown(
418
+ # """
419
+ # Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.
420
+ #
421
+ # **Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
422
+ # """
423
+ # )
424
+ #
425
+ # with gr.Row():
426
+ # with gr.Column(scale=1):
427
+ # file_input = gr.File(
428
+ # label="1. Upload Label Studio JSON Dataset"
429
+ # )
430
+ #
431
+ # gr.Markdown("---")
432
+ # gr.Markdown("### βš™οΈ Training Parameters")
433
+ #
434
+ # batch_size_input = gr.Slider(
435
+ # minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
436
+ # )
437
+ # epochs_input = gr.Slider(
438
+ # minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
439
+ # )
440
+ # lr_input = gr.Number(
441
+ # value=5e-5, label="Learning Rate (--lr)"
442
+ # )
443
+ # max_len_input = gr.Number(
444
+ # value=512, label="Max Sequence Length (--max_len)"
445
+ # )
446
+ #
447
+ # with gr.Column(scale=2):
448
+ # train_button = gr.Button("πŸ”₯ Train Model", variant="primary")
449
+ #
450
+ # log_output = gr.Textbox(
451
+ # label="Training Log Output",
452
+ # lines=20,
453
+ # autoscroll=True,
454
+ # placeholder="Click 'Train Model' to start and see real-time logs..."
455
+ # )
456
+ #
457
+ # gr.Markdown("---")
458
+ # gr.Markdown(f"### πŸŽ‰ Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")
459
+ #
460
+ # # Only providing the download link for the saved .pth model file
461
+ # model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)
462
+ #
463
+ # # Define the action when the button is clicked
464
+ # train_button.click(
465
+ # fn=train_model,
466
+ # inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
467
+ # outputs=[log_output, model_download]
468
+ # )
469
+ #
470
+ # if __name__ == "__main__":
471
+ # demo.launch()
472
+
473
 
474
  import gradio as gr
475
  import subprocess
476
  import os
477
  import sys
478
  from datetime import datetime
479
+ import shutil
480
 
481
  # FIX: Update the script name to the correct one you uploaded
482
  TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"
 
500
 
501
  # 2. File Handling: Use the temporary path of the uploaded file
502
  if dataset_file is None:
503
+ return "❌ ERROR: Please upload a file.", None
 
504
 
505
  # Using .name (Corrected in previous steps)
506
  input_path = dataset_file.name
507
 
508
  if not input_path.lower().endswith(".json"):
509
+ return "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None
 
510
 
511
  progress(0.1, desc="Starting LayoutLMv3 Training...")
512
 
 
515
  # 3. Construct the subprocess command
516
  command = [
517
  sys.executable,
 
518
  TRAINING_SCRIPT,
519
  "--mode", "train",
520
  "--input", input_path,
 
525
  ]
526
 
527
  log_output += f"Executing command: {' '.join(command)}\n\n"
 
528
 
529
  try:
530
  # 4. Run the training script and capture output
 
539
  # Stream logs in real-time
540
  for line in iter(process.stdout.readline, ""):
541
  log_output += line
542
+ # Print to console as well for debugging
543
+ print(line, end='')
544
 
545
  process.stdout.close()
546
  return_code = process.wait()
 
548
  # 5. Check for successful completion
549
  if return_code == 0:
550
  log_output += "\nβœ… TRAINING COMPLETE! Model saved."
551
+ print("\nβœ… TRAINING COMPLETE! Model saved.")
552
+
553
+ # 6. Verify model file exists
554
+ if os.path.exists(MODEL_FILE_PATH):
555
+ file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
556
+ log_output += f"\nπŸ“¦ Model file: {MODEL_FILE_PATH}"
557
+ log_output += f"\nπŸ“Š Model size: {file_size:.2f} MB"
558
+ log_output += f"\n⬇️ Click the download button below to save your model!"
559
+
560
+ print(f"\nβœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
561
+
562
+ # Create a copy in the root directory for easier access
563
+ root_copy = MODEL_FILE_NAME
564
+ try:
565
+ shutil.copy2(MODEL_FILE_PATH, root_copy)
566
+ log_output += f"\nπŸ“‹ Copy created: {root_copy}"
567
+ print(f"βœ… Created copy at: {root_copy}")
568
+ except Exception as e:
569
+ log_output += f"\n⚠️ Could not create root copy: {e}"
570
+ root_copy = MODEL_FILE_PATH
571
+
572
+ # Return the full absolute path to ensure Gradio can find it
573
+ absolute_path = os.path.abspath(root_copy)
574
+ log_output += f"\nπŸ”— Download path: {absolute_path}"
575
+
576
+ return log_output, absolute_path
577
  else:
578
  log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
579
+ log_output += f"\nπŸ” Checking directory contents..."
580
+
581
+ # List files in checkpoints directory for debugging
582
+ if os.path.exists(MODEL_OUTPUT_DIR):
583
+ files = os.listdir(MODEL_OUTPUT_DIR)
584
+ log_output += f"\nπŸ“ Files in {MODEL_OUTPUT_DIR}: {files}"
585
+ else:
586
+ log_output += f"\n❌ Directory {MODEL_OUTPUT_DIR} does not exist!"
587
+
588
  return log_output, None
589
  else:
590
  log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
591
  return log_output, None
592
 
593
  except FileNotFoundError:
594
+ error_msg = f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
595
+ print(error_msg)
596
+ return error_msg, None
597
  except Exception as e:
598
+ error_msg = f"❌ An unexpected error occurred: {e}"
599
+ print(error_msg)
600
+ import traceback
601
+ print(traceback.format_exc())
602
+ return error_msg, None
603
 
604
 
605
  # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
606
+ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as demo:
607
  gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
608
  gr.Markdown(
609
  """
610
+ Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model.
611
+
612
+ **⚠️ IMPORTANT - Free Tier Users:**
613
+ - **Download your model IMMEDIATELY** after training completes!
614
+ - The model file is **temporary** and will be deleted when the Space restarts.
615
+ - The download button will appear below once training is complete.
616
+ - Model is saved as: **`layoutlmv3_crf_passage.pth`**
617
 
618
+ **⏱️ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
619
  """
620
  )
621
 
622
  with gr.Row():
623
  with gr.Column(scale=1):
624
+ gr.Markdown("### πŸ“ Dataset Upload")
625
  file_input = gr.File(
626
+ label="Upload Label Studio JSON Dataset",
627
+ file_types=[".json"]
628
  )
629
 
630
  gr.Markdown("---")
631
  gr.Markdown("### βš™οΈ Training Parameters")
632
 
633
  batch_size_input = gr.Slider(
634
+ minimum=1, maximum=16, step=1, value=4,
635
+ label="Batch Size",
636
+ info="Smaller = less memory, slower training"
637
  )
638
  epochs_input = gr.Slider(
639
+ minimum=1, maximum=10, step=1, value=3,
640
+ label="Epochs",
641
+ info="Fewer epochs = faster training (recommended: 3-5)"
642
  )
643
  lr_input = gr.Number(
644
+ value=5e-5, label="Learning Rate",
645
+ info="Default: 5e-5"
646
  )
647
+ max_len_input = gr.Slider(
648
+ minimum=128, maximum=512, step=128, value=512,
649
+ label="Max Sequence Length",
650
+ info="Shorter = faster training, less memory"
651
  )
652
 
653
+ train_button = gr.Button("πŸ”₯ Start Training", variant="primary", size="lg")
654
+
655
  with gr.Column(scale=2):
656
+ gr.Markdown("### πŸ“Š Training Progress")
657
 
658
  log_output = gr.Textbox(
659
+ label="Training Logs",
660
+ lines=25,
661
+ max_lines=30,
662
  autoscroll=True,
663
+ show_copy_button=True,
664
+ placeholder="Click 'Start Training' to begin...\n\nLogs will appear here in real-time."
665
  )
666
 
667
+ gr.Markdown("### ⬇️ Download Trained Model")
 
668
 
669
+ model_download = gr.File(
670
+ label="Trained Model File (layoutlmv3_crf_passage.pth)",
671
+ interactive=False,
672
+ visible=True
673
+ )
674
+
675
+ gr.Markdown(
676
+ """
677
+ **πŸ“₯ Download Instructions:**
678
+ 1. Wait for training to complete (βœ… appears in logs)
679
+ 2. Click the download button/icon that appears above
680
+ 3. Save the `.pth` file to your local machine
681
+ 4. **Do this immediately** - file is temporary!
682
+
683
+ **πŸ”§ Troubleshooting:**
684
+ - If download button doesn't appear, check the logs for errors
685
+ - Try reducing epochs or batch size if timeout occurs
686
+ - Ensure your JSON file is properly formatted
687
+ """
688
+ )
689
 
690
  # Define the action when the button is clicked
691
  train_button.click(
692
  fn=train_model,
693
  inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
694
+ outputs=[log_output, model_download],
695
+ api_name="train"
696
+ )
697
+
698
+ # Add example info
699
+ gr.Markdown(
700
+ """
701
+ ---
702
+ ### πŸ“– About
703
+ This Space fine-tunes LayoutLMv3 with CRF for document understanding tasks including:
704
+ - Questions, Options, Answers
705
+ - Section Headings
706
+ - Passages
707
+
708
+ **Model Details:** LayoutLMv3-base + CRF layer for sequence labeling
709
+ """
710
  )
711
 
712
  if __name__ == "__main__":