aagamjtdev commited on
Commit
28f8ac4
Β·
1 Parent(s): d988980
Files changed (1) hide show
  1. app.py +166 -6
app.py CHANGED
@@ -150,16 +150,177 @@
150
  # demo.launch(server_port=7860, server_name="0.0.0.0")
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  import gradio as gr
154
  import subprocess
155
  import os
156
  import sys
157
  from datetime import datetime
158
 
159
- # The name of your existing training script
160
- TRAINING_SCRIPT = "LayoutLM_Train_Passage.py"
161
 
162
- # --- CORRECTED MODEL PATH BASED ON LayoutLM_Train_Passage.py ---
163
  MODEL_OUTPUT_DIR = "checkpoints"
164
  MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
165
  MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
@@ -181,7 +342,7 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
181
  yield "❌ ERROR: Please upload a file.", None
182
  return
183
 
184
- # FIX: Gradio returns the path in the .name attribute, not .path
185
  input_path = dataset_file.name
186
 
187
  if not input_path.lower().endswith(".json"):
@@ -195,6 +356,7 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
195
  # 3. Construct the subprocess command
196
  command = [
197
  sys.executable,
 
198
  TRAINING_SCRIPT,
199
  "--mode", "train",
200
  "--input", input_path,
@@ -306,6 +468,4 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
306
  )
307
 
308
  if __name__ == "__main__":
309
- # Removed server_port and server_name as they are often unnecessary
310
- # and sometimes cause issues in managed Space environments.
311
  demo.launch()
 
150
  # demo.launch(server_port=7860, server_name="0.0.0.0")
151
 
152
 
