|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
import subprocess |
|
|
import shutil |
|
|
import stat |
|
|
import yaml |
|
|
import gradio as gr |
|
|
from roboflow import Roboflow |
|
|
import re |
|
|
from urllib.parse import urlparse |
|
|
import random |
|
|
import logging |
|
|
import requests |
|
|
import json |
|
|
from PIL import Image |
|
|
import torch |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
from threading import Thread |
|
|
from queue import Queue |
|
|
from glob import glob |
|
|
import time |
|
|
import base64 |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2" |
|
|
REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2") |
|
|
PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") |
|
|
WEIGHTS_DIR = os.path.join(PY_IMPL_DIR, "weights") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
COMMON_REQUIREMENTS = [ |
|
|
"gradio>=4.36.1", |
|
|
"roboflow>=1.1.28", |
|
|
"pandas>=2.0.0", |
|
|
"matplotlib>=3.7.0", |
|
|
"pyyaml>=6.0.1", |
|
|
"Pillow>=10.0.0", |
|
|
"requests>=2.31.0", |
|
|
"huggingface_hub>=0.22.0", |
|
|
] |
|
|
|
|
|
def pip_install(args): |
|
|
logging.info(f"pip install {' '.join(args)}") |
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install"] + args) |
|
|
|
|
|
def ensure_repo_and_requirements(): |
|
|
os.makedirs(os.path.dirname(REPO_DIR), exist_ok=True) |
|
|
if not os.path.exists(REPO_DIR): |
|
|
logging.info(f"Cloning RT-DETRv2 repo to {REPO_DIR} ...") |
|
|
subprocess.check_call(["git", "clone", "--depth", "1", REPO_URL, REPO_DIR]) |
|
|
else: |
|
|
logging.info("RT-DETRv2 repo already present, pulling latest...") |
|
|
try: |
|
|
subprocess.check_call(["git", "-C", REPO_DIR, "pull", "--ff-only"]) |
|
|
except Exception: |
|
|
logging.warning("Could not pull latest; continuing with current checkout.") |
|
|
|
|
|
|
|
|
pip_install(COMMON_REQUIREMENTS) |
|
|
|
|
|
|
|
|
req_file = os.path.join(PY_IMPL_DIR, "requirements.txt") |
|
|
if os.path.exists(req_file): |
|
|
pip_install(["-r", req_file]) |
|
|
else: |
|
|
logging.info("No rtdetrv2_pytorch/requirements.txt found; relying on common reqs.") |
|
|
|
|
|
|
|
|
try: |
|
|
ensure_repo_and_requirements() |
|
|
except Exception as e: |
|
|
logging.exception("Bootstrap failed") |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CHOICES = [ |
|
|
("rtdetrv2_s", "Small (default)"), |
|
|
("rtdetrv2_l", "Large"), |
|
|
("rtdetrv2_x", "X-Large") |
|
|
] |
|
|
DEFAULT_MODEL_KEY = "rtdetrv2_s" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_remove_readonly(func, path, exc_info): |
|
|
try: |
|
|
os.chmod(path, stat.S_IWRITE) |
|
|
except Exception: |
|
|
pass |
|
|
func(path) |
|
|
|
|
|
_ROBO_URL_RX = re.compile( |
|
|
r""" |
|
|
^(?: |
|
|
(?:https?://)?(?:universe|app|www)?\.?roboflow\.com/ |
|
|
(?P<ws>[A-Za-z0-9\-_]+)/ |
|
|
(?P<proj>[A-Za-z0-9\-_]+)/? |
|
|
(?:(?:dataset/[^/]+/)?(?:v?(?P<ver>\d+))?)? |
|
|
| |
|
|
(?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))? |
|
|
)$ |
|
|
""", re.VERBOSE | re.IGNORECASE |
|
|
) |
|
|
|
|
|
def parse_roboflow_url(s: str): |
|
|
s = s.strip() |
|
|
m = _ROBO_URL_RX.match(s) |
|
|
if m: |
|
|
ws = m.group('ws') or m.group('ws2') |
|
|
proj = m.group('proj') or m.group('proj2') |
|
|
ver = m.group('ver') or m.group('ver2') |
|
|
return ws, proj, (int(ver) if ver else None) |
|
|
|
|
|
parsed = urlparse(s) |
|
|
parts = [p for p in parsed.path.strip('/').split('/') if p] |
|
|
if len(parts) >= 2: |
|
|
version = None |
|
|
if len(parts) >= 3: |
|
|
vpart = parts[2] |
|
|
if vpart.lower().startswith('v') and vpart[1:].isdigit(): |
|
|
version = int(vpart[1:]) |
|
|
elif vpart.isdigit(): |
|
|
version = int(vpart) |
|
|
return parts[0], parts[1], version |
|
|
|
|
|
if '/' in s and 'roboflow' not in s: |
|
|
p = s.split('/') |
|
|
if len(p) >= 2: |
|
|
version = None |
|
|
if len(p) >= 3: |
|
|
v = p[2] |
|
|
if v.lower().startswith('v') and v[1:].isdigit(): |
|
|
version = int(v[1:]) |
|
|
elif v.isdigit(): |
|
|
version = int(v) |
|
|
return p[0], p[1], version |
|
|
|
|
|
return None, None, None |
|
|
|
|
|
def get_latest_version(api_key, workspace, project): |
|
|
try: |
|
|
rf = Roboflow(api_key=api_key) |
|
|
proj = rf.workspace(workspace).project(project) |
|
|
versions = sorted([int(v.version) for v in proj.versions()], reverse=True) |
|
|
return versions[0] if versions else None |
|
|
except Exception as e: |
|
|
logging.error(f"Could not get latest version for {workspace}/{project}: {e}") |
|
|
return None |
|
|
|
|
|
def _extract_class_names(data_yaml): |
|
|
names = data_yaml.get('names', None) |
|
|
if isinstance(names, dict): |
|
|
def _k(x): |
|
|
try: |
|
|
return int(x) |
|
|
except Exception: |
|
|
return str(x) |
|
|
ordered_keys = sorted(names.keys(), key=_k) |
|
|
names_list = [names[k] for k in ordered_keys] |
|
|
elif isinstance(names, list): |
|
|
names_list = names |
|
|
else: |
|
|
nc = data_yaml.get('nc', 0) |
|
|
try: |
|
|
nc = int(nc) |
|
|
except Exception: |
|
|
nc = 0 |
|
|
names_list = [f"class_{i}" for i in range(nc)] |
|
|
return [str(x) for x in names_list] |
|
|
|
|
|
def download_dataset(api_key, workspace, project, version): |
|
|
"""Download a Roboflow dataset in YOLOv8 format (labels are compatible with our merger).""" |
|
|
try: |
|
|
rf = Roboflow(api_key=api_key) |
|
|
proj = rf.workspace(workspace).project(project) |
|
|
ver = proj.version(int(version)) |
|
|
dataset = ver.download("yolov8") |
|
|
|
|
|
data_yaml_path = os.path.join(dataset.location, 'data.yaml') |
|
|
with open(data_yaml_path, 'r') as f: |
|
|
data_yaml = yaml.safe_load(f) |
|
|
|
|
|
class_names = _extract_class_names(data_yaml) |
|
|
try: |
|
|
nc = int(data_yaml.get('nc', len(class_names))) |
|
|
except Exception: |
|
|
nc = len(class_names) |
|
|
if len(class_names) != nc: |
|
|
logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.") |
|
|
|
|
|
splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))] |
|
|
return dataset.location, class_names, splits, f"{project}-v{version}" |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}") |
|
|
return None, [], [], None |
|
|
|
|
|
def label_path_for(img_path: str) -> str: |
|
|
split_dir = os.path.dirname(os.path.dirname(img_path)) |
|
|
base = os.path.splitext(os.path.basename(img_path))[0] + '.txt' |
|
|
return os.path.join(split_dir, 'labels', base) |
|
|
|
|
|
def gather_class_counts(dataset_info, class_mapping): |
|
|
if not dataset_info: |
|
|
return {} |
|
|
final_names = set(v for v in class_mapping.values() if v is not None) |
|
|
counts = {name: 0 for name in final_names} |
|
|
|
|
|
for loc, names, splits, _ in dataset_info: |
|
|
id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)} |
|
|
for split in splits: |
|
|
labels_dir = os.path.join(loc, split, 'labels') |
|
|
if not os.path.exists(labels_dir): |
|
|
continue |
|
|
for label_file in os.listdir(labels_dir): |
|
|
if not label_file.endswith('.txt'): |
|
|
continue |
|
|
found = set() |
|
|
with open(os.path.join(labels_dir, label_file), 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
cls_id = int(parts[0]) |
|
|
mapped = id_to_name.get(cls_id, None) |
|
|
if mapped: |
|
|
found.add(mapped) |
|
|
except Exception: |
|
|
continue |
|
|
for m in found: |
|
|
counts[m] += 1 |
|
|
return counts |
|
|
|
|
|
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()): |
|
|
merged_dir = 'rolo_merged_dataset' |
|
|
if os.path.exists(merged_dir): |
|
|
shutil.rmtree(merged_dir, onerror=handle_remove_readonly) |
|
|
|
|
|
progress(0, desc="Creating directories...") |
|
|
for split in ['train', 'valid', 'test']: |
|
|
os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True) |
|
|
os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True) |
|
|
|
|
|
active_classes = sorted({cls for cls, limit in class_limits.items() if limit > 0}) |
|
|
final_class_map = {name: i for i, name in enumerate(active_classes)} |
|
|
|
|
|
all_images = [] |
|
|
for loc, _, splits, _ in dataset_info: |
|
|
for split in splits: |
|
|
img_dir = os.path.join(loc, split, 'images') |
|
|
if not os.path.exists(img_dir): |
|
|
continue |
|
|
for img_file in os.listdir(img_dir): |
|
|
if img_file.lower().endswith(('.jpg', '.jpeg', '.png')): |
|
|
all_images.append((os.path.join(img_dir, img_file), split, loc)) |
|
|
random.shuffle(all_images) |
|
|
|
|
|
progress(0.2, desc="Selecting images based on limits...") |
|
|
selected_images = [] |
|
|
current_counts = {cls: 0 for cls in active_classes} |
|
|
loc_to_names = {info[0]: info[1] for info in dataset_info} |
|
|
|
|
|
|
|
|
for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"): |
|
|
lbl_path = label_path_for(img_path) |
|
|
if not os.path.exists(lbl_path): |
|
|
continue |
|
|
|
|
|
source_names = loc_to_names.get(source_loc, []) |
|
|
image_classes = set() |
|
|
with open(lbl_path, 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
cls_id = int(parts[0]) |
|
|
orig = source_names[cls_id] |
|
|
mapped = class_mapping.get(orig, orig) |
|
|
if mapped in active_classes: |
|
|
image_classes.add(mapped) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if not image_classes: |
|
|
continue |
|
|
|
|
|
if any(current_counts[c] >= class_limits[c] for c in image_classes): |
|
|
continue |
|
|
|
|
|
selected_images.append((img_path, split)) |
|
|
for c in image_classes: |
|
|
current_counts[c] += 1 |
|
|
|
|
|
progress(0.6, desc=f"Copying {len(selected_images)} files...") |
|
|
for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"): |
|
|
lbl_path = label_path_for(img_path) |
|
|
out_img = os.path.join(merged_dir, split, 'images', os.path.basename(img_path)) |
|
|
out_lbl = os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path)) |
|
|
shutil.copy(img_path, out_img) |
|
|
|
|
|
source_loc = None |
|
|
for info in dataset_info: |
|
|
if img_path.startswith(info[0]): |
|
|
source_loc = info[0] |
|
|
break |
|
|
source_names = loc_to_names.get(source_loc, []) |
|
|
|
|
|
with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out: |
|
|
for line in f_in: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
old_id = int(parts[0]) |
|
|
original_name = source_names[old_id] |
|
|
mapped_name = class_mapping.get(original_name, original_name) |
|
|
if mapped_name in final_class_map: |
|
|
new_id = final_class_map[mapped_name] |
|
|
f_out.write(f"{new_id} {' '.join(parts[1:])}\n") |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
progress(0.95, desc="Creating data.yaml...") |
|
|
with open(os.path.join(merged_dir, 'data.yaml'), 'w') as f: |
|
|
yaml.dump({ |
|
|
'path': os.path.abspath(merged_dir), |
|
|
'train': 'train/images', |
|
|
'val': 'valid/images', |
|
|
'test': 'test/images', |
|
|
'nc': len(active_classes), |
|
|
'names': active_classes |
|
|
}, f) |
|
|
|
|
|
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_training_entrypoint(): |
|
|
""" |
|
|
We try a couple of common patterns inside the Supervisely repo: |
|
|
1) rtdetrv2_pytorch/train.py |
|
|
2) tools/train.py |
|
|
Returns (python_file, style) where style hints how to build args. |
|
|
""" |
|
|
cand1 = os.path.join(PY_IMPL_DIR, "train.py") |
|
|
cand2 = os.path.join(REPO_DIR, "tools", "train.py") |
|
|
if os.path.exists(cand1): |
|
|
return cand1, "pytorch_train" |
|
|
if os.path.exists(cand2): |
|
|
return cand2, "tools_train" |
|
|
|
|
|
cand3 = os.path.join(REPO_DIR, "src", "main.py") |
|
|
if os.path.exists(cand3): |
|
|
return cand3, "app_main" |
|
|
return None, None |
|
|
|
|
|
def build_command(entrypoint, style, dataset_path, model_key, run_name, epochs, batch, imgsz, lr, optimizer): |
|
|
""" |
|
|
Build a best-guess command for the detected style. |
|
|
Users never have to edit CLI; we do it for them. |
|
|
We keep args conservative and standard (data, epochs, batch, img size). |
|
|
""" |
|
|
data_yaml = os.path.join(dataset_path, "data.yaml") |
|
|
out_dir = os.path.join("runs", "train", str(run_name)) |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
if style == "pytorch_train": |
|
|
|
|
|
cmd = [ |
|
|
sys.executable, entrypoint, |
|
|
"--data", data_yaml, |
|
|
"--model", model_key, |
|
|
"--epochs", str(int(epochs)), |
|
|
"--batch", str(int(batch)), |
|
|
"--imgsz", str(int(imgsz)), |
|
|
"--project", os.path.abspath(out_dir) |
|
|
] |
|
|
if lr is not None: |
|
|
cmd += ["--lr", str(float(lr))] |
|
|
if optimizer: |
|
|
cmd += ["--optimizer", str(optimizer)] |
|
|
return cmd, out_dir |
|
|
|
|
|
if style == "tools_train": |
|
|
|
|
|
cmd = [ |
|
|
sys.executable, entrypoint, |
|
|
"--data", data_yaml, |
|
|
"--model", model_key, |
|
|
"--epochs", str(int(epochs)), |
|
|
"--batch-size", str(int(batch)), |
|
|
"--imgsz", str(int(imgsz)), |
|
|
"--project", os.path.abspath(out_dir), |
|
|
"--name", "exp" |
|
|
] |
|
|
if lr is not None: |
|
|
cmd += ["--lr0", str(float(lr))] |
|
|
if optimizer: |
|
|
cmd += ["--optimizer", str(optimizer)] |
|
|
return cmd, out_dir |
|
|
|
|
|
if style == "app_main": |
|
|
|
|
|
cmd = [ |
|
|
sys.executable, entrypoint, |
|
|
"--data", data_yaml, |
|
|
"--model", model_key, |
|
|
"--epochs", str(int(epochs)), |
|
|
"--batch", str(int(batch)), |
|
|
"--imgsz", str(int(imgsz)), |
|
|
"--output", os.path.abspath(out_dir) |
|
|
] |
|
|
if lr is not None: |
|
|
cmd += ["--lr", str(float(lr))] |
|
|
if optimizer: |
|
|
cmd += ["--optimizer", str(optimizer)] |
|
|
return cmd, out_dir |
|
|
|
|
|
raise gr.Error("Could not locate a training script inside RT-DETRv2 repo. Please check the repo layout.") |
|
|
|
|
|
def find_best_checkpoint(out_dir): |
|
|
|
|
|
patterns = [ |
|
|
os.path.join(out_dir, "**", "best*.pt"), |
|
|
os.path.join(out_dir, "**", "best*.pth"), |
|
|
os.path.join(out_dir, "**", "model_best*.pt"), |
|
|
os.path.join(out_dir, "**", "model_best*.pth"), |
|
|
] |
|
|
for p in patterns: |
|
|
files = sorted(glob(p, recursive=True)) |
|
|
if files: |
|
|
return files[0] |
|
|
|
|
|
any_ckpt = sorted(glob(os.path.join(out_dir, "**", "*.pt"), recursive=True) + |
|
|
glob(os.path.join(out_dir, "**", "*.pth"), recursive=True)) |
|
|
return any_ckpt[-1] if any_ckpt else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_datasets_handler(api_key, url_file, progress=gr.Progress()): |
|
|
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "") |
|
|
if not api_key: |
|
|
raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).") |
|
|
if not url_file: |
|
|
raise gr.Error("Please upload a .txt file with Roboflow URLs or lines like 'workspace/project[/vN]'.") |
|
|
|
|
|
with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f: |
|
|
urls = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
dataset_info, failures = [], [] |
|
|
for i, raw in enumerate(urls): |
|
|
progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}") |
|
|
ws, proj, ver = parse_roboflow_url(raw) |
|
|
if not (ws and proj): |
|
|
failures.append((raw, "ParseError: could not resolve workspace/project")) |
|
|
continue |
|
|
if ver is None: |
|
|
ver = get_latest_version(api_key, ws, proj) |
|
|
if ver is None: |
|
|
failures.append((raw, f"Could not resolve latest version for {ws}/{proj}")) |
|
|
continue |
|
|
|
|
|
loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver)) |
|
|
if loc: |
|
|
dataset_info.append((loc, names, splits, name_str)) |
|
|
else: |
|
|
failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}")) |
|
|
|
|
|
if not dataset_info: |
|
|
msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]]) |
|
|
raise gr.Error(msg) |
|
|
|
|
|
|
|
|
all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names}) |
|
|
class_map = {name: name for name in all_names} |
|
|
|
|
|
initial_counts = gather_class_counts(dataset_info, class_map) |
|
|
df = pd.DataFrame([[name, name, initial_counts.get(name, 0), False] for name in all_names], |
|
|
columns=["Original Name", "Rename To", "Max Images", "Remove"]) |
|
|
status_text = "Datasets loaded successfully." |
|
|
if failures: |
|
|
status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)." |
|
|
|
|
|
|
|
|
return status_text, dataset_info, df |
|
|
|
|
|
def update_class_counts_handler(class_df, dataset_info): |
|
|
if class_df is None or not dataset_info: |
|
|
return None |
|
|
|
|
|
class_df = pd.DataFrame(class_df) |
|
|
mapping = {} |
|
|
for _, row in class_df.iterrows(): |
|
|
orig = row["Original Name"] |
|
|
mapping[orig] = None if bool(row["Remove"]) else row["Rename To"] |
|
|
|
|
|
final_names = sorted(set(v for v in mapping.values() if v)) |
|
|
counts = {k: 0 for k in final_names} |
|
|
|
|
|
for loc, names, splits, _ in dataset_info: |
|
|
id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)} |
|
|
for split in splits: |
|
|
labels_dir = os.path.join(loc, split, 'labels') |
|
|
if not os.path.exists(labels_dir): |
|
|
continue |
|
|
for label_file in os.listdir(labels_dir): |
|
|
if not label_file.endswith('.txt'): |
|
|
continue |
|
|
found = set() |
|
|
with open(os.path.join(labels_dir, label_file), 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
cls_id = int(parts[0]) |
|
|
mapped = id_to_final.get(cls_id, None) |
|
|
if mapped: |
|
|
found.add(mapped) |
|
|
except Exception: |
|
|
continue |
|
|
for m in found: |
|
|
counts[m] += 1 |
|
|
|
|
|
return pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"]) |
|
|
|
|
|
def finalize_handler(dataset_info, class_df, progress=gr.Progress()): |
|
|
if not dataset_info: |
|
|
raise gr.Error("Load datasets first in Tab 1.") |
|
|
if class_df is None: |
|
|
raise gr.Error("Class data is missing.") |
|
|
|
|
|
class_df = pd.DataFrame(class_df) |
|
|
class_mapping, class_limits = {}, {} |
|
|
for _, row in class_df.iterrows(): |
|
|
orig = row["Original Name"] |
|
|
if bool(row["Remove"]): |
|
|
continue |
|
|
final_name = row["Rename To"] |
|
|
class_mapping[orig] = final_name |
|
|
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"]) |
|
|
|
|
|
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress) |
|
|
return status, path |
|
|
|
|
|
def training_handler(dataset_path, model_choice_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()): |
|
|
if not dataset_path: |
|
|
raise gr.Error("Finalize a dataset in Tab 2 before training.") |
|
|
|
|
|
|
|
|
entrypoint, style = detect_training_entrypoint() |
|
|
if not entrypoint: |
|
|
raise gr.Error("RT-DETRv2 training script not found in the repo. Please check repo contents.") |
|
|
|
|
|
|
|
|
cmd, out_dir = build_command( |
|
|
entrypoint=entrypoint, |
|
|
style=style, |
|
|
dataset_path=dataset_path, |
|
|
model_key=model_choice_key, |
|
|
run_name=run_name, |
|
|
epochs=epochs, |
|
|
batch=batch, |
|
|
imgsz=imgsz, |
|
|
lr=lr, |
|
|
optimizer=opt |
|
|
) |
|
|
logging.info(f"Training command: {' '.join(cmd)}") |
|
|
|
|
|
|
|
|
q = Queue() |
|
|
|
|
|
def run_train(): |
|
|
try: |
|
|
env = os.environ.copy() |
|
|
env["PYTHONPATH"] = REPO_DIR + os.pathsep + env.get("PYTHONPATH", "") |
|
|
proc = subprocess.Popen(cmd, cwd=REPO_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, text=True, env=env) |
|
|
for line in proc.stdout: |
|
|
q.put(line.rstrip()) |
|
|
proc.wait() |
|
|
q.put(f"__EXITCODE__:{proc.returncode}") |
|
|
except Exception as e: |
|
|
q.put(f"__ERROR__:{e}") |
|
|
|
|
|
Thread(target=run_train, daemon=True).start() |
|
|
|
|
|
log_lines = [] |
|
|
last_epoch = 0 |
|
|
total_epochs = int(epochs) |
|
|
while True: |
|
|
line = q.get() |
|
|
if line.startswith("__EXITCODE__"): |
|
|
code = int(line.split(":", 1)[1]) |
|
|
if code != 0: |
|
|
raise gr.Error(f"Training process exited with code {code}. Check logs above.") |
|
|
break |
|
|
if line.startswith("__ERROR__"): |
|
|
raise gr.Error(f"Training failed: {line.split(':',1)[1]}") |
|
|
|
|
|
log_lines.append(line) |
|
|
|
|
|
m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line) |
|
|
if m: |
|
|
try: |
|
|
last_epoch = int(m.group(1)) |
|
|
total_epochs = max(total_epochs, int(m.group(2))) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
frac = min(max(last_epoch / max(1, total_epochs), 0.0), 1.0) |
|
|
progress(frac, desc=f"Epoch {last_epoch}/{total_epochs}") |
|
|
|
|
|
|
|
|
fig_loss = plt.figure() |
|
|
ax_loss = fig_loss.add_subplot(111) |
|
|
ax_loss.set_title("Loss (see logs)") |
|
|
fig_map = plt.figure() |
|
|
ax_map = fig_map.add_subplot(111) |
|
|
ax_map.set_title("mAP (see logs)") |
|
|
|
|
|
yield "\n".join(log_lines[-30:]), fig_loss, fig_map, None |
|
|
|
|
|
|
|
|
ckpt = find_best_checkpoint(out_dir) |
|
|
if not ckpt or not os.path.exists(ckpt): |
|
|
|
|
|
alt = find_best_checkpoint("runs") |
|
|
if not alt or not os.path.exists(alt): |
|
|
raise gr.Error("Training finished, but checkpoint file was not found. See logs for details.") |
|
|
ckpt = alt |
|
|
|
|
|
yield "Training complete!", None, None, gr.File.update(value=ckpt, visible=True) |
|
|
|
|
|
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()): |
|
|
if not model_file: |
|
|
raise gr.Error("No trained model file available to upload. Train a model first.") |
|
|
|
|
|
from huggingface_hub import HfApi, HfFolder |
|
|
|
|
|
hf_status = "Skipped Hugging Face (credentials not provided)." |
|
|
if hf_token and hf_repo: |
|
|
progress(0, desc="Uploading to Hugging Face...") |
|
|
try: |
|
|
api = HfApi() |
|
|
HfFolder.save_token(hf_token) |
|
|
repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token) |
|
|
api.upload_file( |
|
|
path_or_fileobj=model_file.name, |
|
|
path_in_repo=os.path.basename(model_file.name), |
|
|
repo_id=hf_repo, |
|
|
token=hf_token |
|
|
) |
|
|
hf_status = f"Success! Model at: {repo_url}" |
|
|
except Exception as e: |
|
|
hf_status = f"Hugging Face Error: {e}" |
|
|
|
|
|
gh_status = "Skipped GitHub (credentials not provided)." |
|
|
if gh_token and gh_repo: |
|
|
progress(0.5, desc="Uploading to GitHub...") |
|
|
try: |
|
|
if '/' not in gh_repo: |
|
|
raise ValueError("GitHub repo must be in the form 'username/repo'.") |
|
|
|
|
|
username, repo_name = gh_repo.split('/') |
|
|
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}" |
|
|
headers = {"Authorization": f"token {gh_token}"} |
|
|
|
|
|
with open(model_file.name, "rb") as f: |
|
|
content = base64.b64encode(f.read()).decode() |
|
|
|
|
|
get_resp = requests.get(api_url, headers=headers, timeout=30) |
|
|
sha = get_resp.json().get('sha') if get_resp.ok else None |
|
|
|
|
|
data = {"message": "Upload trained model from Rolo app", "content": content} |
|
|
if sha: |
|
|
data["sha"] = sha |
|
|
|
|
|
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60) |
|
|
|
|
|
if put_resp.ok: |
|
|
gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}" |
|
|
else: |
|
|
msg = put_resp.json().get('message', 'Unknown') |
|
|
gh_status = f"GitHub Error: {msg}" |
|
|
except Exception as e: |
|
|
gh_status = f"GitHub Error: {e}" |
|
|
|
|
|
progress(1) |
|
|
return hf_status, gh_status |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app: |
|
|
gr.Markdown("# Rolo: RT-DETRv2 Training (Supervisely ecosystem only)") |
|
|
|
|
|
dataset_info_state = gr.State([]) |
|
|
final_dataset_path_state = gr.State(None) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("1. Prepare Datasets"): |
|
|
gr.Markdown("### Load Roboflow Datasets\nProvide your Roboflow API key and upload a `.txt` file containing one Roboflow dataset URL or `workspace/project[/vN]` per line.") |
|
|
with gr.Row(): |
|
|
rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY env)", type="password", scale=2) |
|
|
rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1) |
|
|
load_btn = gr.Button("Load Datasets", variant="primary") |
|
|
dataset_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.TabItem("2. Manage & Merge"): |
|
|
gr.Markdown("### Configure Classes and Finalize Dataset\nRename classes to merge them, set image limits, or remove them. Click **Update Counts** to preview, then **Finalize** to create the dataset.") |
|
|
with gr.Row(): |
|
|
class_df = gr.DataFrame( |
|
|
headers=["Original Name", "Rename To", "Max Images", "Remove"], |
|
|
datatype=["str", "str", "number", "bool"], |
|
|
label="Class Configuration", interactive=True, scale=3 |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
class_count_summary_df = gr.DataFrame( |
|
|
label="Merged Class Counts Preview", |
|
|
headers=["Final Class Name", "Est. Total Images"], |
|
|
interactive=False |
|
|
) |
|
|
update_counts_btn = gr.Button("Update Counts") |
|
|
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary") |
|
|
finalize_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.TabItem("3. Configure & Train"): |
|
|
gr.Markdown("### Set Hyperparameters and Train the RT-DETRv2 Model") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
model_file_dd = gr.Dropdown( |
|
|
label="Model (only RT-DETRv2 from Supervisely)", |
|
|
choices=[k for k, _ in MODEL_CHOICES], |
|
|
value=DEFAULT_MODEL_KEY |
|
|
) |
|
|
model_hints = gr.Markdown( |
|
|
"Choices: " + |
|
|
", ".join([f"`{k}` ({label})" for k, label in MODEL_CHOICES]) |
|
|
) |
|
|
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1") |
|
|
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs") |
|
|
batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size") |
|
|
imgsz_num = gr.Number(label="Image Size", value=640) |
|
|
lr_num = gr.Number(label="Learning Rate", value=0.001) |
|
|
opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="Adam", label="Optimizer") |
|
|
train_btn = gr.Button("Start Training", variant="primary") |
|
|
with gr.Column(scale=2): |
|
|
train_status = gr.Textbox(label="Live Logs (tail)", interactive=False, lines=12) |
|
|
loss_plot = gr.Plot(label="Loss") |
|
|
map_plot = gr.Plot(label="mAP") |
|
|
final_model_file = gr.File(label="Download Trained Model", interactive=False, visible=False) |
|
|
|
|
|
with gr.TabItem("4. Upload Model"): |
|
|
gr.Markdown("### Upload Your Trained Model") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Hugging Face") |
|
|
hf_token = gr.Textbox(label="Hugging Face API Token", type="password") |
|
|
hf_repo = gr.Textbox(label="Hugging Face Repo ID", placeholder="e.g., username/my-rtdetrv2-model") |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### GitHub") |
|
|
gh_token = gr.Textbox(label="GitHub Personal Access Token", type="password") |
|
|
gh_repo = gr.Textbox(label="GitHub Repo", placeholder="e.g., username/my-rtdetrv2-repo") |
|
|
upload_btn = gr.Button("Upload Model", variant="primary") |
|
|
with gr.Row(): |
|
|
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False) |
|
|
gh_status = gr.Textbox(label="GitHub Status", interactive=False) |
|
|
|
|
|
|
|
|
load_btn.click( |
|
|
fn=load_datasets_handler, |
|
|
inputs=[rf_api_key, rf_url_file], |
|
|
outputs=[dataset_status, dataset_info_state, class_df] |
|
|
) |
|
|
update_counts_btn.click( |
|
|
fn=update_class_counts_handler, |
|
|
inputs=[class_df, dataset_info_state], |
|
|
outputs=[class_count_summary_df] |
|
|
) |
|
|
finalize_btn.click( |
|
|
fn=finalize_handler, |
|
|
inputs=[dataset_info_state, class_df], |
|
|
outputs=[finalize_status, final_dataset_path_state] |
|
|
) |
|
|
train_btn.click( |
|
|
fn=training_handler, |
|
|
inputs=[final_dataset_path_state, model_file_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd], |
|
|
outputs=[train_status, loss_plot, map_plot, final_model_file] |
|
|
) |
|
|
upload_btn.click( |
|
|
fn=upload_handler, |
|
|
inputs=[final_model_file, hf_token, hf_repo, gh_token, gh_repo], |
|
|
outputs=[hf_status, gh_status] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") |
|
|
app.launch(debug=True) |
|
|
|