Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| import time | |
| import threading | |
| from data_utils import list_available_datasets, get_dataset_info | |
| from model_utils import list_available_huggingface_models | |
| from training_utils import ( | |
| start_model_training, | |
| stop_model_training, | |
| get_running_training_jobs, | |
| simulate_training | |
| ) | |
| from utils import ( | |
| set_page_config, | |
| display_sidebar, | |
| add_log, | |
| display_logs, | |
| plot_training_progress | |
| ) | |
| # Set page configuration | |
| set_page_config() | |
| # Display sidebar | |
| display_sidebar() | |
| # Title | |
| st.title("Model Training") | |
| st.markdown("Configure and train code generation models on your datasets.") | |
| # Training configuration tab | |
| tab1, tab2 = st.tabs(["Configure Training", "Monitor Jobs"]) | |
| with tab1: | |
| st.subheader("Train a New Model") | |
| # Model ID input | |
| model_id = st.text_input("Model ID", placeholder="e.g., my_codegen_model_v1") | |
| # Dataset selection | |
| available_datasets = list_available_datasets() | |
| if not available_datasets: | |
| st.warning("No datasets available. Please upload a dataset in the Dataset Management section.") | |
| dataset_name = None | |
| else: | |
| dataset_name = st.selectbox("Select Dataset", available_datasets) | |
| # Model selection | |
| model_options = list_available_huggingface_models() | |
| base_model = st.selectbox("Select Base Model", model_options) | |
| # Training parameters | |
| st.markdown("### Training Parameters") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| learning_rate = st.number_input( | |
| "Learning Rate", | |
| min_value=1e-6, | |
| max_value=1e-3, | |
| value=2e-5, | |
| format="%.2e" | |
| ) | |
| batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8, step=1) | |
| with col2: | |
| epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3, step=1) | |
| use_simulation = st.checkbox("Use Simulation Mode (for demonstration)", value=True) | |
| # Start training button | |
| if st.button("Start Training", disabled=not dataset_name): | |
| if not model_id: | |
| st.error("Please provide a model ID") | |
| elif model_id in st.session_state.get('trained_models', {}): | |
| st.error(f"Model with ID '{model_id}' already exists. Please choose a different ID.") | |
| elif model_id in st.session_state.get('training_progress', {}): | |
| st.error(f"A training job for model '{model_id}' already exists.") | |
| else: | |
| # Initialize stop_events if not present | |
| if 'stop_events' not in st.session_state: | |
| st.session_state.stop_events = {} | |
| # Start training (real or simulated) | |
| if use_simulation: | |
| st.session_state.stop_events[model_id] = simulate_training( | |
| model_id, dataset_name, base_model, epochs | |
| ) | |
| add_log(f"Started simulated training for model '{model_id}'") | |
| else: | |
| st.session_state.stop_events[model_id] = start_model_training( | |
| model_id, dataset_name, base_model, learning_rate, batch_size, epochs | |
| ) | |
| add_log(f"Started training for model '{model_id}'") | |
| st.success(f"Training job started for model '{model_id}'") | |
| time.sleep(1) | |
| st.rerun() | |
| with tab2: | |
| st.subheader("Training Jobs") | |
| # Check if there are any training jobs | |
| if 'training_progress' not in st.session_state or not st.session_state.training_progress: | |
| st.info("No training jobs found. Start a new training job in the 'Configure Training' tab.") | |
| else: | |
| # List all training jobs | |
| all_jobs = list(st.session_state.training_progress.keys()) | |
| selected_job = st.selectbox("Select Training Job", all_jobs) | |
| if selected_job: | |
| # Get job progress | |
| job_progress = st.session_state.training_progress[selected_job] | |
| # Display job status | |
| status = job_progress['status'] | |
| status_color = { | |
| 'initialized': 'blue', | |
| 'running': 'green', | |
| 'completed': 'green', | |
| 'failed': 'red', | |
| 'stopped': 'orange' | |
| }.get(status, 'gray') | |
| st.markdown(f"### Status: :{status_color}[{status.upper()}]") | |
| # Display progress bar | |
| progress = job_progress['progress'] | |
| st.progress(progress/100) | |
| # Display job details | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("### Job Details") | |
| st.markdown(f"**Model ID:** {selected_job}") | |
| st.markdown(f"**Current Epoch:** {job_progress['current_epoch']}/{job_progress['total_epochs']}") | |
| st.markdown(f"**Started At:** {job_progress['started_at']}") | |
| if job_progress['completed_at']: | |
| st.markdown(f"**Completed At:** {job_progress['completed_at']}") | |
| with col2: | |
| # Training controls | |
| st.markdown("### Controls") | |
| # Only show stop button for running jobs | |
| if status == 'running' and selected_job in st.session_state.get('stop_events', {}): | |
| if st.button("Stop Training"): | |
| stop_event = st.session_state.stop_events[selected_job] | |
| stop_model_training(selected_job, stop_event) | |
| st.success(f"Stopping training for model '{selected_job}'") | |
| time.sleep(1) | |
| st.rerun() | |
| # Add delete button for completed/failed/stopped jobs | |
| if status in ['completed', 'failed', 'stopped']: | |
| if st.button("Delete Job"): | |
| del st.session_state.training_progress[selected_job] | |
| if selected_job in st.session_state.get('stop_events', {}): | |
| del st.session_state.stop_events[selected_job] | |
| add_log(f"Deleted training job for model '{selected_job}'") | |
| st.success(f"Training job for model '{selected_job}' deleted") | |
| time.sleep(1) | |
| st.rerun() | |
| # Display training progress plot | |
| st.markdown("### Training Progress") | |
| plot_training_progress(selected_job) | |
| # Display logs | |
| st.markdown("### Training Logs") | |
| display_logs() | |
| # Display running jobs summary at the bottom | |
| st.markdown("---") | |
| st.subheader("Running Jobs Summary") | |
| running_jobs = get_running_training_jobs() | |
| if not running_jobs: | |
| st.info("No active training jobs") | |
| else: | |
| for job in running_jobs: | |
| progress = st.session_state.training_progress[job] | |
| col1, col2, col3 = st.columns([2, 1, 1]) | |
| with col1: | |
| st.markdown(f"**{job}**") | |
| with col2: | |
| st.markdown(f"Epoch {progress['current_epoch']}/{progress['total_epochs']}") | |
| with col3: | |
| st.progress(progress['progress']/100) | |