aagamjtdev commited on
Commit
dc56cce
Β·
1 Parent(s): b249ba5
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import sys
5
+ from datetime import datetime
6
+
7
+ # The name of your existing training script
8
+ TRAINING_SCRIPT = "LayoutLM_Train_Passage.py"
9
+
10
+ # --- CORRECTED MODEL PATH BASED ON LayoutLM_Train_Passage.py ---
11
+ MODEL_OUTPUT_DIR = "checkpoints"
12
+ MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
13
+ MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
14
+
15
+
16
+ # ----------------------------------------------------------------
17
+
18
+
19
+ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
20
+ """
21
+ Handles the Gradio submission and executes the training script using subprocess.
22
+ """
23
+
24
+ # 1. Setup: Create output directory if it doesn't exist
25
+ os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
26
+
27
+ # 2. File Handling: Use the temporary path of the uploaded file
28
+ if dataset_file is None or not dataset_file.path.endswith(".json"):
29
+ return "❌ ERROR: Please upload a valid Label Studio JSON file.", None
30
+
31
+ input_path = dataset_file.path
32
+
33
+ progress(0.1, desc="Starting LayoutLMv3 Training...")
34
+
35
+ log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
36
+
37
+ # 3. Construct the subprocess command
38
+ command = [
39
+ sys.executable,
40
+ TRAINING_SCRIPT,
41
+ "--mode", "train",
42
+ "--input", input_path,
43
+ "--batch_size", str(batch_size),
44
+ "--epochs", str(epochs),
45
+ "--lr", str(lr),
46
+ "--max_len", str(max_len)
47
+ ]
48
+
49
+ log_output += f"Executing command: {' '.join(command)}\n\n"
50
+
51
+ try:
52
+ # 4. Run the training script and capture output
53
+ process = subprocess.Popen(
54
+ command,
55
+ stdout=subprocess.PIPE,
56
+ stderr=subprocess.STDOUT,
57
+ text=True,
58
+ bufsize=1
59
+ )
60
+
61
+ # Stream logs in real-time
62
+ for line in iter(process.stdout.readline, ""):
63
+ log_output += line
64
+ yield log_output, None # Send partial log to Gradio output
65
+
66
+ process.stdout.close()
67
+ return_code = process.wait()
68
+
69
+ # 5. Check for successful completion
70
+ if return_code == 0:
71
+ log_output += "\nβœ… TRAINING COMPLETE! Model saved."
72
+
73
+ # 6. Prepare download links based on script's saved path
74
+ model_exists = os.path.exists(MODEL_FILE_PATH)
75
+
76
+ if model_exists:
77
+ log_output += f"\nModel path: {MODEL_FILE_PATH}"
78
+ # Return final log, and the file path for Gradio's download component
79
+ return log_output, MODEL_FILE_PATH
80
+ else:
81
+ log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
82
+ return log_output, None
83
+ else:
84
+ log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
85
+ return log_output, None
86
+
87
+ except FileNotFoundError:
88
+ return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
89
+ except Exception as e:
90
+ return f"❌ An unexpected error occurred: {e}", None
91
+
92
+
93
+ # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
94
+ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
95
+ gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
96
+ gr.Markdown(
97
+ """
98
+ Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.
99
+
100
+ **Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
101
+ """
102
+ )
103
+
104
+ with gr.Row():
105
+ with gr.Column(scale=1):
106
+ file_input = gr.File(
107
+ label="1. Upload Label Studio JSON Dataset",
108
+ file_types=[".json"],
109
+ )
110
+
111
+ gr.Markdown("---")
112
+ gr.Markdown("### βš™οΈ Training Parameters")
113
+
114
+ batch_size_input = gr.Slider(
115
+ minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
116
+ )
117
+ epochs_input = gr.Slider(
118
+ minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
119
+ )
120
+ lr_input = gr.Number(
121
+ value=5e-5, label="Learning Rate (--lr)"
122
+ )
123
+ max_len_input = gr.Number(
124
+ value=512, label="Max Sequence Length (--max_len)"
125
+ )
126
+
127
+ with gr.Column(scale=2):
128
+ train_button = gr.Button("πŸ”₯ Train Model", variant="primary")
129
+
130
+ log_output = gr.Textbox(
131
+ label="Training Log Output",
132
+ lines=20,
133
+ autoscroll=True,
134
+ placeholder="Click 'Train Model' to start and see real-time logs..."
135
+ )
136
+
137
+ gr.Markdown("---")
138
+ gr.Markdown(f"### πŸŽ‰ Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")
139
+
140
+ # Only providing the download link for the saved .pth model file
141
+ model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)
142
+
143
+ # Define the action when the button is clicked
144
+ train_button.click(
145
+ fn=train_model,
146
+ inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
147
+ outputs=[log_output, model_download]
148
+ )
149
+
150
+ if __name__ == "__main__":
151
+ demo.launch(server_port=7860, server_name="0.0.0.0")