LayoutLM_train / app.py
heerjtdev's picture
Update app.py
bb64eee verified
import gradio as gr
import subprocess
import os
import sys
from datetime import datetime
import shutil
TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"
MODEL_OUTPUT_DIR = "checkpoints"
MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
# ----------------------------------------------------------------
def retrieve_model():
"""
Checks for the final model file and prepares it for download.
Useful for when the training job finishes server-side but the
client connection has timed out.
"""
MODEL_OUTPUT_DIR = "checkpoints"
MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
if os.path.exists(MODEL_FILE_PATH):
file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
# CRITICAL: Copy to a simple location that Gradio can reliably serve
import tempfile
temp_dir = tempfile.gettempdir()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_model_path = os.path.join(temp_dir, f"layoutlmv3_trained_{timestamp}_recovered.pth")
try:
shutil.copy2(MODEL_FILE_PATH, temp_model_path)
download_path = temp_model_path
log_output = (
f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
f"πŸŽ‰ SUCCESS! A trained model was found and recovered. Boobs! AASTIK MERA NAAM\n"
f"πŸ“¦ Model file: {MODEL_FILE_PATH}\n"
f"πŸ“Š Model size: {file_size:.2f} MB\n"
f"πŸ”— Download path prepared: {download_path}\n\n"
f"⬇️ Click the 'πŸ“₯ Download Model' button below to save your model."
)
return log_output, download_path, gr.Button(visible=True)
except Exception as e:
log_output = (
f"--- Model Status Check FAILED ---\n"
f"⚠️ Trained model found, but could not prepare for download: {e}\n"
f"πŸ“ Original Path: {MODEL_FILE_PATH}. Try again or check Space logs."
)
return log_output, None, gr.Button(visible=False)
else:
log_output = (
f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
f"❌ Model file not found at {MODEL_FILE_PATH}.\n"
f"Training may still be running or it failed. Check back later."
)
return log_output, None, gr.Button(visible=False)
def clear_memory(dataset_file: gr.File):
"""
Deletes the model output directory and the uploaded dataset file.
"""
MODEL_OUTPUT_DIR = "checkpoints"
log_output = f"--- Memory Clear Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
# 1. Clear Model Checkpoints Directory
if os.path.exists(MODEL_OUTPUT_DIR):
try:
shutil.rmtree(MODEL_OUTPUT_DIR)
log_output += f"βœ… Successfully deleted model directory: {MODEL_OUTPUT_DIR}\n"
except Exception as e:
log_output += f"❌ ERROR deleting model directory {MODEL_OUTPUT_DIR}: {e}\n"
else:
log_output += f"ℹ️ Model directory not found: {MODEL_OUTPUT_DIR} (Nothing to delete)\n"
# 2. Clear Uploaded Dataset File (Temporary file cleanup)
if dataset_file is not None:
input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)
if os.path.exists(input_path):
try:
os.remove(input_path)
log_output += f"βœ… Successfully deleted uploaded dataset file: {input_path}\n"
except Exception as e:
log_output += f"❌ ERROR deleting dataset file {input_path}: {e}\n"
else:
log_output += f"ℹ️ Uploaded dataset file not found at {input_path}.\n"
else:
log_output += f"ℹ️ No dataset file currently tracked for deletion.\n"
# 3. Final message and state reset
log_output += f"--- Memory Clear Complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
log_output += "✨ Files and checkpoints have been removed. You can now start a fresh training run."
# Reset log_output, model_path_state, download_btn visibility, and model_download component
return log_output, None, gr.Button(visible=False), None
def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
"""
Handles the Gradio submission and executes the training script using subprocess.
Yields logs in real-time for user feedback.
"""
# 1. Setup: Create output directory if it doesn't exist
os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
# 2. File Handling: Use the temporary path of the uploaded file
if dataset_file is None:
yield "❌ ERROR: Please upload a file.", None, gr.Button(visible=False)
return
# CRITICAL FIX: dataset_file is a gradio.File object, use .name to get the path
# This is a temporary file path like /tmp/gradio/.../filename.json
input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)
# Verify the file actually exists before proceeding
if not os.path.exists(input_path):
error_msg = f"❌ ERROR: Uploaded file not found at {input_path}. Please try uploading again."
yield error_msg, None, gr.Button(visible=False)
return
if not input_path.lower().endswith(".json"):
yield "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None, gr.Button(visible=False)
return
progress(0.1, desc="Starting LayoutLMv3 Training...")
log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
# 3. Construct the subprocess command
command = [
sys.executable,
TRAINING_SCRIPT,
"--mode", "train",
"--input", input_path,
"--batch_size", str(batch_size),
"--epochs", str(epochs),
"--lr", str(lr),
"--max_len", str(max_len)
]
log_output += f"Executing command: {' '.join(command)}\n\n"
yield log_output, None, gr.Button(visible=False) # Initial yield
try:
# 4. Run the training script and capture output
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
# Stream logs in real-time
for line in iter(process.stdout.readline, ""):
log_output += line
# Print to console as well for debugging
print(line, end='')
# Yield updated logs in real-time
yield log_output, None, gr.Button(visible=False)
process.stdout.close()
return_code = process.wait()
# 5. Check for successful completion
if return_code == 0:
log_output += "\n" + "=" * 60 + "\n"
log_output += "βœ… TRAINING COMPLETE! Model saved successfully.\n"
log_output += "=" * 60 + "\n"
print("\nβœ… TRAINING COMPLETE! Model saved.")
# 6. Verify model file exists
if os.path.exists(MODEL_FILE_PATH):
file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
log_output += f"\nπŸ“¦ Model file found: {MODEL_FILE_PATH}"
log_output += f"\nπŸ“Š Model size: {file_size:.2f} MB"
print(f"\nβœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
# CRITICAL: Copy to a simple location that Gradio can reliably serve
# Use the same temp directory pattern as the uploaded JSON file
import tempfile
temp_dir = tempfile.gettempdir()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create filename in temp directory
temp_model_path = os.path.join(temp_dir, f"layoutlmv3_trained_{timestamp}.pth")
try:
# Copy the model to temp directory
shutil.copy2(MODEL_FILE_PATH, temp_model_path)
log_output += f"\nπŸ“‹ Model copied to temporary download location"
log_output += f"\nπŸ”— Download path: {temp_model_path}"
print(f"βœ… Model copied to temp location: {temp_model_path}")
# Verify the copy exists
if os.path.exists(temp_model_path):
log_output += f"\nβœ… Download file verified and ready!"
download_path = temp_model_path
else:
log_output += f"\n⚠️ Warning: Temp copy verification failed, using original path"
download_path = MODEL_FILE_PATH
except Exception as e:
log_output += f"\n⚠️ Could not create temp copy: {e}"
log_output += f"\nπŸ“ Using original path: {MODEL_FILE_PATH}"
print(f"⚠️ Copy failed: {e}, using original path")
download_path = MODEL_FILE_PATH
# Final success message
log_output += f"\n\n{'=' * 60}"
log_output += f"\nπŸŽ‰ SUCCESS! Your model is ready for download."
log_output += f"\n{'=' * 60}"
log_output += f"\n\n⬇️ Click the 'πŸ“₯ Download Model' button below to save your model."
log_output += f"\n⚠️ CRITICAL: Download NOW! File will be deleted when:"
log_output += f"\n - This tab is closed"
log_output += f"\n - Space restarts or goes idle"
log_output += f"\n - System clears temp files"
log_output += f"\n\nπŸ“₯ The file will download as a .pth file to your computer's Downloads folder."
log_output += f"\n\n{'=' * 60}\n"
# Return final logs and make download button visible
# IMPORTANT: Return the path that Gradio can access
yield log_output, download_path, gr.Button(visible=True)
return
else:
log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
log_output += f"\nπŸ” Checking directory contents..."
# List files in checkpoints directory for debugging
if os.path.exists(MODEL_OUTPUT_DIR):
files = os.listdir(MODEL_OUTPUT_DIR)
log_output += f"\nπŸ“ Files in {MODEL_OUTPUT_DIR}: {files}"
else:
log_output += f"\n❌ Directory {MODEL_OUTPUT_DIR} does not exist!"
yield log_output, None, gr.Button(visible=False)
return
else:
log_output += f"\n\n{'=' * 60}\n"
log_output += f"❌ TRAINING FAILED with return code {return_code}\n"
log_output += f"{'=' * 60}\n"
log_output += f"\nPlease check the logs above for error details.\n"
yield log_output, None, gr.Button(visible=False)
return
except FileNotFoundError:
error_msg = f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
print(error_msg)
yield log_output + "\n" + error_msg, None, gr.Button(visible=False)
return
except Exception as e:
error_msg = f"❌ An unexpected error occurred: {e}"
print(error_msg)
import traceback
print(traceback.format_exc())
yield log_output + "\n" + error_msg, None, gr.Button(visible=False)
return
# --- Gradio Interface Setup (using Blocks for a nicer layout) ---
with gr.Blocks(title="LayoutLMv3 Fine-Tuning App by Aastik", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
gr.Markdown(
"""
Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model.
**⚠️ IMPORTANT - Free Tier Users:**
- **Download your model IMMEDIATELY** after training completes!
- The model file is **temporary** and will be deleted when the Space restarts.
- A download button will appear below once training is complete.
- **Real-time logs** will stream during training so you can monitor progress.
**⏱️ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“ Dataset Upload")
file_input = gr.File(
label="Upload Label Studio JSON Dataset",
file_types=[".json"]
)
gr.Markdown("---")
gr.Markdown("### βš™οΈ Training Parameters")
batch_size_input = gr.Slider(
minimum=1, maximum=16, step=1, value=4,
label="Batch Size",
info="Smaller = less memory, slower training"
)
epochs_input = gr.Slider(
minimum=1, maximum=10, step=1, value=3,
label="Epochs",
info="Fewer epochs = faster training (recommended: 3-5)"
)
lr_input = gr.Number(
value=5e-5, label="Learning Rate",
info="Default: 5e-5"
)
max_len_input = gr.Slider(
minimum=128, maximum=512, step=128, value=512,
label="Max Sequence Length",
info="Shorter = faster training, less memory"
)
train_button = gr.Button("πŸ”₯ Start Training", variant="primary", size="lg")
check_button = gr.Button("πŸ” Check Model Status/Download", variant="secondary", size="lg")
clear_button = gr.Button("🧹 Clear Model/Dataset Files", variant="stop", size="lg")
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Training Progress (Real-Time Logs)")
log_output = gr.Textbox(
label="Training Logs - Updates in Real-Time",
lines=25,
max_lines=30,
autoscroll=True,
show_copy_button=True,
placeholder="Click 'Start Training' to begin...\n\nLogs will stream here in real-time as training progresses."
)
gr.Markdown("### ⬇️ Download Trained Model")
# Hidden state to store the file path
model_path_state = gr.State(value=None)
# Download button (initially hidden)
download_btn = gr.Button(
"πŸ“₯ Download Model (.pth file)",
variant="primary",
size="lg",
visible=False
)
check_button.click(
fn=retrieve_model, # A new function we'll define
inputs=[],
outputs=[log_output, model_path_state, download_btn]
)
# File output for download
model_download = gr.File(
label="Your trained model will appear here after clicking Download",
interactive=False,
visible=True
)
clear_button.click(
fn=clear_memory,
inputs=[file_input], # Pass the uploaded file object to delete the temp file
outputs=[log_output, model_path_state, download_btn, model_download]
)
gr.Markdown(
"""
**πŸ“₯ Download Instructions:**
1. Wait for training to complete - watch the real-time logs above
2. Look for **"βœ… TRAINING COMPLETE!"** message
3. Click the **"πŸ“₯ Download Model"** button that appears above
4. Save the `.pth` file to your local machine
5. **Do this immediately** - file is temporary and will be deleted on Space restart!
**πŸ”§ Troubleshooting:**
- If download button doesn't appear, check the logs for errors
- Try reducing epochs or batch size if timeout occurs
- Ensure your JSON file is properly formatted
- Logs update in real-time - you can monitor training progress
"""
)
# Define the training action - now with real-time log streaming via yield
train_button.click(
fn=train_model,
inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
outputs=[log_output, model_path_state, download_btn],
api_name="train"
)
# Define the download action
download_btn.click(
fn=lambda path: path,
inputs=[model_path_state],
outputs=[model_download]
)
gr.Markdown(
"""
---
### πŸ“– About
This Space fine-tunes LayoutLMv3 with CRF for document understanding tasks including:
- Questions, Options, Answers
- Section Headings
- Passages
**Model Details:** LayoutLMv3-base + CRF layer for sequence labeling
**Features:**
- βœ… Real-time log streaming during training
- βœ… Progress monitoring with epoch/batch updates
- βœ… Immediate model download after completion
- βœ… Automatic file preparation for download
"""
)
if __name__ == "__main__":
demo.launch()