153
+ # import gradio as gr
154
+ # import subprocess
155
+ # import os
156
+ # import sys
157
+ # from datetime import datetime
158
+ #
159
+ # # The name of your existing training script
160
+ # TRAINING_SCRIPT = "LayoutLM_Train_Passage.py"
161
+ #
162
+ # # --- CORRECTED MODEL PATH BASED ON LayoutLM_Train_Passage.py ---
163
+ # MODEL_OUTPUT_DIR = "checkpoints"
164
+ # MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
165
+ # MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
166
+ #
167
+ #
168
+ # # ----------------------------------------------------------------
169
+ #
170
+ #
171
+ # def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
172
+ # """
173
+ # Handles the Gradio submission and executes the training script using subprocess.
174
+ # """
175
+ #
176
+ # # 1. Setup: Create output directory if it doesn't exist
177
+ # os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
178
+ #
179
+ # # 2. File Handling: Use the temporary path of the uploaded file
180
+ # if dataset_file is None:
181
+ # yield "❌ ERROR: Please upload a file.", None
182
+ # return
183
+ #
184
+ # # FIX: Gradio returns the path in the .name attribute, not .path
185
+ # input_path = dataset_file.name
186
+ #
187
+ # if not input_path.lower().endswith(".json"):
188
+ # yield "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None
189
+ # return
190
+ #
191
+ # progress(0.1, desc="Starting LayoutLMv3 Training...")
192
+ #
193
+ # log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
194
+ #
195
+ # # 3. Construct the subprocess command
196
+ # command = [
197
+ # sys.executable,
198
+ # TRAINING_SCRIPT,
199
+ # "--mode", "train",
200
+ # "--input", input_path,
201
+ # "--batch_size", str(batch_size),
202
+ # "--epochs", str(epochs),
203
+ # "--lr", str(lr),
204
+ # "--max_len", str(max_len)
205
+ # ]
206
+ #
207
+ # log_output += f"Executing command: {' '.join(command)}\n\n"
208
+ # yield log_output, None # Yield the command to the log output
209
+ #
210
+ # try:
211
+ # # 4. Run the training script and capture output
212
+ # process = subprocess.Popen(
213
+ # command,
214
+ # stdout=subprocess.PIPE,
215
+ # stderr=subprocess.STDOUT,
216
+ # text=True,
217
+ # bufsize=1
218
+ # )
219
+ #
220
+ # # Stream logs in real-time
221
+ # for line in iter(process.stdout.readline, ""):
222
+ # log_output += line
223
+ # yield log_output, None # Send partial log to Gradio output
224
+ #
225
+ # process.stdout.close()
226
+ # return_code = process.wait()
227
+ #
228
+ # # 5. Check for successful completion
229
+ # if return_code == 0:
230
+ # log_output += "\nβœ… TRAINING COMPLETE! Model saved."
231
+ #
232
+ # # 6. Prepare download links based on script's saved path
233
+ # model_exists = os.path.exists(MODEL_FILE_PATH)
234
+ #
235
+ # if model_exists:
236
+ # log_output += f"\nModel path: {MODEL_FILE_PATH}"
237
+ # # Return final log, and the file path for Gradio's download component
238
+ # return log_output, MODEL_FILE_PATH
239
+ # else:
240
+ # log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
241
+ # return log_output, None
242
+ # else:
243
+ # log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
244
+ # return log_output, None
245
+ #
246
+ # except FileNotFoundError:
247
+ # return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
248
+ # except Exception as e:
249
+ # return f"❌ An unexpected error occurred: {e}", None
250
+ #
251
+ #
252
+ # # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
253
+ # with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
254
+ # gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
255
+ # gr.Markdown(
256
+ # """
257
+ # Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.
258
+ #
259
+ # **Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
260
+ # """
261
+ # )
262
+ #
263
+ # with gr.Row():
264
+ # with gr.Column(scale=1):
265
+ # file_input = gr.File(
266
+ # label="1. Upload Label Studio JSON Dataset"
267
+ # )
268
+ #
269
+ # gr.Markdown("---")
270
+ # gr.Markdown("### βš™οΈ Training Parameters")
271
+ #
272
+ # batch_size_input = gr.Slider(
273
+ # minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
274
+ # )
275
+ # epochs_input = gr.Slider(
276
+ # minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
277
+ # )
278
+ # lr_input = gr.Number(
279
+ # value=5e-5, label="Learning Rate (--lr)"
280
+ # )
281
+ # max_len_input = gr.Number(
282
+ # value=512, label="Max Sequence Length (--max_len)"
283
+ # )
284
+ #
285
+ # with gr.Column(scale=2):
286
+ # train_button = gr.Button("πŸ”₯ Train Model", variant="primary")
287
+ #
288
+ # log_output = gr.Textbox(
289
+ # label="Training Log Output",
290
+ # lines=20,
291
+ # autoscroll=True,
292
+ # placeholder="Click 'Train Model' to start and see real-time logs..."
293
+ # )
294
+ #
295
+ # gr.Markdown("---")
296
+ # gr.Markdown(f"### πŸŽ‰ Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")
297
+ #
298
+ # # Only providing the download link for the saved .pth model file
299
+ # model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)
300
+ #
301
+ # # Define the action when the button is clicked
302
+ # train_button.click(
303
+ # fn=train_model,
304
+ # inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
305
+ # outputs=[log_output, model_download]
306
+ # )
307
+ #
308
+ # if __name__ == "__main__":
309
+ # # Removed server_port and server_name as they are often unnecessary
310
+ # # and sometimes cause issues in managed Space environments.
311
+ # demo.launch()
312
+
313
+
314
  import gradio as gr
315
  import subprocess
316
  import os
317
  import sys
318
  from datetime import datetime
319
 
320
+ # FIX: Update the script name to the correct one you uploaded
321
+ TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"
322
 
323
+ # --- CORRECTED MODEL PATH BASED ON YOUR SCRIPT ---
324
  MODEL_OUTPUT_DIR = "checkpoints"
325
  MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
326
  MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
 
342
  yield "❌ ERROR: Please upload a file.", None
343
  return
344
 
345
+ # Using .name (Corrected in previous steps)
346
  input_path = dataset_file.name
347
 
348
  if not input_path.lower().endswith(".json"):
 
356
  # 3. Construct the subprocess command
357
  command = [
358
  sys.executable,
359
+ # Now uses the corrected TRAINING_SCRIPT variable
360
  TRAINING_SCRIPT,
361
  "--mode", "train",
362
  "--input", input_path,
 
468
  )
469
 
470
  if __name__ == "__main__":
 
 
471
  demo.launch()