aagamjtdev commited on
Commit
c822327
Β·
1 Parent(s): 8348feb
Files changed (1) hide show
  1. app.py +55 -0
app.py CHANGED
@@ -16,6 +16,55 @@ MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
16
 
17
  # ----------------------------------------------------------------
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
21
  """
@@ -256,6 +305,12 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
256
  visible=False
257
  )
258
 
 
 
 
 
 
 
259
  # File output for download
260
  model_download = gr.File(
261
  label="Your trained model will appear here after clicking Download",
 
16
 
17
  # ----------------------------------------------------------------
18
 
19
+ def retrieve_model():
20
+ """
21
+ Checks for the final model file and prepares it for download.
22
+ Useful for when the training job finishes server-side but the
23
+ client connection has timed out.
24
+ """
25
+ MODEL_OUTPUT_DIR = "checkpoints"
26
+ MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
27
+ MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
28
+
29
+ if os.path.exists(MODEL_FILE_PATH):
30
+ file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
31
+
32
+ # CRITICAL: Copy to a simple location that Gradio can reliably serve
33
+ import tempfile
34
+ temp_dir = tempfile.gettempdir()
35
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
36
+ temp_model_path = os.path.join(temp_dir, f"layoutlmv3_trained_{timestamp}_recovered.pth")
37
+
38
+ try:
39
+ shutil.copy2(MODEL_FILE_PATH, temp_model_path)
40
+ download_path = temp_model_path
41
+
42
+ log_output = (
43
+ f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
44
+ f"πŸŽ‰ SUCCESS! A trained model was found and recovered.\n"
45
+ f"πŸ“¦ Model file: {MODEL_FILE_PATH}\n"
46
+ f"πŸ“Š Model size: {file_size:.2f} MB\n"
47
+ f"πŸ”— Download path prepared: {download_path}\n\n"
48
+ f"⬇️ Click the 'πŸ“₯ Download Model' button below to save your model."
49
+ )
50
+ return log_output, download_path, gr.Button(visible=True)
51
+
52
+ except Exception as e:
53
+ log_output = (
54
+ f"--- Model Status Check FAILED ---\n"
55
+ f"⚠️ Trained model found, but could not prepare for download: {e}\n"
56
+ f"πŸ“ Original Path: {MODEL_FILE_PATH}. Try again or check Space logs."
57
+ )
58
+ return log_output, None, gr.Button(visible=False)
59
+
60
+ else:
61
+ log_output = (
62
+ f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
63
+ f"❌ Model file not found at {MODEL_FILE_PATH}.\n"
64
+ f"Training may still be running or it failed. Check back later."
65
+ )
66
+ return log_output, None, gr.Button(visible=False)
67
+
68
 
69
  def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
70
  """
 
305
  visible=False
306
  )
307
 
308
+ check_button.click(
309
+ fn=retrieve_model, # A new function we'll define
310
+ inputs=[],
311
+ outputs=[log_output, model_path_state, download_btn]
312
+ )
313
+
314
  # File output for download
315
  model_download = gr.File(
316
  label="Your trained model will appear here after clicking Download",