diff --git "a/app.py" "b/app.py"
new file mode 100644--- /dev/null
+++ "b/app.py"
@@ -0,0 +1,2772 @@
+import os
+import shutil
+import stat
+import yaml
+import streamlit as st
+from ultralytics import YOLO
+from roboflow import Roboflow
+import re
+from urllib.parse import urlparse
+import random
+import logging
+import requests
+import json
+import concurrent.futures
+import multiprocessing
+from difflib import get_close_matches
+from io import StringIO
+import matplotlib.pyplot as plt
+import base64
+import io
+import zipfile
+from huggingface_hub import HfApi, HfFolder
+import torch
+from PIL import Image
+from ultralytics.data.utils import autosplit
+
+# Configure logging
+logging.basicConfig(
+ filename='rolo_app.log',
+ level=logging.INFO,
+ format='%(asctime)s - %(message)s'
+)
+
+# -------------------------------------------------------------------------
+# ADDITIONAL SESSION STATE FOR EARLY STOPPING AND NEW METRICS
+# -------------------------------------------------------------------------
+if 'early_stop_patience' not in st.session_state:
+ st.session_state['early_stop_patience'] = 10
+if 'best_val_loss' not in st.session_state:
+ st.session_state['best_val_loss'] = float('inf')
+if 'epochs_no_improvement' not in st.session_state:
+ st.session_state['epochs_no_improvement'] = 0
+
+# ----------------------------------------------------------------------------
+# SESSION STATE INITIALIZATION
+# ----------------------------------------------------------------------------
+if 'workflow_step' not in st.session_state:
+ st.session_state['workflow_step'] = "Prepare Datasets"
+
+if 'dataset_info_list' not in st.session_state:
+ st.session_state['dataset_info_list'] = []
+if 'dataset_location' not in st.session_state:
+ st.session_state['dataset_location'] = ''
+if 'dataset_prepared' not in st.session_state:
+ st.session_state['dataset_prepared'] = False
+if 'classes_to_include' not in st.session_state:
+ st.session_state['classes_to_include'] = {}
+if 'class_name_mapping' not in st.session_state:
+ st.session_state['class_name_mapping'] = {}
+if 'merging_confirmed' not in st.session_state:
+ st.session_state['merging_confirmed'] = False
+if 'selected_images' not in st.session_state:
+ st.session_state['selected_images'] = set()
+if 'class_counters' not in st.session_state:
+ st.session_state['class_counters'] = {}
+if 'adjustment_confirmed' not in st.session_state:
+ st.session_state['adjustment_confirmed'] = False
+if 'dataset_finalized' not in st.session_state:
+ st.session_state['dataset_finalized'] = False
+
+if 'metrics_data' not in st.session_state:
+ st.session_state['metrics_data'] = {
+ 'epoch': [],
+ 'train_loss': [],
+ 'val_loss': [],
+ 'mAP50': [],
+ 'mAP50_95': [],
+ 'top1_acc': [],
+ 'top5_acc': [],
+ 'mAPpose50': [],
+ 'mAPpose50_95': [],
+ # Add new metrics
+ 'precision': [],
+ 'recall': [],
+ 'F1': []
+ }
+if 'is_rerunning' not in st.session_state:
+ st.session_state['is_rerunning'] = False
+if 'class_image_counts' not in st.session_state:
+ st.session_state['class_image_counts'] = {}
+if 'datasets_used' not in st.session_state:
+ st.session_state['datasets_used'] = []
+if 'model_location' not in st.session_state:
+ st.session_state['model_location'] = None
+if 'model_url' not in st.session_state:
+ st.session_state['model_url'] = ""
+if 'templates_found' not in st.session_state:
+ st.session_state['templates_found'] = False
+if 'final_readme_content_editor' not in st.session_state:
+ st.session_state['final_readme_content_editor'] = ""
+if 'total_class_images' not in st.session_state:
+ st.session_state['total_class_images'] = {}
+
+if 'lrf' not in st.session_state:
+ st.session_state['lrf'] = 0.01
+if 'warmup_epochs' not in st.session_state:
+ st.session_state['warmup_epochs'] = 3
+if 'cache_dataset' not in st.session_state:
+ st.session_state['cache_dataset'] = False
+if 'freeze_layers' not in st.session_state:
+ st.session_state['freeze_layers'] = 0
+if 'hsv_h' not in st.session_state:
+ st.session_state['hsv_h'] = 0.015
+if 'hsv_s' not in st.session_state:
+ st.session_state['hsv_s'] = 0.7
+if 'hsv_v' not in st.session_state:
+ st.session_state['hsv_v'] = 0.4
+
+if 'min_bbox_area' not in st.session_state:
+ st.session_state['min_bbox_area'] = 0.0001
+if 'max_bbox_area' not in st.session_state:
+ st.session_state['max_bbox_area'] = 1.0
+if 'min_visible_keypoints' not in st.session_state:
+ st.session_state['min_visible_keypoints'] = 1
+
+if 'train_ratio' not in st.session_state:
+ st.session_state['train_ratio'] = 0.7
+if 'val_ratio' not in st.session_state:
+ st.session_state['val_ratio'] = 0.2
+if 'test_ratio' not in st.session_state:
+ st.session_state['test_ratio'] = 0.1
+
+if 'kpt_shape' not in st.session_state:
+ st.session_state['kpt_shape'] = None
+
+loss_chart_placeholder = st.empty()
+map_chart_placeholder = st.empty()
+progress_text_placeholder = st.empty() # New placeholder for epoch progress
+
+# ----------------------------------------------------------------------------
+# CACHED FUNCTIONS (UNCHANGED)
+# ----------------------------------------------------------------------------
+
+@st.cache_data(show_spinner=False)
+def load_optimizers(url: str):
+ """
+ [DATASET PREP] Load available optimizers and their descriptions from an external JSON.
+ """
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ return response.json()
+ except Exception as e:
+ logging.error(f"[DATASET PREP] Failed to load optimizers: {e}")
+ return {}
+
+@st.cache_data(show_spinner=False)
+def load_recommended_datasets(url: str):
+ """
+ [DATASET PREP] Load recommended datasets from an external JSON file hosted on GitHub.
+ Each dataset should have 'name', 'description', 'source', and 'url'.
+ NOTE: These recommended datasets are meant for *detection tasks* only.
+ """
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ return response.json()
+ except Exception as e:
+ logging.error(f"[DATASET PREP] Failed to load recommended datasets: {e}")
+ return []
+
+@st.cache_data(show_spinner=False)
+def load_presets(url: str):
+ """
+ Load configuration presets from an external JSON.
+ """
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ return response.json()
+ except Exception as e:
+ logging.error(f"[DATASET PREP] Failed to load presets: {e}")
+ return {}
+
+@st.cache_data(show_spinner=False)
+def cache_load_model_configs(url: str):
+ """
+ Loads model architectures and their pre-trained models from a remote JSON.
+ The JSON now has categories for each architecture:
+ model_config[arch]["categories"][category] -> list of model entries
+ """
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ return response.json()
+ except Exception as e:
+ logging.error(f"[DATASET PREP] Failed to load model configurations: {e}")
+ return None
+
+@st.cache_resource(show_spinner=False)
+def cache_download_pretrained_model(model_filename, download_url, models_dir='pretrained_models'):
+ """
+ [DATASET PREP] Example of using Streamlit's @st.cache_resource to cache a file download.
+ If the file already exists locally, we skip downloading.
+ """
+ if not os.path.exists(models_dir):
+ os.makedirs(models_dir, exist_ok=True)
+ logging.info(f"[DATASET PREP] Created directory '{models_dir}' for pre-trained models.")
+
+ model_path = os.path.join(models_dir, model_filename)
+ if os.path.exists(model_path):
+ logging.info(f"[DATASET PREP] Pre-trained model '{model_filename}' already exists locally.")
+ return model_path
+
+ try:
+ logging.info(f"[DATASET PREP] Starting download for model '{model_filename}' from '{download_url}'.")
+ response = requests.get(download_url, stream=True)
+ response.raise_for_status()
+
+ with open(model_path, 'wb') as f:
+ for data in response.iter_content(chunk_size=4096):
+ f.write(data)
+
+ logging.info(f"[DATASET PREP] Downloaded '{model_filename}' successfully.")
+ return model_path
+ except Exception as e:
+ logging.error(f"[DATASET PREP] Failed to download model '{model_filename}': {e}")
+ return None
+
+# ----------------------------------------------------------------------------
+# HELPER FUNCTIONS
+# ----------------------------------------------------------------------------
+
+def handle_remove_readonly(func, path, exc_info):
+ """Function to handle errors during shutil.rmtree (Windows file permissions)."""
+ os.chmod(path, stat.S_IWRITE)
+ func(path)
+
+def parse_roboflow_url(url):
+ """[DATASET PREP] Parses a Roboflow URL to extract workspace, project, and version number."""
+ parsed_url = urlparse(url.strip())
+ path_parts = parsed_url.path.strip('/').split('/')
+
+ if len(path_parts) >= 4 and path_parts[-2] == 'dataset':
+ workspace = path_parts[0]
+ project = path_parts[1]
+ version = path_parts[-1]
+ return workspace, project, version
+
+ query = parsed_url.query
+ match = re.search(r'version=(\d+)', query)
+ if match:
+ version = match.group(1)
+ if len(path_parts) >= 2:
+ workspace = path_parts[0]
+ project = path_parts[1]
+ return workspace, project, version
+
+ if len(path_parts) >= 2:
+ workspace = path_parts[0]
+ project = path_parts[1]
+ version = None
+ return workspace, project, version
+
+ return None, None, None
+
+def get_latest_version(rf, workspace_name, project_name):
+ """[DATASET PREP] Retrieves the latest version number for a Roboflow project."""
+ try:
+ project = rf.workspace(workspace_name).project(project_name)
+ versions = project.versions()
+ version_numbers = []
+ for v in versions:
+ version_num = getattr(v, 'version_number', getattr(v, 'number', None))
+ if version_num is not None:
+ version_numbers.append(int(version_num))
+ if not version_numbers:
+ st.error(f"No versions found for project '{project_name}'.")
+ logging.error(f"[DATASET PREP] No versions found for project '{project_name}'.")
+ return None
+ latest_version = max(version_numbers)
+ logging.info(f"[DATASET PREP] Latest version for project '{project_name}' is v{latest_version}.")
+ return str(latest_version)
+ except Exception as e:
+ st.error(f"Failed to retrieve latest version for project '{project_name}': {e}")
+ logging.error(f"[DATASET PREP] Failed to retrieve latest version: {e}")
+ return None
+
+def on_train_epoch_end(trainer):
+ """
+ [TRAINING] Callback to capture metrics at the end of each training epoch.
+ Includes early stopping logic based on val_loss improvements.
+ Dynamically displays metrics based on selected task category.
+ """
+ epoch = trainer.epoch
+ train_loss = trainer.metrics.get('train/loss', None)
+ val_loss = trainer.metrics.get('val/loss', None)
+ total_epochs = st.session_state.get('epochs', 1) # Get total epochs from session state
+
+ selected_category = st.session_state.get('selected_category', 'detection')
+
+ # Removed the problematic metric reset block.
+ # Metrics are now cleared at the start of train_model() only.
+ # The 'is_rerunning' flag will correctly be True for the first epoch's callback,
+ # and then set to False, as intended by its initial setup.
+
+ if epoch is not None:
+ st.session_state['metrics_data']['epoch'].append(epoch)
+ if train_loss is not None:
+ st.session_state['metrics_data']['train_loss'].append(train_loss)
+ if val_loss is not None:
+ st.session_state['metrics_data']['val_loss'].append(val_loss)
+
+ # --- Update progress counter ---
+ if total_epochs > 0:
+ current_epoch_display = epoch + 1
+ progress_percent = (current_epoch_display / total_epochs) * 100
+ progress_text_placeholder.write(f"Epoch {current_epoch_display}/{total_epochs} ({progress_percent:.1f}% done)")
+ else:
+ progress_text_placeholder.write(f"Epoch {epoch + 1} (Total epochs not set)")
+ # --- End progress counter update ---
+
+ if selected_category == 'classification':
+ top1_acc = trainer.metrics.get('metrics/top1_acc', None)
+ top5_acc = trainer.metrics.get('metrics/top5_acc', None)
+ if top1_acc is not None:
+ st.session_state['metrics_data']['top1_acc'].append(top1_acc)
+ if top5_acc is not None:
+ st.session_state['metrics_data']['top5_acc'].append(top5_acc)
+
+ if st.session_state['metrics_data']['epoch']:
+ loss_chart_placeholder.line_chart({
+ "Train Loss": st.session_state['metrics_data']['train_loss'],
+ "Val Loss": st.session_state['metrics_data']['val_loss']
+ })
+ if st.session_state['metrics_data']['top1_acc']:
+ map_chart_placeholder.line_chart({
+ "Top-1 Accuracy": st.session_state['metrics_data']['top1_acc'],
+ "Top-5 Accuracy": st.session_state['metrics_data']['top5_acc']
+ })
+ else:
+ map_chart_placeholder.empty()
+ elif selected_category == 'keypoint':
+ mappose50 = trainer.metrics.get('metrics/mAP50(P)', None)
+ mappose5095 = trainer.metrics.get('metrics/mAP50-95(P)', None)
+ if mappose50 is not None:
+ st.session_state['metrics_data']['mAPpose50'].append(mappose50)
+ if mappose5095 is not None:
+ st.session_state['metrics_data']['mAPpose50_95'].append(mappose5095)
+
+ if st.session_state['metrics_data']['epoch']:
+ loss_chart_placeholder.line_chart({
+ "Train Loss": st.session_state['metrics_data']['train_loss'],
+ "Val Loss": st.session_state['metrics_data']['val_loss']
+ })
+ if st.session_state['metrics_data']['mAPpose50']:
+ map_chart_placeholder.line_chart({
+ "mAPpose@0.5": st.session_state['metrics_data']['mAPpose50'],
+ "mAPpose@0.5:0.95": st.session_state['metrics_data']['mAPpose50_95']
+ })
+ else:
+ map_chart_placeholder.empty()
+ else: # detection, segmentation, obb
+ map50 = trainer.metrics.get('metrics/mAP50(B)', None)
+ map5095 = trainer.metrics.get('metrics/mAP50-95(B)', None)
+
+ # 1. Capture Precision, Recall, F1 using the correct (B) suffix keys, with fallback
+ precision = trainer.metrics.get('metrics/precision(B)', trainer.metrics.get('metrics/precision', None))
+ recall = trainer.metrics.get('metrics/recall(B)', trainer.metrics.get('metrics/recall', None))
+ f1 = trainer.metrics.get('metrics/f1(B)', trainer.metrics.get('metrics/f1', None)) # F1 also often has (B) suffix
+
+ if st.session_state["show_debug"]:
+ st.write(f"Epoch {epoch} metrics: {trainer.metrics}") # Debug: See exact keys
+
+ if map50 is not None:
+ st.session_state['metrics_data']['mAP50'].append(map50)
+ if map5095 is not None:
+ st.session_state['metrics_data']['mAP50_95'].append(map5095)
+ if precision is not None:
+ st.session_state['metrics_data']['precision'].append(precision)
+ if recall is not None:
+ st.session_state['metrics_data']['recall'].append(recall)
+ if f1 is not None:
+ st.session_state['metrics_data']['F1'].append(f1)
+
+ # 2. Build a single dictionary for map/precision curves and plot once
+ if st.session_state['metrics_data']['epoch']: # Ensure there's at least one epoch data
+ loss_chart_placeholder.line_chart({
+ "Train Loss": st.session_state['metrics_data']['train_loss'],
+ "Val Loss": st.session_state['metrics_data']['val_loss']
+ })
+
+ chart_data_for_map_metrics = {}
+ if st.session_state['metrics_data']['mAP50']:
+ chart_data_for_map_metrics["mAP@0.5"] = st.session_state['metrics_data']['mAP50']
+ if st.session_state['metrics_data']['mAP50_95']:
+ chart_data_for_map_metrics["mAP@0.5:0.95"] = st.session_state['metrics_data']['mAP50_95']
+ if st.session_state['metrics_data']['precision']:
+ chart_data_for_map_metrics["Precision"] = st.session_state['metrics_data']['precision']
+ if st.session_state['metrics_data']['recall']:
+ chart_data_for_map_metrics["Recall"] = st.session_state['metrics_data']['recall']
+ if st.session_state['metrics_data']['F1']:
+ chart_data_for_map_metrics["F1 Score"] = st.session_state['metrics_data']['F1']
+
+ if chart_data_for_map_metrics:
+ map_chart_placeholder.line_chart(chart_data_for_map_metrics)
+ else:
+ map_chart_placeholder.empty()
+
+ patience = st.session_state.get('early_stop_patience', 10)
+ if val_loss is not None:
+ if val_loss < st.session_state['best_val_loss']:
+ st.session_state['best_val_loss'] = val_loss
+ st.session_state['epochs_no_improvement'] = 0
+ if st.session_state["show_debug"]:
+ st.info(f"Epoch {epoch}: Validation loss improved to {val_loss:.4f}. Resetting patience counter.")
+ else:
+ st.session_state['epochs_no_improvement'] += 1
+ if st.session_state["show_debug"]:
+ st.warning(f"Epoch {epoch}: Validation loss did not improve. No improvement for {st.session_state['epochs_no_improvement']} epochs.")
+
+ if st.session_state['epochs_no_improvement'] >= patience:
+ st.warning(f"Early stopping triggered at epoch {epoch} due to no improvement in val_loss for {patience} epochs.")
+ if st.session_state["show_debug"]:
+ st.warning("Recommendation: Consider increasing 'Early Stopping Patience', adding more 'Data Augmentation', or reducing 'Learning Rate'.")
+ trainer.stop = True
+ else:
+ if st.session_state["show_debug"]:
+ st.warning(f"Epoch {epoch}: Validation loss not available for early stopping.")
+
+ # This line now just ensures is_rerunning is False after the first epoch,
+ # since the clearing logic has been moved outside this callback.
+ if st.session_state.get('is_rerunning', False):
+ st.session_state['is_rerunning'] = False
+
+
+# ----------------------------------------------------------------------------
+# DATASET PREPARATION HELPERS
+# ----------------------------------------------------------------------------
+
+def download_and_prepare_roboflow_dataset(
+ rf_api_key,
+ workspace_name,
+ project_name,
+ version_number,
+ selected_architecture,
+ selected_category
+):
+ """
+ [DATASET PREP] Downloads and prepares the Roboflow dataset based on the task category.
+ """
+ arch_to_rf_format = {
+ "YOLOv3": "yolov3",
+ "YOLOv5": "yolov5",
+ "YOLOv8": "yolov8",
+ "YOLOv9": "yolov9",
+ "YOLOv10": "yolov8",
+ "YOLOv11": "yolov8",
+ "YOLOv12": "yolov8",
+ "RTDETR": "yolov8" # RT-DETR uses YOLOv8 compatible dataset format
+ }
+
+ rf = Roboflow(api_key=rf_api_key)
+ try:
+ project = rf.workspace(workspace_name).project(project_name)
+ except Exception as e:
+ st.error(f"Failed to locate project '{project_name}' in workspace '{workspace_name}': {e}")
+ logging.error(f"[DATASET PREP] Failed to locate project: {e}")
+ return None, None, None
+
+ roboflow_export_format = None
+ if selected_category == "classification":
+ roboflow_export_format = "folder"
+ elif selected_category == "keypoint":
+ roboflow_export_format = "yolov8-keypoint"
+ else:
+ roboflow_export_format = arch_to_rf_format.get(selected_architecture, "yolov8")
+
+ try:
+ st.info(f"[DATASET PREP] Attempting download format '{roboflow_export_format}' for {selected_category}...")
+ dataset = project.version(version_number).download(roboflow_export_format)
+ logging.info(f"[DATASET PREP] Dataset '{project_name}' (v{version_number}) downloaded with format '{roboflow_export_format}'.")
+ except Exception as e:
+ logging.warning(f"[DATASET PREP] Failed to download with format '{roboflow_export_format}': {e}")
+ st.warning(f"Failed with '{roboflow_export_format}': {e}. Trying fallbacks...")
+ if selected_category in ["detection", "segmentation", "obb"]:
+ try:
+ dataset = project.version(version_number).download("yolov8")
+ logging.info(f"[DATASET PREP] Dataset '{project_name}' (v{version_number}) downloaded with 'yolov8' fallback.")
+ except Exception as e2:
+ logging.warning(f"Failed with 'yolov8': {e2}")
+ st.warning(f"Failed with 'yolov8': {e2}. Trying 'coco' next...")
+ try:
+ dataset = project.version(version_number).download("coco")
+ logging.info(f"[DATASET PREP] Dataset '{project_name}' (v{version_number}) downloaded with 'coco' fallback.")
+ except Exception as e3:
+ logging.error(f"[DATASET PREP] All fallback attempts failed for '{project_name}': {e3}")
+ st.error(f"Failed to download dataset '{project_name}' with '{roboflow_export_format}', 'yolov8', or 'coco': {e3}")
+ return None, None, None
+ else:
+ st.error(f"Failed to download dataset '{project_name}' in required '{roboflow_export_format}' format for {selected_category}. No fallback available.")
+ logging.error(f"[DATASET PREP] No fallback for {selected_category} after {roboflow_export_format} failed.")
+ return None, None, None
+
+ dataset_location = dataset.location
+ data_yaml_path = os.path.join(dataset_location, 'data.yaml')
+ class_names = []
+ kpt_shape_info = None
+
+ if selected_category == "classification":
+ class_names = []
+ for split_dir in ['train', 'valid', 'test']:
+ split_path = os.path.join(dataset_location, split_dir)
+ if os.path.exists(split_path):
+ class_names.extend([d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))])
+ class_names = sorted(list(set(class_names)))
+
+ data_yaml = {
+ 'path': dataset_location,
+ 'train': 'train',
+ 'val': 'valid',
+ 'test': 'test',
+ 'nc': len(class_names),
+ 'names': class_names
+ }
+ with open(data_yaml_path, 'w') as f:
+ yaml.safe_dump(data_yaml, f)
+ logging.info(f"[DATASET PREP] Created/Updated data.yaml for classification at {data_yaml_path}.")
+
+ else:
+ if not os.path.exists(data_yaml_path):
+ try:
+ class_names = project.classes
+ if not class_names:
+ logging.error("[DATASET PREP] No class names found via API.")
+ return None, None, None
+ data_yaml = {
+ 'train': 'train/images',
+ 'val': 'valid/images',
+ 'test': 'test/images',
+ 'nc': len(class_names),
+ 'names': class_names
+ }
+ if selected_category == "keypoint":
+ st.warning("Keypoint dataset detected. Please ensure its data.yaml will contain 'kpt_shape' field.")
+ with open(data_yaml_path, 'w') as f:
+ yaml.safe_dump(data_yaml, f)
+ except Exception as e:
+ logging.error(f"[DATASET PREP] Error writing data.yaml: {e}")
+ return None, None, None
+ else:
+ try:
+ with open(data_yaml_path, 'r') as f:
+ data_yaml = yaml.safe_load(f)
+ class_names = data_yaml.get('names', [])
+ if not class_names:
+ logging.error("[DATASET PREP] No class names found in data.yaml.")
+ return None, None, None
+ if selected_category == "keypoint":
+ kpt_shape_info = data_yaml.get('kpt_shape', None)
+ if kpt_shape_info is None:
+ st.error("Keypoint data.yaml missing 'kpt_shape'. Cannot proceed with keypoint training.")
+ logging.error("[DATASET PREP] Keypoint data.yaml missing 'kpt_shape'.")
+ return None, None, None
+ st.session_state['kpt_shape'] = kpt_shape_info
+ except Exception as e:
+ logging.error(f"[DATASET PREP] Error reading data.yaml: {e}")
+ return None, None, None
+
+ if 'train' in data_yaml and not data_yaml['train'].endswith('/images'): data_yaml['train'] = os.path.join(data_yaml['train'], 'images')
+ if 'val' in data_yaml and not data_yaml['val'].endswith('/images'): data_yaml['val'] = os.path.join(data_yaml['val'], 'images')
+ if 'test' in data_yaml and 'test' in data_yaml and not data_yaml['test'].endswith('/images'): data_yaml['test'] = os.path.join(data_yaml['test'], 'images')
+ with open(data_yaml_path, 'w') as f:
+ yaml.safe_dump(data_yaml, f)
+
+ splits = []
+ if selected_category == "classification":
+ for split in ['train', 'valid', 'test']:
+ if os.path.exists(os.path.join(dataset_location, split)):
+ splits.append(split)
+ else:
+ for split in ['train', 'valid', 'test']:
+ if os.path.exists(os.path.join(dataset_location, split, 'images')):
+ splits.append(split)
+
+ return dataset_location, class_names, splits
+
+def download_and_prepare_github_dataset(github_zip_url, dataset_name):
+ """
+ [DATASET PREP] Downloads and prepares a dataset from a GitHub (or Ultralytics) ZIP archive.
+ NOTE: The recommended datasets from your external JSON are for *detection only*.
+ """
+ try:
+ dataset_zip_path = os.path.join("temp_datasets", f"{dataset_name}.zip")
+ os.makedirs("temp_datasets", exist_ok=True)
+
+ with requests.get(github_zip_url, stream=True) as r:
+ r.raise_for_status()
+ with open(dataset_zip_path, 'wb') as f:
+ shutil.copyfileobj(r.raw, f)
+
+ extract_dir = os.path.join("temp_datasets", dataset_name)
+ shutil.unpack_archive(dataset_zip_path, extract_dir)
+ logging.info(f"[DATASET PREP] Extracted {dataset_name} to {extract_dir}")
+
+ data_yaml_path = os.path.join(extract_dir, 'data.yaml')
+ if not os.path.exists(data_yaml_path):
+ st.error(f"'data.yaml' not found in the dataset '{dataset_name}'.")
+ logging.error(f"[DATASET PREP] 'data.yaml' not found in {dataset_name}.")
+ return None, None, None
+
+ with open(data_yaml_path, 'r') as f:
+ data_yaml = yaml.safe_load(f)
+ class_names = data_yaml.get('names', [])
+
+ splits = ['train', 'valid', 'test']
+ available_splits = [s for s in splits if os.path.exists(os.path.join(extract_dir, s))]
+
+ return extract_dir, class_names, available_splits
+ except Exception as e:
+ st.error(f"Failed to download or prepare GitHub dataset '{dataset_name}': {e}")
+ logging.error(f"[DATASET PREP] Failed to download or prepare GitHub dataset '{dataset_name}': {e}")
+ return None, None, None
+
+# ----------------------------------------------------------------------------
+# TRAINING LOGIC & AUGMENTATIONS
+# ----------------------------------------------------------------------------
+
+def train_model(
+ dataset_location,
+ epochs,
+ batch_size,
+ img_size,
+ learning_rate,
+ optimizer,
+ data_augmentation_level,
+ pre_trained_model,
+ custom_model_path,
+ custom_model_name,
+ show_debug,
+ model_config,
+ selected_architecture,
+ selected_category
+):
+ """
+ [TRAINING] Train a model (e.g., YOLO, RT-DETR) with the given parameters,
+ supporting GPU detection and half precision by default.
+ Includes new hyperparameters and task-specific settings.
+ """
+ st.info(f"[TRAINING] Training model '{selected_architecture}' for task: '{selected_category}'...")
+ logging.info("[TRAINING] Initiating model training.")
+
+ train_data_path = os.path.join(dataset_location, "data.yaml")
+ if not os.path.exists(train_data_path):
+ st.error("data.yaml not found. Please ensure the dataset is prepared correctly.")
+ logging.error("[TRAINING] data.yaml not found.")
+ return None
+
+ if pre_trained_model == "Custom":
+ if not custom_model_path or not os.path.exists(custom_model_path):
+ st.error("Please provide a valid path for the custom pre-trained model.")
+ logging.error("[TRAINING] Custom model path invalid.")
+ return None
+ pre_trained_weights = custom_model_path
+ else:
+ model_entry = None
+ sub_models = model_config[selected_architecture]["categories"].get(selected_category, [])
+ for m in sub_models:
+ if m["filename"] == pre_trained_model:
+ model_entry = m
+ break
+ if model_entry is None:
+ st.error(f"Model '{pre_trained_model}' not found under '{selected_architecture}/{selected_category}'.")
+ return None
+
+ downloaded_path = cache_download_pretrained_model(pre_trained_model, model_entry["url"])
+ if not downloaded_path:
+ st.error(f"Failed to download the pre-trained model '{pre_trained_model}'.")
+ return None
+ pre_trained_weights = downloaded_path
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if device == "cuda":
+ st.write("Using **GPU** (CUDA) for training.")
+ logging.info("[TRAINING] Using GPU (CUDA).")
+ else:
+ st.write("Using **CPU** for training (this may be slower).")
+ logging.info("[TRAINING] Using CPU.")
+
+ use_half_precision = st.session_state.get("use_half_precision", True)
+ if use_half_precision and device == "cuda":
+ st.write("Training with **half-precision (FP16)** on GPU.")
+ logging.info("[TRAINING] Using half precision (FP16).")
+ elif use_half_precision and device == "cpu":
+ st.warning("Half-precision was requested but no GPU found. Skipping half-precision.")
+ logging.info("[TRAINING] Skipping half precision due to CPU-only environment.")
+ use_half_precision = False
+
+ try:
+ # The ultralytics library uses the YOLO class to handle various models, including RT-DETR.
+ # When a model like 'rtdetr-l.pt' is passed, it loads the correct architecture.
+ model = YOLO(pre_trained_weights)
+ logging.info(f"[TRAINING] Loaded pre-trained model: {pre_trained_weights}")
+ except Exception as e:
+ st.error(f"Failed to load the pre-trained model '{pre_trained_weights}': {e}")
+ logging.error(f"[TRAINING] Failed to load the pre-trained model '{pre_trained_weights}': {e}")
+ return None
+
+ TASK_MAPPING = {
+ "detection": "detect",
+ "segmentation": "segment",
+ "obb": "obb",
+ "classification": "classify",
+ "keypoint": "pose"
+ }
+ yolo_task = TASK_MAPPING.get(selected_category, "detect")
+
+ current_img_size = st.session_state["img_size"]
+ if selected_category == "classification" and current_img_size == 640:
+ current_img_size = 224
+ st.info(f"Using recommended image size {current_img_size} for classification.")
+ elif selected_category == "keypoint" and current_img_size == 640:
+ current_img_size = 640
+ st.info(f"Using recommended image size {current_img_size} for keypoint detection.")
+
+
+ augmentations = {}
+ if data_augmentation_level == "None":
+ augmentations = {
+ 'degrees': 0.0,
+ 'translate': 0.0,
+ 'scale': 0.0,
+ 'shear': 0.0,
+ 'fliplr': 0.0,
+ 'mixup': 0.0,
+ 'perspective': 0.0,
+ }
+ elif data_augmentation_level == "Basic":
+ augmentations = {
+ 'degrees': 5.0,
+ 'translate': 0.1,
+ 'scale': 0.1,
+ 'fliplr': 0.5
+ }
+ elif data_augmentation_level == "Moderate":
+ augmentations = {
+ 'degrees': 15.0,
+ 'translate': 0.2,
+ 'scale': 0.3,
+ 'shear': 2.0,
+ 'fliplr': 0.5,
+ 'mixup': 0.2
+ }
+ elif data_augmentation_level == "Advanced":
+ augmentations = {
+ 'degrees': 30.0,
+ 'translate': 0.3,
+ 'scale': 0.5,
+ 'shear': 5.0,
+ 'fliplr': 0.5,
+ 'mixup': 0.4,
+ 'perspective':0.1
+ }
+
+ augmentations['hsv_h'] = st.session_state["hsv_h"]
+ augmentations['hsv_s'] = st.session_state["hsv_s"]
+ augmentations['hsv_v'] = st.session_state["hsv_v"]
+
+
+ train_args = {
+ 'data': train_data_path,
+ 'epochs': epochs,
+ 'batch': batch_size,
+ 'imgsz': current_img_size,
+ 'lr0': learning_rate,
+ 'optimizer':optimizer.lower(),
+ 'workers': 4,
+ 'project': 'runs/train',
+ 'name': custom_model_name,
+ 'augment': (data_augmentation_level != "None"), # Enable/disable augmentation based on selection
+ 'task': yolo_task,
+ 'device': device,
+ 'half': use_half_precision,
+ 'lrf': st.session_state['lrf'],
+ 'warmup_epochs': st.session_state['warmup_epochs'],
+ 'cache': st.session_state['cache_dataset']
+ }
+ if data_augmentation_level != "None": # Only update with augs if not 'None'
+ train_args.update(augmentations)
+
+ if st.session_state['freeze_layers'] > 0:
+ train_args['freeze'] = st.session_state['freeze_layers']
+ logging.info(f"[TRAINING] Freezing {st.session_state['freeze_layers']} layers.")
+
+ model.add_callback("on_train_epoch_end", on_train_epoch_end)
+
+ # Clear previous metrics and reset early stopping state for a fresh training run
+ st.session_state['metrics_data'] = {
+ 'epoch': [], 'train_loss': [], 'val_loss': [],
+ 'mAP50': [], 'mAP50_95': [],
+ 'top1_acc': [], 'top5_acc': [],
+ 'mAPpose50': [], 'mAPpose50_95': [],
+ 'precision': [], 'recall': [], 'F1': []
+ }
+ st.session_state['best_val_loss'] = float('inf')
+ st.session_state['epochs_no_improvement'] = 0
+ st.session_state['is_rerunning'] = True # Set to True, will be reset to False after first epoch callback
+
+ try:
+ logging.info(f"[TRAINING] Training arguments: {train_args}")
+ model.train(**train_args)
+ logging.info("[TRAINING] Model training completed successfully.")
+ return model
+ except Exception as e:
+ st.error(f"An error occurred during training: {e}")
+ logging.error(f"[TRAINING] An error occurred during training: {e}")
+ if show_debug:
+ st.error(f"Training Error Details: {e}")
+ return None
+
+# ----------------------------------------------------------------------------
+# MERGING + PREPARATION FOR MERGED DATASET
+# ----------------------------------------------------------------------------
+
+def process_label_file(args):
+ """
+ [MERGE] Worker function for parallel processing of label files.
+ Reads each label file, parses classes, and returns relevant info.
+ """
+ label_path, class_names_dataset, class_name_mapping, unified_class_names = args
+ try:
+ with open(label_path, 'r') as f:
+ lines = f.readlines()
+ except Exception as e:
+ return None, f"[MERGE] Failed to read label file '{label_path}': {e}"
+
+ image_filename = os.path.basename(label_path).replace('.txt', '.jpg')
+ image_classes_set = set()
+ for line in lines:
+ parts = line.strip().split()
+ if len(parts) < 1:
+ continue
+ try:
+ class_id = int(parts[0])
+ if class_id < 0 or class_id >= len(class_names_dataset):
+ continue
+ original_class_name = class_names_dataset[class_id]
+ new_class_name = class_name_mapping.get(original_class_name, original_class_name)
+ if new_class_name in unified_class_names:
+ image_classes_set.add(new_class_name)
+ except (ValueError, IndexError):
+ continue
+ return (image_filename, image_classes_set), None
+
+def gather_class_counts(dataset_info_list, class_name_mapping, selected_category):
+ """
+ [MERGE] Gathers how many images (for classification) or how many images contain at least one
+ annotation for a given class (for others).
+ Returns a dictionary like {'class': count, ...}.
+ """
+ all_class_names = []
+ for _, class_names, _, _ in dataset_info_list:
+ all_class_names.extend(class_names)
+ all_class_names = sorted(set(all_class_names))
+
+ unified_class_names = set()
+ for class_name in all_class_names:
+ new_name = class_name_mapping.get(class_name, class_name)
+ unified_class_names.add(new_name)
+ unified_class_names = sorted(list(unified_class_names))
+
+ class_counters = {cls: 0 for cls in unified_class_names}
+
+ if selected_category == "classification":
+ for dataset_location, _, splits, _ in dataset_info_list:
+ for split_key in splits:
+ split_dir = os.path.join(dataset_location, split_key)
+ if not os.path.exists(split_dir):
+ continue
+ for class_folder_name in os.listdir(split_dir):
+ original_class_path = os.path.join(split_dir, class_folder_name)
+ if not os.path.isdir(original_class_path):
+ continue
+
+ unified_class_name = class_name_mapping.get(class_folder_name, class_folder_name)
+ if unified_class_name in class_counters:
+ class_counters[unified_class_name] += len([f for f in os.listdir(original_class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
+ else:
+ for dataset_location, class_names_dataset, splits, dataset_name in dataset_info_list:
+ for split_key in splits:
+ labels_src = os.path.join(dataset_location, split_key, 'labels')
+ if not os.path.exists(labels_src):
+ continue
+ for root, dirs, files in os.walk(labels_src):
+ for file in files:
+ if file.endswith('.txt'):
+ label_path = os.path.join(root, file)
+ try:
+ with open(label_path, 'r') as f:
+ lines = f.readlines()
+ except:
+ continue
+
+ image_classes_found_in_file = set()
+ for line in lines:
+ parts = line.strip().split()
+ if len(parts) < 1: continue
+ try:
+ class_id = int(parts[0])
+ if 0 <= class_id < len(class_names_dataset):
+ orig_cls = class_names_dataset[class_id]
+ new_cls = class_name_mapping.get(orig_cls, orig_cls)
+ if new_cls in unified_class_names:
+ image_classes_found_in_file.add(new_cls)
+ except:
+ continue
+ for cls in image_classes_found_in_file:
+ class_counters[cls] += 1
+
+ return class_counters
+
+def merge_classification_datasets(dataset_info_list, classes_to_include, class_name_mapping, show_debug):
+ """
+ [MERGE - CLASSIFICATION] Merges multiple classification datasets into 'rolo_dataset',
+ copying images into new, class-specific subfolders.
+ """
+ merged_dataset_dir = 'rolo_dataset'
+ if os.path.exists(merged_dataset_dir):
+ st.warning(f"[MERGE] Output directory '{merged_dataset_dir}' already exists. Overwriting...")
+ try:
+ shutil.rmtree(merged_dataset_dir, onerror=handle_remove_readonly)
+ except Exception as e:
+ if show_debug: st.error(f"[MERGE] Error deleting directory '{merged_dataset_dir}': {e}")
+ logging.error(f"[MERGE] Error deleting directory '{merged_dataset_dir}': {e}")
+ return None
+
+ all_original_class_names = set()
+ for _, class_names, _, _ in dataset_info_list:
+ all_original_class_names.update(class_names)
+
+ unified_class_names = set()
+ for class_name in all_original_class_names:
+ new_name = class_name_mapping.get(class_name, class_name)
+ unified_class_names.add(new_name)
+ unified_class_names = sorted(list(unified_class_names))
+
+ class_image_counts = {cls: 0 for cls in unified_class_names}
+ images_per_class_src = {cls: [] for cls in unified_class_names}
+
+ for dataset_location, class_names_dataset, splits, dataset_name in dataset_info_list:
+ for split_key in splits:
+ split_dir = os.path.join(dataset_location, split_key)
+ if not os.path.exists(split_dir):
+ continue
+ for class_folder_name in os.listdir(split_dir):
+ original_class_path = os.path.join(split_dir, class_folder_name)
+ if not os.path.isdir(original_class_path):
+ continue
+
+ unified_class_name = class_name_mapping.get(class_folder_name, class_folder_name)
+
+ total_limit_for_unified_class = 0
+ for orig_cls_key, limit in classes_to_include.items():
+ if class_name_mapping.get(orig_cls_key, orig_cls_key) == unified_class_name:
+ total_limit_for_unified_class += limit
+
+ if total_limit_for_unified_class == 0:
+ continue
+
+ for img_file in os.listdir(original_class_path):
+ if img_file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff')):
+ img_path_src = os.path.join(original_class_path, img_file)
+ images_per_class_src[unified_class_name].append((img_path_src, split_key))
+
+ st.subheader("Dataset Statistics & Insights")
+ st.write("### Available Images per Class (Pre-Adjustment):")
+ for cls in unified_class_names:
+ st.write(f"- **{cls}**: {len(images_per_class_src[cls])} images")
+
+ fig, ax = plt.subplots(figsize=(6, 3))
+ class_counts_list = [len(images_per_class_src[cls]) for cls in unified_class_names]
+ ax.bar(unified_class_names, class_counts_list)
+ ax.set_title("Class Distribution (All Datasets)")
+ ax.set_xlabel("Class Names")
+ ax.set_ylabel("Image Count")
+ plt.xticks(rotation=45, ha='right')
+ st.pyplot(fig)
+
+
+ selected_images_per_class = {cls: [] for cls in unified_class_names}
+ final_class_counts = {cls: 0 for cls in unified_class_names}
+ adjusted_limits_for_unified_classes = {}
+
+ for unified_cls_name in unified_class_names:
+ combined_limit = 0
+ for orig_cls_key, limit in classes_to_include.items():
+ if class_name_mapping.get(orig_cls_key, orig_cls_key) == unified_cls_name:
+ combined_limit += limit
+
+ max_available = len(images_per_class_src[unified_cls_name])
+ final_limit = min(combined_limit, max_available) if combined_limit > 0 else 0
+ adjusted_limits_for_unified_classes[unified_cls_name] = final_limit
+
+ random.shuffle(images_per_class_src[unified_cls_name])
+ selected_images_per_class[unified_cls_name] = images_per_class_src[unified_cls_name][:final_limit]
+ final_class_counts[unified_cls_name] = len(selected_images_per_class[unified_cls_name])
+
+ st.session_state['total_class_images'] = final_class_counts
+ st.session_state['class_image_counts'] = final_class_counts
+
+ for split in ['train', 'valid', 'test']:
+ for cls in unified_class_names:
+ if final_class_counts[cls] > 0:
+ os.makedirs(os.path.join(merged_dataset_dir, split, cls), exist_ok=True)
+ logging.info(f"[MERGE] Created directory '{os.path.join(merged_dataset_dir, split, cls)}'.")
+
+ for unified_cls_name, images_info in selected_images_per_class.items():
+ for img_path_src, split_key in images_info:
+ img_filename = os.path.basename(img_path_src)
+ img_path_dst = os.path.join(merged_dataset_dir, split_key, unified_cls_name, img_filename)
+ try:
+ shutil.copy(img_path_src, img_path_dst)
+ except Exception as e:
+ if show_debug: st.error(f"[MERGE] Error copying image {img_filename}: {e}")
+ logging.error(f"[MERGE] Error copying image {img_filename}: {e}")
+
+ active_final_classes = [cls for cls, count in final_class_counts.items() if count > 0]
+ merged_data_yaml = {
+ 'path': os.path.abspath(merged_dataset_dir),
+ 'train': 'train',
+ 'val': 'valid',
+ 'test': 'test',
+ 'nc': len(active_final_classes),
+ 'names': active_final_classes
+ }
+ with open(os.path.join(merged_dataset_dir, 'data.yaml'), 'w') as f:
+ yaml.safe_dump(merged_data_yaml, f)
+
+ st.session_state['class_counters'] = final_class_counts
+ st.session_state['active_classes'] = active_final_classes
+ st.session_state['class_limits'] = adjusted_limits_for_unified_classes
+ st.session_state['class_id_mapping'] = {name: idx for idx, name in enumerate(st.session_state['active_classes'])}
+ st.session_state['class_name_mapping'] = class_name_mapping
+ st.session_state['dataset_info_list'] = dataset_info_list
+ st.session_state['merged_dataset_dir'] = merged_dataset_dir
+
+ st.write("### Selected Images Count per Class (After Initial Merging):")
+ for cls in unified_class_names:
+ cnt = final_class_counts.get(cls, 0)
+ limit = adjusted_limits_for_unified_classes.get(cls, 0)
+ st.write(f"- **{cls}**: {cnt} selected (Limit: {limit})")
+
+ st.write("### Merged Classes:")
+ for old_class, new_class in class_name_mapping.items():
+ if old_class != new_class:
+ st.write(f"- **{old_class}** merged into **{new_class}**")
+
+ if len(active_final_classes) == 0:
+ st.error("WARNING: No classes or images remain after merging/excluding. Please adjust your class limits or rename logic.")
+ return merged_dataset_dir
+
+def merge_datasets(dataset_info_list, classes_to_include, class_name_mapping, show_debug):
+ """
+ [MERGE] Merges multiple datasets into 'rolo_dataset', respecting user-specified class limits.
+ Handles detection, segmentation, OBB, and keypoint data.
+ """
+ merged_dataset_dir = 'rolo_dataset'
+
+ if os.path.exists(merged_dataset_dir):
+ st.warning(f"[MERGE] Output directory '{merged_dataset_dir}' already exists. Overwriting...")
+ logging.warning(f"[MERGE] Output directory '{merged_dataset_dir}' already exists. Overwriting...")
+ try:
+ shutil.rmtree(merged_dataset_dir, onerror=handle_remove_readonly)
+ logging.info(f"[MERGE] Existing directory '{merged_dataset_dir}' deleted successfully.")
+ except Exception as e:
+ if show_debug:
+ st.error(f"[MERGE] Error deleting directory '{merged_dataset_dir}': {e}")
+ logging.error(f"[MERGE] Error deleting directory '{merged_dataset_dir}': {e}")
+ return None
+
+ try:
+ os.makedirs(merged_dataset_dir, exist_ok=True)
+ logging.info(f"[MERGE] Merged dataset directory '{merged_dataset_dir}' created.")
+ except Exception as e:
+ if show_debug:
+ st.error(f"[MERGE] Error creating directory '{merged_dataset_dir}': {e}")
+ logging.error(f"[MERGE] Error creating directory '{merged_dataset_dir}': {e}")
+ return None
+
+ all_class_names = []
+ for _, class_names, _, _ in dataset_info_list:
+ all_class_names.extend(class_names)
+ all_class_names = sorted(set(all_class_names))
+
+ unified_class_names = set()
+ for class_name in all_class_names:
+ new_name = class_name_mapping.get(class_name, class_name)
+ unified_class_names.add(new_name)
+ unified_class_names = sorted(list(unified_class_names))
+
+ class_limits = {name: 0 for name in unified_class_names}
+ for original_class, user_chosen_limit in classes_to_include.items():
+ new_name = class_name_mapping.get(original_class, original_class)
+ if new_name in class_limits:
+ class_limits[new_name] += user_chosen_limit
+
+ class_counters = gather_class_counts(dataset_info_list, class_name_mapping, st.session_state['selected_category'])
+ st.session_state['class_image_counts'] = class_counters
+
+ total_class_images = {cls: count for cls, count in class_counters.items()}
+ st.session_state['total_class_images'] = total_class_images
+
+ for cls in class_limits:
+ max_available = total_class_images.get(cls, 0)
+ if class_limits[cls] > max_available:
+ class_limits[cls] = max_available
+ if show_debug:
+ st.warning(f"[MERGE] Class '{cls}' limit adjusted to available images: {max_available}")
+
+ active_classes = [cls for cls, limit in class_limits.items() if limit > 0]
+
+ class_id_mapping = {name: idx for idx, name in enumerate(active_classes)}
+
+ class_to_images = {name: set() for name in unified_class_names}
+ image_to_classes = {}
+ image_to_label_path = {}
+
+ args_list = []
+ for dataset_location, class_names_dataset, splits, dataset_name in dataset_info_list:
+ for split_key in splits:
+ labels_src = os.path.join(dataset_location, split_key, 'labels')
+ if not os.path.exists(labels_src):
+ continue
+ label_files = [
+ os.path.join(root, f)
+ for root, dirs, files in os.walk(labels_src)
+ for f in files if f.endswith('.txt')
+ ]
+ for lf in label_files:
+ args_list.append((lf, class_names_dataset, class_name_mapping, unified_class_names,))
+
+ max_workers = min(32, (multiprocessing.cpu_count() or 1) + 4)
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+ future_to_args = {executor.submit(process_label_file, args): args for args in args_list}
+ for future in concurrent.futures.as_completed(future_to_args):
+ res, err = future.result()
+ if err:
+ logging.warning(f"[MERGE] Warning: {err}")
+ continue
+ if not res:
+ continue
+
+ image_filename, image_classes_set = res
+ lf_args = future_to_args[future]
+ label_path = lf_args[0]
+ images_src = label_path.replace('labels', 'images').replace('.txt', '.jpg')
+ if not os.path.exists(images_src):
+ alt_path = images_src.replace('.jpg', '.png')
+ if os.path.exists(alt_path):
+ images_src = alt_path
+ else:
+ continue
+
+ image_to_classes[images_src] = image_classes_set
+ image_to_label_path[images_src] = label_path
+ for cls in image_classes_set:
+ class_to_images[cls].add(images_src)
+
+ st.subheader("Dataset Statistics & Insights")
+ st.write("### Available Images per Class:")
+ for cls in unified_class_names:
+ st.write(f"- **{cls}**: {len(class_to_images.get(cls, []))} images")
+
+ fig, ax = plt.subplots(figsize=(6, 3))
+ class_counts_list = [len(class_to_images.get(cls, [])) for cls in unified_class_names]
+ ax.bar(unified_class_names, class_counts_list)
+ ax.set_title("Class Distribution (All Datasets)")
+ ax.set_xlabel("Class Names")
+ ax.set_ylabel("Image Count")
+ plt.xticks(rotation=45, ha='right')
+ st.pyplot(fig)
+
+ inconsistent_images = [img for img, cls_set in image_to_classes.items() if not cls_set]
+ if inconsistent_images:
+ st.warning(f"[MERGE] {len(inconsistent_images)} images have no valid annotations (excluded classes only).")
+
+ selected_images = set()
+ current_image_class_counts = {cls: 0 for cls in active_classes}
+
+ all_candidate_images = []
+ for cls in active_classes:
+ all_candidate_images.extend(list(class_to_images[cls]))
+ random.shuffle(all_candidate_images)
+
+ for img_path in all_candidate_images:
+ if img_path in selected_images: continue
+
+ classes_in_this_image = image_to_classes.get(img_path, set())
+
+ would_exceed_limit = False
+ for c in classes_in_this_image:
+ if c in active_classes and current_image_class_counts[c] >= class_limits[c]:
+ would_exceed_limit = True
+ break
+
+ if would_exceed_limit:
+ continue
+
+ selected_images.add(img_path)
+ for c in classes_in_this_image:
+ if c in active_classes:
+ current_image_class_counts[c] += 1
+
+ st.session_state['selected_images'] = selected_images
+ st.session_state['class_counters'] = current_image_class_counts
+ st.session_state['active_classes'] = active_classes
+ st.session_state['class_limits'] = class_limits
+ st.session_state['class_id_mapping'] = class_id_mapping
+ st.session_state['class_name_mapping'] = class_name_mapping
+ st.session_state['image_to_classes'] = image_to_classes
+ st.session_state['image_to_label_path'] = image_to_label_path
+ st.session_state['class_to_images'] = class_to_images
+ st.session_state['dataset_info_list'] = dataset_info_list
+ st.session_state['merged_dataset_dir'] = merged_dataset_dir
+
+ class_image_counts = {}
+ for cls in unified_class_names:
+ class_image_counts[cls] = current_image_class_counts.get(cls, 0)
+ st.session_state['class_image_counts'] = class_image_counts
+
+ datasets_used = list(sorted(set([info[3] for info in st.session_state['dataset_info_list']])))
+ st.session_state['datasets_used'] = datasets_used
+
+ st.write("### Selected Images Count per Class (Before Adjustment):")
+ for cls in unified_class_names:
+ cnt = current_image_class_counts.get(cls, 0)
+ limit = class_limits.get(cls, 0)
+ st.write(f"- **{cls}**: {cnt} selected (Limit: {limit})")
+
+ st.write("### Merged Classes:")
+ for old_class, new_class in class_name_mapping.items():
+ if old_class != new_class:
+ st.write(f"- **{old_class}** merged into **{new_class}**")
+
+ if len(selected_images) == 0:
+ st.error("WARNING: No images remain after merging/excluding. Please adjust your class limits or rename logic.")
+ return merged_dataset_dir
+
+def adjust_selected_images():
+ """
+ [MERGE] Allows user to manually adjust the number of selected images per class
+ using st.number_input. Then recalculates which images to keep.
+ """
+ if st.session_state['selected_category'] == "classification":
+ st.info("Manual image adjustment is not currently supported for Classification datasets. The class limits were applied during the merge step.")
+ st.session_state['adjustment_confirmed'] = True
+ return
+
+ st.header("Adjust Selected Images per Class")
+ st.markdown("You can manually adjust the number of selected images per class if desired.")
+
+ adjusted_class_limits_input = {}
+ with st.form("adjustment_form"):
+ for cls in st.session_state['active_classes']:
+ current_selected_images = st.session_state['class_counters'].get(cls, 0)
+ original_input_limit = st.session_state['class_limits'].get(cls, 0)
+ total_available_images_for_class = st.session_state['total_class_images'].get(cls, 0)
+
+ adjusted_val = st.number_input(
+ f"Adjust images for '{cls}' (Currently Selected: {current_selected_images}, Originally Targeted: {original_input_limit})",
+ min_value=0,
+ max_value=total_available_images_for_class,
+ value=current_selected_images,
+ step=1,
+ key=f"adjusted_count_{cls}"
+ )
+ adjusted_class_limits_input[cls] = adjusted_val
+
+ st.info("The final count for each class will be adjusted to the amount you specify, up to the total available images after previous filtering.")
+ confirm_adjustment = st.form_submit_button("Confirm Adjustments")
+
+ if confirm_adjustment:
+ new_selected_images = set()
+ new_class_image_counts = {cls: 0 for cls in st.session_state['active_classes']}
+
+ all_images_with_active_classes = list(st.session_state['image_to_classes'].keys())
+ random.shuffle(all_images_with_active_classes)
+
+ for img_path in all_images_with_active_classes:
+ classes_in_this_image = st.session_state['image_to_classes'].get(img_path, set())
+
+ can_add_image = False
+ for c in classes_in_this_image:
+ if c in st.session_state['active_classes'] and new_class_image_counts[c] < adjusted_class_limits_input[c]:
+ can_add_image = True
+ break
+
+ if not can_add_image:
+ continue
+
+ new_selected_images.add(img_path)
+ for c in classes_in_this_image:
+ if c in st.session_state['active_classes']:
+ new_class_image_counts[c] += 1
+
+ st.session_state['selected_images'] = new_selected_images
+ st.session_state['class_counters'] = new_class_image_counts
+ st.session_state['adjustment_confirmed'] = True
+
+ for cls in st.session_state['active_classes']:
+ st.session_state['class_image_counts'][cls] = new_class_image_counts[cls]
+
+ st.success("Adjustments confirmed. Proceed to finalize the merged dataset.")
+ st.session_state['workflow_step'] = "Finalize Dataset"
+
+def finalize_merged_dataset():
+ """
+ [MERGE] Copies selected images and corresponding labels into a new merged dataset directory,
+ respecting the user's class mappings. Creates a final 'data.yaml' consistent with only
+ the included classes.
+ """
+ if st.session_state['selected_category'] == "classification":
+ st.info("Classification dataset was finalized during the merge step. Ready for training!")
+ st.session_state['dataset_location'] = st.session_state['merged_dataset_dir']
+ st.session_state['dataset_finalized'] = True
+ return
+
+ st.info("[FINALIZE] Finalizing the merged dataset...")
+ selected_images = st.session_state['selected_images']
+ active_classes = st.session_state['active_classes']
+ class_id_mapping = st.session_state['class_id_mapping']
+ class_name_mapping = st.session_state['class_name_mapping']
+ image_to_classes = st.session_state['image_to_classes']
+ image_to_label_path = st.session_state['image_to_label_path']
+ dataset_info_list = st.session_state['dataset_info_list']
+ merged_dataset_dir = st.session_state['merged_dataset_dir']
+
+ splits_dirs = ['train', 'valid', 'test']
+ for split in splits_dirs:
+ images_dst = os.path.join(merged_dataset_dir, split, 'images')
+ labels_dst = os.path.join(merged_dataset_dir, split, 'labels')
+ try:
+ os.makedirs(images_dst, exist_ok=True)
+ os.makedirs(labels_dst, exist_ok=True)
+ except Exception as e:
+ st.error(f"[FINALIZE] Error creating split directories: {e}")
+ logging.error(f"[FINALIZE] Error creating split directories: {e}")
+
+ for img_path in selected_images:
+ split_key = 'train'
+ for d_info in dataset_info_list:
+ d_loc = d_info[0]
+ splits_in_dataset = d_info[2]
+ for s in splits_in_dataset:
+ split_images_dir = os.path.join(d_loc, s, 'images')
+ if img_path.startswith(split_images_dir):
+ split_key = s
+ break
+ if split_key != 'train':
+ break
+
+ image_filename = os.path.basename(img_path)
+ label_filename = os.path.splitext(image_filename)[0] + '.txt'
+ label_src_path = image_to_label_path.get(img_path)
+
+ image_dst_path = os.path.join(merged_dataset_dir, split_key, 'images', image_filename)
+ label_dst_path = os.path.join(merged_dataset_dir, split_key, 'labels', label_filename)
+
+ shutil.copy(img_path, image_dst_path)
+
+ if label_src_path and os.path.exists(label_src_path):
+ with open(label_src_path, 'r') as f:
+ lines = f.readlines()
+
+ updated_lines = []
+ for line in lines:
+ parts = line.strip().split()
+ if len(parts) < 1:
+ continue
+ try:
+ old_class_id = int(parts[0])
+ old_class_name = None
+
+ for d_loc2, class_names_dataset, _, _ in dataset_info_list:
+ if img_path.startswith(d_loc2):
+ if 0 <= old_class_id < len(class_names_dataset):
+ old_class_name = class_names_dataset[old_class_id]
+ break
+
+ if old_class_name is None:
+ continue
+
+ new_class_name = class_name_mapping.get(old_class_name, old_class_name)
+ if new_class_name not in active_classes:
+ continue
+ new_class_id = class_id_mapping.get(new_class_name, None)
+ if new_class_id is None:
+ continue
+
+ coords = parts[1:]
+ updated_line = f"{new_class_id} " + " ".join(coords)
+ updated_lines.append(updated_line)
+ except Exception as e:
+ if st.session_state["show_debug"]:
+ st.warning(f"Error re-indexing line '{line.strip()}' for image {image_filename}: {e}")
+ logging.warning(f"Error re-indexing line '{line.strip()}' for image {image_filename}: {e}")
+ continue
+
+ if not updated_lines:
+ if os.path.exists(image_dst_path):
+ os.remove(image_dst_path)
+ if os.path.exists(label_dst_path):
+ os.remove(label_dst_path)
+ else:
+ with open(label_dst_path, 'w') as f:
+ f.write("\n".join(updated_lines))
+
+ merged_data_yaml = {
+ 'path': os.path.abspath(merged_dataset_dir),
+ 'train': 'train/images',
+ 'val': 'valid/images',
+ 'test': 'test/images',
+ 'nc': len(active_classes),
+ 'names': active_classes
+ }
+
+ if st.session_state['selected_category'] == "keypoint" and st.session_state['kpt_shape'] is not None:
+ merged_data_yaml['kpt_shape'] = st.session_state['kpt_shape']
+ st.info(f"Keypoint dataset detected. Adding kpt_shape: {st.session_state['kpt_shape']} to data.yaml")
+
+ with open(os.path.join(merged_dataset_dir, 'data.yaml'), 'w') as f:
+ yaml.safe_dump(merged_data_yaml, f)
+
+ st.success(f"[FINALIZE] Final merged dataset is ready at '{merged_dataset_dir}'.")
+ st.session_state['dataset_location'] = merged_dataset_dir
+ st.session_state['dataset_finalized'] = True
+
+
+# ----------------------------------------------------------------------------
+# UTILITY: ZIP THE MERGED DATASET FOR DOWNLOAD
+# ----------------------------------------------------------------------------
+def zip_directory(source_dir):
+ """
+ Utility function to zip an entire folder into memory (BytesIO)
+ so we can offer a download in Streamlit.
+ """
+ buffer = io.BytesIO()
+ with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
+ for root, _, files in os.walk(source_dir):
+ for file in files:
+ file_path = os.path.join(root, file)
+ arcname = os.path.relpath(file_path, start=source_dir)
+ zipf.write(file_path, arcname)
+ buffer.seek(0)
+ return buffer
+
+# ----------------------------------------------------------------------------
+# GITHUB & HUGGING FACE UTILITIES (UNCHANGED)
+# ----------------------------------------------------------------------------
+
+def create_github_repo(username, token, repo_name, private=False):
+ """
+ [UPLOAD] Attempts to create a new GitHub repo under the specified username/org.
+ """
+ url = "https://api.github.com/user/repos"
+ headers = {
+ "Authorization": f"token {token}",
+ "Accept": "application/vnd.github.v3+json",
+ }
+ data = {"name": repo_name, "private": private}
+ response = requests.post(url, headers=headers, json=data)
+ return response.status_code, response.json()
+
+def get_github_repo_contents(username, repo_name, token):
+ """
+ Retrieves the contents of a GitHub repository.
+ """
+ url = f"https://api.github.com/repos/{username}/{repo_name}/contents/"
+ headers = {
+ "Authorization": f"token {token}",
+ "Accept": "application/vnd.github.v3+json",
+ }
+ response = requests.get(url, headers=headers)
+ return response.status_code, response.json()
+
+def list_github_files(owner, repo, path, token=None):
+ """
+ Lists the files in a specified GitHub repository folder.
+ Returns a list of file names.
+ """
+ api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
+ headers = {}
+ if token:
+ headers['Authorization'] = f'token {token}'
+ response = requests.get(api_url, headers=headers)
+ if response.status_code == 200:
+ contents = response.json()
+ return [item['name'] for item in contents if item['type'] == 'file']
+ elif response.status_code == 404:
+ return []
+ else:
+ st.error(f"GitHub API error: {response.status_code}")
+ return []
+
+def fetch_github_file(owner, repo, path, token=None):
+ """
+ Fetches the raw content of a file from GitHub. Returns the file content as text.
+ NOTE: The example templates are for *detection tasks* only.
+ """
+ api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
+ headers = {
+ "Accept": "application/vnd.github.v3.raw"
+ }
+ if token:
+ headers['Authorization'] = f'token {token}'
+ response = requests.get(api_url, headers=headers)
+ if response.status_code == 200:
+ return response.text
+ else:
+ st.error(f"Failed to fetch {path} from GitHub. Status Code: {response.status_code}")
+ return None
+
+def customize_template(template_content, model_url, selected_category, metrics_data, class_id_mapping, datasets_used, class_image_counts):
+ """
+ Replaces placeholders in the template with the actual model URL and training details.
+ """
+ # 1) Compute final metrics (improved robustness with .get and length check)
+ final_map50 = final_map50_95 = final_top1_acc = final_top5_acc = final_mappose50 = final_mappose50_95 = "N/A"
+ final_prec = final_rec = final_f1 = "N/A"
+
+ if selected_category in ('detection', 'segmentation', 'obb'):
+ if metrics_data.get('mAP50') and len(metrics_data['mAP50']) > 0:
+ final_map50 = f"{metrics_data['mAP50'][-1]:.4f}"
+ if metrics_data.get('mAP50_95') and len(metrics_data['mAP50_95']) > 0:
+ final_map50_95 = f"{metrics_data['mAP50_95'][-1]:.4f}"
+ elif selected_category == 'classification':
+ if metrics_data.get('top1_acc') and len(metrics_data['top1_acc']) > 0:
+ final_top1_acc = f"{metrics_data['top1_acc'][-1]:.4f}"
+ if metrics_data.get('top5_acc') and len(metrics_data['top5_acc']) > 0:
+ final_top5_acc = f"{metrics_data['top5_acc'][-1]:.4f}"
+ elif selected_category == 'keypoint':
+ if metrics_data.get('mAPpose50') and len(metrics_data['mAPpose50']) > 0:
+ final_mappose50 = f"{metrics_data['mAPpose50'][-1]:.4f}"
+ if metrics_data.get('mAPpose50_95') and len(metrics_data['mAPpose50_95']) > 0:
+ final_mappose50_95 = f"{metrics_data['mAPpose50_95'][-1]:.4f}"
+
+ # Get F1, Precision, Recall if available (assuming they are always appended regardless of category)
+ if metrics_data.get('precision') and len(metrics_data['precision']) > 0:
+ final_prec = f"{metrics_data['precision'][-1]:.4f}"
+ if metrics_data.get('recall') and len(metrics_data['recall']) > 0:
+ final_rec = f"{metrics_data['recall'][-1]:.4f}"
+ if metrics_data.get('F1') and len(metrics_data['F1']) > 0:
+ final_f1 = f"{metrics_data['F1'][-1]:.4f}"
+
+ placeholders = {
+ "{{MODEL_URL}}": model_url,
+ "{{MODEL_ARCH}}": st.session_state.get("selected_architecture", "N/A"),
+ "{{EPOCHS}}": str(st.session_state.get("epochs", "N/A")),
+ "{{BATCH_SIZE}}": str(st.session_state.get("batch_size", "N/A")),
+ "{{OPTIMIZER}}": st.session_state.get("optimizer", "N/A"),
+ "{{LEARNING_RATE}}": f"{st.session_state.get('learning_rate', 0.0):.5f}",
+ "{{DATA_AUG}}": st.session_state.get("data_augmentation_level", "N/A"),
+ "{{MODEL_NAME}}": st.session_state.get("custom_model_name", "N/A"),
+ "{{TASK_TYPE}}": selected_category.capitalize(),
+ "{{LRF}}": f"{st.session_state.get('lrf', 0.0):.3f}",
+ "{{WARMUP_EPOCHS}}": str(st.session_state.get('warmup_epochs', 'N/A')),
+ "{{HSV_H}}": f"{st.session_state.get('hsv_h', 0.0):.3f}",
+ "{{HSV_S}}": f"{st.session_state.get('hsv_s', 0.0):.2f}",
+ "{{HSV_V}}": f"{st.session_state.get('hsv_v', 0.0):.2f}",
+ # --- ADDED: Individual Metric Placeholders ---
+ "{{FINAL_MAP50}}": final_map50,
+ "{{FINAL_MAP5095}}": final_map50_95,
+ "{{FINAL_TOP1_ACC}}": final_top1_acc,
+ "{{FINAL_TOP5_ACC}}": final_top5_acc,
+ "{{FINAL_MAPPOSE50}}": final_mappose50,
+ "{{FINAL_MAPPOSE5095}}": final_mappose50_95,
+ "{{FINAL_PRECISION}}": final_prec,
+ "{{FINAL_RECALL}}": final_rec,
+ "{{FINAL_F1}}": final_f1,
+ # --- END ADDED ---
+ }
+
+ # --- UPDATED: Construct FINAL_METRICS block using the computed values ---
+ final_metrics_lines = []
+ if selected_category in ('detection', 'segmentation', 'obb'):
+ final_metrics_lines.append(f"- **mAP@0.5:** {final_map50}")
+ final_metrics_lines.append(f"- **mAP@0.5:0.95:** {final_map50_95}")
+ elif selected_category == 'classification':
+ final_metrics_lines.append(f"- **Top-1 Accuracy:** {final_top1_acc}")
+ final_metrics_lines.append(f"- **Top-5 Accuracy:** {final_top5_acc}")
+ elif selected_category == 'keypoint':
+ final_metrics_lines.append(f"- **mAPpose@0.5:** {final_mappose50}")
+ final_metrics_lines.append(f"- **mAPpose@0.5:0.95:** {final_mappose50_95}")
+
+ # Append F1, Precision, Recall if available (can be combined with other metrics)
+ # These are general metrics that can apply across tasks, so append if they have data
+ if final_prec != "N/A":
+ final_metrics_lines.append(f"- **Precision:** {final_prec}")
+ if final_rec != "N/A":
+ final_metrics_lines.append(f"- **Recall:** {final_rec}")
+ if final_f1 != "N/A":
+ final_metrics_lines.append(f"- **F1 Score:** {final_f1}")
+
+ if final_metrics_lines:
+ placeholders["{{FINAL_METRICS}}"] = "\n".join(final_metrics_lines)
+ else:
+ placeholders["{{FINAL_METRICS}}"] = "N/A (Metrics not available for this task type or not yet trained)"
+ # --- END UPDATED ---
+
+ freeze_layers = st.session_state.get('freeze_layers', 0)
+ if freeze_layers > 0:
+ placeholders["{{FREEZE_LAYERS}}"] = f"{freeze_layers} layers frozen"
+ placeholders["{{FREEZE_LAYERS_TEXT}}"] = f"The initial {freeze_layers} layers of the model were frozen during training to leverage transfer learning."
+ else:
+ placeholders["{{FREEZE_LAYERS}}"] = "None"
+ placeholders["{{FREEZE_LAYERS_TEXT}}"] = ""
+
+ cache_dataset = st.session_state.get('cache_dataset', False)
+ if cache_dataset:
+ placeholders["{{CACHE_DATASET}}"] = "Enabled (RAM)"
+ placeholders["{{CACHE_DATASET_TEXT}}"] = "The dataset was cached in RAM during training to speed up I/O operations."
+ else:
+ placeholders["{{CACHE_DATASET}}"] = "Disabled"
+ placeholders["{{CACHE_DATASET_TEXT}}"] = "Dataset caching was disabled during training."
+
+ if class_id_mapping:
+ sorted_classes = sorted(class_id_mapping.items(), key=lambda item: item[1])
+ class_ids_md = "| Class ID | Class Name |\n|----------|------------|\n"
+ for cls_name, cls_id in sorted_classes:
+ class_ids_md += f"| {cls_id} | {cls_name} |\n"
+ else:
+ class_ids_md = "No class IDs available (or model not yet trained with class mapping)."
+ placeholders["{{CLASS_IDS}}"] = class_ids_md
+
+ if datasets_used:
+ datasets_used_str = "\n".join([f"- {ds}" for ds in datasets_used])
+ else:
+ datasets_used_str = "N/A (No datasets used yet or not tracked)."
+ placeholders["{{DATASETS_USED}}"] = datasets_used_str
+
+ if class_image_counts:
+ count_type_label = "Images"
+ if selected_category == "classification":
+ count_type_label = "Images"
+ elif selected_category in ["detection", "segmentation", "obb", "keypoint"]:
+ count_type_label = "Images with Instances"
+
+ class_counts_md = f"| Class Name | {count_type_label} Count |\n|------------|-------------|\n"
+ for cls, count in class_image_counts.items():
+ class_counts_md += f"| {cls} | {count} |\n"
+ else:
+ class_counts_md = "No class image counts available (or model not yet trained)."
+ placeholders["{{CLASS_IMAGE_COUNTS}}"] = class_counts_md
+
+ final_content = template_content
+ for key, val in placeholders.items():
+ final_content = final_content.replace(key, str(val))
+
+ return final_content
+
+def get_download_link(content, filename, mime):
+ """
+ Generates a download link for the given content.
+ """
+ b64 = base64.b64encode(content.encode()).decode()
+ return f'Download {filename}'
+
+def generate_zip(generated_examples):
+ """
+ Generates a ZIP file from the generated examples. Returns the ZIP as bytes.
+ """
+ # Corrected to use a local BytesIO object
+ zip_buffer = io.BytesIO()
+ with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
+ for fname, content in generated_examples.items():
+ zipf.writestr(fname, content)
+ zip_buffer.seek(0)
+ return zip_buffer.read()
+
+def detect_model_location():
+ """
+ Attempts to detect if the model URL is on Huggingface or GitHub.
+ Sets 'model_location' in session_state accordingly.
+ """
+ model_url = st.session_state.get('model_url', "").strip()
+ if not model_url:
+ st.warning("[UPLOAD] Model URL is empty. Please provide a valid URL.")
+ st.session_state['model_location'] = None
+ return
+
+ parsed_url = urlparse(model_url)
+ if 'huggingface.co' in parsed_url.netloc:
+ st.session_state['model_location'] = 'huggingface'
+ elif 'github.com' in parsed_url.netloc:
+ st.session_state['model_location'] = 'github'
+ else:
+ st.error("[UPLOAD] Unsupported model URL. Please provide a Huggingface or GitHub URL.")
+ st.session_state['model_location'] = None
+
+# ----------------------------------------------------------------------------
+# MAIN APP LOGIC
+# ----------------------------------------------------------------------------
+
+st.title("Rolo: Roboflow & Ultralytics Combined Model Training Dashboard")
+
+steps = [
+ "Prepare Datasets",
+ "Class Selection",
+ "Merge & Adjust",
+ "Finalize Dataset",
+ "Train Model",
+ "Upload Model",
+ "Generate Examples"
+]
+
+current_workflow_step_index = 0
+if 'workflow_step' in st.session_state:
+ try:
+ current_workflow_step_index = steps.index(st.session_state['workflow_step'])
+ except ValueError:
+ current_workflow_step_index = 0
+
+st.session_state['workflow_step'] = st.radio(
+ "Workflow Steps",
+ steps,
+ index=current_workflow_step_index
+)
+
+opt_json_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/optimizers.json"
+optimizers_dict = load_optimizers(opt_json_url)
+if not optimizers_dict:
+ st.warning("Could not load external optimizers.json. Fallback to default if needed.")
+
+available_optimizers = list(optimizers_dict.keys()) if optimizers_dict else ["Adam", "SGD", "RMSProp"]
+
+recommended_datasets_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/recommended_datasets.json"
+recommended_datasets = load_recommended_datasets(recommended_datasets_url)
+if not recommended_datasets:
+ st.warning("Could not load external recommended_datasets.json. No recommended detection datasets available.")
+
+presets_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/presets.json"
+presets = load_presets(presets_url)
+
+# ----------------------------------------------------------------------------
+# SIDEBAR: SAVE / LOAD CONFIG
+# ----------------------------------------------------------------------------
+
+with st.sidebar.expander("Save / Load Configuration"):
+ if st.button("Save Current Configuration"):
+ config = {
+ "rf_api_key": st.session_state.get("rf_api_key", ""),
+ "epochs": st.session_state.get("epochs", 150),
+ "batch_size": st.session_state.get("batch_size", 32),
+ "img_size": st.session_state.get("img_size", 640),
+ "learning_rate": st.session_state.get("learning_rate", 0.0005),
+ "optimizer": st.session_state.get("optimizer", "Adam"),
+ "data_augmentation_level": st.session_state.get("data_augmentation_level", "Moderate"),
+ "custom_model_name": st.session_state.get("custom_model_name", "rolo_trained_model"),
+ "selected_architecture": st.session_state.get("selected_architecture", ""),
+ "selected_category": st.session_state.get("selected_category", ""),
+ "pre_trained_model": st.session_state.get("pre_trained_model", ""),
+ "custom_model_path": st.session_state.get("custom_model_path", ""),
+ "show_debug": st.session_state.get("show_debug", False),
+ "use_half_precision": st.session_state.get("use_half_precision", True),
+ "early_stop_patience": st.session_state.get("early_stop_patience", 10),
+ "lrf": st.session_state.get("lrf", 0.01),
+ "warmup_epochs": st.session_state.get("warmup_epochs", 3),
+ "cache_dataset": st.session_state.get("cache_dataset", False),
+ "freeze_layers": st.session_state.get("freeze_layers", 0),
+ "hsv_h": st.session_state.get("hsv_h", 0.015),
+ "hsv_s": st.session_state.get("hsv_s", 0.7),
+ "hsv_v": st.session_state.get("hsv_v", 0.4),
+ "train_ratio": st.session_state.get("train_ratio", 0.7),
+ "val_ratio": st.session_state.get("val_ratio", 0.2),
+ "test_ratio": st.session_state.get("test_ratio", 0.1),
+ }
+ config_json = json.dumps(config, indent=4)
+
+ st.download_button(
+ label="Download Current Configuration",
+ data=config_json,
+ file_name="rolo_config.json",
+ mime="application/json"
+ )
+
+ uploaded_profile = st.file_uploader("Load Configuration Profile", type=["json"])
+ if uploaded_profile is not None:
+ try:
+ config = json.load(uploaded_profile)
+ for key, value in config.items():
+ if key in st.session_state:
+ st.session_state[key] = value
+ st.success("Configuration loaded successfully.")
+ except Exception as e:
+ st.error(f"Failed to load configuration: {e}")
+
+# ----------------------------------------------------------------------------
+# SIDEBAR: PRESETS
+# ----------------------------------------------------------------------------
+
+st.sidebar.header("Presets")
+if presets:
+ preset_names = list(presets.keys())
+ selected_preset = st.sidebar.selectbox("Load a Preset", ["None"] + preset_names)
+ if selected_preset != "None":
+ preset = presets[selected_preset]
+ for key, value in preset.items():
+ st.session_state[key] = value
+ st.sidebar.success(f"Preset '{selected_preset}' loaded successfully.")
+else:
+ st.sidebar.warning("Could not load presets from the external JSON.")
+
+# ----------------------------------------------------------------------------
+# SIDEBAR: ROBOFLOW CONFIG
+# ----------------------------------------------------------------------------
+
+st.sidebar.header("Roboflow Configuration")
+rf_api_key = st.sidebar.text_input(
+ "Enter your Roboflow API Key",
+ type="password",
+ value=st.session_state.get("rf_api_key", ""),
+ help="Your secret Roboflow API Key."
+)
+st.session_state["rf_api_key"] = rf_api_key
+
+# ----------------------------------------------------------------------------
+# SIDEBAR: TRAINING SETTINGS
+# ----------------------------------------------------------------------------
+
+st.sidebar.header("Training Settings")
+epochs = st.sidebar.slider(
+ "Number of Epochs",
+ min_value=1,
+ max_value=500,
+ value=st.session_state.get("epochs", 150),
+ help="How many epochs to train for?"
+)
+st.session_state["epochs"] = epochs
+
+batch_size_options = [8, 16, 32, 64, 128]
+default_batch_size = st.session_state.get("batch_size", 32)
+batch_size = st.sidebar.selectbox(
+ "Batch Size",
+ options=batch_size_options,
+ index=batch_size_options.index(default_batch_size)
+ if default_batch_size in batch_size_options else 2,
+ help="Batch size per training step."
+)
+st.session_state["batch_size"] = batch_size
+
+img_size = st.sidebar.number_input(
+ "Image Size (e.g., 640)",
+ min_value=64,
+ max_value=1280,
+ value=st.session_state.get("img_size", 640),
+ step=32,
+ help="Resolution of training images. Larger sizes can improve small object detection."
+)
+st.session_state["img_size"] = img_size
+
+learning_rate = st.sidebar.number_input(
+ "Learning Rate (lr0)",
+ min_value=1e-6,
+ max_value=0.1,
+ value=st.session_state.get("learning_rate", 0.0005),
+ format="%.5f",
+ step=1e-5,
+ help="Initial learning rate."
+)
+st.session_state["learning_rate"] = learning_rate
+
+optimizer = st.sidebar.selectbox(
+ "Optimizer",
+ options=available_optimizers,
+ index=available_optimizers.index(
+ st.session_state.get("optimizer", available_optimizers[0])
+ ) if st.session_state.get("optimizer", available_optimizers[0]) in available_optimizers else 0,
+ help=optimizers_dict.get(
+ st.session_state.get("optimizer", available_optimizers[0]),
+ "No description available."
+ )
+)
+st.session_state["optimizer"] = optimizer
+st.sidebar.markdown(f"**Description**: {optimizers_dict.get(optimizer, 'No description available.')}", unsafe_allow_html=True)
+
+data_augmentation_options = ["None", "Basic", "Moderate", "Advanced"]
+data_augmentation_level = st.sidebar.selectbox(
+ "Data Augmentation Level",
+ data_augmentation_options,
+ index=data_augmentation_options.index(
+ st.session_state.get("data_augmentation_level", "Moderate")
+ ) if st.session_state.get("data_augmentation_level", "Moderate") in data_augmentation_options else 2,
+ help="Choose a preset augmentation level. More aggressive levels can help with robustness but might increase training time."
+)
+st.session_state["data_augmentation_level"] = data_augmentation_level
+
+st.sidebar.subheader("Fine-tune HSV Augmentations")
+st.session_state["hsv_h"] = st.sidebar.slider("HSV Hue Augmentation", min_value=0.0, max_value=0.1, value=st.session_state.get("hsv_h", 0.015), step=0.001, format="%.3f", help="Hue variation for augmentation.")
+st.session_state["hsv_s"] = st.sidebar.slider("HSV Saturation Augmentation", min_value=0.0, max_value=1.0, value=st.session_state.get("hsv_s", 0.7), step=0.01, format="%.2f", help="Saturation variation for augmentation.")
+st.session_state["hsv_v"] = st.sidebar.slider("HSV Value Augmentation", min_value=0.0, max_value=1.0, value=st.session_state.get("hsv_v", 0.4), step=0.01, format="%.2f", help="Value (brightness) variation for augmentation.")
+
+
+custom_model_name = st.sidebar.text_input(
+ "Custom Model Name",
+ value=st.session_state.get("custom_model_name", "rolo_trained_model"),
+ help="Name of the run, used for the training folder (e.g., runs/train/my_custom_model_name)."
+)
+st.session_state["custom_model_name"] = custom_model_name
+
+use_half_precision = st.sidebar.checkbox(
+ "Enable Half-Precision (FP16) on GPU",
+ value=st.session_state.get("use_half_precision", True),
+ help="Requires a GPU (CUDA) for acceleration. If enabled on CPU, it will be ignored."
+)
+st.session_state["use_half_precision"] = use_half_precision
+
+early_stop_patience = st.sidebar.number_input(
+ "Early Stopping Patience (epochs)",
+ min_value=1,
+ max_value=100,
+ value=st.session_state.get("early_stop_patience", 10),
+ help="Stop training if no improvement in validation loss for this many epochs."
+)
+st.session_state["early_stop_patience"] = early_stop_patience
+
+st.sidebar.subheader("Learning Rate Schedule")
+st.session_state["lrf"] = st.sidebar.number_input(
+ "Final LR Factor (lrf)",
+ min_value=0.000,
+ max_value=1.0,
+ value=st.session_state.get("lrf", 0.01),
+ step=0.001,
+ format="%.3f",
+ help="Learning rate final multiplier (lr = lr0 * lrf). Affects learning rate schedule shape."
+)
+st.session_state["warmup_epochs"] = st.sidebar.number_input(
+ "Warmup Epochs",
+ min_value=0,
+ max_value=20,
+ value=st.session_state.get("warmup_epochs", 3),
+ step=1,
+ help="Number of epochs for learning rate warmup. Gradually increases LR at start."
+)
+
+st.sidebar.subheader("Transfer Learning Options")
+st.session_state["freeze_layers"] = st.sidebar.number_input(
+ "Freeze Layers",
+ min_value=0,
+ max_value=30,
+ value=st.session_state.get("freeze_layers", 0),
+ step=1,
+ help="Number of initial layers to freeze during training. Reduces computation and helps prevent overfitting on small datasets. Set to 0 for no freezing."
+)
+
+st.sidebar.subheader("Dataset Caching")
+st.session_state["cache_dataset"] = st.sidebar.checkbox(
+ "Cache Dataset for Training",
+ value=st.session_state.get("cache_dataset", False),
+ help="If checked, Ultralytics will cache preprocessed images. Set to True for RAM caching. Can significantly speed up training on large datasets but requires more memory."
+)
+
+st.sidebar.header("Model Architecture & Category")
+model_config_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/model_config.json"
+model_config = cache_load_model_configs(model_config_url)
+if not model_config:
+ st.warning("Could not load model configuration. Please check your internet connection.")
+ st.stop()
+
+architecture_options = list(model_config.keys())
+selected_architecture = st.sidebar.selectbox(
+ "Select Model Architecture",
+ options=architecture_options,
+ index=architecture_options.index(st.session_state.get("selected_architecture", architecture_options[0])) if st.session_state.get("selected_architecture", architecture_options[0]) in architecture_options else 0
+)
+st.session_state["selected_architecture"] = selected_architecture
+
+all_cats_dict = model_config[selected_architecture]["categories"]
+valid_categories = [cat for cat, models in all_cats_dict.items() if len(models) > 0]
+
+selected_category = st.sidebar.selectbox(
+ "Select Model Category (Task)",
+ valid_categories,
+ index=valid_categories.index(st.session_state.get("selected_category", valid_categories[0])) if valid_categories and st.session_state.get("selected_category", valid_categories[0]) in valid_categories else 0
+)
+st.session_state["selected_category"] = selected_category
+
+pre_trained_models_info = model_config[selected_architecture]["categories"].get(selected_category, [])
+available_pretrained_filenames = [m["filename"] for m in pre_trained_models_info]
+pre_trained_options = available_pretrained_filenames + ["Custom"]
+
+pre_trained_model = st.sidebar.selectbox(
+ "Select Pre-Trained Model",
+ options=pre_trained_options,
+ index=pre_trained_options.index(st.session_state.get("pre_trained_model", available_pretrained_filenames[0])) if pre_trained_options and st.session_state.get("pre_trained_model", available_pretrained_filenames[0]) in pre_trained_options else 0,
+ help="Choose a built-in pretrained model or 'Custom'."
+)
+st.session_state["pre_trained_model"] = pre_trained_model
+
+if pre_trained_model == "Custom":
+ custom_model_path = st.sidebar.text_input(
+ "Enter Custom Model Path",
+ value=st.session_state.get("custom_model_path", ""),
+ help="Local path to a custom .pt file."
+ )
+ st.session_state["custom_model_path"] = custom_model_path
+
+show_debug = st.sidebar.checkbox(
+ "Show Debug Messages",
+ value=st.session_state.get("show_debug", False),
+ help="Display detailed debugging info in the UI."
+)
+st.session_state["show_debug"] = show_debug
+
+st.sidebar.markdown("---")
+st.sidebar.markdown("Developed by [wuhp](https://huggingface.co/wuhp).")
+
+# ----------------------------------------------------------------------------
+# MAIN WORKFLOW STEPS
+# ----------------------------------------------------------------------------
+
+def allow_dataset_download():
+ """
+ After finalizing the dataset, user can download it as a ZIP,
+ or just the data.yaml.
+ """
+ merged_dataset_dir = st.session_state.get('merged_dataset_dir', None)
+ if merged_dataset_dir and os.path.exists(merged_dataset_dir):
+ st.write("## Download Merged Dataset")
+ zip_buf = zip_directory(merged_dataset_dir)
+ st.download_button(
+ label="Download Entire Merged Dataset (ZIP)",
+ data=zip_buf.getvalue(),
+ file_name="merged_dataset.zip",
+ mime="application/octet-stream"
+ )
+ yaml_path = os.path.join(merged_dataset_dir, "data.yaml")
+ if os.path.exists(yaml_path):
+ with open(yaml_path, "r") as f:
+ yaml_data = f.read()
+ st.download_button(
+ label="Download data.yaml",
+ data=yaml_data,
+ file_name="data.yaml",
+ mime="text/yaml"
+ )
+
+# ----------------------------------------------------------------------------
+# 1) Prepare Datasets
+# 2) Class Selection
+# 3) Merge & Adjust
+# 4) Finalize Dataset
+# 5) Train Model
+# 6) Upload Model
+# 7) Generate Examples
+# ----------------------------------------------------------------------------
+
+if st.session_state['workflow_step'] == "Prepare Datasets":
+ st.subheader("Step 1: Prepare Datasets")
+ st.markdown("Choose to load a recommended dataset (for **detection only**) or add your own Roboflow dataset URLs.")
+ st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**")
+
+ tabs = st.tabs(["Recommended Datasets (Detection Only)", "Custom Roboflow Datasets"])
+
+ with tabs[0]:
+ if st.session_state['selected_category'] != "detection":
+ st.warning("Recommended datasets are currently only available for **Detection** task. Please select 'Detection' in the sidebar or use 'Custom Roboflow Datasets' for other tasks.")
+
+ if recommended_datasets and st.session_state['selected_category'] == "detection":
+ st.markdown("### Select Recommended Datasets (for Detection)")
+ selected_recommended = st.multiselect(
+ "Choose datasets to load:",
+ options=[dataset["name"] for dataset in recommended_datasets],
+ format_func=lambda x: next(item for item in recommended_datasets if item["name"] == x)["description"]
+ )
+
+ if st.button("Load Selected Recommended Datasets"):
+ if not selected_recommended:
+ st.error("Please select at least one recommended dataset.")
+ else:
+ dataset_info_list = []
+ progress_bar = st.progress(0)
+ for idx, dataset_name in enumerate(selected_recommended):
+ dataset = next(item for item in recommended_datasets if item["name"] == dataset_name)
+ try:
+ source = dataset["source"]
+ url = dataset["url"]
+ if source in ["GitHub", "Ultralytics"]:
+ github_zip_url = url
+ dloc, class_names, splits = download_and_prepare_github_dataset(
+ github_zip_url,
+ dataset_name
+ )
+ elif source == "Roboflow":
+ workspace, project, version = parse_roboflow_url(url)
+ if not workspace or not project:
+ st.warning(f"Roboflow URL '{url}' is invalid.")
+ continue
+
+ if not version:
+ version = get_latest_version(Roboflow(api_key=rf_api_key), workspace, project)
+ if not version:
+ st.warning(f"Could not retrieve latest version for '{project}'. Skipping.")
+ continue
+
+ st.info(f"[DATASET PREP] Downloading Roboflow dataset '{project}' (v{version})...")
+ dloc, class_names, splits = download_and_prepare_roboflow_dataset(
+ rf_api_key,
+ workspace,
+ project,
+ version,
+ st.session_state["selected_architecture"],
+ st.session_state["selected_category"]
+ )
+ else:
+ st.error(f"Unknown source '{source}' for dataset '{dataset_name}'.")
+ continue
+
+ if dloc:
+ dataset_info_list.append((dloc, class_names, splits, dataset_name))
+ st.success(f"Loaded recommended dataset '{dataset_name}'.")
+ except Exception as e:
+ st.error(f"Failed to load dataset '{dataset_name}': {e}")
+
+ progress_bar.progress((idx + 1) / len(selected_recommended))
+
+ if dataset_info_list:
+ st.session_state['dataset_info_list'].extend(dataset_info_list)
+ st.session_state['dataset_prepared'] = True
+ st.success("Selected recommended detection datasets have been loaded successfully.")
+ st.session_state['workflow_step'] = "Class Selection"
+ elif not recommended_datasets:
+ st.warning("No recommended detection datasets available. Please check the JSON file.")
+
+ with tabs[1]:
+ st.markdown("### Add Custom Roboflow Datasets")
+ st.markdown("Upload a `.txt` file containing Roboflow dataset URLs, one per line.")
+
+ custom_rf_uploaded_file = st.file_uploader(
+ "Upload a .txt file with Roboflow dataset URLs",
+ type=["txt"],
+ help="Example: https://universe.roboflow.com/workspace/project/version..."
+ )
+
+ st.markdown("Split ratios for your dataset (will be applied *before* merging):")
+ col_split_ratio1, col_split_ratio2, col_split_ratio3 = st.columns(3)
+ with col_split_ratio1:
+ st.session_state['train_ratio'] = st.slider("Train Split (%)", min_value=50, max_value=90, value=int(st.session_state.get('train_ratio', 0.7)*100), step=5) / 100.0
+ with col_split_ratio2:
+ st.session_state['val_ratio'] = st.slider("Validation Split (%)", min_value=5, max_value=20, value=int(st.session_state.get('val_ratio', 0.2)*100), step=5) / 100.0
+ with col_split_ratio3:
+ st.session_state['test_ratio'] = st.slider("Test Split (%)", min_value=0, max_value=20, value=int(st.session_state.get('test_ratio', 0.1)*100), step=5) / 100.0
+
+ sum_ratios = st.session_state['train_ratio'] + st.session_state['val_ratio'] + st.session_state['test_ratio']
+ if not (0.99 <= sum_ratios <= 1.01):
+ st.warning(f"Warning: Split ratios sum to {sum_ratios*100:.1f}%. They should sum to 100% for balanced splitting.")
+
+
+ if st.button("Load Custom Roboflow Datasets"):
+ if not rf_api_key:
+ st.error("Please enter your Roboflow API Key.")
+ elif not custom_rf_uploaded_file:
+ st.error("Please upload a `.txt` file containing Roboflow dataset URLs.")
+ else:
+ try:
+ content = custom_rf_uploaded_file.read().decode("utf-8")
+ urls = content.strip().split('\n')
+ urls = [url.strip() for url in urls if url.strip()]
+ if not urls:
+ st.error("The uploaded file is empty or contains invalid URLs.")
+ st.stop()
+ except Exception as e:
+ st.error(f"Failed to read the uploaded file: {e}")
+ st.stop()
+
+ dataset_info_list = []
+ progress_bar = st.progress(0)
+ for idx, url in enumerate(urls):
+ workspace, project, version = parse_roboflow_url(url)
+ if not workspace or not project:
+ st.warning(f"URL '{url}' is invalid.")
+ continue
+
+ if not version:
+ version = get_latest_version(Roboflow(api_key=rf_api_key), workspace, project)
+ if not version:
+ st.warning(f"Could not retrieve latest version for '{project}'. Skipping.")
+ continue
+
+ st.info(f"[DATASET PREP] Downloading dataset '{project}' (v{version}) from workspace '{workspace}'...")
+ dloc, class_names, splits = download_and_prepare_roboflow_dataset(
+ rf_api_key,
+ workspace,
+ project,
+ version,
+ st.session_state["selected_architecture"],
+ st.session_state["selected_category"]
+ )
+ if dloc:
+ dataset_info_list.append((dloc, class_names, splits, f"{project}_v{version}"))
+ st.success(f"Dataset '{project}' (v{version}) prepared at '{dloc}'")
+ progress_bar.progress((idx + 1) / len(urls))
+
+ if not dataset_info_list:
+ st.error("No datasets were prepared successfully. Please check the URLs and try again.")
+ else:
+ st.session_state['dataset_info_list'].extend(dataset_info_list)
+ st.session_state['dataset_prepared'] = True
+ st.success("All custom Roboflow datasets have been prepared successfully.")
+ st.session_state['workflow_step'] = "Class Selection"
+
+elif st.session_state['workflow_step'] == "Class Selection":
+ if not st.session_state.get('dataset_prepared', False):
+ st.warning("Datasets not prepared yet. Go to 'Prepare Datasets' first.")
+ else:
+ st.subheader("Step 2: Class Selection, Renaming, and Removal")
+ st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**")
+
+ class_name_mapping = st.session_state['class_name_mapping']
+ class_image_counts_pre_merge = gather_class_counts(
+ st.session_state['dataset_info_list'],
+ class_name_mapping,
+ st.session_state['selected_category']
+ )
+ st.session_state['class_image_counts'] = class_image_counts_pre_merge
+
+ all_class_names_raw = []
+ for _, class_names, _, _ in st.session_state['dataset_info_list']:
+ all_class_names_raw.extend(class_names)
+ all_class_names_raw = sorted(set(all_class_names_raw))
+
+ merged_names = set()
+ for class_name in all_class_names_raw:
+ new_name = class_name_mapping.get(class_name, class_name)
+ merged_names.add(new_name)
+ merged_names = sorted(list(merged_names))
+
+ st.write("### Current Merged Class Distribution:")
+ if st.session_state['selected_category'] == "classification":
+ st.write("Counts represent **images** per class.")
+ else:
+ st.write("Counts represent **images that contain at least one annotation** for a class.")
+ for cls in merged_names:
+ st.write(f"- **{cls}**: {class_image_counts_pre_merge.get(cls, 0)} available")
+
+ with st.form("class_selection_form"):
+ classes_to_include = {}
+ new_mapping = {}
+
+ st.markdown("""
+ **Instructions**:
+ - Use **"Rename Class"** to merge classes by assigning the **same** new name to multiple classes.
+ - Use **"Max Images"** to limit how many images are pulled from that class.
+ - Check **"Remove"** if you want to exclude the class entirely.
+ """)
+
+ enable_fuzzy_suggestions = st.checkbox("Enable Fuzzy Class Renaming Suggestions", value=False, help="Suggests similar existing class names for renaming.")
+
+ for c in all_class_names_raw:
+ current_final_name = class_name_mapping.get(c, c)
+ current_avail = class_image_counts_pre_merge.get(current_final_name, 0)
+
+ remove_checkbox_col, rename_col, images_col = st.columns([1, 2.5, 2.5])
+
+ with remove_checkbox_col:
+ remove_class = st.checkbox(
+ "Remove",
+ value=False,
+ key=f"remove_class_{c}",
+ help=f"Check to exclude '{c}' entirely."
+ )
+
+ with rename_col:
+ rename_val = st.text_input(
+ label=f"Rename '{c}'",
+ value=current_final_name,
+ key=f"rename_input_{c}"
+ )
+ if enable_fuzzy_suggestions:
+ suggestions = get_close_matches(rename_val, all_class_names_raw, n=3, cutoff=0.7)
+ suggestions = [s for s in suggestions if s != c and s != rename_val]
+ if suggestions:
+ st.markdown(f"Suggestions: {', '.join(suggestions)}", unsafe_allow_html=True)
+
+ with images_col:
+ max_value_for_input = current_avail if current_avail > 0 else 0
+ default_value = current_avail
+ disabled_flag = remove_class
+
+ new_count = st.number_input(
+ f"Max Images (Avail: {current_avail})",
+ min_value=0,
+ max_value=max_value_for_input,
+ value=default_value,
+ step=1,
+ disabled=disabled_flag,
+ key=f"max_images_{c}"
+ )
+ if remove_class:
+ new_count = 0
+
+ new_mapping[c] = rename_val.strip()
+ classes_to_include[c] = new_count
+
+ st.info("If multiple classes share the same renamed name, they are merged into one. Counts refer to images.")
+ confirm = st.form_submit_button("Confirm Selections & Renaming")
+
+ if confirm:
+ st.session_state['classes_to_include'] = classes_to_include
+ st.session_state['class_name_mapping'] = new_mapping
+ st.session_state['merging_confirmed'] = True
+ st.success("Class selections, renaming, and removal confirmed.")
+ st.session_state['workflow_step'] = "Merge & Adjust"
+
+elif st.session_state['workflow_step'] == "Merge & Adjust":
+ if not st.session_state.get('merging_confirmed', False):
+ st.warning("Please complete 'Class Selection' first.")
+ else:
+ st.subheader("Step 3: Merge & Adjust Datasets")
+ st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**")
+
+ if st.button("Merge Datasets"):
+ merged_dataset_dir = None
+ if st.session_state['selected_category'] == "classification":
+ merged_dataset_dir = merge_classification_datasets(
+ st.session_state['dataset_info_list'],
+ st.session_state['classes_to_include'],
+ st.session_state['class_name_mapping'],
+ show_debug
+ )
+ else:
+ merged_dataset_dir = merge_datasets(
+ st.session_state['dataset_info_list'],
+ st.session_state['classes_to_include'],
+ st.session_state['class_name_mapping'],
+ show_debug
+ )
+ if merged_dataset_dir:
+ st.success(f"Merged dataset created at '{merged_dataset_dir}'. You can now adjust it.")
+
+ if st.session_state.get('merged_dataset_dir', None):
+ adjust_selected_images()
+
+elif st.session_state['workflow_step'] == "Finalize Dataset":
+ if not st.session_state.get('adjustment_confirmed', False):
+ st.warning("Please complete 'Merge & Adjust' first.")
+ else:
+ st.subheader("Step 4: Finalize Merged Dataset")
+ st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**")
+
+ st.markdown("### Automated Dataset Splitting")
+ st.info("You can re-split the merged dataset into train/validation/test sets after finalization. This step assumes a YOLO-compatible dataset structure with 'images' and 'labels' folders within each split.")
+
+ col_split_auto_1, col_split_auto_2, col_split_auto_3 = st.columns(3)
+ with col_split_auto_1:
+ st.session_state['train_ratio'] = st.number_input("Train Ratio", min_value=0.0, max_value=1.0, value=st.session_state.get("train_ratio", 0.7), step=0.05, format="%.2f")
+ with col_split_auto_2:
+ st.session_state['val_ratio'] = st.number_input("Validation Ratio", min_value=0.0, max_value=1.0, value=st.session_state.get("val_ratio", 0.2), step=0.05, format="%.2f")
+ with col_split_auto_3:
+ st.session_state['test_ratio'] = st.number_input("Test Ratio", min_value=0.0, max_value=1.0, value=st.session_state.get("test_ratio", 0.1), step=0.05, format="%.2f")
+
+ sum_ratios = st.session_state['train_ratio'] + st.session_state['val_ratio'] + st.session_state['test_ratio']
+ if not (0.99 <= sum_ratios <= 1.01):
+ st.warning(f"Warning: Split ratios sum to {sum_ratios*100:.1f}%. It is recommended they sum to 100%.")
+
+ if st.button("Finalize Dataset"):
+ finalize_merged_dataset()
+ if st.session_state.get('dataset_finalized', False):
+ st.session_state['workflow_step'] = "Train Model"
+ st.success("Dataset finalized. Ready for training!")
+
+ if st.session_state.get('dataset_location') and st.button("Re-split Finalized Dataset"):
+ merged_dataset_path = st.session_state['dataset_location']
+ if merged_dataset_path and os.path.exists(merged_dataset_path):
+ if st.session_state['selected_category'] == "classification":
+ st.warning("Automated re-splitting is not currently supported for Classification datasets. They use a folder-per-class structure.")
+ else:
+ try:
+ with st.spinner("Re-splitting dataset... This might take a moment."):
+ for split_dir in ['train', 'valid', 'test']:
+ if os.path.exists(os.path.join(merged_dataset_path, split_dir)):
+ shutil.rmtree(os.path.join(merged_dataset_path, split_dir), onerror=handle_remove_readonly)
+
+ autosplit(
+ path=merged_dataset_path,
+ weights=(st.session_state['train_ratio'], st.session_state['val_ratio'], st.session_state['test_ratio']),
+ random=True,
+ verbose=True
+ )
+ st.success(f"Dataset successfully re-split into {st.session_state['train_ratio']*100}% train, {st.session_state['val_ratio']*100}% val, {st.session_state['test_ratio']*100}% test.")
+
+ data_yaml_path = os.path.join(merged_dataset_path, 'data.yaml')
+ with open(data_yaml_path, 'r') as f:
+ data_yaml = yaml.safe_load(f)
+ data_yaml['train'] = 'train/images'
+ data_yaml['val'] = 'valid/images'
+ data_yaml['test'] = 'test/images'
+ with open(data_yaml_path, 'w') as f:
+ yaml.safe_dump(data_yaml, f)
+
+ except Exception as e:
+ st.error(f"Failed to re-split dataset: {e}. Ensure the merged dataset has a compatible structure (e.g., images and labels directly under train/val/test subfolders).")
+ logging.error(f"Failed to re-split dataset: {e}")
+
+ allow_dataset_download()
+
+elif st.session_state['workflow_step'] == "Train Model":
+ st.subheader("Step 5: Train Model")
+ st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**")
+
+ auto_download_model = st.checkbox("Automatically present model download after training completes?", value=True)
+
+ if st.button("Start Training"):
+ if not st.session_state.get('dataset_finalized', False):
+ st.error("Datasets are not fully finalized. Please finalize the dataset first.")
+ else:
+ # Clear previous metrics and progress text before starting new training
+ # This is done *here* before starting training, not in the callback.
+ st.session_state['metrics_data'] = {
+ 'epoch': [], 'train_loss': [], 'val_loss': [],
+ 'mAP50': [], 'mAP50_95': [],
+ 'top1_acc': [], 'top5_acc': [],
+ 'mAPpose50': [], 'mAPpose50_95': [],
+ 'precision': [], 'recall': [], 'F1': []
+ }
+ loss_chart_placeholder.empty()
+ map_chart_placeholder.empty()
+ progress_text_placeholder.empty()
+
+ dataset_location = st.session_state['dataset_location']
+ with st.spinner("Training in progress..."):
+ model = train_model(
+ dataset_location = dataset_location,
+ epochs = st.session_state["epochs"],
+ batch_size = st.session_state["batch_size"],
+ img_size = st.session_state["img_size"],
+ learning_rate = st.session_state["learning_rate"],
+ optimizer = st.session_state["optimizer"],
+ data_augmentation_level = st.session_state["data_augmentation_level"],
+ pre_trained_model = st.session_state["pre_trained_model"],
+ custom_model_path = st.session_state.get("custom_model_path", ""),
+ custom_model_name = st.session_state["custom_model_name"],
+ show_debug = st.session_state["show_debug"],
+ model_config = model_config,
+ selected_architecture = st.session_state["selected_architecture"],
+ selected_category = st.session_state["selected_category"]
+ )
+ if model:
+ st.success("Model training completed!")
+ # Clear final progress text
+ progress_text_placeholder.empty()
+ trained_weights_path = os.path.join(
+ "runs", "train", st.session_state["custom_model_name"], "weights", "best.pt"
+ )
+ if auto_download_model and os.path.exists(trained_weights_path):
+ st.write("### Download Your Trained Model:")
+ with open(trained_weights_path, "rb") as f:
+ st.download_button(
+ label="Download best.pt",
+ data=f,
+ file_name=f"{st.session_state['custom_model_name']}.pt",
+ mime="application/octet-stream",
+ )
+ elif not os.path.exists(trained_weights_path):
+ st.warning("Trained model not found. Check your training run output directory.")
+ st.session_state['workflow_step'] = "Upload Model"
+
+elif st.session_state['workflow_step'] == "Upload Model":
+ st.subheader("Step 6: Upload Model")
+ st.markdown("Upload your trained model and optionally a custom README, **OR** generate one from a template.")
+ st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**")
+
+ col1, col2 = st.columns(2)
+ with col1:
+ model_file = st.file_uploader("Upload Trained Model (.pt)", type=["pt"], help="Select your trained model file.")
+ with col2:
+ readme_file = st.file_uploader("Upload README Content (.txt)", type=["txt"], help="Select a `.txt` file with your README content.")
+
+ st.markdown("---")
+ st.markdown("**Or create a README from a template on GitHub**:")
+ st.info("Note: Default template is optimized for detection tasks and may not fully cover classification/keypoint specifics.")
+ generate_from_template = st.button("Generate README From Template")
+
+ if "final_readme_content_editor" not in st.session_state:
+ st.session_state["final_readme_content_editor"] = ""
+
+ if generate_from_template:
+ try:
+ template_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/template.txt"
+ response = requests.get(template_url)
+ response.raise_for_status()
+ template_text = response.text
+
+ st.session_state["final_readme_content_editor"] = customize_template(
+ template_text,
+ st.session_state.get("model_url", "N/A"),
+ st.session_state["selected_category"],
+ st.session_state["metrics_data"],
+ st.session_state.get("class_id_mapping", {}),
+ st.session_state.get('datasets_used', []),
+ st.session_state.get('class_image_counts', {})
+ )
+ st.success("README template loaded and placeholders filled. You can edit below.")
+ except requests.exceptions.RequestException as e:
+ st.error(f"Failed to fetch README template from GitHub: {e}")
+ except Exception as e:
+ st.error(f"Error generating README from template: {e}")
+
+ if st.session_state["final_readme_content_editor"]:
+ st.text_area(
+ "Generated README Content",
+ value=st.session_state["final_readme_content_editor"],
+ height=300,
+ key="final_readme_content_editor"
+ )
+
+ st.markdown("### GitHub Credentials")
+ gh_username = st.text_input("GitHub Username", value=st.session_state.get('gh_username', ''), help="Your GitHub username or organization.")
+ gh_repo_name = st.text_input("GitHub Repository Name", value=st.session_state.get('gh_repo_name', ''), help="Just the name of the repository (no slashes).")
+ gh_personal_token = st.text_input("GitHub Personal Access Token", type="password", value=st.session_state.get('gh_personal_token', ''), help="Personal access token for authentication.")
+ gh_make_private = st.checkbox("Make new GitHub repo private?", value=st.session_state.get('gh_make_private', False))
+
+ st.session_state['gh_username'] = gh_username
+ st.session_state['gh_repo_name'] = gh_repo_name
+ st.session_state['gh_personal_token'] = gh_personal_token
+ st.session_state['gh_make_private'] = gh_make_private
+
+ st.markdown("### Hugging Face Credentials")
+ hf_api_key = st.text_input("Hugging Face API Key", type="password", value=st.session_state.get('hf_api_key', ''), help="Your Hugging Face API key.")
+ hf_model_repo = st.text_input("Hugging Face Model Repository Name", value=st.session_state.get('hf_model_repo', ''), help="Format: username/repo-name (e.g., 'myusername/my-model-repo')")
+
+ st.session_state['hf_api_key'] = hf_api_key
+ st.session_state['hf_model_repo'] = hf_model_repo
+
+
+ if st.button("Upload to GitHub and Hugging Face"):
+ final_readme_str = ""
+
+ if readme_file is not None:
+ final_readme_str = readme_file.read().decode("utf-8")
+ elif st.session_state["final_readme_content_editor"]:
+ final_readme_str = st.session_state["final_readme_content_editor"]
+ else:
+ st.warning("No README content provided or generated. Uploading without a README.")
+
+ if not model_file:
+ st.error("Please upload the trained model (.pt) before publishing.")
+ else:
+ temp_dir = "temp_upload_data"
+ os.makedirs(temp_dir, exist_ok=True)
+
+ # --- FIX START ---
+ model_filename_to_upload = st.session_state.get('custom_model_name', 'best') + '.pt'
+ # Define all temporary paths here, before any conditional logic that uses them
+ temp_model_path = os.path.join(temp_dir, model_filename_to_upload)
+ temp_readme_path = os.path.join(temp_dir, "README.md")
+ temp_license_path = os.path.join(temp_dir, "Licensing.rolo")
+ # --- FIX END ---
+
+ trained_weights_path = os.path.join(
+ "runs", "train", st.session_state["custom_model_name"], "weights", "best.pt"
+ )
+ if os.path.exists(trained_weights_path):
+ shutil.copy(trained_weights_path, temp_model_path)
+ st.info(f"Using locally trained model: {trained_weights_path}")
+ else:
+ with open(temp_model_path, "wb") as f:
+ f.write(model_file.read())
+ st.info("Using uploaded model file.")
+
+
+ if final_readme_str.strip():
+ with open(temp_readme_path, "w", encoding="utf-8") as f:
+ f.write(final_readme_str)
+
+ license_url = "https://raw.githubusercontent.com/wuhplaptop/sdadasgds/main/Licensing.rolo"
+ try:
+ lic_response = requests.get(license_url)
+ lic_response.raise_for_status()
+ with open(temp_license_path, "wb") as f:
+ f.write(lic_response.content)
+ except Exception as e:
+ st.error(f"Failed to download license file: {e}")
+ st.stop()
+
+ if gh_username and gh_repo_name and gh_personal_token:
+ with st.spinner("Uploading to GitHub..."):
+ status_code, resp = create_github_repo(gh_username, gh_personal_token, gh_repo_name, private=gh_make_private)
+ if status_code in [201, 422]:
+ if status_code == 422:
+ st.info(f"GitHub repo '{gh_repo_name}' already exists. Proceeding to upload/update files.")
+ else:
+ st.success(f"GitHub repo '{gh_repo_name}' created successfully.")
+ else:
+ st.error(f"Failed to create/access GitHub repo: {resp.get('message', 'Unknown error')}")
+ st.stop()
+
+ try:
+ gh_api_url = f"https://api.github.com/repos/{gh_username}/{gh_repo_name}/contents/"
+ headers = {
+ "Authorization": f"token {gh_personal_token}",
+ "Accept": "application/vnd.github.v3+json",
+ }
+
+ def upload_to_github(file_path, repo_path):
+ check_url = gh_api_url + repo_path
+ check_response = requests.get(check_url, headers=headers)
+ sha = None
+ if check_response.status_code == 200:
+ sha = check_response.json().get('sha')
+
+ with open(file_path, "rb") as f:
+ content = base64.b64encode(f.read()).decode("utf-8")
+
+ data_payload = {"message": f"Upload {repo_path} from Rolo", "content": content}
+ if sha:
+ data_payload["sha"] = sha
+
+ response = requests.put(
+ check_url,
+ headers=headers,
+ json=data_payload,
+ )
+ return response.status_code, response.json()
+
+ status, response_json = upload_to_github(temp_model_path, model_filename_to_upload)
+ if status in [200, 201]:
+ st.success(f"Uploaded {model_filename_to_upload} to GitHub.")
+ st.session_state['model_url'] = f"https://github.com/{gh_username}/{gh_repo_name}/raw/main/{model_filename_to_upload}"
+ else:
+ st.error(f"Failed to upload {model_filename_to_upload} to GitHub: {response_json.get('message', 'Unknown error')}")
+
+ if final_readme_str.strip():
+ status, response_json = upload_to_github(temp_readme_path, "README.md")
+ if status in [200, 201]:
+ st.success("Uploaded README.md to GitHub.")
+ else:
+ st.error(f"Failed to upload README.md to GitHub: {response_json.get('message', 'Unknown error')}")
+
+ status, response_json = upload_to_github(temp_license_path, "Licensing.rolo")
+ if status in [200, 201]:
+ st.success("Uploaded Licensing.rolo to GitHub.")
+ else:
+ st.error(f"Failed to upload Licensing.rolo to GitHub: {response_json.get('message', 'Unknown error')}")
+
+ except Exception as e:
+ st.error(f"Failed to upload to GitHub: {e}")
+
+ else:
+ st.error("Please provide GitHub credentials to upload the model.")
+
+ if hf_api_key and hf_model_repo:
+ with st.spinner("Uploading to Hugging Face..."):
+ try:
+ api = HfApi()
+ HfFolder.save_token(hf_api_key)
+
+ repo_url = api.create_repo(repo_id=hf_model_repo, exist_ok=True, token=hf_api_key)
+
+ api.upload_file(
+ path_or_fileobj=temp_model_path,
+ path_in_repo=model_filename_to_upload,
+ repo_id=hf_model_repo,
+ token=hf_api_key
+ )
+ st.success(f"Uploaded model to Hugging Face: {repo_url}")
+ st.session_state['model_url'] = f"https://huggingface.co/{hf_model_repo}/blob/main/{model_filename_to_upload}"
+
+ if final_readme_str.strip():
+ api.upload_file(
+ path_or_fileobj=temp_readme_path,
+ path_in_repo="README.md",
+ repo_id=hf_model_repo,
+ token=hf_api_key
+ )
+ st.success("Uploaded README.md to Hugging Face.")
+
+ api.upload_file(
+ path_or_fileobj=temp_license_path,
+ path_in_repo="Licensing.rolo",
+ repo_id=hf_model_repo,
+ token=hf_api_key
+ )
+ st.success("Uploaded Licensing.rolo to Hugging Face.")
+
+ except Exception as e:
+ st.error(f"Failed to upload to Hugging Face: {e}")
+ else:
+ st.warning("Hugging Face credentials not provided. Skipping Hugging Face upload.")
+
+ shutil.rmtree(temp_dir, ignore_errors=True)
+ st.success("Upload process completed.")
+
+ if st.button("Proceed to Generate Examples"):
+ st.session_state['workflow_step'] = "Generate Examples"
+ st.rerun()
+
+elif st.session_state['workflow_step'] == "Generate Examples":
+ st.subheader("Step 7: Generate Code Examples")
+ st.markdown("""
+ These templates are primarily designed for **detection tasks**.
+ If you need examples for other tasks (classification, keypoint),
+ you'll have to provide or create different templates on the GitHub repository.
+ """)
+ st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**")
+
+ if st.button("Detect Model Location"):
+ detect_model_location()
+ if st.session_state['model_location']:
+ st.success(f"Model detected on {st.session_state['model_location'].capitalize()}: {st.session_state['model_url']}")
+ else:
+ st.warning("Could not automatically detect model location. Please input the model URL manually.")
+
+ if not st.session_state['model_location'] or not st.session_state.get('model_url'):
+ model_url_input = st.text_input("Enter your Model URL (Huggingface or GitHub)", value=st.session_state.get('model_url', ""))
+ if st.button("Set Model URL for Examples"):
+ if model_url_input:
+ st.session_state['model_url'] = model_url_input
+ parsed_url = urlparse(model_url_input)
+ if 'huggingface.co' in parsed_url.netloc:
+ st.session_state['model_location'] = 'huggingface'
+ elif 'github.com' in parsed_url.netloc:
+ st.session_state['model_location'] = 'github'
+ else:
+ st.error("Unsupported model URL. Please provide a Huggingface or GitHub URL.")
+ else:
+ st.error("Please enter a valid model URL.")
+
+ if st.session_state.get('model_location') and st.session_state.get('model_url'):
+ example_type = st.selectbox("Select Example Type", ["Python", "Colab"])
+ if st.button("Generate Examples"):
+ owner = "Wuhpondiscord"
+ repo = "models"
+ path = f"examples/{example_type.lower()}"
+ token = st.session_state.get('gh_personal_token')
+
+ files = list_github_files(owner, repo, path, token=token)
+ if not files:
+ st.info(f"No example templates found for {example_type} for this task type. Please add them to '{owner}/{repo}/{path}'.")
+ else:
+ st.session_state['templates_found'] = True
+ st.success(f"Found {len(files)} template(s) for {example_type} examples in {path}.")
+
+ generated_examples = {}
+ for fname in files:
+ template_content = fetch_github_file(owner, repo, f"{path}/{fname}", token=token)
+ if template_content:
+ customized_content = customize_template(
+ template_content,
+ st.session_state['model_url'],
+ st.session_state["selected_category"],
+ st.session_state["metrics_data"],
+ st.session_state.get("class_id_mapping", {}),
+ st.session_state.get('datasets_used', []),
+ st.session_state.get('class_image_counts', {})
+ )
+ generated_examples[fname] = customized_content
+
+ if generated_examples:
+ st.markdown("### Generated Examples")
+ for fname, content in generated_examples.items():
+ st.markdown(f"**{fname}**")
+ st.text_area(f"Example: {fname}", value=content, height=300, key=f"example_text_{fname}")
+ download_link = get_download_link(content, fname, "text/plain")
+ st.markdown(download_link, unsafe_allow_html=True)
+
+ zip_bytes = generate_zip(generated_examples)
+ st.download_button(
+ label="Download All Examples as ZIP",
+ data=zip_bytes,
+ file_name=f"{example_type.lower()}_examples.zip",
+ mime="application/zip"
+ )
\ No newline at end of file