diff --git "a/app.py" "b/app.py" new file mode 100644--- /dev/null +++ "b/app.py" @@ -0,0 +1,2772 @@ +import os +import shutil +import stat +import yaml +import streamlit as st +from ultralytics import YOLO +from roboflow import Roboflow +import re +from urllib.parse import urlparse +import random +import logging +import requests +import json +import concurrent.futures +import multiprocessing +from difflib import get_close_matches +from io import StringIO +import matplotlib.pyplot as plt +import base64 +import io +import zipfile +from huggingface_hub import HfApi, HfFolder +import torch +from PIL import Image +from ultralytics.data.utils import autosplit + +# Configure logging +logging.basicConfig( + filename='rolo_app.log', + level=logging.INFO, + format='%(asctime)s - %(message)s' +) + +# ------------------------------------------------------------------------- +# ADDITIONAL SESSION STATE FOR EARLY STOPPING AND NEW METRICS +# ------------------------------------------------------------------------- +if 'early_stop_patience' not in st.session_state: + st.session_state['early_stop_patience'] = 10 +if 'best_val_loss' not in st.session_state: + st.session_state['best_val_loss'] = float('inf') +if 'epochs_no_improvement' not in st.session_state: + st.session_state['epochs_no_improvement'] = 0 + +# ---------------------------------------------------------------------------- +# SESSION STATE INITIALIZATION +# ---------------------------------------------------------------------------- +if 'workflow_step' not in st.session_state: + st.session_state['workflow_step'] = "Prepare Datasets" + +if 'dataset_info_list' not in st.session_state: + st.session_state['dataset_info_list'] = [] +if 'dataset_location' not in st.session_state: + st.session_state['dataset_location'] = '' +if 'dataset_prepared' not in st.session_state: + st.session_state['dataset_prepared'] = False +if 'classes_to_include' not in st.session_state: + st.session_state['classes_to_include'] = {} +if 'class_name_mapping' not in st.session_state: + st.session_state['class_name_mapping'] = {} +if 'merging_confirmed' not in st.session_state: + st.session_state['merging_confirmed'] = False +if 'selected_images' not in st.session_state: + st.session_state['selected_images'] = set() +if 'class_counters' not in st.session_state: + st.session_state['class_counters'] = {} +if 'adjustment_confirmed' not in st.session_state: + st.session_state['adjustment_confirmed'] = False +if 'dataset_finalized' not in st.session_state: + st.session_state['dataset_finalized'] = False + +if 'metrics_data' not in st.session_state: + st.session_state['metrics_data'] = { + 'epoch': [], + 'train_loss': [], + 'val_loss': [], + 'mAP50': [], + 'mAP50_95': [], + 'top1_acc': [], + 'top5_acc': [], + 'mAPpose50': [], + 'mAPpose50_95': [], + # Add new metrics + 'precision': [], + 'recall': [], + 'F1': [] + } +if 'is_rerunning' not in st.session_state: + st.session_state['is_rerunning'] = False +if 'class_image_counts' not in st.session_state: + st.session_state['class_image_counts'] = {} +if 'datasets_used' not in st.session_state: + st.session_state['datasets_used'] = [] +if 'model_location' not in st.session_state: + st.session_state['model_location'] = None +if 'model_url' not in st.session_state: + st.session_state['model_url'] = "" +if 'templates_found' not in st.session_state: + st.session_state['templates_found'] = False +if 'final_readme_content_editor' not in st.session_state: + st.session_state['final_readme_content_editor'] = "" +if 'total_class_images' not in st.session_state: + st.session_state['total_class_images'] = {} + +if 'lrf' not in st.session_state: + st.session_state['lrf'] = 0.01 +if 'warmup_epochs' not in st.session_state: + st.session_state['warmup_epochs'] = 3 +if 'cache_dataset' not in st.session_state: + st.session_state['cache_dataset'] = False +if 'freeze_layers' not in st.session_state: + st.session_state['freeze_layers'] = 0 +if 'hsv_h' not in st.session_state: + st.session_state['hsv_h'] = 0.015 +if 'hsv_s' not in st.session_state: + st.session_state['hsv_s'] = 0.7 +if 'hsv_v' not in st.session_state: + st.session_state['hsv_v'] = 0.4 + +if 'min_bbox_area' not in st.session_state: + st.session_state['min_bbox_area'] = 0.0001 +if 'max_bbox_area' not in st.session_state: + st.session_state['max_bbox_area'] = 1.0 +if 'min_visible_keypoints' not in st.session_state: + st.session_state['min_visible_keypoints'] = 1 + +if 'train_ratio' not in st.session_state: + st.session_state['train_ratio'] = 0.7 +if 'val_ratio' not in st.session_state: + st.session_state['val_ratio'] = 0.2 +if 'test_ratio' not in st.session_state: + st.session_state['test_ratio'] = 0.1 + +if 'kpt_shape' not in st.session_state: + st.session_state['kpt_shape'] = None + +loss_chart_placeholder = st.empty() +map_chart_placeholder = st.empty() +progress_text_placeholder = st.empty() # New placeholder for epoch progress + +# ---------------------------------------------------------------------------- +# CACHED FUNCTIONS (UNCHANGED) +# ---------------------------------------------------------------------------- + +@st.cache_data(show_spinner=False) +def load_optimizers(url: str): + """ + [DATASET PREP] Load available optimizers and their descriptions from an external JSON. + """ + try: + response = requests.get(url) + response.raise_for_status() + return response.json() + except Exception as e: + logging.error(f"[DATASET PREP] Failed to load optimizers: {e}") + return {} + +@st.cache_data(show_spinner=False) +def load_recommended_datasets(url: str): + """ + [DATASET PREP] Load recommended datasets from an external JSON file hosted on GitHub. + Each dataset should have 'name', 'description', 'source', and 'url'. + NOTE: These recommended datasets are meant for *detection tasks* only. + """ + try: + response = requests.get(url) + response.raise_for_status() + return response.json() + except Exception as e: + logging.error(f"[DATASET PREP] Failed to load recommended datasets: {e}") + return [] + +@st.cache_data(show_spinner=False) +def load_presets(url: str): + """ + Load configuration presets from an external JSON. + """ + try: + response = requests.get(url) + response.raise_for_status() + return response.json() + except Exception as e: + logging.error(f"[DATASET PREP] Failed to load presets: {e}") + return {} + +@st.cache_data(show_spinner=False) +def cache_load_model_configs(url: str): + """ + Loads model architectures and their pre-trained models from a remote JSON. + The JSON now has categories for each architecture: + model_config[arch]["categories"][category] -> list of model entries + """ + try: + response = requests.get(url) + response.raise_for_status() + return response.json() + except Exception as e: + logging.error(f"[DATASET PREP] Failed to load model configurations: {e}") + return None + +@st.cache_resource(show_spinner=False) +def cache_download_pretrained_model(model_filename, download_url, models_dir='pretrained_models'): + """ + [DATASET PREP] Example of using Streamlit's @st.cache_resource to cache a file download. + If the file already exists locally, we skip downloading. + """ + if not os.path.exists(models_dir): + os.makedirs(models_dir, exist_ok=True) + logging.info(f"[DATASET PREP] Created directory '{models_dir}' for pre-trained models.") + + model_path = os.path.join(models_dir, model_filename) + if os.path.exists(model_path): + logging.info(f"[DATASET PREP] Pre-trained model '{model_filename}' already exists locally.") + return model_path + + try: + logging.info(f"[DATASET PREP] Starting download for model '{model_filename}' from '{download_url}'.") + response = requests.get(download_url, stream=True) + response.raise_for_status() + + with open(model_path, 'wb') as f: + for data in response.iter_content(chunk_size=4096): + f.write(data) + + logging.info(f"[DATASET PREP] Downloaded '{model_filename}' successfully.") + return model_path + except Exception as e: + logging.error(f"[DATASET PREP] Failed to download model '{model_filename}': {e}") + return None + +# ---------------------------------------------------------------------------- +# HELPER FUNCTIONS +# ---------------------------------------------------------------------------- + +def handle_remove_readonly(func, path, exc_info): + """Function to handle errors during shutil.rmtree (Windows file permissions).""" + os.chmod(path, stat.S_IWRITE) + func(path) + +def parse_roboflow_url(url): + """[DATASET PREP] Parses a Roboflow URL to extract workspace, project, and version number.""" + parsed_url = urlparse(url.strip()) + path_parts = parsed_url.path.strip('/').split('/') + + if len(path_parts) >= 4 and path_parts[-2] == 'dataset': + workspace = path_parts[0] + project = path_parts[1] + version = path_parts[-1] + return workspace, project, version + + query = parsed_url.query + match = re.search(r'version=(\d+)', query) + if match: + version = match.group(1) + if len(path_parts) >= 2: + workspace = path_parts[0] + project = path_parts[1] + return workspace, project, version + + if len(path_parts) >= 2: + workspace = path_parts[0] + project = path_parts[1] + version = None + return workspace, project, version + + return None, None, None + +def get_latest_version(rf, workspace_name, project_name): + """[DATASET PREP] Retrieves the latest version number for a Roboflow project.""" + try: + project = rf.workspace(workspace_name).project(project_name) + versions = project.versions() + version_numbers = [] + for v in versions: + version_num = getattr(v, 'version_number', getattr(v, 'number', None)) + if version_num is not None: + version_numbers.append(int(version_num)) + if not version_numbers: + st.error(f"No versions found for project '{project_name}'.") + logging.error(f"[DATASET PREP] No versions found for project '{project_name}'.") + return None + latest_version = max(version_numbers) + logging.info(f"[DATASET PREP] Latest version for project '{project_name}' is v{latest_version}.") + return str(latest_version) + except Exception as e: + st.error(f"Failed to retrieve latest version for project '{project_name}': {e}") + logging.error(f"[DATASET PREP] Failed to retrieve latest version: {e}") + return None + +def on_train_epoch_end(trainer): + """ + [TRAINING] Callback to capture metrics at the end of each training epoch. + Includes early stopping logic based on val_loss improvements. + Dynamically displays metrics based on selected task category. + """ + epoch = trainer.epoch + train_loss = trainer.metrics.get('train/loss', None) + val_loss = trainer.metrics.get('val/loss', None) + total_epochs = st.session_state.get('epochs', 1) # Get total epochs from session state + + selected_category = st.session_state.get('selected_category', 'detection') + + # Removed the problematic metric reset block. + # Metrics are now cleared at the start of train_model() only. + # The 'is_rerunning' flag will correctly be True for the first epoch's callback, + # and then set to False, as intended by its initial setup. + + if epoch is not None: + st.session_state['metrics_data']['epoch'].append(epoch) + if train_loss is not None: + st.session_state['metrics_data']['train_loss'].append(train_loss) + if val_loss is not None: + st.session_state['metrics_data']['val_loss'].append(val_loss) + + # --- Update progress counter --- + if total_epochs > 0: + current_epoch_display = epoch + 1 + progress_percent = (current_epoch_display / total_epochs) * 100 + progress_text_placeholder.write(f"Epoch {current_epoch_display}/{total_epochs} ({progress_percent:.1f}% done)") + else: + progress_text_placeholder.write(f"Epoch {epoch + 1} (Total epochs not set)") + # --- End progress counter update --- + + if selected_category == 'classification': + top1_acc = trainer.metrics.get('metrics/top1_acc', None) + top5_acc = trainer.metrics.get('metrics/top5_acc', None) + if top1_acc is not None: + st.session_state['metrics_data']['top1_acc'].append(top1_acc) + if top5_acc is not None: + st.session_state['metrics_data']['top5_acc'].append(top5_acc) + + if st.session_state['metrics_data']['epoch']: + loss_chart_placeholder.line_chart({ + "Train Loss": st.session_state['metrics_data']['train_loss'], + "Val Loss": st.session_state['metrics_data']['val_loss'] + }) + if st.session_state['metrics_data']['top1_acc']: + map_chart_placeholder.line_chart({ + "Top-1 Accuracy": st.session_state['metrics_data']['top1_acc'], + "Top-5 Accuracy": st.session_state['metrics_data']['top5_acc'] + }) + else: + map_chart_placeholder.empty() + elif selected_category == 'keypoint': + mappose50 = trainer.metrics.get('metrics/mAP50(P)', None) + mappose5095 = trainer.metrics.get('metrics/mAP50-95(P)', None) + if mappose50 is not None: + st.session_state['metrics_data']['mAPpose50'].append(mappose50) + if mappose5095 is not None: + st.session_state['metrics_data']['mAPpose50_95'].append(mappose5095) + + if st.session_state['metrics_data']['epoch']: + loss_chart_placeholder.line_chart({ + "Train Loss": st.session_state['metrics_data']['train_loss'], + "Val Loss": st.session_state['metrics_data']['val_loss'] + }) + if st.session_state['metrics_data']['mAPpose50']: + map_chart_placeholder.line_chart({ + "mAPpose@0.5": st.session_state['metrics_data']['mAPpose50'], + "mAPpose@0.5:0.95": st.session_state['metrics_data']['mAPpose50_95'] + }) + else: + map_chart_placeholder.empty() + else: # detection, segmentation, obb + map50 = trainer.metrics.get('metrics/mAP50(B)', None) + map5095 = trainer.metrics.get('metrics/mAP50-95(B)', None) + + # 1. Capture Precision, Recall, F1 using the correct (B) suffix keys, with fallback + precision = trainer.metrics.get('metrics/precision(B)', trainer.metrics.get('metrics/precision', None)) + recall = trainer.metrics.get('metrics/recall(B)', trainer.metrics.get('metrics/recall', None)) + f1 = trainer.metrics.get('metrics/f1(B)', trainer.metrics.get('metrics/f1', None)) # F1 also often has (B) suffix + + if st.session_state["show_debug"]: + st.write(f"Epoch {epoch} metrics: {trainer.metrics}") # Debug: See exact keys + + if map50 is not None: + st.session_state['metrics_data']['mAP50'].append(map50) + if map5095 is not None: + st.session_state['metrics_data']['mAP50_95'].append(map5095) + if precision is not None: + st.session_state['metrics_data']['precision'].append(precision) + if recall is not None: + st.session_state['metrics_data']['recall'].append(recall) + if f1 is not None: + st.session_state['metrics_data']['F1'].append(f1) + + # 2. Build a single dictionary for map/precision curves and plot once + if st.session_state['metrics_data']['epoch']: # Ensure there's at least one epoch data + loss_chart_placeholder.line_chart({ + "Train Loss": st.session_state['metrics_data']['train_loss'], + "Val Loss": st.session_state['metrics_data']['val_loss'] + }) + + chart_data_for_map_metrics = {} + if st.session_state['metrics_data']['mAP50']: + chart_data_for_map_metrics["mAP@0.5"] = st.session_state['metrics_data']['mAP50'] + if st.session_state['metrics_data']['mAP50_95']: + chart_data_for_map_metrics["mAP@0.5:0.95"] = st.session_state['metrics_data']['mAP50_95'] + if st.session_state['metrics_data']['precision']: + chart_data_for_map_metrics["Precision"] = st.session_state['metrics_data']['precision'] + if st.session_state['metrics_data']['recall']: + chart_data_for_map_metrics["Recall"] = st.session_state['metrics_data']['recall'] + if st.session_state['metrics_data']['F1']: + chart_data_for_map_metrics["F1 Score"] = st.session_state['metrics_data']['F1'] + + if chart_data_for_map_metrics: + map_chart_placeholder.line_chart(chart_data_for_map_metrics) + else: + map_chart_placeholder.empty() + + patience = st.session_state.get('early_stop_patience', 10) + if val_loss is not None: + if val_loss < st.session_state['best_val_loss']: + st.session_state['best_val_loss'] = val_loss + st.session_state['epochs_no_improvement'] = 0 + if st.session_state["show_debug"]: + st.info(f"Epoch {epoch}: Validation loss improved to {val_loss:.4f}. Resetting patience counter.") + else: + st.session_state['epochs_no_improvement'] += 1 + if st.session_state["show_debug"]: + st.warning(f"Epoch {epoch}: Validation loss did not improve. No improvement for {st.session_state['epochs_no_improvement']} epochs.") + + if st.session_state['epochs_no_improvement'] >= patience: + st.warning(f"Early stopping triggered at epoch {epoch} due to no improvement in val_loss for {patience} epochs.") + if st.session_state["show_debug"]: + st.warning("Recommendation: Consider increasing 'Early Stopping Patience', adding more 'Data Augmentation', or reducing 'Learning Rate'.") + trainer.stop = True + else: + if st.session_state["show_debug"]: + st.warning(f"Epoch {epoch}: Validation loss not available for early stopping.") + + # This line now just ensures is_rerunning is False after the first epoch, + # since the clearing logic has been moved outside this callback. + if st.session_state.get('is_rerunning', False): + st.session_state['is_rerunning'] = False + + +# ---------------------------------------------------------------------------- +# DATASET PREPARATION HELPERS +# ---------------------------------------------------------------------------- + +def download_and_prepare_roboflow_dataset( + rf_api_key, + workspace_name, + project_name, + version_number, + selected_architecture, + selected_category +): + """ + [DATASET PREP] Downloads and prepares the Roboflow dataset based on the task category. + """ + arch_to_rf_format = { + "YOLOv3": "yolov3", + "YOLOv5": "yolov5", + "YOLOv8": "yolov8", + "YOLOv9": "yolov9", + "YOLOv10": "yolov8", + "YOLOv11": "yolov8", + "YOLOv12": "yolov8", + "RTDETR": "yolov8" # RT-DETR uses YOLOv8 compatible dataset format + } + + rf = Roboflow(api_key=rf_api_key) + try: + project = rf.workspace(workspace_name).project(project_name) + except Exception as e: + st.error(f"Failed to locate project '{project_name}' in workspace '{workspace_name}': {e}") + logging.error(f"[DATASET PREP] Failed to locate project: {e}") + return None, None, None + + roboflow_export_format = None + if selected_category == "classification": + roboflow_export_format = "folder" + elif selected_category == "keypoint": + roboflow_export_format = "yolov8-keypoint" + else: + roboflow_export_format = arch_to_rf_format.get(selected_architecture, "yolov8") + + try: + st.info(f"[DATASET PREP] Attempting download format '{roboflow_export_format}' for {selected_category}...") + dataset = project.version(version_number).download(roboflow_export_format) + logging.info(f"[DATASET PREP] Dataset '{project_name}' (v{version_number}) downloaded with format '{roboflow_export_format}'.") + except Exception as e: + logging.warning(f"[DATASET PREP] Failed to download with format '{roboflow_export_format}': {e}") + st.warning(f"Failed with '{roboflow_export_format}': {e}. Trying fallbacks...") + if selected_category in ["detection", "segmentation", "obb"]: + try: + dataset = project.version(version_number).download("yolov8") + logging.info(f"[DATASET PREP] Dataset '{project_name}' (v{version_number}) downloaded with 'yolov8' fallback.") + except Exception as e2: + logging.warning(f"Failed with 'yolov8': {e2}") + st.warning(f"Failed with 'yolov8': {e2}. Trying 'coco' next...") + try: + dataset = project.version(version_number).download("coco") + logging.info(f"[DATASET PREP] Dataset '{project_name}' (v{version_number}) downloaded with 'coco' fallback.") + except Exception as e3: + logging.error(f"[DATASET PREP] All fallback attempts failed for '{project_name}': {e3}") + st.error(f"Failed to download dataset '{project_name}' with '{roboflow_export_format}', 'yolov8', or 'coco': {e3}") + return None, None, None + else: + st.error(f"Failed to download dataset '{project_name}' in required '{roboflow_export_format}' format for {selected_category}. No fallback available.") + logging.error(f"[DATASET PREP] No fallback for {selected_category} after {roboflow_export_format} failed.") + return None, None, None + + dataset_location = dataset.location + data_yaml_path = os.path.join(dataset_location, 'data.yaml') + class_names = [] + kpt_shape_info = None + + if selected_category == "classification": + class_names = [] + for split_dir in ['train', 'valid', 'test']: + split_path = os.path.join(dataset_location, split_dir) + if os.path.exists(split_path): + class_names.extend([d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]) + class_names = sorted(list(set(class_names))) + + data_yaml = { + 'path': dataset_location, + 'train': 'train', + 'val': 'valid', + 'test': 'test', + 'nc': len(class_names), + 'names': class_names + } + with open(data_yaml_path, 'w') as f: + yaml.safe_dump(data_yaml, f) + logging.info(f"[DATASET PREP] Created/Updated data.yaml for classification at {data_yaml_path}.") + + else: + if not os.path.exists(data_yaml_path): + try: + class_names = project.classes + if not class_names: + logging.error("[DATASET PREP] No class names found via API.") + return None, None, None + data_yaml = { + 'train': 'train/images', + 'val': 'valid/images', + 'test': 'test/images', + 'nc': len(class_names), + 'names': class_names + } + if selected_category == "keypoint": + st.warning("Keypoint dataset detected. Please ensure its data.yaml will contain 'kpt_shape' field.") + with open(data_yaml_path, 'w') as f: + yaml.safe_dump(data_yaml, f) + except Exception as e: + logging.error(f"[DATASET PREP] Error writing data.yaml: {e}") + return None, None, None + else: + try: + with open(data_yaml_path, 'r') as f: + data_yaml = yaml.safe_load(f) + class_names = data_yaml.get('names', []) + if not class_names: + logging.error("[DATASET PREP] No class names found in data.yaml.") + return None, None, None + if selected_category == "keypoint": + kpt_shape_info = data_yaml.get('kpt_shape', None) + if kpt_shape_info is None: + st.error("Keypoint data.yaml missing 'kpt_shape'. Cannot proceed with keypoint training.") + logging.error("[DATASET PREP] Keypoint data.yaml missing 'kpt_shape'.") + return None, None, None + st.session_state['kpt_shape'] = kpt_shape_info + except Exception as e: + logging.error(f"[DATASET PREP] Error reading data.yaml: {e}") + return None, None, None + + if 'train' in data_yaml and not data_yaml['train'].endswith('/images'): data_yaml['train'] = os.path.join(data_yaml['train'], 'images') + if 'val' in data_yaml and not data_yaml['val'].endswith('/images'): data_yaml['val'] = os.path.join(data_yaml['val'], 'images') + if 'test' in data_yaml and 'test' in data_yaml and not data_yaml['test'].endswith('/images'): data_yaml['test'] = os.path.join(data_yaml['test'], 'images') + with open(data_yaml_path, 'w') as f: + yaml.safe_dump(data_yaml, f) + + splits = [] + if selected_category == "classification": + for split in ['train', 'valid', 'test']: + if os.path.exists(os.path.join(dataset_location, split)): + splits.append(split) + else: + for split in ['train', 'valid', 'test']: + if os.path.exists(os.path.join(dataset_location, split, 'images')): + splits.append(split) + + return dataset_location, class_names, splits + +def download_and_prepare_github_dataset(github_zip_url, dataset_name): + """ + [DATASET PREP] Downloads and prepares a dataset from a GitHub (or Ultralytics) ZIP archive. + NOTE: The recommended datasets from your external JSON are for *detection only*. + """ + try: + dataset_zip_path = os.path.join("temp_datasets", f"{dataset_name}.zip") + os.makedirs("temp_datasets", exist_ok=True) + + with requests.get(github_zip_url, stream=True) as r: + r.raise_for_status() + with open(dataset_zip_path, 'wb') as f: + shutil.copyfileobj(r.raw, f) + + extract_dir = os.path.join("temp_datasets", dataset_name) + shutil.unpack_archive(dataset_zip_path, extract_dir) + logging.info(f"[DATASET PREP] Extracted {dataset_name} to {extract_dir}") + + data_yaml_path = os.path.join(extract_dir, 'data.yaml') + if not os.path.exists(data_yaml_path): + st.error(f"'data.yaml' not found in the dataset '{dataset_name}'.") + logging.error(f"[DATASET PREP] 'data.yaml' not found in {dataset_name}.") + return None, None, None + + with open(data_yaml_path, 'r') as f: + data_yaml = yaml.safe_load(f) + class_names = data_yaml.get('names', []) + + splits = ['train', 'valid', 'test'] + available_splits = [s for s in splits if os.path.exists(os.path.join(extract_dir, s))] + + return extract_dir, class_names, available_splits + except Exception as e: + st.error(f"Failed to download or prepare GitHub dataset '{dataset_name}': {e}") + logging.error(f"[DATASET PREP] Failed to download or prepare GitHub dataset '{dataset_name}': {e}") + return None, None, None + +# ---------------------------------------------------------------------------- +# TRAINING LOGIC & AUGMENTATIONS +# ---------------------------------------------------------------------------- + +def train_model( + dataset_location, + epochs, + batch_size, + img_size, + learning_rate, + optimizer, + data_augmentation_level, + pre_trained_model, + custom_model_path, + custom_model_name, + show_debug, + model_config, + selected_architecture, + selected_category +): + """ + [TRAINING] Train a model (e.g., YOLO, RT-DETR) with the given parameters, + supporting GPU detection and half precision by default. + Includes new hyperparameters and task-specific settings. + """ + st.info(f"[TRAINING] Training model '{selected_architecture}' for task: '{selected_category}'...") + logging.info("[TRAINING] Initiating model training.") + + train_data_path = os.path.join(dataset_location, "data.yaml") + if not os.path.exists(train_data_path): + st.error("data.yaml not found. Please ensure the dataset is prepared correctly.") + logging.error("[TRAINING] data.yaml not found.") + return None + + if pre_trained_model == "Custom": + if not custom_model_path or not os.path.exists(custom_model_path): + st.error("Please provide a valid path for the custom pre-trained model.") + logging.error("[TRAINING] Custom model path invalid.") + return None + pre_trained_weights = custom_model_path + else: + model_entry = None + sub_models = model_config[selected_architecture]["categories"].get(selected_category, []) + for m in sub_models: + if m["filename"] == pre_trained_model: + model_entry = m + break + if model_entry is None: + st.error(f"Model '{pre_trained_model}' not found under '{selected_architecture}/{selected_category}'.") + return None + + downloaded_path = cache_download_pretrained_model(pre_trained_model, model_entry["url"]) + if not downloaded_path: + st.error(f"Failed to download the pre-trained model '{pre_trained_model}'.") + return None + pre_trained_weights = downloaded_path + + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda": + st.write("Using **GPU** (CUDA) for training.") + logging.info("[TRAINING] Using GPU (CUDA).") + else: + st.write("Using **CPU** for training (this may be slower).") + logging.info("[TRAINING] Using CPU.") + + use_half_precision = st.session_state.get("use_half_precision", True) + if use_half_precision and device == "cuda": + st.write("Training with **half-precision (FP16)** on GPU.") + logging.info("[TRAINING] Using half precision (FP16).") + elif use_half_precision and device == "cpu": + st.warning("Half-precision was requested but no GPU found. Skipping half-precision.") + logging.info("[TRAINING] Skipping half precision due to CPU-only environment.") + use_half_precision = False + + try: + # The ultralytics library uses the YOLO class to handle various models, including RT-DETR. + # When a model like 'rtdetr-l.pt' is passed, it loads the correct architecture. + model = YOLO(pre_trained_weights) + logging.info(f"[TRAINING] Loaded pre-trained model: {pre_trained_weights}") + except Exception as e: + st.error(f"Failed to load the pre-trained model '{pre_trained_weights}': {e}") + logging.error(f"[TRAINING] Failed to load the pre-trained model '{pre_trained_weights}': {e}") + return None + + TASK_MAPPING = { + "detection": "detect", + "segmentation": "segment", + "obb": "obb", + "classification": "classify", + "keypoint": "pose" + } + yolo_task = TASK_MAPPING.get(selected_category, "detect") + + current_img_size = st.session_state["img_size"] + if selected_category == "classification" and current_img_size == 640: + current_img_size = 224 + st.info(f"Using recommended image size {current_img_size} for classification.") + elif selected_category == "keypoint" and current_img_size == 640: + current_img_size = 640 + st.info(f"Using recommended image size {current_img_size} for keypoint detection.") + + + augmentations = {} + if data_augmentation_level == "None": + augmentations = { + 'degrees': 0.0, + 'translate': 0.0, + 'scale': 0.0, + 'shear': 0.0, + 'fliplr': 0.0, + 'mixup': 0.0, + 'perspective': 0.0, + } + elif data_augmentation_level == "Basic": + augmentations = { + 'degrees': 5.0, + 'translate': 0.1, + 'scale': 0.1, + 'fliplr': 0.5 + } + elif data_augmentation_level == "Moderate": + augmentations = { + 'degrees': 15.0, + 'translate': 0.2, + 'scale': 0.3, + 'shear': 2.0, + 'fliplr': 0.5, + 'mixup': 0.2 + } + elif data_augmentation_level == "Advanced": + augmentations = { + 'degrees': 30.0, + 'translate': 0.3, + 'scale': 0.5, + 'shear': 5.0, + 'fliplr': 0.5, + 'mixup': 0.4, + 'perspective':0.1 + } + + augmentations['hsv_h'] = st.session_state["hsv_h"] + augmentations['hsv_s'] = st.session_state["hsv_s"] + augmentations['hsv_v'] = st.session_state["hsv_v"] + + + train_args = { + 'data': train_data_path, + 'epochs': epochs, + 'batch': batch_size, + 'imgsz': current_img_size, + 'lr0': learning_rate, + 'optimizer':optimizer.lower(), + 'workers': 4, + 'project': 'runs/train', + 'name': custom_model_name, + 'augment': (data_augmentation_level != "None"), # Enable/disable augmentation based on selection + 'task': yolo_task, + 'device': device, + 'half': use_half_precision, + 'lrf': st.session_state['lrf'], + 'warmup_epochs': st.session_state['warmup_epochs'], + 'cache': st.session_state['cache_dataset'] + } + if data_augmentation_level != "None": # Only update with augs if not 'None' + train_args.update(augmentations) + + if st.session_state['freeze_layers'] > 0: + train_args['freeze'] = st.session_state['freeze_layers'] + logging.info(f"[TRAINING] Freezing {st.session_state['freeze_layers']} layers.") + + model.add_callback("on_train_epoch_end", on_train_epoch_end) + + # Clear previous metrics and reset early stopping state for a fresh training run + st.session_state['metrics_data'] = { + 'epoch': [], 'train_loss': [], 'val_loss': [], + 'mAP50': [], 'mAP50_95': [], + 'top1_acc': [], 'top5_acc': [], + 'mAPpose50': [], 'mAPpose50_95': [], + 'precision': [], 'recall': [], 'F1': [] + } + st.session_state['best_val_loss'] = float('inf') + st.session_state['epochs_no_improvement'] = 0 + st.session_state['is_rerunning'] = True # Set to True, will be reset to False after first epoch callback + + try: + logging.info(f"[TRAINING] Training arguments: {train_args}") + model.train(**train_args) + logging.info("[TRAINING] Model training completed successfully.") + return model + except Exception as e: + st.error(f"An error occurred during training: {e}") + logging.error(f"[TRAINING] An error occurred during training: {e}") + if show_debug: + st.error(f"Training Error Details: {e}") + return None + +# ---------------------------------------------------------------------------- +# MERGING + PREPARATION FOR MERGED DATASET +# ---------------------------------------------------------------------------- + +def process_label_file(args): + """ + [MERGE] Worker function for parallel processing of label files. + Reads each label file, parses classes, and returns relevant info. + """ + label_path, class_names_dataset, class_name_mapping, unified_class_names = args + try: + with open(label_path, 'r') as f: + lines = f.readlines() + except Exception as e: + return None, f"[MERGE] Failed to read label file '{label_path}': {e}" + + image_filename = os.path.basename(label_path).replace('.txt', '.jpg') + image_classes_set = set() + for line in lines: + parts = line.strip().split() + if len(parts) < 1: + continue + try: + class_id = int(parts[0]) + if class_id < 0 or class_id >= len(class_names_dataset): + continue + original_class_name = class_names_dataset[class_id] + new_class_name = class_name_mapping.get(original_class_name, original_class_name) + if new_class_name in unified_class_names: + image_classes_set.add(new_class_name) + except (ValueError, IndexError): + continue + return (image_filename, image_classes_set), None + +def gather_class_counts(dataset_info_list, class_name_mapping, selected_category): + """ + [MERGE] Gathers how many images (for classification) or how many images contain at least one + annotation for a given class (for others). + Returns a dictionary like {'class': count, ...}. + """ + all_class_names = [] + for _, class_names, _, _ in dataset_info_list: + all_class_names.extend(class_names) + all_class_names = sorted(set(all_class_names)) + + unified_class_names = set() + for class_name in all_class_names: + new_name = class_name_mapping.get(class_name, class_name) + unified_class_names.add(new_name) + unified_class_names = sorted(list(unified_class_names)) + + class_counters = {cls: 0 for cls in unified_class_names} + + if selected_category == "classification": + for dataset_location, _, splits, _ in dataset_info_list: + for split_key in splits: + split_dir = os.path.join(dataset_location, split_key) + if not os.path.exists(split_dir): + continue + for class_folder_name in os.listdir(split_dir): + original_class_path = os.path.join(split_dir, class_folder_name) + if not os.path.isdir(original_class_path): + continue + + unified_class_name = class_name_mapping.get(class_folder_name, class_folder_name) + if unified_class_name in class_counters: + class_counters[unified_class_name] += len([f for f in os.listdir(original_class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) + else: + for dataset_location, class_names_dataset, splits, dataset_name in dataset_info_list: + for split_key in splits: + labels_src = os.path.join(dataset_location, split_key, 'labels') + if not os.path.exists(labels_src): + continue + for root, dirs, files in os.walk(labels_src): + for file in files: + if file.endswith('.txt'): + label_path = os.path.join(root, file) + try: + with open(label_path, 'r') as f: + lines = f.readlines() + except: + continue + + image_classes_found_in_file = set() + for line in lines: + parts = line.strip().split() + if len(parts) < 1: continue + try: + class_id = int(parts[0]) + if 0 <= class_id < len(class_names_dataset): + orig_cls = class_names_dataset[class_id] + new_cls = class_name_mapping.get(orig_cls, orig_cls) + if new_cls in unified_class_names: + image_classes_found_in_file.add(new_cls) + except: + continue + for cls in image_classes_found_in_file: + class_counters[cls] += 1 + + return class_counters + +def merge_classification_datasets(dataset_info_list, classes_to_include, class_name_mapping, show_debug): + """ + [MERGE - CLASSIFICATION] Merges multiple classification datasets into 'rolo_dataset', + copying images into new, class-specific subfolders. + """ + merged_dataset_dir = 'rolo_dataset' + if os.path.exists(merged_dataset_dir): + st.warning(f"[MERGE] Output directory '{merged_dataset_dir}' already exists. Overwriting...") + try: + shutil.rmtree(merged_dataset_dir, onerror=handle_remove_readonly) + except Exception as e: + if show_debug: st.error(f"[MERGE] Error deleting directory '{merged_dataset_dir}': {e}") + logging.error(f"[MERGE] Error deleting directory '{merged_dataset_dir}': {e}") + return None + + all_original_class_names = set() + for _, class_names, _, _ in dataset_info_list: + all_original_class_names.update(class_names) + + unified_class_names = set() + for class_name in all_original_class_names: + new_name = class_name_mapping.get(class_name, class_name) + unified_class_names.add(new_name) + unified_class_names = sorted(list(unified_class_names)) + + class_image_counts = {cls: 0 for cls in unified_class_names} + images_per_class_src = {cls: [] for cls in unified_class_names} + + for dataset_location, class_names_dataset, splits, dataset_name in dataset_info_list: + for split_key in splits: + split_dir = os.path.join(dataset_location, split_key) + if not os.path.exists(split_dir): + continue + for class_folder_name in os.listdir(split_dir): + original_class_path = os.path.join(split_dir, class_folder_name) + if not os.path.isdir(original_class_path): + continue + + unified_class_name = class_name_mapping.get(class_folder_name, class_folder_name) + + total_limit_for_unified_class = 0 + for orig_cls_key, limit in classes_to_include.items(): + if class_name_mapping.get(orig_cls_key, orig_cls_key) == unified_class_name: + total_limit_for_unified_class += limit + + if total_limit_for_unified_class == 0: + continue + + for img_file in os.listdir(original_class_path): + if img_file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff')): + img_path_src = os.path.join(original_class_path, img_file) + images_per_class_src[unified_class_name].append((img_path_src, split_key)) + + st.subheader("Dataset Statistics & Insights") + st.write("### Available Images per Class (Pre-Adjustment):") + for cls in unified_class_names: + st.write(f"- **{cls}**: {len(images_per_class_src[cls])} images") + + fig, ax = plt.subplots(figsize=(6, 3)) + class_counts_list = [len(images_per_class_src[cls]) for cls in unified_class_names] + ax.bar(unified_class_names, class_counts_list) + ax.set_title("Class Distribution (All Datasets)") + ax.set_xlabel("Class Names") + ax.set_ylabel("Image Count") + plt.xticks(rotation=45, ha='right') + st.pyplot(fig) + + + selected_images_per_class = {cls: [] for cls in unified_class_names} + final_class_counts = {cls: 0 for cls in unified_class_names} + adjusted_limits_for_unified_classes = {} + + for unified_cls_name in unified_class_names: + combined_limit = 0 + for orig_cls_key, limit in classes_to_include.items(): + if class_name_mapping.get(orig_cls_key, orig_cls_key) == unified_cls_name: + combined_limit += limit + + max_available = len(images_per_class_src[unified_cls_name]) + final_limit = min(combined_limit, max_available) if combined_limit > 0 else 0 + adjusted_limits_for_unified_classes[unified_cls_name] = final_limit + + random.shuffle(images_per_class_src[unified_cls_name]) + selected_images_per_class[unified_cls_name] = images_per_class_src[unified_cls_name][:final_limit] + final_class_counts[unified_cls_name] = len(selected_images_per_class[unified_cls_name]) + + st.session_state['total_class_images'] = final_class_counts + st.session_state['class_image_counts'] = final_class_counts + + for split in ['train', 'valid', 'test']: + for cls in unified_class_names: + if final_class_counts[cls] > 0: + os.makedirs(os.path.join(merged_dataset_dir, split, cls), exist_ok=True) + logging.info(f"[MERGE] Created directory '{os.path.join(merged_dataset_dir, split, cls)}'.") + + for unified_cls_name, images_info in selected_images_per_class.items(): + for img_path_src, split_key in images_info: + img_filename = os.path.basename(img_path_src) + img_path_dst = os.path.join(merged_dataset_dir, split_key, unified_cls_name, img_filename) + try: + shutil.copy(img_path_src, img_path_dst) + except Exception as e: + if show_debug: st.error(f"[MERGE] Error copying image {img_filename}: {e}") + logging.error(f"[MERGE] Error copying image {img_filename}: {e}") + + active_final_classes = [cls for cls, count in final_class_counts.items() if count > 0] + merged_data_yaml = { + 'path': os.path.abspath(merged_dataset_dir), + 'train': 'train', + 'val': 'valid', + 'test': 'test', + 'nc': len(active_final_classes), + 'names': active_final_classes + } + with open(os.path.join(merged_dataset_dir, 'data.yaml'), 'w') as f: + yaml.safe_dump(merged_data_yaml, f) + + st.session_state['class_counters'] = final_class_counts + st.session_state['active_classes'] = active_final_classes + st.session_state['class_limits'] = adjusted_limits_for_unified_classes + st.session_state['class_id_mapping'] = {name: idx for idx, name in enumerate(st.session_state['active_classes'])} + st.session_state['class_name_mapping'] = class_name_mapping + st.session_state['dataset_info_list'] = dataset_info_list + st.session_state['merged_dataset_dir'] = merged_dataset_dir + + st.write("### Selected Images Count per Class (After Initial Merging):") + for cls in unified_class_names: + cnt = final_class_counts.get(cls, 0) + limit = adjusted_limits_for_unified_classes.get(cls, 0) + st.write(f"- **{cls}**: {cnt} selected (Limit: {limit})") + + st.write("### Merged Classes:") + for old_class, new_class in class_name_mapping.items(): + if old_class != new_class: + st.write(f"- **{old_class}** merged into **{new_class}**") + + if len(active_final_classes) == 0: + st.error("WARNING: No classes or images remain after merging/excluding. Please adjust your class limits or rename logic.") + return merged_dataset_dir + +def merge_datasets(dataset_info_list, classes_to_include, class_name_mapping, show_debug): + """ + [MERGE] Merges multiple datasets into 'rolo_dataset', respecting user-specified class limits. + Handles detection, segmentation, OBB, and keypoint data. + """ + merged_dataset_dir = 'rolo_dataset' + + if os.path.exists(merged_dataset_dir): + st.warning(f"[MERGE] Output directory '{merged_dataset_dir}' already exists. Overwriting...") + logging.warning(f"[MERGE] Output directory '{merged_dataset_dir}' already exists. Overwriting...") + try: + shutil.rmtree(merged_dataset_dir, onerror=handle_remove_readonly) + logging.info(f"[MERGE] Existing directory '{merged_dataset_dir}' deleted successfully.") + except Exception as e: + if show_debug: + st.error(f"[MERGE] Error deleting directory '{merged_dataset_dir}': {e}") + logging.error(f"[MERGE] Error deleting directory '{merged_dataset_dir}': {e}") + return None + + try: + os.makedirs(merged_dataset_dir, exist_ok=True) + logging.info(f"[MERGE] Merged dataset directory '{merged_dataset_dir}' created.") + except Exception as e: + if show_debug: + st.error(f"[MERGE] Error creating directory '{merged_dataset_dir}': {e}") + logging.error(f"[MERGE] Error creating directory '{merged_dataset_dir}': {e}") + return None + + all_class_names = [] + for _, class_names, _, _ in dataset_info_list: + all_class_names.extend(class_names) + all_class_names = sorted(set(all_class_names)) + + unified_class_names = set() + for class_name in all_class_names: + new_name = class_name_mapping.get(class_name, class_name) + unified_class_names.add(new_name) + unified_class_names = sorted(list(unified_class_names)) + + class_limits = {name: 0 for name in unified_class_names} + for original_class, user_chosen_limit in classes_to_include.items(): + new_name = class_name_mapping.get(original_class, original_class) + if new_name in class_limits: + class_limits[new_name] += user_chosen_limit + + class_counters = gather_class_counts(dataset_info_list, class_name_mapping, st.session_state['selected_category']) + st.session_state['class_image_counts'] = class_counters + + total_class_images = {cls: count for cls, count in class_counters.items()} + st.session_state['total_class_images'] = total_class_images + + for cls in class_limits: + max_available = total_class_images.get(cls, 0) + if class_limits[cls] > max_available: + class_limits[cls] = max_available + if show_debug: + st.warning(f"[MERGE] Class '{cls}' limit adjusted to available images: {max_available}") + + active_classes = [cls for cls, limit in class_limits.items() if limit > 0] + + class_id_mapping = {name: idx for idx, name in enumerate(active_classes)} + + class_to_images = {name: set() for name in unified_class_names} + image_to_classes = {} + image_to_label_path = {} + + args_list = [] + for dataset_location, class_names_dataset, splits, dataset_name in dataset_info_list: + for split_key in splits: + labels_src = os.path.join(dataset_location, split_key, 'labels') + if not os.path.exists(labels_src): + continue + label_files = [ + os.path.join(root, f) + for root, dirs, files in os.walk(labels_src) + for f in files if f.endswith('.txt') + ] + for lf in label_files: + args_list.append((lf, class_names_dataset, class_name_mapping, unified_class_names,)) + + max_workers = min(32, (multiprocessing.cpu_count() or 1) + 4) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_args = {executor.submit(process_label_file, args): args for args in args_list} + for future in concurrent.futures.as_completed(future_to_args): + res, err = future.result() + if err: + logging.warning(f"[MERGE] Warning: {err}") + continue + if not res: + continue + + image_filename, image_classes_set = res + lf_args = future_to_args[future] + label_path = lf_args[0] + images_src = label_path.replace('labels', 'images').replace('.txt', '.jpg') + if not os.path.exists(images_src): + alt_path = images_src.replace('.jpg', '.png') + if os.path.exists(alt_path): + images_src = alt_path + else: + continue + + image_to_classes[images_src] = image_classes_set + image_to_label_path[images_src] = label_path + for cls in image_classes_set: + class_to_images[cls].add(images_src) + + st.subheader("Dataset Statistics & Insights") + st.write("### Available Images per Class:") + for cls in unified_class_names: + st.write(f"- **{cls}**: {len(class_to_images.get(cls, []))} images") + + fig, ax = plt.subplots(figsize=(6, 3)) + class_counts_list = [len(class_to_images.get(cls, [])) for cls in unified_class_names] + ax.bar(unified_class_names, class_counts_list) + ax.set_title("Class Distribution (All Datasets)") + ax.set_xlabel("Class Names") + ax.set_ylabel("Image Count") + plt.xticks(rotation=45, ha='right') + st.pyplot(fig) + + inconsistent_images = [img for img, cls_set in image_to_classes.items() if not cls_set] + if inconsistent_images: + st.warning(f"[MERGE] {len(inconsistent_images)} images have no valid annotations (excluded classes only).") + + selected_images = set() + current_image_class_counts = {cls: 0 for cls in active_classes} + + all_candidate_images = [] + for cls in active_classes: + all_candidate_images.extend(list(class_to_images[cls])) + random.shuffle(all_candidate_images) + + for img_path in all_candidate_images: + if img_path in selected_images: continue + + classes_in_this_image = image_to_classes.get(img_path, set()) + + would_exceed_limit = False + for c in classes_in_this_image: + if c in active_classes and current_image_class_counts[c] >= class_limits[c]: + would_exceed_limit = True + break + + if would_exceed_limit: + continue + + selected_images.add(img_path) + for c in classes_in_this_image: + if c in active_classes: + current_image_class_counts[c] += 1 + + st.session_state['selected_images'] = selected_images + st.session_state['class_counters'] = current_image_class_counts + st.session_state['active_classes'] = active_classes + st.session_state['class_limits'] = class_limits + st.session_state['class_id_mapping'] = class_id_mapping + st.session_state['class_name_mapping'] = class_name_mapping + st.session_state['image_to_classes'] = image_to_classes + st.session_state['image_to_label_path'] = image_to_label_path + st.session_state['class_to_images'] = class_to_images + st.session_state['dataset_info_list'] = dataset_info_list + st.session_state['merged_dataset_dir'] = merged_dataset_dir + + class_image_counts = {} + for cls in unified_class_names: + class_image_counts[cls] = current_image_class_counts.get(cls, 0) + st.session_state['class_image_counts'] = class_image_counts + + datasets_used = list(sorted(set([info[3] for info in st.session_state['dataset_info_list']]))) + st.session_state['datasets_used'] = datasets_used + + st.write("### Selected Images Count per Class (Before Adjustment):") + for cls in unified_class_names: + cnt = current_image_class_counts.get(cls, 0) + limit = class_limits.get(cls, 0) + st.write(f"- **{cls}**: {cnt} selected (Limit: {limit})") + + st.write("### Merged Classes:") + for old_class, new_class in class_name_mapping.items(): + if old_class != new_class: + st.write(f"- **{old_class}** merged into **{new_class}**") + + if len(selected_images) == 0: + st.error("WARNING: No images remain after merging/excluding. Please adjust your class limits or rename logic.") + return merged_dataset_dir + +def adjust_selected_images(): + """ + [MERGE] Allows user to manually adjust the number of selected images per class + using st.number_input. Then recalculates which images to keep. + """ + if st.session_state['selected_category'] == "classification": + st.info("Manual image adjustment is not currently supported for Classification datasets. The class limits were applied during the merge step.") + st.session_state['adjustment_confirmed'] = True + return + + st.header("Adjust Selected Images per Class") + st.markdown("You can manually adjust the number of selected images per class if desired.") + + adjusted_class_limits_input = {} + with st.form("adjustment_form"): + for cls in st.session_state['active_classes']: + current_selected_images = st.session_state['class_counters'].get(cls, 0) + original_input_limit = st.session_state['class_limits'].get(cls, 0) + total_available_images_for_class = st.session_state['total_class_images'].get(cls, 0) + + adjusted_val = st.number_input( + f"Adjust images for '{cls}' (Currently Selected: {current_selected_images}, Originally Targeted: {original_input_limit})", + min_value=0, + max_value=total_available_images_for_class, + value=current_selected_images, + step=1, + key=f"adjusted_count_{cls}" + ) + adjusted_class_limits_input[cls] = adjusted_val + + st.info("The final count for each class will be adjusted to the amount you specify, up to the total available images after previous filtering.") + confirm_adjustment = st.form_submit_button("Confirm Adjustments") + + if confirm_adjustment: + new_selected_images = set() + new_class_image_counts = {cls: 0 for cls in st.session_state['active_classes']} + + all_images_with_active_classes = list(st.session_state['image_to_classes'].keys()) + random.shuffle(all_images_with_active_classes) + + for img_path in all_images_with_active_classes: + classes_in_this_image = st.session_state['image_to_classes'].get(img_path, set()) + + can_add_image = False + for c in classes_in_this_image: + if c in st.session_state['active_classes'] and new_class_image_counts[c] < adjusted_class_limits_input[c]: + can_add_image = True + break + + if not can_add_image: + continue + + new_selected_images.add(img_path) + for c in classes_in_this_image: + if c in st.session_state['active_classes']: + new_class_image_counts[c] += 1 + + st.session_state['selected_images'] = new_selected_images + st.session_state['class_counters'] = new_class_image_counts + st.session_state['adjustment_confirmed'] = True + + for cls in st.session_state['active_classes']: + st.session_state['class_image_counts'][cls] = new_class_image_counts[cls] + + st.success("Adjustments confirmed. Proceed to finalize the merged dataset.") + st.session_state['workflow_step'] = "Finalize Dataset" + +def finalize_merged_dataset(): + """ + [MERGE] Copies selected images and corresponding labels into a new merged dataset directory, + respecting the user's class mappings. Creates a final 'data.yaml' consistent with only + the included classes. + """ + if st.session_state['selected_category'] == "classification": + st.info("Classification dataset was finalized during the merge step. Ready for training!") + st.session_state['dataset_location'] = st.session_state['merged_dataset_dir'] + st.session_state['dataset_finalized'] = True + return + + st.info("[FINALIZE] Finalizing the merged dataset...") + selected_images = st.session_state['selected_images'] + active_classes = st.session_state['active_classes'] + class_id_mapping = st.session_state['class_id_mapping'] + class_name_mapping = st.session_state['class_name_mapping'] + image_to_classes = st.session_state['image_to_classes'] + image_to_label_path = st.session_state['image_to_label_path'] + dataset_info_list = st.session_state['dataset_info_list'] + merged_dataset_dir = st.session_state['merged_dataset_dir'] + + splits_dirs = ['train', 'valid', 'test'] + for split in splits_dirs: + images_dst = os.path.join(merged_dataset_dir, split, 'images') + labels_dst = os.path.join(merged_dataset_dir, split, 'labels') + try: + os.makedirs(images_dst, exist_ok=True) + os.makedirs(labels_dst, exist_ok=True) + except Exception as e: + st.error(f"[FINALIZE] Error creating split directories: {e}") + logging.error(f"[FINALIZE] Error creating split directories: {e}") + + for img_path in selected_images: + split_key = 'train' + for d_info in dataset_info_list: + d_loc = d_info[0] + splits_in_dataset = d_info[2] + for s in splits_in_dataset: + split_images_dir = os.path.join(d_loc, s, 'images') + if img_path.startswith(split_images_dir): + split_key = s + break + if split_key != 'train': + break + + image_filename = os.path.basename(img_path) + label_filename = os.path.splitext(image_filename)[0] + '.txt' + label_src_path = image_to_label_path.get(img_path) + + image_dst_path = os.path.join(merged_dataset_dir, split_key, 'images', image_filename) + label_dst_path = os.path.join(merged_dataset_dir, split_key, 'labels', label_filename) + + shutil.copy(img_path, image_dst_path) + + if label_src_path and os.path.exists(label_src_path): + with open(label_src_path, 'r') as f: + lines = f.readlines() + + updated_lines = [] + for line in lines: + parts = line.strip().split() + if len(parts) < 1: + continue + try: + old_class_id = int(parts[0]) + old_class_name = None + + for d_loc2, class_names_dataset, _, _ in dataset_info_list: + if img_path.startswith(d_loc2): + if 0 <= old_class_id < len(class_names_dataset): + old_class_name = class_names_dataset[old_class_id] + break + + if old_class_name is None: + continue + + new_class_name = class_name_mapping.get(old_class_name, old_class_name) + if new_class_name not in active_classes: + continue + new_class_id = class_id_mapping.get(new_class_name, None) + if new_class_id is None: + continue + + coords = parts[1:] + updated_line = f"{new_class_id} " + " ".join(coords) + updated_lines.append(updated_line) + except Exception as e: + if st.session_state["show_debug"]: + st.warning(f"Error re-indexing line '{line.strip()}' for image {image_filename}: {e}") + logging.warning(f"Error re-indexing line '{line.strip()}' for image {image_filename}: {e}") + continue + + if not updated_lines: + if os.path.exists(image_dst_path): + os.remove(image_dst_path) + if os.path.exists(label_dst_path): + os.remove(label_dst_path) + else: + with open(label_dst_path, 'w') as f: + f.write("\n".join(updated_lines)) + + merged_data_yaml = { + 'path': os.path.abspath(merged_dataset_dir), + 'train': 'train/images', + 'val': 'valid/images', + 'test': 'test/images', + 'nc': len(active_classes), + 'names': active_classes + } + + if st.session_state['selected_category'] == "keypoint" and st.session_state['kpt_shape'] is not None: + merged_data_yaml['kpt_shape'] = st.session_state['kpt_shape'] + st.info(f"Keypoint dataset detected. Adding kpt_shape: {st.session_state['kpt_shape']} to data.yaml") + + with open(os.path.join(merged_dataset_dir, 'data.yaml'), 'w') as f: + yaml.safe_dump(merged_data_yaml, f) + + st.success(f"[FINALIZE] Final merged dataset is ready at '{merged_dataset_dir}'.") + st.session_state['dataset_location'] = merged_dataset_dir + st.session_state['dataset_finalized'] = True + + +# ---------------------------------------------------------------------------- +# UTILITY: ZIP THE MERGED DATASET FOR DOWNLOAD +# ---------------------------------------------------------------------------- +def zip_directory(source_dir): + """ + Utility function to zip an entire folder into memory (BytesIO) + so we can offer a download in Streamlit. + """ + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zipf: + for root, _, files in os.walk(source_dir): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, start=source_dir) + zipf.write(file_path, arcname) + buffer.seek(0) + return buffer + +# ---------------------------------------------------------------------------- +# GITHUB & HUGGING FACE UTILITIES (UNCHANGED) +# ---------------------------------------------------------------------------- + +def create_github_repo(username, token, repo_name, private=False): + """ + [UPLOAD] Attempts to create a new GitHub repo under the specified username/org. + """ + url = "https://api.github.com/user/repos" + headers = { + "Authorization": f"token {token}", + "Accept": "application/vnd.github.v3+json", + } + data = {"name": repo_name, "private": private} + response = requests.post(url, headers=headers, json=data) + return response.status_code, response.json() + +def get_github_repo_contents(username, repo_name, token): + """ + Retrieves the contents of a GitHub repository. + """ + url = f"https://api.github.com/repos/{username}/{repo_name}/contents/" + headers = { + "Authorization": f"token {token}", + "Accept": "application/vnd.github.v3+json", + } + response = requests.get(url, headers=headers) + return response.status_code, response.json() + +def list_github_files(owner, repo, path, token=None): + """ + Lists the files in a specified GitHub repository folder. + Returns a list of file names. + """ + api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}" + headers = {} + if token: + headers['Authorization'] = f'token {token}' + response = requests.get(api_url, headers=headers) + if response.status_code == 200: + contents = response.json() + return [item['name'] for item in contents if item['type'] == 'file'] + elif response.status_code == 404: + return [] + else: + st.error(f"GitHub API error: {response.status_code}") + return [] + +def fetch_github_file(owner, repo, path, token=None): + """ + Fetches the raw content of a file from GitHub. Returns the file content as text. + NOTE: The example templates are for *detection tasks* only. + """ + api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}" + headers = { + "Accept": "application/vnd.github.v3.raw" + } + if token: + headers['Authorization'] = f'token {token}' + response = requests.get(api_url, headers=headers) + if response.status_code == 200: + return response.text + else: + st.error(f"Failed to fetch {path} from GitHub. Status Code: {response.status_code}") + return None + +def customize_template(template_content, model_url, selected_category, metrics_data, class_id_mapping, datasets_used, class_image_counts): + """ + Replaces placeholders in the template with the actual model URL and training details. + """ + # 1) Compute final metrics (improved robustness with .get and length check) + final_map50 = final_map50_95 = final_top1_acc = final_top5_acc = final_mappose50 = final_mappose50_95 = "N/A" + final_prec = final_rec = final_f1 = "N/A" + + if selected_category in ('detection', 'segmentation', 'obb'): + if metrics_data.get('mAP50') and len(metrics_data['mAP50']) > 0: + final_map50 = f"{metrics_data['mAP50'][-1]:.4f}" + if metrics_data.get('mAP50_95') and len(metrics_data['mAP50_95']) > 0: + final_map50_95 = f"{metrics_data['mAP50_95'][-1]:.4f}" + elif selected_category == 'classification': + if metrics_data.get('top1_acc') and len(metrics_data['top1_acc']) > 0: + final_top1_acc = f"{metrics_data['top1_acc'][-1]:.4f}" + if metrics_data.get('top5_acc') and len(metrics_data['top5_acc']) > 0: + final_top5_acc = f"{metrics_data['top5_acc'][-1]:.4f}" + elif selected_category == 'keypoint': + if metrics_data.get('mAPpose50') and len(metrics_data['mAPpose50']) > 0: + final_mappose50 = f"{metrics_data['mAPpose50'][-1]:.4f}" + if metrics_data.get('mAPpose50_95') and len(metrics_data['mAPpose50_95']) > 0: + final_mappose50_95 = f"{metrics_data['mAPpose50_95'][-1]:.4f}" + + # Get F1, Precision, Recall if available (assuming they are always appended regardless of category) + if metrics_data.get('precision') and len(metrics_data['precision']) > 0: + final_prec = f"{metrics_data['precision'][-1]:.4f}" + if metrics_data.get('recall') and len(metrics_data['recall']) > 0: + final_rec = f"{metrics_data['recall'][-1]:.4f}" + if metrics_data.get('F1') and len(metrics_data['F1']) > 0: + final_f1 = f"{metrics_data['F1'][-1]:.4f}" + + placeholders = { + "{{MODEL_URL}}": model_url, + "{{MODEL_ARCH}}": st.session_state.get("selected_architecture", "N/A"), + "{{EPOCHS}}": str(st.session_state.get("epochs", "N/A")), + "{{BATCH_SIZE}}": str(st.session_state.get("batch_size", "N/A")), + "{{OPTIMIZER}}": st.session_state.get("optimizer", "N/A"), + "{{LEARNING_RATE}}": f"{st.session_state.get('learning_rate', 0.0):.5f}", + "{{DATA_AUG}}": st.session_state.get("data_augmentation_level", "N/A"), + "{{MODEL_NAME}}": st.session_state.get("custom_model_name", "N/A"), + "{{TASK_TYPE}}": selected_category.capitalize(), + "{{LRF}}": f"{st.session_state.get('lrf', 0.0):.3f}", + "{{WARMUP_EPOCHS}}": str(st.session_state.get('warmup_epochs', 'N/A')), + "{{HSV_H}}": f"{st.session_state.get('hsv_h', 0.0):.3f}", + "{{HSV_S}}": f"{st.session_state.get('hsv_s', 0.0):.2f}", + "{{HSV_V}}": f"{st.session_state.get('hsv_v', 0.0):.2f}", + # --- ADDED: Individual Metric Placeholders --- + "{{FINAL_MAP50}}": final_map50, + "{{FINAL_MAP5095}}": final_map50_95, + "{{FINAL_TOP1_ACC}}": final_top1_acc, + "{{FINAL_TOP5_ACC}}": final_top5_acc, + "{{FINAL_MAPPOSE50}}": final_mappose50, + "{{FINAL_MAPPOSE5095}}": final_mappose50_95, + "{{FINAL_PRECISION}}": final_prec, + "{{FINAL_RECALL}}": final_rec, + "{{FINAL_F1}}": final_f1, + # --- END ADDED --- + } + + # --- UPDATED: Construct FINAL_METRICS block using the computed values --- + final_metrics_lines = [] + if selected_category in ('detection', 'segmentation', 'obb'): + final_metrics_lines.append(f"- **mAP@0.5:** {final_map50}") + final_metrics_lines.append(f"- **mAP@0.5:0.95:** {final_map50_95}") + elif selected_category == 'classification': + final_metrics_lines.append(f"- **Top-1 Accuracy:** {final_top1_acc}") + final_metrics_lines.append(f"- **Top-5 Accuracy:** {final_top5_acc}") + elif selected_category == 'keypoint': + final_metrics_lines.append(f"- **mAPpose@0.5:** {final_mappose50}") + final_metrics_lines.append(f"- **mAPpose@0.5:0.95:** {final_mappose50_95}") + + # Append F1, Precision, Recall if available (can be combined with other metrics) + # These are general metrics that can apply across tasks, so append if they have data + if final_prec != "N/A": + final_metrics_lines.append(f"- **Precision:** {final_prec}") + if final_rec != "N/A": + final_metrics_lines.append(f"- **Recall:** {final_rec}") + if final_f1 != "N/A": + final_metrics_lines.append(f"- **F1 Score:** {final_f1}") + + if final_metrics_lines: + placeholders["{{FINAL_METRICS}}"] = "\n".join(final_metrics_lines) + else: + placeholders["{{FINAL_METRICS}}"] = "N/A (Metrics not available for this task type or not yet trained)" + # --- END UPDATED --- + + freeze_layers = st.session_state.get('freeze_layers', 0) + if freeze_layers > 0: + placeholders["{{FREEZE_LAYERS}}"] = f"{freeze_layers} layers frozen" + placeholders["{{FREEZE_LAYERS_TEXT}}"] = f"The initial {freeze_layers} layers of the model were frozen during training to leverage transfer learning." + else: + placeholders["{{FREEZE_LAYERS}}"] = "None" + placeholders["{{FREEZE_LAYERS_TEXT}}"] = "" + + cache_dataset = st.session_state.get('cache_dataset', False) + if cache_dataset: + placeholders["{{CACHE_DATASET}}"] = "Enabled (RAM)" + placeholders["{{CACHE_DATASET_TEXT}}"] = "The dataset was cached in RAM during training to speed up I/O operations." + else: + placeholders["{{CACHE_DATASET}}"] = "Disabled" + placeholders["{{CACHE_DATASET_TEXT}}"] = "Dataset caching was disabled during training." + + if class_id_mapping: + sorted_classes = sorted(class_id_mapping.items(), key=lambda item: item[1]) + class_ids_md = "| Class ID | Class Name |\n|----------|------------|\n" + for cls_name, cls_id in sorted_classes: + class_ids_md += f"| {cls_id} | {cls_name} |\n" + else: + class_ids_md = "No class IDs available (or model not yet trained with class mapping)." + placeholders["{{CLASS_IDS}}"] = class_ids_md + + if datasets_used: + datasets_used_str = "\n".join([f"- {ds}" for ds in datasets_used]) + else: + datasets_used_str = "N/A (No datasets used yet or not tracked)." + placeholders["{{DATASETS_USED}}"] = datasets_used_str + + if class_image_counts: + count_type_label = "Images" + if selected_category == "classification": + count_type_label = "Images" + elif selected_category in ["detection", "segmentation", "obb", "keypoint"]: + count_type_label = "Images with Instances" + + class_counts_md = f"| Class Name | {count_type_label} Count |\n|------------|-------------|\n" + for cls, count in class_image_counts.items(): + class_counts_md += f"| {cls} | {count} |\n" + else: + class_counts_md = "No class image counts available (or model not yet trained)." + placeholders["{{CLASS_IMAGE_COUNTS}}"] = class_counts_md + + final_content = template_content + for key, val in placeholders.items(): + final_content = final_content.replace(key, str(val)) + + return final_content + +def get_download_link(content, filename, mime): + """ + Generates a download link for the given content. + """ + b64 = base64.b64encode(content.encode()).decode() + return f'Download {filename}' + +def generate_zip(generated_examples): + """ + Generates a ZIP file from the generated examples. Returns the ZIP as bytes. + """ + # Corrected to use a local BytesIO object + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf: + for fname, content in generated_examples.items(): + zipf.writestr(fname, content) + zip_buffer.seek(0) + return zip_buffer.read() + +def detect_model_location(): + """ + Attempts to detect if the model URL is on Huggingface or GitHub. + Sets 'model_location' in session_state accordingly. + """ + model_url = st.session_state.get('model_url', "").strip() + if not model_url: + st.warning("[UPLOAD] Model URL is empty. Please provide a valid URL.") + st.session_state['model_location'] = None + return + + parsed_url = urlparse(model_url) + if 'huggingface.co' in parsed_url.netloc: + st.session_state['model_location'] = 'huggingface' + elif 'github.com' in parsed_url.netloc: + st.session_state['model_location'] = 'github' + else: + st.error("[UPLOAD] Unsupported model URL. Please provide a Huggingface or GitHub URL.") + st.session_state['model_location'] = None + +# ---------------------------------------------------------------------------- +# MAIN APP LOGIC +# ---------------------------------------------------------------------------- + +st.title("Rolo: Roboflow & Ultralytics Combined Model Training Dashboard") + +steps = [ + "Prepare Datasets", + "Class Selection", + "Merge & Adjust", + "Finalize Dataset", + "Train Model", + "Upload Model", + "Generate Examples" +] + +current_workflow_step_index = 0 +if 'workflow_step' in st.session_state: + try: + current_workflow_step_index = steps.index(st.session_state['workflow_step']) + except ValueError: + current_workflow_step_index = 0 + +st.session_state['workflow_step'] = st.radio( + "Workflow Steps", + steps, + index=current_workflow_step_index +) + +opt_json_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/optimizers.json" +optimizers_dict = load_optimizers(opt_json_url) +if not optimizers_dict: + st.warning("Could not load external optimizers.json. Fallback to default if needed.") + +available_optimizers = list(optimizers_dict.keys()) if optimizers_dict else ["Adam", "SGD", "RMSProp"] + +recommended_datasets_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/recommended_datasets.json" +recommended_datasets = load_recommended_datasets(recommended_datasets_url) +if not recommended_datasets: + st.warning("Could not load external recommended_datasets.json. No recommended detection datasets available.") + +presets_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/presets.json" +presets = load_presets(presets_url) + +# ---------------------------------------------------------------------------- +# SIDEBAR: SAVE / LOAD CONFIG +# ---------------------------------------------------------------------------- + +with st.sidebar.expander("Save / Load Configuration"): + if st.button("Save Current Configuration"): + config = { + "rf_api_key": st.session_state.get("rf_api_key", ""), + "epochs": st.session_state.get("epochs", 150), + "batch_size": st.session_state.get("batch_size", 32), + "img_size": st.session_state.get("img_size", 640), + "learning_rate": st.session_state.get("learning_rate", 0.0005), + "optimizer": st.session_state.get("optimizer", "Adam"), + "data_augmentation_level": st.session_state.get("data_augmentation_level", "Moderate"), + "custom_model_name": st.session_state.get("custom_model_name", "rolo_trained_model"), + "selected_architecture": st.session_state.get("selected_architecture", ""), + "selected_category": st.session_state.get("selected_category", ""), + "pre_trained_model": st.session_state.get("pre_trained_model", ""), + "custom_model_path": st.session_state.get("custom_model_path", ""), + "show_debug": st.session_state.get("show_debug", False), + "use_half_precision": st.session_state.get("use_half_precision", True), + "early_stop_patience": st.session_state.get("early_stop_patience", 10), + "lrf": st.session_state.get("lrf", 0.01), + "warmup_epochs": st.session_state.get("warmup_epochs", 3), + "cache_dataset": st.session_state.get("cache_dataset", False), + "freeze_layers": st.session_state.get("freeze_layers", 0), + "hsv_h": st.session_state.get("hsv_h", 0.015), + "hsv_s": st.session_state.get("hsv_s", 0.7), + "hsv_v": st.session_state.get("hsv_v", 0.4), + "train_ratio": st.session_state.get("train_ratio", 0.7), + "val_ratio": st.session_state.get("val_ratio", 0.2), + "test_ratio": st.session_state.get("test_ratio", 0.1), + } + config_json = json.dumps(config, indent=4) + + st.download_button( + label="Download Current Configuration", + data=config_json, + file_name="rolo_config.json", + mime="application/json" + ) + + uploaded_profile = st.file_uploader("Load Configuration Profile", type=["json"]) + if uploaded_profile is not None: + try: + config = json.load(uploaded_profile) + for key, value in config.items(): + if key in st.session_state: + st.session_state[key] = value + st.success("Configuration loaded successfully.") + except Exception as e: + st.error(f"Failed to load configuration: {e}") + +# ---------------------------------------------------------------------------- +# SIDEBAR: PRESETS +# ---------------------------------------------------------------------------- + +st.sidebar.header("Presets") +if presets: + preset_names = list(presets.keys()) + selected_preset = st.sidebar.selectbox("Load a Preset", ["None"] + preset_names) + if selected_preset != "None": + preset = presets[selected_preset] + for key, value in preset.items(): + st.session_state[key] = value + st.sidebar.success(f"Preset '{selected_preset}' loaded successfully.") +else: + st.sidebar.warning("Could not load presets from the external JSON.") + +# ---------------------------------------------------------------------------- +# SIDEBAR: ROBOFLOW CONFIG +# ---------------------------------------------------------------------------- + +st.sidebar.header("Roboflow Configuration") +rf_api_key = st.sidebar.text_input( + "Enter your Roboflow API Key", + type="password", + value=st.session_state.get("rf_api_key", ""), + help="Your secret Roboflow API Key." +) +st.session_state["rf_api_key"] = rf_api_key + +# ---------------------------------------------------------------------------- +# SIDEBAR: TRAINING SETTINGS +# ---------------------------------------------------------------------------- + +st.sidebar.header("Training Settings") +epochs = st.sidebar.slider( + "Number of Epochs", + min_value=1, + max_value=500, + value=st.session_state.get("epochs", 150), + help="How many epochs to train for?" +) +st.session_state["epochs"] = epochs + +batch_size_options = [8, 16, 32, 64, 128] +default_batch_size = st.session_state.get("batch_size", 32) +batch_size = st.sidebar.selectbox( + "Batch Size", + options=batch_size_options, + index=batch_size_options.index(default_batch_size) + if default_batch_size in batch_size_options else 2, + help="Batch size per training step." +) +st.session_state["batch_size"] = batch_size + +img_size = st.sidebar.number_input( + "Image Size (e.g., 640)", + min_value=64, + max_value=1280, + value=st.session_state.get("img_size", 640), + step=32, + help="Resolution of training images. Larger sizes can improve small object detection." +) +st.session_state["img_size"] = img_size + +learning_rate = st.sidebar.number_input( + "Learning Rate (lr0)", + min_value=1e-6, + max_value=0.1, + value=st.session_state.get("learning_rate", 0.0005), + format="%.5f", + step=1e-5, + help="Initial learning rate." +) +st.session_state["learning_rate"] = learning_rate + +optimizer = st.sidebar.selectbox( + "Optimizer", + options=available_optimizers, + index=available_optimizers.index( + st.session_state.get("optimizer", available_optimizers[0]) + ) if st.session_state.get("optimizer", available_optimizers[0]) in available_optimizers else 0, + help=optimizers_dict.get( + st.session_state.get("optimizer", available_optimizers[0]), + "No description available." + ) +) +st.session_state["optimizer"] = optimizer +st.sidebar.markdown(f"**Description**: {optimizers_dict.get(optimizer, 'No description available.')}", unsafe_allow_html=True) + +data_augmentation_options = ["None", "Basic", "Moderate", "Advanced"] +data_augmentation_level = st.sidebar.selectbox( + "Data Augmentation Level", + data_augmentation_options, + index=data_augmentation_options.index( + st.session_state.get("data_augmentation_level", "Moderate") + ) if st.session_state.get("data_augmentation_level", "Moderate") in data_augmentation_options else 2, + help="Choose a preset augmentation level. More aggressive levels can help with robustness but might increase training time." +) +st.session_state["data_augmentation_level"] = data_augmentation_level + +st.sidebar.subheader("Fine-tune HSV Augmentations") +st.session_state["hsv_h"] = st.sidebar.slider("HSV Hue Augmentation", min_value=0.0, max_value=0.1, value=st.session_state.get("hsv_h", 0.015), step=0.001, format="%.3f", help="Hue variation for augmentation.") +st.session_state["hsv_s"] = st.sidebar.slider("HSV Saturation Augmentation", min_value=0.0, max_value=1.0, value=st.session_state.get("hsv_s", 0.7), step=0.01, format="%.2f", help="Saturation variation for augmentation.") +st.session_state["hsv_v"] = st.sidebar.slider("HSV Value Augmentation", min_value=0.0, max_value=1.0, value=st.session_state.get("hsv_v", 0.4), step=0.01, format="%.2f", help="Value (brightness) variation for augmentation.") + + +custom_model_name = st.sidebar.text_input( + "Custom Model Name", + value=st.session_state.get("custom_model_name", "rolo_trained_model"), + help="Name of the run, used for the training folder (e.g., runs/train/my_custom_model_name)." +) +st.session_state["custom_model_name"] = custom_model_name + +use_half_precision = st.sidebar.checkbox( + "Enable Half-Precision (FP16) on GPU", + value=st.session_state.get("use_half_precision", True), + help="Requires a GPU (CUDA) for acceleration. If enabled on CPU, it will be ignored." +) +st.session_state["use_half_precision"] = use_half_precision + +early_stop_patience = st.sidebar.number_input( + "Early Stopping Patience (epochs)", + min_value=1, + max_value=100, + value=st.session_state.get("early_stop_patience", 10), + help="Stop training if no improvement in validation loss for this many epochs." +) +st.session_state["early_stop_patience"] = early_stop_patience + +st.sidebar.subheader("Learning Rate Schedule") +st.session_state["lrf"] = st.sidebar.number_input( + "Final LR Factor (lrf)", + min_value=0.000, + max_value=1.0, + value=st.session_state.get("lrf", 0.01), + step=0.001, + format="%.3f", + help="Learning rate final multiplier (lr = lr0 * lrf). Affects learning rate schedule shape." +) +st.session_state["warmup_epochs"] = st.sidebar.number_input( + "Warmup Epochs", + min_value=0, + max_value=20, + value=st.session_state.get("warmup_epochs", 3), + step=1, + help="Number of epochs for learning rate warmup. Gradually increases LR at start." +) + +st.sidebar.subheader("Transfer Learning Options") +st.session_state["freeze_layers"] = st.sidebar.number_input( + "Freeze Layers", + min_value=0, + max_value=30, + value=st.session_state.get("freeze_layers", 0), + step=1, + help="Number of initial layers to freeze during training. Reduces computation and helps prevent overfitting on small datasets. Set to 0 for no freezing." +) + +st.sidebar.subheader("Dataset Caching") +st.session_state["cache_dataset"] = st.sidebar.checkbox( + "Cache Dataset for Training", + value=st.session_state.get("cache_dataset", False), + help="If checked, Ultralytics will cache preprocessed images. Set to True for RAM caching. Can significantly speed up training on large datasets but requires more memory." +) + +st.sidebar.header("Model Architecture & Category") +model_config_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/model_config.json" +model_config = cache_load_model_configs(model_config_url) +if not model_config: + st.warning("Could not load model configuration. Please check your internet connection.") + st.stop() + +architecture_options = list(model_config.keys()) +selected_architecture = st.sidebar.selectbox( + "Select Model Architecture", + options=architecture_options, + index=architecture_options.index(st.session_state.get("selected_architecture", architecture_options[0])) if st.session_state.get("selected_architecture", architecture_options[0]) in architecture_options else 0 +) +st.session_state["selected_architecture"] = selected_architecture + +all_cats_dict = model_config[selected_architecture]["categories"] +valid_categories = [cat for cat, models in all_cats_dict.items() if len(models) > 0] + +selected_category = st.sidebar.selectbox( + "Select Model Category (Task)", + valid_categories, + index=valid_categories.index(st.session_state.get("selected_category", valid_categories[0])) if valid_categories and st.session_state.get("selected_category", valid_categories[0]) in valid_categories else 0 +) +st.session_state["selected_category"] = selected_category + +pre_trained_models_info = model_config[selected_architecture]["categories"].get(selected_category, []) +available_pretrained_filenames = [m["filename"] for m in pre_trained_models_info] +pre_trained_options = available_pretrained_filenames + ["Custom"] + +pre_trained_model = st.sidebar.selectbox( + "Select Pre-Trained Model", + options=pre_trained_options, + index=pre_trained_options.index(st.session_state.get("pre_trained_model", available_pretrained_filenames[0])) if pre_trained_options and st.session_state.get("pre_trained_model", available_pretrained_filenames[0]) in pre_trained_options else 0, + help="Choose a built-in pretrained model or 'Custom'." +) +st.session_state["pre_trained_model"] = pre_trained_model + +if pre_trained_model == "Custom": + custom_model_path = st.sidebar.text_input( + "Enter Custom Model Path", + value=st.session_state.get("custom_model_path", ""), + help="Local path to a custom .pt file." + ) + st.session_state["custom_model_path"] = custom_model_path + +show_debug = st.sidebar.checkbox( + "Show Debug Messages", + value=st.session_state.get("show_debug", False), + help="Display detailed debugging info in the UI." +) +st.session_state["show_debug"] = show_debug + +st.sidebar.markdown("---") +st.sidebar.markdown("Developed by [wuhp](https://huggingface.co/wuhp).") + +# ---------------------------------------------------------------------------- +# MAIN WORKFLOW STEPS +# ---------------------------------------------------------------------------- + +def allow_dataset_download(): + """ + After finalizing the dataset, user can download it as a ZIP, + or just the data.yaml. + """ + merged_dataset_dir = st.session_state.get('merged_dataset_dir', None) + if merged_dataset_dir and os.path.exists(merged_dataset_dir): + st.write("## Download Merged Dataset") + zip_buf = zip_directory(merged_dataset_dir) + st.download_button( + label="Download Entire Merged Dataset (ZIP)", + data=zip_buf.getvalue(), + file_name="merged_dataset.zip", + mime="application/octet-stream" + ) + yaml_path = os.path.join(merged_dataset_dir, "data.yaml") + if os.path.exists(yaml_path): + with open(yaml_path, "r") as f: + yaml_data = f.read() + st.download_button( + label="Download data.yaml", + data=yaml_data, + file_name="data.yaml", + mime="text/yaml" + ) + +# ---------------------------------------------------------------------------- +# 1) Prepare Datasets +# 2) Class Selection +# 3) Merge & Adjust +# 4) Finalize Dataset +# 5) Train Model +# 6) Upload Model +# 7) Generate Examples +# ---------------------------------------------------------------------------- + +if st.session_state['workflow_step'] == "Prepare Datasets": + st.subheader("Step 1: Prepare Datasets") + st.markdown("Choose to load a recommended dataset (for **detection only**) or add your own Roboflow dataset URLs.") + st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**") + + tabs = st.tabs(["Recommended Datasets (Detection Only)", "Custom Roboflow Datasets"]) + + with tabs[0]: + if st.session_state['selected_category'] != "detection": + st.warning("Recommended datasets are currently only available for **Detection** task. Please select 'Detection' in the sidebar or use 'Custom Roboflow Datasets' for other tasks.") + + if recommended_datasets and st.session_state['selected_category'] == "detection": + st.markdown("### Select Recommended Datasets (for Detection)") + selected_recommended = st.multiselect( + "Choose datasets to load:", + options=[dataset["name"] for dataset in recommended_datasets], + format_func=lambda x: next(item for item in recommended_datasets if item["name"] == x)["description"] + ) + + if st.button("Load Selected Recommended Datasets"): + if not selected_recommended: + st.error("Please select at least one recommended dataset.") + else: + dataset_info_list = [] + progress_bar = st.progress(0) + for idx, dataset_name in enumerate(selected_recommended): + dataset = next(item for item in recommended_datasets if item["name"] == dataset_name) + try: + source = dataset["source"] + url = dataset["url"] + if source in ["GitHub", "Ultralytics"]: + github_zip_url = url + dloc, class_names, splits = download_and_prepare_github_dataset( + github_zip_url, + dataset_name + ) + elif source == "Roboflow": + workspace, project, version = parse_roboflow_url(url) + if not workspace or not project: + st.warning(f"Roboflow URL '{url}' is invalid.") + continue + + if not version: + version = get_latest_version(Roboflow(api_key=rf_api_key), workspace, project) + if not version: + st.warning(f"Could not retrieve latest version for '{project}'. Skipping.") + continue + + st.info(f"[DATASET PREP] Downloading Roboflow dataset '{project}' (v{version})...") + dloc, class_names, splits = download_and_prepare_roboflow_dataset( + rf_api_key, + workspace, + project, + version, + st.session_state["selected_architecture"], + st.session_state["selected_category"] + ) + else: + st.error(f"Unknown source '{source}' for dataset '{dataset_name}'.") + continue + + if dloc: + dataset_info_list.append((dloc, class_names, splits, dataset_name)) + st.success(f"Loaded recommended dataset '{dataset_name}'.") + except Exception as e: + st.error(f"Failed to load dataset '{dataset_name}': {e}") + + progress_bar.progress((idx + 1) / len(selected_recommended)) + + if dataset_info_list: + st.session_state['dataset_info_list'].extend(dataset_info_list) + st.session_state['dataset_prepared'] = True + st.success("Selected recommended detection datasets have been loaded successfully.") + st.session_state['workflow_step'] = "Class Selection" + elif not recommended_datasets: + st.warning("No recommended detection datasets available. Please check the JSON file.") + + with tabs[1]: + st.markdown("### Add Custom Roboflow Datasets") + st.markdown("Upload a `.txt` file containing Roboflow dataset URLs, one per line.") + + custom_rf_uploaded_file = st.file_uploader( + "Upload a .txt file with Roboflow dataset URLs", + type=["txt"], + help="Example: https://universe.roboflow.com/workspace/project/version..." + ) + + st.markdown("Split ratios for your dataset (will be applied *before* merging):") + col_split_ratio1, col_split_ratio2, col_split_ratio3 = st.columns(3) + with col_split_ratio1: + st.session_state['train_ratio'] = st.slider("Train Split (%)", min_value=50, max_value=90, value=int(st.session_state.get('train_ratio', 0.7)*100), step=5) / 100.0 + with col_split_ratio2: + st.session_state['val_ratio'] = st.slider("Validation Split (%)", min_value=5, max_value=20, value=int(st.session_state.get('val_ratio', 0.2)*100), step=5) / 100.0 + with col_split_ratio3: + st.session_state['test_ratio'] = st.slider("Test Split (%)", min_value=0, max_value=20, value=int(st.session_state.get('test_ratio', 0.1)*100), step=5) / 100.0 + + sum_ratios = st.session_state['train_ratio'] + st.session_state['val_ratio'] + st.session_state['test_ratio'] + if not (0.99 <= sum_ratios <= 1.01): + st.warning(f"Warning: Split ratios sum to {sum_ratios*100:.1f}%. They should sum to 100% for balanced splitting.") + + + if st.button("Load Custom Roboflow Datasets"): + if not rf_api_key: + st.error("Please enter your Roboflow API Key.") + elif not custom_rf_uploaded_file: + st.error("Please upload a `.txt` file containing Roboflow dataset URLs.") + else: + try: + content = custom_rf_uploaded_file.read().decode("utf-8") + urls = content.strip().split('\n') + urls = [url.strip() for url in urls if url.strip()] + if not urls: + st.error("The uploaded file is empty or contains invalid URLs.") + st.stop() + except Exception as e: + st.error(f"Failed to read the uploaded file: {e}") + st.stop() + + dataset_info_list = [] + progress_bar = st.progress(0) + for idx, url in enumerate(urls): + workspace, project, version = parse_roboflow_url(url) + if not workspace or not project: + st.warning(f"URL '{url}' is invalid.") + continue + + if not version: + version = get_latest_version(Roboflow(api_key=rf_api_key), workspace, project) + if not version: + st.warning(f"Could not retrieve latest version for '{project}'. Skipping.") + continue + + st.info(f"[DATASET PREP] Downloading dataset '{project}' (v{version}) from workspace '{workspace}'...") + dloc, class_names, splits = download_and_prepare_roboflow_dataset( + rf_api_key, + workspace, + project, + version, + st.session_state["selected_architecture"], + st.session_state["selected_category"] + ) + if dloc: + dataset_info_list.append((dloc, class_names, splits, f"{project}_v{version}")) + st.success(f"Dataset '{project}' (v{version}) prepared at '{dloc}'") + progress_bar.progress((idx + 1) / len(urls)) + + if not dataset_info_list: + st.error("No datasets were prepared successfully. Please check the URLs and try again.") + else: + st.session_state['dataset_info_list'].extend(dataset_info_list) + st.session_state['dataset_prepared'] = True + st.success("All custom Roboflow datasets have been prepared successfully.") + st.session_state['workflow_step'] = "Class Selection" + +elif st.session_state['workflow_step'] == "Class Selection": + if not st.session_state.get('dataset_prepared', False): + st.warning("Datasets not prepared yet. Go to 'Prepare Datasets' first.") + else: + st.subheader("Step 2: Class Selection, Renaming, and Removal") + st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**") + + class_name_mapping = st.session_state['class_name_mapping'] + class_image_counts_pre_merge = gather_class_counts( + st.session_state['dataset_info_list'], + class_name_mapping, + st.session_state['selected_category'] + ) + st.session_state['class_image_counts'] = class_image_counts_pre_merge + + all_class_names_raw = [] + for _, class_names, _, _ in st.session_state['dataset_info_list']: + all_class_names_raw.extend(class_names) + all_class_names_raw = sorted(set(all_class_names_raw)) + + merged_names = set() + for class_name in all_class_names_raw: + new_name = class_name_mapping.get(class_name, class_name) + merged_names.add(new_name) + merged_names = sorted(list(merged_names)) + + st.write("### Current Merged Class Distribution:") + if st.session_state['selected_category'] == "classification": + st.write("Counts represent **images** per class.") + else: + st.write("Counts represent **images that contain at least one annotation** for a class.") + for cls in merged_names: + st.write(f"- **{cls}**: {class_image_counts_pre_merge.get(cls, 0)} available") + + with st.form("class_selection_form"): + classes_to_include = {} + new_mapping = {} + + st.markdown(""" + **Instructions**: + - Use **"Rename Class"** to merge classes by assigning the **same** new name to multiple classes. + - Use **"Max Images"** to limit how many images are pulled from that class. + - Check **"Remove"** if you want to exclude the class entirely. + """) + + enable_fuzzy_suggestions = st.checkbox("Enable Fuzzy Class Renaming Suggestions", value=False, help="Suggests similar existing class names for renaming.") + + for c in all_class_names_raw: + current_final_name = class_name_mapping.get(c, c) + current_avail = class_image_counts_pre_merge.get(current_final_name, 0) + + remove_checkbox_col, rename_col, images_col = st.columns([1, 2.5, 2.5]) + + with remove_checkbox_col: + remove_class = st.checkbox( + "Remove", + value=False, + key=f"remove_class_{c}", + help=f"Check to exclude '{c}' entirely." + ) + + with rename_col: + rename_val = st.text_input( + label=f"Rename '{c}'", + value=current_final_name, + key=f"rename_input_{c}" + ) + if enable_fuzzy_suggestions: + suggestions = get_close_matches(rename_val, all_class_names_raw, n=3, cutoff=0.7) + suggestions = [s for s in suggestions if s != c and s != rename_val] + if suggestions: + st.markdown(f"Suggestions: {', '.join(suggestions)}", unsafe_allow_html=True) + + with images_col: + max_value_for_input = current_avail if current_avail > 0 else 0 + default_value = current_avail + disabled_flag = remove_class + + new_count = st.number_input( + f"Max Images (Avail: {current_avail})", + min_value=0, + max_value=max_value_for_input, + value=default_value, + step=1, + disabled=disabled_flag, + key=f"max_images_{c}" + ) + if remove_class: + new_count = 0 + + new_mapping[c] = rename_val.strip() + classes_to_include[c] = new_count + + st.info("If multiple classes share the same renamed name, they are merged into one. Counts refer to images.") + confirm = st.form_submit_button("Confirm Selections & Renaming") + + if confirm: + st.session_state['classes_to_include'] = classes_to_include + st.session_state['class_name_mapping'] = new_mapping + st.session_state['merging_confirmed'] = True + st.success("Class selections, renaming, and removal confirmed.") + st.session_state['workflow_step'] = "Merge & Adjust" + +elif st.session_state['workflow_step'] == "Merge & Adjust": + if not st.session_state.get('merging_confirmed', False): + st.warning("Please complete 'Class Selection' first.") + else: + st.subheader("Step 3: Merge & Adjust Datasets") + st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**") + + if st.button("Merge Datasets"): + merged_dataset_dir = None + if st.session_state['selected_category'] == "classification": + merged_dataset_dir = merge_classification_datasets( + st.session_state['dataset_info_list'], + st.session_state['classes_to_include'], + st.session_state['class_name_mapping'], + show_debug + ) + else: + merged_dataset_dir = merge_datasets( + st.session_state['dataset_info_list'], + st.session_state['classes_to_include'], + st.session_state['class_name_mapping'], + show_debug + ) + if merged_dataset_dir: + st.success(f"Merged dataset created at '{merged_dataset_dir}'. You can now adjust it.") + + if st.session_state.get('merged_dataset_dir', None): + adjust_selected_images() + +elif st.session_state['workflow_step'] == "Finalize Dataset": + if not st.session_state.get('adjustment_confirmed', False): + st.warning("Please complete 'Merge & Adjust' first.") + else: + st.subheader("Step 4: Finalize Merged Dataset") + st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**") + + st.markdown("### Automated Dataset Splitting") + st.info("You can re-split the merged dataset into train/validation/test sets after finalization. This step assumes a YOLO-compatible dataset structure with 'images' and 'labels' folders within each split.") + + col_split_auto_1, col_split_auto_2, col_split_auto_3 = st.columns(3) + with col_split_auto_1: + st.session_state['train_ratio'] = st.number_input("Train Ratio", min_value=0.0, max_value=1.0, value=st.session_state.get("train_ratio", 0.7), step=0.05, format="%.2f") + with col_split_auto_2: + st.session_state['val_ratio'] = st.number_input("Validation Ratio", min_value=0.0, max_value=1.0, value=st.session_state.get("val_ratio", 0.2), step=0.05, format="%.2f") + with col_split_auto_3: + st.session_state['test_ratio'] = st.number_input("Test Ratio", min_value=0.0, max_value=1.0, value=st.session_state.get("test_ratio", 0.1), step=0.05, format="%.2f") + + sum_ratios = st.session_state['train_ratio'] + st.session_state['val_ratio'] + st.session_state['test_ratio'] + if not (0.99 <= sum_ratios <= 1.01): + st.warning(f"Warning: Split ratios sum to {sum_ratios*100:.1f}%. It is recommended they sum to 100%.") + + if st.button("Finalize Dataset"): + finalize_merged_dataset() + if st.session_state.get('dataset_finalized', False): + st.session_state['workflow_step'] = "Train Model" + st.success("Dataset finalized. Ready for training!") + + if st.session_state.get('dataset_location') and st.button("Re-split Finalized Dataset"): + merged_dataset_path = st.session_state['dataset_location'] + if merged_dataset_path and os.path.exists(merged_dataset_path): + if st.session_state['selected_category'] == "classification": + st.warning("Automated re-splitting is not currently supported for Classification datasets. They use a folder-per-class structure.") + else: + try: + with st.spinner("Re-splitting dataset... This might take a moment."): + for split_dir in ['train', 'valid', 'test']: + if os.path.exists(os.path.join(merged_dataset_path, split_dir)): + shutil.rmtree(os.path.join(merged_dataset_path, split_dir), onerror=handle_remove_readonly) + + autosplit( + path=merged_dataset_path, + weights=(st.session_state['train_ratio'], st.session_state['val_ratio'], st.session_state['test_ratio']), + random=True, + verbose=True + ) + st.success(f"Dataset successfully re-split into {st.session_state['train_ratio']*100}% train, {st.session_state['val_ratio']*100}% val, {st.session_state['test_ratio']*100}% test.") + + data_yaml_path = os.path.join(merged_dataset_path, 'data.yaml') + with open(data_yaml_path, 'r') as f: + data_yaml = yaml.safe_load(f) + data_yaml['train'] = 'train/images' + data_yaml['val'] = 'valid/images' + data_yaml['test'] = 'test/images' + with open(data_yaml_path, 'w') as f: + yaml.safe_dump(data_yaml, f) + + except Exception as e: + st.error(f"Failed to re-split dataset: {e}. Ensure the merged dataset has a compatible structure (e.g., images and labels directly under train/val/test subfolders).") + logging.error(f"Failed to re-split dataset: {e}") + + allow_dataset_download() + +elif st.session_state['workflow_step'] == "Train Model": + st.subheader("Step 5: Train Model") + st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**") + + auto_download_model = st.checkbox("Automatically present model download after training completes?", value=True) + + if st.button("Start Training"): + if not st.session_state.get('dataset_finalized', False): + st.error("Datasets are not fully finalized. Please finalize the dataset first.") + else: + # Clear previous metrics and progress text before starting new training + # This is done *here* before starting training, not in the callback. + st.session_state['metrics_data'] = { + 'epoch': [], 'train_loss': [], 'val_loss': [], + 'mAP50': [], 'mAP50_95': [], + 'top1_acc': [], 'top5_acc': [], + 'mAPpose50': [], 'mAPpose50_95': [], + 'precision': [], 'recall': [], 'F1': [] + } + loss_chart_placeholder.empty() + map_chart_placeholder.empty() + progress_text_placeholder.empty() + + dataset_location = st.session_state['dataset_location'] + with st.spinner("Training in progress..."): + model = train_model( + dataset_location = dataset_location, + epochs = st.session_state["epochs"], + batch_size = st.session_state["batch_size"], + img_size = st.session_state["img_size"], + learning_rate = st.session_state["learning_rate"], + optimizer = st.session_state["optimizer"], + data_augmentation_level = st.session_state["data_augmentation_level"], + pre_trained_model = st.session_state["pre_trained_model"], + custom_model_path = st.session_state.get("custom_model_path", ""), + custom_model_name = st.session_state["custom_model_name"], + show_debug = st.session_state["show_debug"], + model_config = model_config, + selected_architecture = st.session_state["selected_architecture"], + selected_category = st.session_state["selected_category"] + ) + if model: + st.success("Model training completed!") + # Clear final progress text + progress_text_placeholder.empty() + trained_weights_path = os.path.join( + "runs", "train", st.session_state["custom_model_name"], "weights", "best.pt" + ) + if auto_download_model and os.path.exists(trained_weights_path): + st.write("### Download Your Trained Model:") + with open(trained_weights_path, "rb") as f: + st.download_button( + label="Download best.pt", + data=f, + file_name=f"{st.session_state['custom_model_name']}.pt", + mime="application/octet-stream", + ) + elif not os.path.exists(trained_weights_path): + st.warning("Trained model not found. Check your training run output directory.") + st.session_state['workflow_step'] = "Upload Model" + +elif st.session_state['workflow_step'] == "Upload Model": + st.subheader("Step 6: Upload Model") + st.markdown("Upload your trained model and optionally a custom README, **OR** generate one from a template.") + st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**") + + col1, col2 = st.columns(2) + with col1: + model_file = st.file_uploader("Upload Trained Model (.pt)", type=["pt"], help="Select your trained model file.") + with col2: + readme_file = st.file_uploader("Upload README Content (.txt)", type=["txt"], help="Select a `.txt` file with your README content.") + + st.markdown("---") + st.markdown("**Or create a README from a template on GitHub**:") + st.info("Note: Default template is optimized for detection tasks and may not fully cover classification/keypoint specifics.") + generate_from_template = st.button("Generate README From Template") + + if "final_readme_content_editor" not in st.session_state: + st.session_state["final_readme_content_editor"] = "" + + if generate_from_template: + try: + template_url = "https://raw.githubusercontent.com/Wuhpondiscord/models/main/template.txt" + response = requests.get(template_url) + response.raise_for_status() + template_text = response.text + + st.session_state["final_readme_content_editor"] = customize_template( + template_text, + st.session_state.get("model_url", "N/A"), + st.session_state["selected_category"], + st.session_state["metrics_data"], + st.session_state.get("class_id_mapping", {}), + st.session_state.get('datasets_used', []), + st.session_state.get('class_image_counts', {}) + ) + st.success("README template loaded and placeholders filled. You can edit below.") + except requests.exceptions.RequestException as e: + st.error(f"Failed to fetch README template from GitHub: {e}") + except Exception as e: + st.error(f"Error generating README from template: {e}") + + if st.session_state["final_readme_content_editor"]: + st.text_area( + "Generated README Content", + value=st.session_state["final_readme_content_editor"], + height=300, + key="final_readme_content_editor" + ) + + st.markdown("### GitHub Credentials") + gh_username = st.text_input("GitHub Username", value=st.session_state.get('gh_username', ''), help="Your GitHub username or organization.") + gh_repo_name = st.text_input("GitHub Repository Name", value=st.session_state.get('gh_repo_name', ''), help="Just the name of the repository (no slashes).") + gh_personal_token = st.text_input("GitHub Personal Access Token", type="password", value=st.session_state.get('gh_personal_token', ''), help="Personal access token for authentication.") + gh_make_private = st.checkbox("Make new GitHub repo private?", value=st.session_state.get('gh_make_private', False)) + + st.session_state['gh_username'] = gh_username + st.session_state['gh_repo_name'] = gh_repo_name + st.session_state['gh_personal_token'] = gh_personal_token + st.session_state['gh_make_private'] = gh_make_private + + st.markdown("### Hugging Face Credentials") + hf_api_key = st.text_input("Hugging Face API Key", type="password", value=st.session_state.get('hf_api_key', ''), help="Your Hugging Face API key.") + hf_model_repo = st.text_input("Hugging Face Model Repository Name", value=st.session_state.get('hf_model_repo', ''), help="Format: username/repo-name (e.g., 'myusername/my-model-repo')") + + st.session_state['hf_api_key'] = hf_api_key + st.session_state['hf_model_repo'] = hf_model_repo + + + if st.button("Upload to GitHub and Hugging Face"): + final_readme_str = "" + + if readme_file is not None: + final_readme_str = readme_file.read().decode("utf-8") + elif st.session_state["final_readme_content_editor"]: + final_readme_str = st.session_state["final_readme_content_editor"] + else: + st.warning("No README content provided or generated. Uploading without a README.") + + if not model_file: + st.error("Please upload the trained model (.pt) before publishing.") + else: + temp_dir = "temp_upload_data" + os.makedirs(temp_dir, exist_ok=True) + + # --- FIX START --- + model_filename_to_upload = st.session_state.get('custom_model_name', 'best') + '.pt' + # Define all temporary paths here, before any conditional logic that uses them + temp_model_path = os.path.join(temp_dir, model_filename_to_upload) + temp_readme_path = os.path.join(temp_dir, "README.md") + temp_license_path = os.path.join(temp_dir, "Licensing.rolo") + # --- FIX END --- + + trained_weights_path = os.path.join( + "runs", "train", st.session_state["custom_model_name"], "weights", "best.pt" + ) + if os.path.exists(trained_weights_path): + shutil.copy(trained_weights_path, temp_model_path) + st.info(f"Using locally trained model: {trained_weights_path}") + else: + with open(temp_model_path, "wb") as f: + f.write(model_file.read()) + st.info("Using uploaded model file.") + + + if final_readme_str.strip(): + with open(temp_readme_path, "w", encoding="utf-8") as f: + f.write(final_readme_str) + + license_url = "https://raw.githubusercontent.com/wuhplaptop/sdadasgds/main/Licensing.rolo" + try: + lic_response = requests.get(license_url) + lic_response.raise_for_status() + with open(temp_license_path, "wb") as f: + f.write(lic_response.content) + except Exception as e: + st.error(f"Failed to download license file: {e}") + st.stop() + + if gh_username and gh_repo_name and gh_personal_token: + with st.spinner("Uploading to GitHub..."): + status_code, resp = create_github_repo(gh_username, gh_personal_token, gh_repo_name, private=gh_make_private) + if status_code in [201, 422]: + if status_code == 422: + st.info(f"GitHub repo '{gh_repo_name}' already exists. Proceeding to upload/update files.") + else: + st.success(f"GitHub repo '{gh_repo_name}' created successfully.") + else: + st.error(f"Failed to create/access GitHub repo: {resp.get('message', 'Unknown error')}") + st.stop() + + try: + gh_api_url = f"https://api.github.com/repos/{gh_username}/{gh_repo_name}/contents/" + headers = { + "Authorization": f"token {gh_personal_token}", + "Accept": "application/vnd.github.v3+json", + } + + def upload_to_github(file_path, repo_path): + check_url = gh_api_url + repo_path + check_response = requests.get(check_url, headers=headers) + sha = None + if check_response.status_code == 200: + sha = check_response.json().get('sha') + + with open(file_path, "rb") as f: + content = base64.b64encode(f.read()).decode("utf-8") + + data_payload = {"message": f"Upload {repo_path} from Rolo", "content": content} + if sha: + data_payload["sha"] = sha + + response = requests.put( + check_url, + headers=headers, + json=data_payload, + ) + return response.status_code, response.json() + + status, response_json = upload_to_github(temp_model_path, model_filename_to_upload) + if status in [200, 201]: + st.success(f"Uploaded {model_filename_to_upload} to GitHub.") + st.session_state['model_url'] = f"https://github.com/{gh_username}/{gh_repo_name}/raw/main/{model_filename_to_upload}" + else: + st.error(f"Failed to upload {model_filename_to_upload} to GitHub: {response_json.get('message', 'Unknown error')}") + + if final_readme_str.strip(): + status, response_json = upload_to_github(temp_readme_path, "README.md") + if status in [200, 201]: + st.success("Uploaded README.md to GitHub.") + else: + st.error(f"Failed to upload README.md to GitHub: {response_json.get('message', 'Unknown error')}") + + status, response_json = upload_to_github(temp_license_path, "Licensing.rolo") + if status in [200, 201]: + st.success("Uploaded Licensing.rolo to GitHub.") + else: + st.error(f"Failed to upload Licensing.rolo to GitHub: {response_json.get('message', 'Unknown error')}") + + except Exception as e: + st.error(f"Failed to upload to GitHub: {e}") + + else: + st.error("Please provide GitHub credentials to upload the model.") + + if hf_api_key and hf_model_repo: + with st.spinner("Uploading to Hugging Face..."): + try: + api = HfApi() + HfFolder.save_token(hf_api_key) + + repo_url = api.create_repo(repo_id=hf_model_repo, exist_ok=True, token=hf_api_key) + + api.upload_file( + path_or_fileobj=temp_model_path, + path_in_repo=model_filename_to_upload, + repo_id=hf_model_repo, + token=hf_api_key + ) + st.success(f"Uploaded model to Hugging Face: {repo_url}") + st.session_state['model_url'] = f"https://huggingface.co/{hf_model_repo}/blob/main/{model_filename_to_upload}" + + if final_readme_str.strip(): + api.upload_file( + path_or_fileobj=temp_readme_path, + path_in_repo="README.md", + repo_id=hf_model_repo, + token=hf_api_key + ) + st.success("Uploaded README.md to Hugging Face.") + + api.upload_file( + path_or_fileobj=temp_license_path, + path_in_repo="Licensing.rolo", + repo_id=hf_model_repo, + token=hf_api_key + ) + st.success("Uploaded Licensing.rolo to Hugging Face.") + + except Exception as e: + st.error(f"Failed to upload to Hugging Face: {e}") + else: + st.warning("Hugging Face credentials not provided. Skipping Hugging Face upload.") + + shutil.rmtree(temp_dir, ignore_errors=True) + st.success("Upload process completed.") + + if st.button("Proceed to Generate Examples"): + st.session_state['workflow_step'] = "Generate Examples" + st.rerun() + +elif st.session_state['workflow_step'] == "Generate Examples": + st.subheader("Step 7: Generate Code Examples") + st.markdown(""" + These templates are primarily designed for **detection tasks**. + If you need examples for other tasks (classification, keypoint), + you'll have to provide or create different templates on the GitHub repository. + """) + st.info(f"Current selected task: **{st.session_state['selected_category'].capitalize()}**") + + if st.button("Detect Model Location"): + detect_model_location() + if st.session_state['model_location']: + st.success(f"Model detected on {st.session_state['model_location'].capitalize()}: {st.session_state['model_url']}") + else: + st.warning("Could not automatically detect model location. Please input the model URL manually.") + + if not st.session_state['model_location'] or not st.session_state.get('model_url'): + model_url_input = st.text_input("Enter your Model URL (Huggingface or GitHub)", value=st.session_state.get('model_url', "")) + if st.button("Set Model URL for Examples"): + if model_url_input: + st.session_state['model_url'] = model_url_input + parsed_url = urlparse(model_url_input) + if 'huggingface.co' in parsed_url.netloc: + st.session_state['model_location'] = 'huggingface' + elif 'github.com' in parsed_url.netloc: + st.session_state['model_location'] = 'github' + else: + st.error("Unsupported model URL. Please provide a Huggingface or GitHub URL.") + else: + st.error("Please enter a valid model URL.") + + if st.session_state.get('model_location') and st.session_state.get('model_url'): + example_type = st.selectbox("Select Example Type", ["Python", "Colab"]) + if st.button("Generate Examples"): + owner = "Wuhpondiscord" + repo = "models" + path = f"examples/{example_type.lower()}" + token = st.session_state.get('gh_personal_token') + + files = list_github_files(owner, repo, path, token=token) + if not files: + st.info(f"No example templates found for {example_type} for this task type. Please add them to '{owner}/{repo}/{path}'.") + else: + st.session_state['templates_found'] = True + st.success(f"Found {len(files)} template(s) for {example_type} examples in {path}.") + + generated_examples = {} + for fname in files: + template_content = fetch_github_file(owner, repo, f"{path}/{fname}", token=token) + if template_content: + customized_content = customize_template( + template_content, + st.session_state['model_url'], + st.session_state["selected_category"], + st.session_state["metrics_data"], + st.session_state.get("class_id_mapping", {}), + st.session_state.get('datasets_used', []), + st.session_state.get('class_image_counts', {}) + ) + generated_examples[fname] = customized_content + + if generated_examples: + st.markdown("### Generated Examples") + for fname, content in generated_examples.items(): + st.markdown(f"**{fname}**") + st.text_area(f"Example: {fname}", value=content, height=300, key=f"example_text_{fname}") + download_link = get_download_link(content, fname, "text/plain") + st.markdown(download_link, unsafe_allow_html=True) + + zip_bytes = generate_zip(generated_examples) + st.download_button( + label="Download All Examples as ZIP", + data=zip_bytes, + file_name=f"{example_type.lower()}_examples.zip", + mime="application/zip" + ) \ No newline at end of file