|
|
import os |
|
|
import shutil |
|
|
import stat |
|
|
import yaml |
|
|
import gradio as gr |
|
|
from ultralytics import YOLO |
|
|
from roboflow import Roboflow |
|
|
import re |
|
|
from urllib.parse import urlparse |
|
|
import random |
|
|
import logging |
|
|
import requests |
|
|
import json |
|
|
from PIL import Image |
|
|
import torch |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
from threading import Thread |
|
|
from queue import Queue |
|
|
from huggingface_hub import HfApi, HfFolder |
|
|
import base64 |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
RTDETR_MODELS = { |
|
|
"detection": [ |
|
|
{ |
|
|
"filename": "rtdetr-l.pt", |
|
|
"url": "https://github.com/ultralytics/assets/releases/download/v8.0.0/rtdetr-l.pt", |
|
|
"description": "RT-DETR Large model (Default)" |
|
|
}, |
|
|
{ |
|
|
"filename": "rtdetr-x.pt", |
|
|
"url": "https://github.com/ultralytics/assets/releases/download/v8.0.0/rtdetr-x.pt", |
|
|
"description": "RT-DETR Extra-Large model." |
|
|
} |
|
|
] |
|
|
} |
|
|
DEFAULT_MODEL = "rtdetr-l.pt" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_remove_readonly(func, path, exc_info): |
|
|
"""Error handler for shutil.rmtree.""" |
|
|
try: |
|
|
os.chmod(path, stat.S_IWRITE) |
|
|
except Exception: |
|
|
pass |
|
|
func(path) |
|
|
|
|
|
|
|
|
_ROBO_URL_RX = re.compile( |
|
|
r""" |
|
|
^(?: |
|
|
(?:https?://)?(?:universe|app|www)?\.?roboflow\.com/ # Any roboflow host |
|
|
(?P<ws>[A-Za-z0-9\-_]+)/ # workspace |
|
|
(?P<proj>[A-Za-z0-9\-_]+)/? # project |
|
|
(?: |
|
|
(?:dataset/[^/]+/)? # optional 'dataset/<fmt>/' |
|
|
(?:v?(?P<ver>\d+))? # optional version 'vN' or 'N' |
|
|
)? |
|
|
| |
|
|
(?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))? # raw ws/proj[/vN] |
|
|
)$ |
|
|
""", |
|
|
re.VERBOSE | re.IGNORECASE |
|
|
) |
|
|
|
|
|
def parse_roboflow_url(s: str): |
|
|
""" |
|
|
Accepts: |
|
|
- https://universe.roboflow.com/<workspace>/<project>[/vN | /N] |
|
|
- https://app.roboflow.com/<workspace>/<project>[/vN | /N] |
|
|
- https://roboflow.com/<workspace>/<project>[/vN | /N] |
|
|
- raw: <workspace>/<project>[/vN | /N] |
|
|
Returns: (workspace, project, version_or_None) |
|
|
""" |
|
|
s = s.strip() |
|
|
|
|
|
m = _ROBO_URL_RX.match(s) |
|
|
if m: |
|
|
ws = m.group('ws') or m.group('ws2') |
|
|
proj = m.group('proj') or m.group('proj2') |
|
|
ver = m.group('ver') or m.group('ver2') |
|
|
return ws, proj, (int(ver) if ver else None) |
|
|
|
|
|
|
|
|
parsed = urlparse(s) |
|
|
parts = [p for p in parsed.path.strip('/').split('/') if p] |
|
|
if len(parts) >= 2: |
|
|
|
|
|
version = None |
|
|
if len(parts) >= 3: |
|
|
|
|
|
vpart = parts[2] |
|
|
if vpart.lower().startswith('v') and vpart[1:].isdigit(): |
|
|
version = int(vpart[1:]) |
|
|
elif vpart.isdigit(): |
|
|
version = int(vpart) |
|
|
return parts[0], parts[1], version |
|
|
|
|
|
|
|
|
if '/' in s and 'roboflow' not in s: |
|
|
p = s.split('/') |
|
|
if len(p) >= 2: |
|
|
|
|
|
version = None |
|
|
if len(p) >= 3: |
|
|
v = p[2] |
|
|
if v.lower().startswith('v') and v[1:].isdigit(): |
|
|
version = int(v[1:]) |
|
|
elif v.isdigit(): |
|
|
version = int(v) |
|
|
return p[0], p[1], version |
|
|
|
|
|
return None, None, None |
|
|
|
|
|
|
|
|
def get_latest_version(api_key, workspace, project): |
|
|
"""Gets the latest version number of a Roboflow project.""" |
|
|
try: |
|
|
rf = Roboflow(api_key=api_key) |
|
|
proj = rf.workspace(workspace).project(project) |
|
|
versions = sorted([int(v.version) for v in proj.versions()], reverse=True) |
|
|
return versions[0] if versions else None |
|
|
except Exception as e: |
|
|
logging.error(f"Could not get latest version for {workspace}/{project}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def _extract_class_names(data_yaml): |
|
|
""" |
|
|
Return a list[str] of class names in index order. |
|
|
Handles: |
|
|
- list (possibly containing non-str types) |
|
|
- dict with numeric keys (e.g., {0: 'cat', 1: 'dog'}) |
|
|
- fallback to ['class_0', ..., f'class_{nc-1}'] if names missing |
|
|
""" |
|
|
names = data_yaml.get('names', None) |
|
|
|
|
|
if isinstance(names, dict): |
|
|
def _k(x): |
|
|
try: |
|
|
return int(x) |
|
|
except Exception: |
|
|
return str(x) |
|
|
ordered_keys = sorted(names.keys(), key=_k) |
|
|
names_list = [names[k] for k in ordered_keys] |
|
|
elif isinstance(names, list): |
|
|
names_list = names |
|
|
else: |
|
|
nc = data_yaml.get('nc', 0) |
|
|
try: |
|
|
nc = int(nc) |
|
|
except Exception: |
|
|
nc = 0 |
|
|
names_list = [f"class_{i}" for i in range(nc)] |
|
|
|
|
|
return [str(x) for x in names_list] |
|
|
|
|
|
|
|
|
def download_dataset(api_key, workspace, project, version): |
|
|
"""Downloads a single dataset from Roboflow (yolov8 format works fine for RT-DETR).""" |
|
|
try: |
|
|
rf = Roboflow(api_key=api_key) |
|
|
proj = rf.workspace(workspace).project(project) |
|
|
ver = proj.version(int(version)) |
|
|
dataset = ver.download("yolov8") |
|
|
|
|
|
data_yaml_path = os.path.join(dataset.location, 'data.yaml') |
|
|
with open(data_yaml_path, 'r') as f: |
|
|
data_yaml = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
class_names = _extract_class_names(data_yaml) |
|
|
try: |
|
|
nc = int(data_yaml.get('nc', len(class_names))) |
|
|
except Exception: |
|
|
nc = len(class_names) |
|
|
if len(class_names) != nc: |
|
|
logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.") |
|
|
|
|
|
splits = [s for s in ['train', 'valid', 'test'] |
|
|
if os.path.exists(os.path.join(dataset.location, s))] |
|
|
|
|
|
return dataset.location, class_names, splits, f"{project}-v{version}" |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}") |
|
|
return None, [], [], None |
|
|
|
|
|
|
|
|
def label_path_for(img_path: str) -> str: |
|
|
"""Convert .../split/images/file.jpg -> .../split/labels/file.txt in a safe way.""" |
|
|
split_dir = os.path.dirname(os.path.dirname(img_path)) |
|
|
base = os.path.splitext(os.path.basename(img_path))[0] + '.txt' |
|
|
return os.path.join(split_dir, 'labels', base) |
|
|
|
|
|
|
|
|
def gather_class_counts(dataset_info, class_mapping): |
|
|
""" |
|
|
Count, per final class, how many images contain at least one instance of that class |
|
|
(counted once per image). class_mapping maps original_name -> final_name. |
|
|
""" |
|
|
if not dataset_info: |
|
|
return {} |
|
|
|
|
|
final_names = set(class_mapping.values()) |
|
|
counts = {name: 0 for name in final_names} |
|
|
|
|
|
for loc, names, splits, _ in dataset_info: |
|
|
|
|
|
id_to_name = {} |
|
|
for idx, n in enumerate(names): |
|
|
id_to_name[idx] = class_mapping.get(n, None) |
|
|
|
|
|
for split in splits: |
|
|
labels_dir = os.path.join(loc, split, 'labels') |
|
|
if not os.path.exists(labels_dir): |
|
|
continue |
|
|
for label_file in os.listdir(labels_dir): |
|
|
if not label_file.endswith('.txt'): |
|
|
continue |
|
|
found = set() |
|
|
with open(os.path.join(labels_dir, label_file), 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
cls_id = int(parts[0]) |
|
|
mapped = id_to_name.get(cls_id, None) |
|
|
if mapped in final_names: |
|
|
found.add(mapped) |
|
|
except Exception: |
|
|
continue |
|
|
for m in found: |
|
|
counts[m] += 1 |
|
|
|
|
|
return counts |
|
|
|
|
|
|
|
|
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()): |
|
|
"""Core function to merge datasets based on user rules.""" |
|
|
merged_dir = 'rolo_merged_dataset' |
|
|
if os.path.exists(merged_dir): |
|
|
shutil.rmtree(merged_dir, onerror=handle_remove_readonly) |
|
|
|
|
|
progress(0, desc="Creating directories...") |
|
|
for split in ['train', 'valid', 'test']: |
|
|
os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True) |
|
|
os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True) |
|
|
|
|
|
|
|
|
active_classes = [cls for cls, limit in class_limits.items() if limit > 0] |
|
|
active_classes = sorted(set(active_classes)) |
|
|
final_class_map = {name: i for i, name in enumerate(active_classes)} |
|
|
|
|
|
|
|
|
all_images = [] |
|
|
for loc, _, splits, _ in dataset_info: |
|
|
for split in splits: |
|
|
img_dir = os.path.join(loc, split, 'images') |
|
|
if not os.path.exists(img_dir): |
|
|
continue |
|
|
for img_file in os.listdir(img_dir): |
|
|
if img_file.lower().endswith(('.jpg', '.jpeg', '.png')): |
|
|
all_images.append((os.path.join(img_dir, img_file), split, loc)) |
|
|
random.shuffle(all_images) |
|
|
|
|
|
progress(0.2, desc="Selecting images based on limits...") |
|
|
selected_images = [] |
|
|
current_counts = {cls: 0 for cls in active_classes} |
|
|
|
|
|
|
|
|
loc_to_names = {info[0]: info[1] for info in dataset_info} |
|
|
|
|
|
for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"): |
|
|
lbl_path = label_path_for(img_path) |
|
|
if not os.path.exists(lbl_path): |
|
|
continue |
|
|
|
|
|
source_names = loc_to_names.get(source_loc, []) |
|
|
image_classes = set() |
|
|
with open(lbl_path, 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
cls_id = int(parts[0]) |
|
|
orig = source_names[cls_id] |
|
|
mapped = class_mapping.get(orig, orig) |
|
|
if mapped in active_classes: |
|
|
image_classes.add(mapped) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if not image_classes: |
|
|
continue |
|
|
|
|
|
|
|
|
if any(current_counts[c] >= class_limits[c] for c in image_classes): |
|
|
continue |
|
|
|
|
|
selected_images.append((img_path, split)) |
|
|
for c in image_classes: |
|
|
current_counts[c] += 1 |
|
|
|
|
|
progress(0.6, desc=f"Copying {len(selected_images)} files...") |
|
|
for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"): |
|
|
lbl_path = label_path_for(img_path) |
|
|
out_img = os.path.join(merged_dir, split, 'images', os.path.basename(img_path)) |
|
|
out_lbl = os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path)) |
|
|
shutil.copy(img_path, out_img) |
|
|
|
|
|
|
|
|
source_loc = None |
|
|
for info in dataset_info: |
|
|
if img_path.startswith(info[0]): |
|
|
source_loc = info[0] |
|
|
break |
|
|
source_names = loc_to_names.get(source_loc, []) |
|
|
|
|
|
with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out: |
|
|
for line in f_in: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
old_id = int(parts[0]) |
|
|
original_name = source_names[old_id] |
|
|
mapped_name = class_mapping.get(original_name, original_name) |
|
|
if mapped_name in final_class_map: |
|
|
new_id = final_class_map[mapped_name] |
|
|
f_out.write(f"{new_id} {' '.join(parts[1:])}\n") |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
progress(0.95, desc="Creating data.yaml...") |
|
|
with open(os.path.join(merged_dir, 'data.yaml'), 'w') as f: |
|
|
yaml.dump({ |
|
|
'path': os.path.abspath(merged_dir), |
|
|
'train': 'train/images', |
|
|
'val': 'valid/images', |
|
|
'test': 'test/images', |
|
|
'nc': len(active_classes), |
|
|
'names': active_classes |
|
|
}, f) |
|
|
|
|
|
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_datasets_handler(api_key, url_file, progress=gr.Progress()): |
|
|
"""Handles the 'Load Datasets' button click.""" |
|
|
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "") |
|
|
if not api_key: |
|
|
raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).") |
|
|
if not url_file: |
|
|
raise gr.Error("Please upload a .txt file with Roboflow URLs or lines like 'workspace/project[/vN]'.") |
|
|
|
|
|
with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f: |
|
|
urls = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
dataset_info = [] |
|
|
failures = [] |
|
|
|
|
|
for i, raw in enumerate(urls): |
|
|
progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}") |
|
|
ws, proj, ver = parse_roboflow_url(raw) |
|
|
if not (ws and proj): |
|
|
failures.append((raw, "ParseError: could not resolve workspace/project")) |
|
|
continue |
|
|
|
|
|
if ver is None: |
|
|
ver = get_latest_version(api_key, ws, proj) |
|
|
if ver is None: |
|
|
failures.append((raw, f"Could not resolve latest version for {ws}/{proj}")) |
|
|
continue |
|
|
|
|
|
loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver)) |
|
|
if loc: |
|
|
dataset_info.append((loc, names, splits, name_str)) |
|
|
else: |
|
|
failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}")) |
|
|
|
|
|
if not dataset_info: |
|
|
|
|
|
msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]]) |
|
|
raise gr.Error(msg) |
|
|
|
|
|
|
|
|
all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names}) |
|
|
class_map = {name: name for name in all_names} |
|
|
|
|
|
|
|
|
initial_counts = gather_class_counts(dataset_info, class_map) |
|
|
df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names] |
|
|
status_text = "Datasets loaded successfully." |
|
|
if failures: |
|
|
status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)." |
|
|
|
|
|
return status_text, dataset_info, gr.DataFrame.update( |
|
|
value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"]) |
|
|
) |
|
|
|
|
|
|
|
|
def update_class_counts_handler(class_df, dataset_info): |
|
|
""" |
|
|
Provides live feedback on class counts as the user edits the DataFrame. |
|
|
We compute a mapping of original -> final (or None if removed), then count images |
|
|
for each final name. |
|
|
""" |
|
|
if class_df is None or not dataset_info: |
|
|
return None |
|
|
|
|
|
|
|
|
class_df = pd.DataFrame(class_df) |
|
|
mapping = {} |
|
|
for _, row in class_df.iterrows(): |
|
|
orig = row["Original Name"] |
|
|
if bool(row["Remove"]): |
|
|
mapping[orig] = None |
|
|
else: |
|
|
mapping[orig] = row["Rename To"] |
|
|
|
|
|
|
|
|
final_names = sorted(set(v for v in mapping.values() if v)) |
|
|
counts = {k: 0 for k in final_names} |
|
|
|
|
|
for loc, names, splits, _ in dataset_info: |
|
|
id_to_final = {} |
|
|
for idx, n in enumerate(names): |
|
|
id_to_final[idx] = mapping.get(n, None) |
|
|
|
|
|
for split in splits: |
|
|
labels_dir = os.path.join(loc, split, 'labels') |
|
|
if not os.path.exists(labels_dir): |
|
|
continue |
|
|
for label_file in os.listdir(labels_dir): |
|
|
if not label_file.endswith('.txt'): |
|
|
continue |
|
|
found = set() |
|
|
with open(os.path.join(labels_dir, label_file), 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
cls_id = int(parts[0]) |
|
|
mapped = id_to_final.get(cls_id, None) |
|
|
if mapped: |
|
|
found.add(mapped) |
|
|
except Exception: |
|
|
continue |
|
|
for m in found: |
|
|
counts[m] += 1 |
|
|
|
|
|
summary_df = pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"]) |
|
|
return summary_df |
|
|
|
|
|
|
|
|
def finalize_handler(dataset_info, class_df, progress=gr.Progress()): |
|
|
"""Handles the 'Finalize' button click.""" |
|
|
if not dataset_info: |
|
|
raise gr.Error("Load datasets first in Tab 1.") |
|
|
if class_df is None: |
|
|
raise gr.Error("Class data is missing.") |
|
|
|
|
|
|
|
|
class_df = pd.DataFrame(class_df) |
|
|
class_mapping = {} |
|
|
class_limits = {} |
|
|
for _, row in class_df.iterrows(): |
|
|
orig = row["Original Name"] |
|
|
if bool(row["Remove"]): |
|
|
continue |
|
|
final_name = row["Rename To"] |
|
|
class_mapping[orig] = final_name |
|
|
|
|
|
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"]) |
|
|
|
|
|
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress) |
|
|
return status, path |
|
|
|
|
|
|
|
|
def training_handler(dataset_path, model_filename, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()): |
|
|
"""Handles the training process with live feedback.""" |
|
|
if not dataset_path: |
|
|
raise gr.Error("Finalize a dataset in Tab 2 before training.") |
|
|
|
|
|
|
|
|
device_str = "0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
metrics_queue = Queue() |
|
|
|
|
|
def on_epoch_end(trainer): |
|
|
|
|
|
m = trainer.metrics or {} |
|
|
metrics_queue.put({ |
|
|
'epoch': (trainer.epoch or 0) + 1, |
|
|
'train_loss': m.get('train/loss') or m.get('loss'), |
|
|
'val_loss': m.get('val/loss'), |
|
|
'mAP50': m.get('metrics/mAP50(B)') or m.get('metrics/mAP50'), |
|
|
'mAP50_95': m.get('metrics/mAP50-95(B)') or m.get('metrics/mAP50-95') |
|
|
}) |
|
|
|
|
|
def train_thread_func(): |
|
|
try: |
|
|
model_url = next(m['url'] for m in RTDETR_MODELS['detection'] if m['filename'] == model_filename) |
|
|
weights_path = os.path.join('pretrained_models', model_filename) |
|
|
if not os.path.exists(weights_path): |
|
|
os.makedirs('pretrained_models', exist_ok=True) |
|
|
r = requests.get(model_url, stream=True, timeout=60) |
|
|
r.raise_for_status() |
|
|
with open(weights_path, 'wb') as f: |
|
|
for chunk in r.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
|
|
|
model = YOLO(weights_path) |
|
|
model.add_callback("on_train_epoch_end", on_epoch_end) |
|
|
|
|
|
model.train( |
|
|
data=os.path.join(dataset_path, 'data.yaml'), |
|
|
epochs=int(epochs), |
|
|
batch=int(batch), |
|
|
imgsz=int(imgsz), |
|
|
lr0=float(lr), |
|
|
optimizer=str(opt), |
|
|
project='runs/train', |
|
|
name=str(run_name), |
|
|
exist_ok=True, |
|
|
device=device_str |
|
|
) |
|
|
metrics_queue.put("done") |
|
|
except Exception as e: |
|
|
logging.exception("Training thread error") |
|
|
metrics_queue.put(f"error: {e}") |
|
|
|
|
|
Thread(target=train_thread_func, daemon=True).start() |
|
|
|
|
|
history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']} |
|
|
while True: |
|
|
item = metrics_queue.get() |
|
|
if isinstance(item, str): |
|
|
if item == "done": |
|
|
break |
|
|
if item.startswith("error"): |
|
|
raise gr.Error(f"Training failed: {item}") |
|
|
|
|
|
|
|
|
for key in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']: |
|
|
val = item.get(key, None) |
|
|
if val is not None: |
|
|
history[key].append(val) |
|
|
|
|
|
current_epoch = history['epoch'][-1] if history['epoch'] else 0 |
|
|
total_epochs = int(epochs) |
|
|
frac = min(max(current_epoch / max(1, total_epochs), 0.0), 1.0) |
|
|
progress(frac, desc=f"Epoch {current_epoch}/{total_epochs}") |
|
|
|
|
|
|
|
|
fig_loss = plt.figure() |
|
|
ax_loss = fig_loss.add_subplot(111) |
|
|
ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss') |
|
|
ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss') |
|
|
ax_loss.legend() |
|
|
ax_loss.set_title("Loss") |
|
|
|
|
|
|
|
|
fig_map = plt.figure() |
|
|
ax_map = fig_map.add_subplot(111) |
|
|
ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5') |
|
|
ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95') |
|
|
ax_map.legend() |
|
|
ax_map.set_title("mAP") |
|
|
|
|
|
yield f"Epoch {current_epoch}/{total_epochs} complete.", fig_loss, fig_map, None |
|
|
|
|
|
final_path = os.path.join('runs', 'train', str(run_name), 'weights', 'best.pt') |
|
|
if not os.path.exists(final_path): |
|
|
raise gr.Error("Training finished, but 'best.pt' was not found.") |
|
|
|
|
|
yield "Training complete!", None, None, gr.File.update(value=final_path, visible=True) |
|
|
|
|
|
|
|
|
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()): |
|
|
"""Handles model upload to Hugging Face and GitHub.""" |
|
|
if not model_file: |
|
|
raise gr.Error("No trained model file available to upload. Train a model first.") |
|
|
|
|
|
hf_status = "Skipped Hugging Face (credentials not provided)." |
|
|
if hf_token and hf_repo: |
|
|
progress(0, desc="Uploading to Hugging Face...") |
|
|
try: |
|
|
api = HfApi() |
|
|
HfFolder.save_token(hf_token) |
|
|
repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token) |
|
|
api.upload_file( |
|
|
path_or_fileobj=model_file.name, |
|
|
path_in_repo=os.path.basename(model_file.name), |
|
|
repo_id=hf_repo, |
|
|
token=hf_token |
|
|
) |
|
|
hf_status = f"Success! Model at: {repo_url}" |
|
|
except Exception as e: |
|
|
hf_status = f"Hugging Face Error: {e}" |
|
|
|
|
|
gh_status = "Skipped GitHub (credentials not provided)." |
|
|
if gh_token and gh_repo: |
|
|
progress(0.5, desc="Uploading to GitHub...") |
|
|
try: |
|
|
if '/' not in gh_repo: |
|
|
raise ValueError("GitHub repo must be in the form 'username/repo'.") |
|
|
|
|
|
username, repo_name = gh_repo.split('/') |
|
|
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}" |
|
|
headers = {"Authorization": f"token {gh_token}"} |
|
|
|
|
|
with open(model_file.name, "rb") as f: |
|
|
content = base64.b64encode(f.read()).decode() |
|
|
|
|
|
get_resp = requests.get(api_url, headers=headers, timeout=30) |
|
|
sha = get_resp.json().get('sha') if get_resp.ok else None |
|
|
|
|
|
data = {"message": "Upload trained model from Rolo app", "content": content} |
|
|
if sha: |
|
|
data["sha"] = sha |
|
|
|
|
|
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60) |
|
|
|
|
|
if put_resp.ok: |
|
|
gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}" |
|
|
else: |
|
|
msg = put_resp.json().get('message', 'Unknown') |
|
|
gh_status = f"GitHub Error: {msg}" |
|
|
except Exception as e: |
|
|
gh_status = f"GitHub Error: {e}" |
|
|
|
|
|
progress(1) |
|
|
return hf_status, gh_status |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app: |
|
|
gr.Markdown("# Rolo: A Dedicated RT-DETR Training Dashboard") |
|
|
|
|
|
|
|
|
dataset_info_state = gr.State([]) |
|
|
final_dataset_path_state = gr.State(None) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("1. Prepare Datasets"): |
|
|
gr.Markdown("### Load Roboflow Datasets\nProvide your Roboflow API key and upload a `.txt` file containing one Roboflow dataset URL or `workspace/project[/vN]` per line.") |
|
|
with gr.Row(): |
|
|
rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY env)", type="password", scale=2) |
|
|
rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1) |
|
|
load_btn = gr.Button("Load Datasets", variant="primary") |
|
|
dataset_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.TabItem("2. Manage & Merge"): |
|
|
gr.Markdown("### Configure Classes and Finalize Dataset\nRename classes to merge them, set image limits, or remove them. Click **Update Counts** to preview, then **Finalize** to create the dataset.") |
|
|
with gr.Row(): |
|
|
class_df = gr.DataFrame( |
|
|
headers=["Original Name", "Rename To", "Max Images", "Remove"], |
|
|
datatype=["str", "str", "number", "bool"], |
|
|
label="Class Configuration", interactive=True, scale=3 |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
class_count_summary_df = gr.DataFrame( |
|
|
label="Merged Class Counts Preview", |
|
|
headers=["Final Class Name", "Est. Total Images"], |
|
|
interactive=False |
|
|
) |
|
|
update_counts_btn = gr.Button("Update Counts") |
|
|
|
|
|
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary") |
|
|
finalize_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.TabItem("3. Configure & Train"): |
|
|
gr.Markdown("### Set Hyperparameters and Train the RT-DETR Model") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
model_file_dd = gr.Dropdown( |
|
|
label="Select Pre-Trained RT-DETR Model", |
|
|
choices=[m["filename"] for m in RTDETR_MODELS["detection"]], |
|
|
value=DEFAULT_MODEL |
|
|
) |
|
|
run_name_tb = gr.Textbox(label="Run Name", value="rtdetr_run_1") |
|
|
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs") |
|
|
batch_sl = gr.Slider(1, 32, 8, step=1, label="Batch Size") |
|
|
imgsz_num = gr.Number(label="Image Size", value=640) |
|
|
lr_num = gr.Number(label="Learning Rate", value=0.001) |
|
|
opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="Adam", label="Optimizer") |
|
|
train_btn = gr.Button("Start Training", variant="primary") |
|
|
with gr.Column(scale=2): |
|
|
train_status = gr.Textbox(label="Live Status", interactive=False) |
|
|
loss_plot = gr.Plot(label="Loss Curves") |
|
|
map_plot = gr.Plot(label="mAP Curves") |
|
|
final_model_file = gr.File(label="Download Trained Model (best.pt)", interactive=False, visible=False) |
|
|
|
|
|
with gr.TabItem("4. Upload Model"): |
|
|
gr.Markdown("### Upload Your Trained Model\nAfter training, you can upload the `best.pt` file to Hugging Face and/or GitHub.") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Hugging Face") |
|
|
hf_token = gr.Textbox(label="Hugging Face API Token", type="password") |
|
|
hf_repo = gr.Textbox(label="Hugging Face Repo ID", placeholder="e.g., username/my-rtdetr-model") |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### GitHub") |
|
|
gh_token = gr.Textbox(label="GitHub Personal Access Token", type="password") |
|
|
gh_repo = gr.Textbox(label="GitHub Repo", placeholder="e.g., username/my-rtdetr-repo") |
|
|
upload_btn = gr.Button("Upload Model", variant="primary") |
|
|
with gr.Row(): |
|
|
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False) |
|
|
gh_status = gr.Textbox(label="GitHub Status", interactive=False) |
|
|
|
|
|
|
|
|
load_btn.click( |
|
|
fn=load_datasets_handler, |
|
|
inputs=[rf_api_key, rf_url_file], |
|
|
outputs=[dataset_status, dataset_info_state, class_df] |
|
|
) |
|
|
update_counts_btn.click( |
|
|
fn=update_class_counts_handler, |
|
|
inputs=[class_df, dataset_info_state], |
|
|
outputs=[class_count_summary_df] |
|
|
) |
|
|
finalize_btn.click( |
|
|
fn=finalize_handler, |
|
|
inputs=[dataset_info_state, class_df], |
|
|
outputs=[finalize_status, final_dataset_path_state] |
|
|
) |
|
|
train_btn.click( |
|
|
fn=training_handler, |
|
|
inputs=[final_dataset_path_state, model_file_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd], |
|
|
outputs=[train_status, loss_plot, map_plot, final_model_file] |
|
|
) |
|
|
upload_btn.click( |
|
|
fn=upload_handler, |
|
|
inputs=[final_model_file, hf_token, hf_repo, gh_token, gh_repo], |
|
|
outputs=[hf_status, gh_status] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
app.launch(debug=True) |
|
|
|