LayoutLM_train / app.py
aagamjtdev's picture
app.py
28f8ac4
raw
history blame
17.5 kB
# import gradio as gr
# import subprocess
# import os
# import sys
# from datetime import datetime
#
# # The name of your existing training script
# TRAINING_SCRIPT = "LayoutLM_Train_Passage.py"
#
# # --- CORRECTED MODEL PATH BASED ON LayoutLM_Train_Passage.py ---
# MODEL_OUTPUT_DIR = "checkpoints"
# MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
# MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
#
#
# # ----------------------------------------------------------------
#
#
# def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
# """
# Handles the Gradio submission and executes the training script using subprocess.
# """
#
# # 1. Setup: Create output directory if it doesn't exist
# os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
#
# # 2. File Handling: Use the temporary path of the uploaded file
# # if dataset_file is None or not dataset_file.path.endswith(".json"):
# # return "❌ ERROR: Please upload a valid Label Studio JSON file.", None
#
# input_path = dataset_file.path
#
# progress(0.1, desc="Starting LayoutLMv3 Training...")
#
# log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
#
# # 3. Construct the subprocess command
# command = [
# sys.executable,
# TRAINING_SCRIPT,
# "--mode", "train",
# "--input", input_path,
# "--batch_size", str(batch_size),
# "--epochs", str(epochs),
# "--lr", str(lr),
# "--max_len", str(max_len)
# ]
#
# log_output += f"Executing command: {' '.join(command)}\n\n"
#
# try:
# # 4. Run the training script and capture output
# process = subprocess.Popen(
# command,
# stdout=subprocess.PIPE,
# stderr=subprocess.STDOUT,
# text=True,
# bufsize=1
# )
#
# # Stream logs in real-time
# for line in iter(process.stdout.readline, ""):
# log_output += line
# yield log_output, None # Send partial log to Gradio output
#
# process.stdout.close()
# return_code = process.wait()
#
# # 5. Check for successful completion
# if return_code == 0:
# log_output += "\nβœ… TRAINING COMPLETE! Model saved."
#
# # 6. Prepare download links based on script's saved path
# model_exists = os.path.exists(MODEL_FILE_PATH)
#
# if model_exists:
# log_output += f"\nModel path: {MODEL_FILE_PATH}"
# # Return final log, and the file path for Gradio's download component
# return log_output, MODEL_FILE_PATH
# else:
# log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
# return log_output, None
# else:
# log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
# return log_output, None
#
# except FileNotFoundError:
# return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
# except Exception as e:
# return f"❌ An unexpected error occurred: {e}", None
#
#
# # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
# with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
# gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
# gr.Markdown(
# """
# Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.
#
# **Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
# """
# )
#
# with gr.Row():
# with gr.Column(scale=1):
# file_input = gr.File(
# label="1. Upload Label Studio JSON Dataset"
# )
#
# gr.Markdown("---")
# gr.Markdown("### βš™οΈ Training Parameters")
#
# batch_size_input = gr.Slider(
# minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
# )
# epochs_input = gr.Slider(
# minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
# )
# lr_input = gr.Number(
# value=5e-5, label="Learning Rate (--lr)"
# )
# max_len_input = gr.Number(
# value=512, label="Max Sequence Length (--max_len)"
# )
#
# with gr.Column(scale=2):
# train_button = gr.Button("πŸ”₯ Train Model", variant="primary")
#
# log_output = gr.Textbox(
# label="Training Log Output",
# lines=20,
# autoscroll=True,
# placeholder="Click 'Train Model' to start and see real-time logs..."
# )
#
# gr.Markdown("---")
# gr.Markdown(f"### πŸŽ‰ Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")
#
# # Only providing the download link for the saved .pth model file
# model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)
#
# # Define the action when the button is clicked
# train_button.click(
# fn=train_model,
# inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
# outputs=[log_output, model_download]
# )
#
# if __name__ == "__main__":
# demo.launch(server_port=7860, server_name="0.0.0.0")
# import gradio as gr
# import subprocess
# import os
# import sys
# from datetime import datetime
#
# # The name of your existing training script
# TRAINING_SCRIPT = "LayoutLM_Train_Passage.py"
#
# # --- CORRECTED MODEL PATH BASED ON LayoutLM_Train_Passage.py ---
# MODEL_OUTPUT_DIR = "checkpoints"
# MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
# MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
#
#
# # ----------------------------------------------------------------
#
#
# def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
# """
# Handles the Gradio submission and executes the training script using subprocess.
# """
#
# # 1. Setup: Create output directory if it doesn't exist
# os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
#
# # 2. File Handling: Use the temporary path of the uploaded file
# if dataset_file is None:
# yield "❌ ERROR: Please upload a file.", None
# return
#
# # FIX: Gradio returns the path in the .name attribute, not .path
# input_path = dataset_file.name
#
# if not input_path.lower().endswith(".json"):
# yield "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None
# return
#
# progress(0.1, desc="Starting LayoutLMv3 Training...")
#
# log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
#
# # 3. Construct the subprocess command
# command = [
# sys.executable,
# TRAINING_SCRIPT,
# "--mode", "train",
# "--input", input_path,
# "--batch_size", str(batch_size),
# "--epochs", str(epochs),
# "--lr", str(lr),
# "--max_len", str(max_len)
# ]
#
# log_output += f"Executing command: {' '.join(command)}\n\n"
# yield log_output, None # Yield the command to the log output
#
# try:
# # 4. Run the training script and capture output
# process = subprocess.Popen(
# command,
# stdout=subprocess.PIPE,
# stderr=subprocess.STDOUT,
# text=True,
# bufsize=1
# )
#
# # Stream logs in real-time
# for line in iter(process.stdout.readline, ""):
# log_output += line
# yield log_output, None # Send partial log to Gradio output
#
# process.stdout.close()
# return_code = process.wait()
#
# # 5. Check for successful completion
# if return_code == 0:
# log_output += "\nβœ… TRAINING COMPLETE! Model saved."
#
# # 6. Prepare download links based on script's saved path
# model_exists = os.path.exists(MODEL_FILE_PATH)
#
# if model_exists:
# log_output += f"\nModel path: {MODEL_FILE_PATH}"
# # Return final log, and the file path for Gradio's download component
# return log_output, MODEL_FILE_PATH
# else:
# log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
# return log_output, None
# else:
# log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
# return log_output, None
#
# except FileNotFoundError:
# return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
# except Exception as e:
# return f"❌ An unexpected error occurred: {e}", None
#
#
# # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
# with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
# gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
# gr.Markdown(
# """
# Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.
#
# **Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
# """
# )
#
# with gr.Row():
# with gr.Column(scale=1):
# file_input = gr.File(
# label="1. Upload Label Studio JSON Dataset"
# )
#
# gr.Markdown("---")
# gr.Markdown("### βš™οΈ Training Parameters")
#
# batch_size_input = gr.Slider(
# minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
# )
# epochs_input = gr.Slider(
# minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
# )
# lr_input = gr.Number(
# value=5e-5, label="Learning Rate (--lr)"
# )
# max_len_input = gr.Number(
# value=512, label="Max Sequence Length (--max_len)"
# )
#
# with gr.Column(scale=2):
# train_button = gr.Button("πŸ”₯ Train Model", variant="primary")
#
# log_output = gr.Textbox(
# label="Training Log Output",
# lines=20,
# autoscroll=True,
# placeholder="Click 'Train Model' to start and see real-time logs..."
# )
#
# gr.Markdown("---")
# gr.Markdown(f"### πŸŽ‰ Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")
#
# # Only providing the download link for the saved .pth model file
# model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)
#
# # Define the action when the button is clicked
# train_button.click(
# fn=train_model,
# inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
# outputs=[log_output, model_download]
# )
#
# if __name__ == "__main__":
# # Removed server_port and server_name as they are often unnecessary
# # and sometimes cause issues in managed Space environments.
# demo.launch()
import gradio as gr
import subprocess
import os
import sys
from datetime import datetime
# FIX: Update the script name to the correct one you uploaded
TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"
# --- CORRECTED MODEL PATH BASED ON YOUR SCRIPT ---
MODEL_OUTPUT_DIR = "checkpoints"
MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
# ----------------------------------------------------------------
def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
"""
Handles the Gradio submission and executes the training script using subprocess.
"""
# 1. Setup: Create output directory if it doesn't exist
os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
# 2. File Handling: Use the temporary path of the uploaded file
if dataset_file is None:
yield "❌ ERROR: Please upload a file.", None
return
# Using .name (Corrected in previous steps)
input_path = dataset_file.name
if not input_path.lower().endswith(".json"):
yield "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None
return
progress(0.1, desc="Starting LayoutLMv3 Training...")
log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
# 3. Construct the subprocess command
command = [
sys.executable,
# Now uses the corrected TRAINING_SCRIPT variable
TRAINING_SCRIPT,
"--mode", "train",
"--input", input_path,
"--batch_size", str(batch_size),
"--epochs", str(epochs),
"--lr", str(lr),
"--max_len", str(max_len)
]
log_output += f"Executing command: {' '.join(command)}\n\n"
yield log_output, None # Yield the command to the log output
try:
# 4. Run the training script and capture output
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
# Stream logs in real-time
for line in iter(process.stdout.readline, ""):
log_output += line
yield log_output, None # Send partial log to Gradio output
process.stdout.close()
return_code = process.wait()
# 5. Check for successful completion
if return_code == 0:
log_output += "\nβœ… TRAINING COMPLETE! Model saved."
# 6. Prepare download links based on script's saved path
model_exists = os.path.exists(MODEL_FILE_PATH)
if model_exists:
log_output += f"\nModel path: {MODEL_FILE_PATH}"
# Return final log, and the file path for Gradio's download component
return log_output, MODEL_FILE_PATH
else:
log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
return log_output, None
else:
log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
return log_output, None
except FileNotFoundError:
return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
except Exception as e:
return f"❌ An unexpected error occurred: {e}", None
# --- Gradio Interface Setup (using Blocks for a nicer layout) ---
with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
gr.Markdown(
"""
Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.
**Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
"""
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="1. Upload Label Studio JSON Dataset"
)
gr.Markdown("---")
gr.Markdown("### βš™οΈ Training Parameters")
batch_size_input = gr.Slider(
minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
)
epochs_input = gr.Slider(
minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
)
lr_input = gr.Number(
value=5e-5, label="Learning Rate (--lr)"
)
max_len_input = gr.Number(
value=512, label="Max Sequence Length (--max_len)"
)
with gr.Column(scale=2):
train_button = gr.Button("πŸ”₯ Train Model", variant="primary")
log_output = gr.Textbox(
label="Training Log Output",
lines=20,
autoscroll=True,
placeholder="Click 'Train Model' to start and see real-time logs..."
)
gr.Markdown("---")
gr.Markdown(f"### πŸŽ‰ Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")
# Only providing the download link for the saved .pth model file
model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)
# Define the action when the button is clicked
train_button.click(
fn=train_model,
inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
outputs=[log_output, model_download]
)
if __name__ == "__main__":
demo.launch()