import gradio as gr import subprocess import os import sys from datetime import datetime import shutil TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py" MODEL_OUTPUT_DIR = "checkpoints" MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth" MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME) # ---------------------------------------------------------------- def retrieve_model(): """ Checks for the final model file and prepares it for download. Useful for when the training job finishes server-side but the client connection has timed out. """ MODEL_OUTPUT_DIR = "checkpoints" MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth" MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME) if os.path.exists(MODEL_FILE_PATH): file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB # CRITICAL: Copy to a simple location that Gradio can reliably serve import tempfile temp_dir = tempfile.gettempdir() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_model_path = os.path.join(temp_dir, f"layoutlmv3_trained_{timestamp}_recovered.pth") try: shutil.copy2(MODEL_FILE_PATH, temp_model_path) download_path = temp_model_path log_output = ( f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" f"๐ŸŽ‰ SUCCESS! A trained model was found and recovered. Boobs! AASTIK MERA NAAM\n" f"๐Ÿ“ฆ Model file: {MODEL_FILE_PATH}\n" f"๐Ÿ“Š Model size: {file_size:.2f} MB\n" f"๐Ÿ”— Download path prepared: {download_path}\n\n" f"โฌ‡๏ธ Click the '๐Ÿ“ฅ Download Model' button below to save your model." ) return log_output, download_path, gr.Button(visible=True) except Exception as e: log_output = ( f"--- Model Status Check FAILED ---\n" f"โš ๏ธ Trained model found, but could not prepare for download: {e}\n" f"๐Ÿ“ Original Path: {MODEL_FILE_PATH}. Try again or check Space logs." ) return log_output, None, gr.Button(visible=False) else: log_output = ( f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" f"โŒ Model file not found at {MODEL_FILE_PATH}.\n" f"Training may still be running or it failed. Check back later." ) return log_output, None, gr.Button(visible=False) def clear_memory(dataset_file: gr.File): """ Deletes the model output directory and the uploaded dataset file. """ MODEL_OUTPUT_DIR = "checkpoints" log_output = f"--- Memory Clear Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" # 1. Clear Model Checkpoints Directory if os.path.exists(MODEL_OUTPUT_DIR): try: shutil.rmtree(MODEL_OUTPUT_DIR) log_output += f"โœ… Successfully deleted model directory: {MODEL_OUTPUT_DIR}\n" except Exception as e: log_output += f"โŒ ERROR deleting model directory {MODEL_OUTPUT_DIR}: {e}\n" else: log_output += f"โ„น๏ธ Model directory not found: {MODEL_OUTPUT_DIR} (Nothing to delete)\n" # 2. Clear Uploaded Dataset File (Temporary file cleanup) if dataset_file is not None: input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file) if os.path.exists(input_path): try: os.remove(input_path) log_output += f"โœ… Successfully deleted uploaded dataset file: {input_path}\n" except Exception as e: log_output += f"โŒ ERROR deleting dataset file {input_path}: {e}\n" else: log_output += f"โ„น๏ธ Uploaded dataset file not found at {input_path}.\n" else: log_output += f"โ„น๏ธ No dataset file currently tracked for deletion.\n" # 3. Final message and state reset log_output += f"--- Memory Clear Complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" log_output += "โœจ Files and checkpoints have been removed. You can now start a fresh training run." # Reset log_output, model_path_state, download_btn visibility, and model_download component return log_output, None, gr.Button(visible=False), None def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()): """ Handles the Gradio submission and executes the training script using subprocess. Yields logs in real-time for user feedback. """ # 1. Setup: Create output directory if it doesn't exist os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True) # 2. File Handling: Use the temporary path of the uploaded file if dataset_file is None: yield "โŒ ERROR: Please upload a file.", None, gr.Button(visible=False) return # CRITICAL FIX: dataset_file is a gradio.File object, use .name to get the path # This is a temporary file path like /tmp/gradio/.../filename.json input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file) # Verify the file actually exists before proceeding if not os.path.exists(input_path): error_msg = f"โŒ ERROR: Uploaded file not found at {input_path}. Please try uploading again." yield error_msg, None, gr.Button(visible=False) return if not input_path.lower().endswith(".json"): yield "โŒ ERROR: Please upload a valid Label Studio JSON file (.json).", None, gr.Button(visible=False) return progress(0.1, desc="Starting LayoutLMv3 Training...") log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" # 3. Construct the subprocess command command = [ sys.executable, TRAINING_SCRIPT, "--mode", "train", "--input", input_path, "--batch_size", str(batch_size), "--epochs", str(epochs), "--lr", str(lr), "--max_len", str(max_len) ] log_output += f"Executing command: {' '.join(command)}\n\n" yield log_output, None, gr.Button(visible=False) # Initial yield try: # 4. Run the training script and capture output process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1 ) # Stream logs in real-time for line in iter(process.stdout.readline, ""): log_output += line # Print to console as well for debugging print(line, end='') # Yield updated logs in real-time yield log_output, None, gr.Button(visible=False) process.stdout.close() return_code = process.wait() # 5. Check for successful completion if return_code == 0: log_output += "\n" + "=" * 60 + "\n" log_output += "โœ… TRAINING COMPLETE! Model saved successfully.\n" log_output += "=" * 60 + "\n" print("\nโœ… TRAINING COMPLETE! Model saved.") # 6. Verify model file exists if os.path.exists(MODEL_FILE_PATH): file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB log_output += f"\n๐Ÿ“ฆ Model file found: {MODEL_FILE_PATH}" log_output += f"\n๐Ÿ“Š Model size: {file_size:.2f} MB" print(f"\nโœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)") # CRITICAL: Copy to a simple location that Gradio can reliably serve # Use the same temp directory pattern as the uploaded JSON file import tempfile temp_dir = tempfile.gettempdir() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Create filename in temp directory temp_model_path = os.path.join(temp_dir, f"layoutlmv3_trained_{timestamp}.pth") try: # Copy the model to temp directory shutil.copy2(MODEL_FILE_PATH, temp_model_path) log_output += f"\n๐Ÿ“‹ Model copied to temporary download location" log_output += f"\n๐Ÿ”— Download path: {temp_model_path}" print(f"โœ… Model copied to temp location: {temp_model_path}") # Verify the copy exists if os.path.exists(temp_model_path): log_output += f"\nโœ… Download file verified and ready!" download_path = temp_model_path else: log_output += f"\nโš ๏ธ Warning: Temp copy verification failed, using original path" download_path = MODEL_FILE_PATH except Exception as e: log_output += f"\nโš ๏ธ Could not create temp copy: {e}" log_output += f"\n๐Ÿ“ Using original path: {MODEL_FILE_PATH}" print(f"โš ๏ธ Copy failed: {e}, using original path") download_path = MODEL_FILE_PATH # Final success message log_output += f"\n\n{'=' * 60}" log_output += f"\n๐ŸŽ‰ SUCCESS! Your model is ready for download." log_output += f"\n{'=' * 60}" log_output += f"\n\nโฌ‡๏ธ Click the '๐Ÿ“ฅ Download Model' button below to save your model." log_output += f"\nโš ๏ธ CRITICAL: Download NOW! File will be deleted when:" log_output += f"\n - This tab is closed" log_output += f"\n - Space restarts or goes idle" log_output += f"\n - System clears temp files" log_output += f"\n\n๐Ÿ“ฅ The file will download as a .pth file to your computer's Downloads folder." log_output += f"\n\n{'=' * 60}\n" # Return final logs and make download button visible # IMPORTANT: Return the path that Gradio can access yield log_output, download_path, gr.Button(visible=True) return else: log_output += f"\nโš ๏ธ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})." log_output += f"\n๐Ÿ” Checking directory contents..." # List files in checkpoints directory for debugging if os.path.exists(MODEL_OUTPUT_DIR): files = os.listdir(MODEL_OUTPUT_DIR) log_output += f"\n๐Ÿ“ Files in {MODEL_OUTPUT_DIR}: {files}" else: log_output += f"\nโŒ Directory {MODEL_OUTPUT_DIR} does not exist!" yield log_output, None, gr.Button(visible=False) return else: log_output += f"\n\n{'=' * 60}\n" log_output += f"โŒ TRAINING FAILED with return code {return_code}\n" log_output += f"{'=' * 60}\n" log_output += f"\nPlease check the logs above for error details.\n" yield log_output, None, gr.Button(visible=False) return except FileNotFoundError: error_msg = f"โŒ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space." print(error_msg) yield log_output + "\n" + error_msg, None, gr.Button(visible=False) return except Exception as e: error_msg = f"โŒ An unexpected error occurred: {e}" print(error_msg) import traceback print(traceback.format_exc()) yield log_output + "\n" + error_msg, None, gr.Button(visible=False) return # --- Gradio Interface Setup (using Blocks for a nicer layout) --- with gr.Blocks(title="LayoutLMv3 Fine-Tuning App by Aastik", theme=gr.themes.Soft()) as demo: gr.Markdown("# ๐Ÿš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces") gr.Markdown( """ Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model. **โš ๏ธ IMPORTANT - Free Tier Users:** - **Download your model IMMEDIATELY** after training completes! - The model file is **temporary** and will be deleted when the Space restarts. - A download button will appear below once training is complete. - **Real-time logs** will stream during training so you can monitor progress. **โฑ๏ธ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training. """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### ๐Ÿ“ Dataset Upload") file_input = gr.File( label="Upload Label Studio JSON Dataset", file_types=[".json"] ) gr.Markdown("---") gr.Markdown("### โš™๏ธ Training Parameters") batch_size_input = gr.Slider( minimum=1, maximum=16, step=1, value=4, label="Batch Size", info="Smaller = less memory, slower training" ) epochs_input = gr.Slider( minimum=1, maximum=10, step=1, value=3, label="Epochs", info="Fewer epochs = faster training (recommended: 3-5)" ) lr_input = gr.Number( value=5e-5, label="Learning Rate", info="Default: 5e-5" ) max_len_input = gr.Slider( minimum=128, maximum=512, step=128, value=512, label="Max Sequence Length", info="Shorter = faster training, less memory" ) train_button = gr.Button("๐Ÿ”ฅ Start Training", variant="primary", size="lg") check_button = gr.Button("๐Ÿ” Check Model Status/Download", variant="secondary", size="lg") clear_button = gr.Button("๐Ÿงน Clear Model/Dataset Files", variant="stop", size="lg") with gr.Column(scale=2): gr.Markdown("### ๐Ÿ“Š Training Progress (Real-Time Logs)") log_output = gr.Textbox( label="Training Logs - Updates in Real-Time", lines=25, max_lines=30, autoscroll=True, show_copy_button=True, placeholder="Click 'Start Training' to begin...\n\nLogs will stream here in real-time as training progresses." ) gr.Markdown("### โฌ‡๏ธ Download Trained Model") # Hidden state to store the file path model_path_state = gr.State(value=None) # Download button (initially hidden) download_btn = gr.Button( "๐Ÿ“ฅ Download Model (.pth file)", variant="primary", size="lg", visible=False ) check_button.click( fn=retrieve_model, # A new function we'll define inputs=[], outputs=[log_output, model_path_state, download_btn] ) # File output for download model_download = gr.File( label="Your trained model will appear here after clicking Download", interactive=False, visible=True ) clear_button.click( fn=clear_memory, inputs=[file_input], # Pass the uploaded file object to delete the temp file outputs=[log_output, model_path_state, download_btn, model_download] ) gr.Markdown( """ **๐Ÿ“ฅ Download Instructions:** 1. Wait for training to complete - watch the real-time logs above 2. Look for **"โœ… TRAINING COMPLETE!"** message 3. Click the **"๐Ÿ“ฅ Download Model"** button that appears above 4. Save the `.pth` file to your local machine 5. **Do this immediately** - file is temporary and will be deleted on Space restart! **๐Ÿ”ง Troubleshooting:** - If download button doesn't appear, check the logs for errors - Try reducing epochs or batch size if timeout occurs - Ensure your JSON file is properly formatted - Logs update in real-time - you can monitor training progress """ ) # Define the training action - now with real-time log streaming via yield train_button.click( fn=train_model, inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input], outputs=[log_output, model_path_state, download_btn], api_name="train" ) # Define the download action download_btn.click( fn=lambda path: path, inputs=[model_path_state], outputs=[model_download] ) gr.Markdown( """ --- ### ๐Ÿ“– About This Space fine-tunes LayoutLMv3 with CRF for document understanding tasks including: - Questions, Options, Answers - Section Headings - Passages **Model Details:** LayoutLMv3-base + CRF layer for sequence labeling **Features:** - โœ… Real-time log streaming during training - โœ… Progress monitoring with epoch/batch updates - โœ… Immediate model download after completion - โœ… Automatic file preparation for download """ ) if __name__ == "__main__": demo.launch()