|
|
import os |
|
|
import shutil |
|
|
import stat |
|
|
import yaml |
|
|
import gradio as gr |
|
|
from roboflow import Roboflow |
|
|
import re |
|
|
from urllib.parse import urlparse |
|
|
import random |
|
|
import logging |
|
|
import requests |
|
|
import json |
|
|
from PIL import Image |
|
|
import pandas as pd |
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
import matplotlib.pyplot as plt |
|
|
from threading import Thread |
|
|
from queue import Queue |
|
|
from huggingface_hub import HfApi, HfFolder |
|
|
import base64 |
|
|
import subprocess |
|
|
import sys |
|
|
import time |
|
|
import glob |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
RTDETRV2_REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2" |
|
|
DEFAULT_REPO_DIR = os.path.join("third_party", "rtdetrv2") |
|
|
|
|
|
RTDETRV2_MODELS = [ |
|
|
"rtdetrv2-l-640", |
|
|
"rtdetrv2-x-640" |
|
|
] |
|
|
DEFAULT_MODEL = RTDETRV2_MODELS[0] |
|
|
|
|
|
|
|
|
def handle_remove_readonly(func, path, exc_info): |
|
|
try: |
|
|
os.chmod(path, stat.S_IWRITE) |
|
|
except Exception: |
|
|
pass |
|
|
func(path) |
|
|
|
|
|
_ROBO_URL_RX = re.compile( |
|
|
r""" |
|
|
^(?: |
|
|
(?:https?://)?(?:universe|app|www)?\.?roboflow\.com/ |
|
|
(?P<ws>[A-Za-z0-9\-_]+)/ |
|
|
(?P<proj>[A-Za-z0-9\-_]+)/? |
|
|
(?: |
|
|
(?:dataset/[^/]+/)? |
|
|
(?:v?(?P<ver>\d+))? |
|
|
)? |
|
|
| |
|
|
(?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))? |
|
|
)$ |
|
|
""", |
|
|
re.VERBOSE | re.IGNORECASE |
|
|
) |
|
|
|
|
|
def parse_roboflow_url(s: str): |
|
|
s = s.strip() |
|
|
m = _ROBO_URL_RX.match(s) |
|
|
if m: |
|
|
ws = m.group('ws') or m.group('ws2') |
|
|
proj = m.group('proj') or m.group('proj2') |
|
|
ver = m.group('ver') or m.group('ver2') |
|
|
return ws, proj, (int(ver) if ver else None) |
|
|
|
|
|
parsed = urlparse(s) |
|
|
parts = [p for p in parsed.path.strip('/').split('/') if p] |
|
|
if len(parts) >= 2: |
|
|
version = None |
|
|
if len(parts) >= 3: |
|
|
vpart = parts[2] |
|
|
if vpart.lower().startswith('v') and vpart[1:].isdigit(): |
|
|
version = int(vpart[1:]) |
|
|
elif vpart.isdigit(): |
|
|
version = int(vpart) |
|
|
return parts[0], parts[1], version |
|
|
|
|
|
if '/' in s and 'roboflow' not in s: |
|
|
p = s.split('/') |
|
|
if len(p) >= 2: |
|
|
version = None |
|
|
if len(p) >= 3: |
|
|
v = p[2] |
|
|
if v.lower().startswith('v') and v[1:].isdigit(): |
|
|
version = int(v[1:]) |
|
|
elif v.isdigit(): |
|
|
version = int(v) |
|
|
return p[0], p[1], version |
|
|
|
|
|
return None, None, None |
|
|
|
|
|
def get_latest_version(api_key, workspace, project): |
|
|
try: |
|
|
rf = Roboflow(api_key=api_key) |
|
|
proj = rf.workspace(workspace).project(project) |
|
|
versions = sorted([int(v.version) for v in proj.versions()], reverse=True) |
|
|
return versions[0] if versions else None |
|
|
except Exception as e: |
|
|
logging.error(f"Could not get latest version for {workspace}/{project}: {e}") |
|
|
return None |
|
|
|
|
|
def _extract_class_names(data_yaml): |
|
|
names = data_yaml.get('names', None) |
|
|
if isinstance(names, dict): |
|
|
def _k(x): |
|
|
try: return int(x) |
|
|
except Exception: return str(x) |
|
|
ordered = sorted(names.keys(), key=_k) |
|
|
names_list = [names[k] for k in ordered] |
|
|
elif isinstance(names, list): |
|
|
names_list = names |
|
|
else: |
|
|
nc = data_yaml.get('nc', 0) |
|
|
try: nc = int(nc) |
|
|
except Exception: nc = 0 |
|
|
names_list = [f"class_{i}" for i in range(nc)] |
|
|
return [str(x) for x in names_list] |
|
|
|
|
|
def download_dataset(api_key, workspace, project, version): |
|
|
try: |
|
|
rf = Roboflow(api_key=api_key) |
|
|
proj = rf.workspace(workspace).project(project) |
|
|
ver = proj.version(int(version)) |
|
|
dataset = ver.download("yolov8") |
|
|
|
|
|
data_yaml_path = os.path.join(dataset.location, 'data.yaml') |
|
|
with open(data_yaml_path, 'r') as f: |
|
|
data_yaml = yaml.safe_load(f) |
|
|
|
|
|
class_names = _extract_class_names(data_yaml) |
|
|
try: |
|
|
nc = int(data_yaml.get('nc', len(class_names))) |
|
|
except Exception: |
|
|
nc = len(class_names) |
|
|
if len(class_names) != nc: |
|
|
logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.") |
|
|
|
|
|
splits = [s for s in ['train', 'valid', 'test'] |
|
|
if os.path.exists(os.path.join(dataset.location, s))] |
|
|
|
|
|
return dataset.location, class_names, splits, f"{project}-v{version}" |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}") |
|
|
return None, [], [], None |
|
|
|
|
|
def label_path_for(img_path: str) -> str: |
|
|
split_dir = os.path.dirname(os.path.dirname(img_path)) |
|
|
base = os.path.splitext(os.path.basename(img_path))[0] + '.txt' |
|
|
return os.path.join(split_dir, 'labels', base) |
|
|
|
|
|
def gather_class_counts(dataset_info, class_mapping): |
|
|
if not dataset_info: |
|
|
return {} |
|
|
final_names = set(v for v in class_mapping.values() if v is not None) |
|
|
counts = {name: 0 for name in final_names} |
|
|
for loc, names, splits, _ in dataset_info: |
|
|
id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)} |
|
|
for split in splits: |
|
|
labels_dir = os.path.join(loc, split, 'labels') |
|
|
if not os.path.exists(labels_dir): |
|
|
continue |
|
|
for label_file in os.listdir(labels_dir): |
|
|
if not label_file.endswith('.txt'): |
|
|
continue |
|
|
found = set() |
|
|
with open(os.path.join(labels_dir, label_file), 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
cls_id = int(parts[0]) |
|
|
mapped = id_to_name.get(cls_id, None) |
|
|
if mapped in final_names: |
|
|
found.add(mapped) |
|
|
except Exception: |
|
|
continue |
|
|
for m in found: |
|
|
counts[m] += 1 |
|
|
return counts |
|
|
|
|
|
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()): |
|
|
merged_dir = 'rolo_merged_dataset' |
|
|
if os.path.exists(merged_dir): |
|
|
shutil.rmtree(merged_dir, onerror=handle_remove_readonly) |
|
|
|
|
|
progress(0, desc="Creating directories...") |
|
|
for split in ['train', 'valid', 'test']: |
|
|
os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True) |
|
|
os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True) |
|
|
|
|
|
active_classes = sorted(set([cls for cls, limit in class_limits.items() if limit > 0])) |
|
|
final_class_map = {name: i for i, name in enumerate(active_classes)} |
|
|
|
|
|
all_images = [] |
|
|
for loc, _, splits, _ in dataset_info: |
|
|
for split in splits: |
|
|
img_dir = os.path.join(loc, split, 'images') |
|
|
if not os.path.exists(img_dir): |
|
|
continue |
|
|
for img_file in os.listdir(img_dir): |
|
|
if img_file.lower().endswith(('.jpg', '.jpeg', '.png')): |
|
|
all_images.append((os.path.join(img_dir, img_file), split, loc)) |
|
|
random.shuffle(all_images) |
|
|
|
|
|
progress(0.2, desc="Selecting images based on limits...") |
|
|
selected_images = [] |
|
|
current_counts = {cls: 0 for cls in active_classes} |
|
|
loc_to_names = {info[0]: info[1] for info in dataset_info} |
|
|
|
|
|
for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"): |
|
|
lbl_path = label_path_for(img_path) |
|
|
if not os.path.exists(lbl_path): |
|
|
continue |
|
|
|
|
|
source_names = loc_to_names.get(source_loc, []) |
|
|
image_classes = set() |
|
|
with open(lbl_path, 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
cls_id = int(parts[0]) |
|
|
orig = source_names[cls_id] |
|
|
mapped = class_mapping.get(orig, orig) |
|
|
if mapped in active_classes: |
|
|
image_classes.add(mapped) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if not image_classes: |
|
|
continue |
|
|
if any(current_counts[c] >= class_limits[c] for c in image_classes): |
|
|
continue |
|
|
|
|
|
selected_images.append((img_path, split)) |
|
|
for c in image_classes: |
|
|
current_counts[c] += 1 |
|
|
|
|
|
progress(0.6, desc=f"Copying {len(selected_images)} files...") |
|
|
for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"): |
|
|
lbl_path = label_path_for(img_path) |
|
|
out_img = os.path.join(merged_dir, split, 'images', os.path.basename(img_path)) |
|
|
out_lbl = os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path)) |
|
|
shutil.copy(img_path, out_img) |
|
|
|
|
|
source_loc = None |
|
|
for info in dataset_info: |
|
|
if img_path.startswith(info[0]): |
|
|
source_loc = info[0] |
|
|
break |
|
|
source_names = loc_to_names.get(source_loc, []) |
|
|
|
|
|
with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out: |
|
|
for line in f_in: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
old_id = int(parts[0]) |
|
|
original_name = source_names[old_id] |
|
|
mapped_name = class_mapping.get(original_name, original_name) |
|
|
if mapped_name in final_class_map: |
|
|
new_id = final_class_map[mapped_name] |
|
|
f_out.write(f"{new_id} {' '.join(parts[1:])}\n") |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
progress(0.95, desc="Creating data.yaml...") |
|
|
with open(os.path.join(merged_dir, 'data.yaml'), 'w') as f: |
|
|
yaml.dump({ |
|
|
'path': os.path.abspath(merged_dir), |
|
|
'train': 'train/images', |
|
|
'val': 'valid/images', |
|
|
'test': 'test/images', |
|
|
'nc': len(active_classes), |
|
|
'names': active_classes |
|
|
}, f) |
|
|
|
|
|
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir) |
|
|
|
|
|
|
|
|
|
|
|
def run_pip_install(args, desc="pip install"): |
|
|
logging.info(f"{desc}: {args}") |
|
|
cmd = [sys.executable, "-m", "pip", "install"] + args |
|
|
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) |
|
|
logging.info(proc.stdout) |
|
|
if proc.returncode != 0: |
|
|
raise RuntimeError(f"{desc} failed with code {proc.returncode}") |
|
|
|
|
|
def ensure_repo(repo_dir: str, repo_url: str = RTDETRV2_REPO_URL): |
|
|
if os.path.isdir(repo_dir) and os.path.isdir(os.path.join(repo_dir, ".git")): |
|
|
return |
|
|
os.makedirs(os.path.dirname(repo_dir), exist_ok=True) |
|
|
logging.info(f"Cloning RT-DETRv2 repo into {repo_dir} ...") |
|
|
subprocess.run(["git", "clone", "--depth", "1", repo_url, repo_dir], check=True) |
|
|
|
|
|
def ensure_python_deps(repo_dir: str): |
|
|
""" |
|
|
Auto-install dependencies (idempotent). |
|
|
- Tries to install pinned basics that are often needed. |
|
|
- If repo has requirements*.txt, install them. |
|
|
- Creates a .deps_installed marker to skip on next run. |
|
|
""" |
|
|
marker = os.path.join(repo_dir, ".deps_installed") |
|
|
if os.path.exists(marker): |
|
|
logging.info("Dependencies already installed; skipping.") |
|
|
return |
|
|
|
|
|
|
|
|
basics = [ |
|
|
"numpy<2", |
|
|
"pillow", |
|
|
"tqdm", |
|
|
"pyyaml", |
|
|
"matplotlib", |
|
|
"pandas", |
|
|
"scipy", |
|
|
"opencv-python-headless", |
|
|
"packaging", |
|
|
"requests", |
|
|
"pycocotools-windows; platform_system=='Windows'", |
|
|
"pycocotools; platform_system!='Windows'", |
|
|
] |
|
|
try: |
|
|
run_pip_install(basics, desc="Installing common basics") |
|
|
except Exception as e: |
|
|
logging.warning(f"Basic installs had issues: {e}") |
|
|
|
|
|
|
|
|
req_files = [] |
|
|
for name in ["requirements.txt", "requirements-dev.txt", "requirements.in"]: |
|
|
p = os.path.join(repo_dir, name) |
|
|
if os.path.isfile(p): |
|
|
req_files.append(p) |
|
|
|
|
|
for rf in req_files: |
|
|
try: |
|
|
run_pip_install(["-r", rf], desc=f"Installing repo requirements from {rf}") |
|
|
except Exception as e: |
|
|
logging.warning(f"Installing {rf} failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
run_pip_install(["torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cpu"], desc="Installing PyTorch (CPU)") |
|
|
except Exception as e: |
|
|
logging.warning(f"PyTorch installation failed/skipped: {e}") |
|
|
|
|
|
|
|
|
with open(marker, "w") as f: |
|
|
f.write("ok\n") |
|
|
|
|
|
def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, imgsz: int, |
|
|
lr: float, optimizer: str, run_name: str, output_dir: str) -> str: |
|
|
return template.format( |
|
|
data_yaml=data_yaml, |
|
|
epochs=int(epochs), |
|
|
batch=int(batch), |
|
|
imgsz=int(imgsz), |
|
|
lr=float(lr), |
|
|
optimizer=str(optimizer), |
|
|
run_name=str(run_name), |
|
|
output_dir=output_dir |
|
|
) |
|
|
|
|
|
_METRIC_PATTERNS = [ |
|
|
(re.compile(r"mAP@0\.5[:/]?0\.95[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50_95"), |
|
|
(re.compile(r"mAP50[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50"), |
|
|
(re.compile(r"\bval[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "val_loss"), |
|
|
(re.compile(r"\btrain[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "train_loss"), |
|
|
(re.compile(r"\bepoch[^0-9]*([0-9]+)"), "epoch"), |
|
|
] |
|
|
|
|
|
def parse_metrics_from_line(line: str): |
|
|
result = {} |
|
|
for pat, key in _METRIC_PATTERNS: |
|
|
m = pat.search(line) |
|
|
if m: |
|
|
val = m.group(1) |
|
|
try: |
|
|
result[key] = int(val) if key == "epoch" else float(val) |
|
|
except Exception: |
|
|
pass |
|
|
return result |
|
|
|
|
|
def guess_final_weights(output_dir: str): |
|
|
patterns = [ |
|
|
os.path.join(output_dir, "**", "best.*"), |
|
|
os.path.join(output_dir, "**", "best_model.*"), |
|
|
os.path.join(output_dir, "**", "checkpoint_best.*"), |
|
|
] |
|
|
for p in patterns: |
|
|
hits = glob.glob(p, recursive=True) |
|
|
if hits: |
|
|
return hits[0] |
|
|
return None |
|
|
|
|
|
|
|
|
def load_datasets_handler(api_key, url_file, progress=gr.Progress()): |
|
|
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "") |
|
|
if not api_key: |
|
|
raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).") |
|
|
if not url_file: |
|
|
raise gr.Error("Please upload a .txt file with Roboflow URLs or lines like 'workspace/project[/vN]'.") |
|
|
|
|
|
with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f: |
|
|
urls = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
dataset_info = [] |
|
|
failures = [] |
|
|
for i, raw in enumerate(urls): |
|
|
progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}") |
|
|
ws, proj, ver = parse_roboflow_url(raw) |
|
|
if not (ws and proj): |
|
|
failures.append((raw, "ParseError: could not resolve workspace/project")) |
|
|
continue |
|
|
if ver is None: |
|
|
ver = get_latest_version(api_key, ws, proj) |
|
|
if ver is None: |
|
|
failures.append((raw, f"Could not resolve latest version for {ws}/{proj}")) |
|
|
continue |
|
|
|
|
|
loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver)) |
|
|
if loc: |
|
|
dataset_info.append((loc, names, splits, name_str)) |
|
|
else: |
|
|
failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}")) |
|
|
|
|
|
if not dataset_info: |
|
|
msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]]) |
|
|
raise gr.Error(msg) |
|
|
|
|
|
all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names}) |
|
|
class_map = {name: name for name in all_names} |
|
|
initial_counts = gather_class_counts(dataset_info, class_map) |
|
|
df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names] |
|
|
status_text = "Datasets loaded successfully." |
|
|
if failures: |
|
|
status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)." |
|
|
|
|
|
return status_text, dataset_info, gr.update( |
|
|
value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"]) |
|
|
) |
|
|
|
|
|
def update_class_counts_handler(class_df, dataset_info): |
|
|
if class_df is None or not dataset_info: |
|
|
return None |
|
|
class_df = pd.DataFrame(class_df) |
|
|
mapping = {} |
|
|
for _, row in class_df.iterrows(): |
|
|
orig = row["Original Name"] |
|
|
if bool(row["Remove"]): |
|
|
mapping[orig] = None |
|
|
else: |
|
|
mapping[orig] = row["Rename To"] |
|
|
|
|
|
final_names = sorted(set(v for v in mapping.values() if v)) |
|
|
counts = {k: 0 for k in final_names} |
|
|
for loc, names, splits, _ in dataset_info: |
|
|
id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)} |
|
|
for split in splits: |
|
|
labels_dir = os.path.join(loc, split, 'labels') |
|
|
if not os.path.exists(labels_dir): |
|
|
continue |
|
|
for label_file in os.listdir(labels_dir): |
|
|
if not label_file.endswith('.txt'): |
|
|
continue |
|
|
found = set() |
|
|
with open(os.path.join(labels_dir, label_file), 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if not parts: |
|
|
continue |
|
|
try: |
|
|
cls_id = int(parts[0]) |
|
|
mapped = id_to_final.get(cls_id, None) |
|
|
if mapped: |
|
|
found.add(mapped) |
|
|
except Exception: |
|
|
continue |
|
|
for m in found: |
|
|
counts[m] += 1 |
|
|
|
|
|
summary_df = pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"]) |
|
|
return summary_df |
|
|
|
|
|
def finalize_handler(dataset_info, class_df, progress=gr.Progress()): |
|
|
if not dataset_info: |
|
|
raise gr.Error("Load datasets first in Tab 1.") |
|
|
if class_df is None: |
|
|
raise gr.Error("Class data is missing.") |
|
|
|
|
|
class_df = pd.DataFrame(class_df) |
|
|
class_mapping = {} |
|
|
class_limits = {} |
|
|
for _, row in class_df.iterrows(): |
|
|
orig = row["Original Name"] |
|
|
if bool(row["Remove"]): |
|
|
continue |
|
|
final_name = row["Rename To"] |
|
|
class_mapping[orig] = final_name |
|
|
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"]) |
|
|
|
|
|
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress) |
|
|
return status, path |
|
|
|
|
|
def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, epochs, batch, imgsz, lr, opt, |
|
|
cmd_template, progress=gr.Progress()): |
|
|
if not dataset_path: |
|
|
raise gr.Error("Finalize a dataset in Tab 2 before training.") |
|
|
|
|
|
|
|
|
try: |
|
|
ensure_repo(repo_dir) |
|
|
ensure_python_deps(repo_dir) |
|
|
except subprocess.CalledProcessError as e: |
|
|
raise gr.Error(f"Failed to clone repo: {e}") |
|
|
except Exception as e: |
|
|
raise gr.Error(f"Dependency setup failed: {e}") |
|
|
|
|
|
|
|
|
output_dir = os.path.join("runs", "train", str(run_name)) |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
data_yaml = os.path.join(dataset_path, "data.yaml") |
|
|
if not os.path.isfile(data_yaml): |
|
|
raise gr.Error(f"'data.yaml' was not found in: {dataset_path}") |
|
|
|
|
|
|
|
|
cmd = make_train_command( |
|
|
template=cmd_template, |
|
|
data_yaml=data_yaml, |
|
|
epochs=int(epochs), |
|
|
batch=int(batch), |
|
|
imgsz=int(imgsz), |
|
|
lr=float(lr), |
|
|
optimizer=str(opt), |
|
|
run_name=str(run_name), |
|
|
output_dir=output_dir |
|
|
) |
|
|
|
|
|
logging.info(f"Running training command in {repo_dir}: {cmd}") |
|
|
proc = subprocess.Popen( |
|
|
cmd, cwd=repo_dir, shell=True, |
|
|
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, |
|
|
bufsize=1, universal_newlines=True, env={**os.environ} |
|
|
) |
|
|
|
|
|
history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']} |
|
|
for line in iter(proc.stdout.readline, ''): |
|
|
line = line.rstrip() |
|
|
progress(0.0, desc=line[-120:]) |
|
|
metrics = parse_metrics_from_line(line) |
|
|
if metrics: |
|
|
for k, v in metrics.items(): |
|
|
history[k].append(v) |
|
|
|
|
|
|
|
|
fig_loss = plt.figure() |
|
|
ax_loss = fig_loss.add_subplot(111) |
|
|
ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss') |
|
|
ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss') |
|
|
ax_loss.legend(); ax_loss.set_title("Loss") |
|
|
|
|
|
|
|
|
fig_map = plt.figure() |
|
|
ax_map = fig_map.add_subplot(111) |
|
|
ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5') |
|
|
ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95') |
|
|
ax_map.legend(); ax_map.set_title("mAP") |
|
|
|
|
|
yield line[-200:], fig_loss, fig_map, None |
|
|
|
|
|
proc.stdout.close() |
|
|
ret = proc.wait() |
|
|
if ret != 0: |
|
|
raise gr.Error(f"Training process exited with code {ret}. Check console/logs for details.") |
|
|
|
|
|
final_ckpt = guess_final_weights(output_dir) |
|
|
if final_ckpt and os.path.isfile(final_ckpt): |
|
|
yield "Training complete!", None, None, gr.File.update(value=final_ckpt, visible=True) |
|
|
else: |
|
|
yield ("Training finished. Could not auto-detect a 'best' checkpoint; " |
|
|
"please check the output directory."), None, None, gr.update(visible=False) |
|
|
|
|
|
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()): |
|
|
if not model_file: |
|
|
raise gr.Error("No trained model file available to upload. Train a model first.") |
|
|
|
|
|
hf_status = "Skipped Hugging Face (credentials not provided)." |
|
|
if hf_token and hf_repo: |
|
|
progress(0, desc="Uploading to Hugging Face...") |
|
|
try: |
|
|
api = HfApi() |
|
|
HfFolder.save_token(hf_token) |
|
|
repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token) |
|
|
api.upload_file( |
|
|
path_or_fileobj=model_file.name, |
|
|
path_in_repo=os.path.basename(model_file.name), |
|
|
repo_id=hf_repo, |
|
|
token=hf_token |
|
|
) |
|
|
hf_status = f"Success! Model at: {repo_url}" |
|
|
except Exception as e: |
|
|
hf_status = f"Hugging Face Error: {e}" |
|
|
|
|
|
gh_status = "Skipped GitHub (credentials not provided)." |
|
|
if gh_token and gh_repo: |
|
|
progress(0.5, desc="Uploading to GitHub...") |
|
|
try: |
|
|
if '/' not in gh_repo: |
|
|
raise ValueError("GitHub repo must be in the form 'username/repo'.") |
|
|
username, repo_name = gh_repo.split('/') |
|
|
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}" |
|
|
headers = {"Authorization": f"token {gh_token}"} |
|
|
|
|
|
with open(model_file.name, "rb") as f: |
|
|
content = base64.b64encode(f.read()).decode() |
|
|
|
|
|
get_resp = requests.get(api_url, headers=headers, timeout=30) |
|
|
sha = get_resp.json().get('sha') if get_resp.ok else None |
|
|
|
|
|
data = {"message": "Upload trained model from Rolo app", "content": content} |
|
|
if sha: data["sha"] = sha |
|
|
|
|
|
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60) |
|
|
if put_resp.ok: |
|
|
gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}" |
|
|
else: |
|
|
msg = put_resp.json().get('message', 'Unknown') |
|
|
gh_status = f"GitHub Error: {msg}" |
|
|
except Exception as e: |
|
|
gh_status = f"GitHub Error: {e}" |
|
|
|
|
|
progress(1) |
|
|
return hf_status, gh_status |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app: |
|
|
gr.Markdown("# Rolo: RT-DETRv2 Training Dashboard (Auto-setup for Hugging Face)") |
|
|
|
|
|
dataset_info_state = gr.State([]) |
|
|
final_dataset_path_state = gr.State(None) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("1. Prepare Datasets"): |
|
|
gr.Markdown("Upload a `.txt` with Roboflow URLs or `workspace/project[/vN]` lines.") |
|
|
with gr.Row(): |
|
|
rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY)", type="password", scale=2) |
|
|
rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1) |
|
|
load_btn = gr.Button("Load Datasets", variant="primary") |
|
|
dataset_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.TabItem("2. Manage & Merge"): |
|
|
gr.Markdown("Rename classes, set image limits, or remove them. Preview, then finalize.") |
|
|
with gr.Row(): |
|
|
class_df = gr.DataFrame( |
|
|
headers=["Original Name", "Rename To", "Max Images", "Remove"], |
|
|
datatype=["str", "str", "number", "bool"], |
|
|
label="Class Configuration", interactive=True, scale=3 |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
class_count_summary_df = gr.DataFrame( |
|
|
label="Merged Class Counts Preview", |
|
|
headers=["Final Class Name", "Est. Total Images"], |
|
|
interactive=False |
|
|
) |
|
|
update_counts_btn = gr.Button("Update Counts") |
|
|
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary") |
|
|
finalize_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.TabItem("3. Configure & Train"): |
|
|
gr.Markdown("Set hyperparameters and the training command template.") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
model_choice_dd = gr.Dropdown( |
|
|
label="Model Choice (label only; use your config in the template)", |
|
|
choices=RTDETRV2_MODELS, value=DEFAULT_MODEL |
|
|
) |
|
|
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1") |
|
|
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs") |
|
|
batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size") |
|
|
imgsz_num = gr.Number(label="Image Size", value=640) |
|
|
lr_num = gr.Number(label="Learning Rate", value=0.001) |
|
|
opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="AdamW", label="Optimizer") |
|
|
repo_dir_tb = gr.Textbox(label="RT-DETRv2 repo directory", value=DEFAULT_REPO_DIR) |
|
|
cmd_template_tb = gr.Textbox( |
|
|
label="Train command template", |
|
|
value=( |
|
|
"python tools/train.py " |
|
|
"--data {data_yaml} " |
|
|
"--epochs {epochs} " |
|
|
"--batch {batch} " |
|
|
"--imgsz {imgsz} " |
|
|
"--lr {lr} " |
|
|
"--optimizer {optimizer} " |
|
|
"--output {output_dir}" |
|
|
), |
|
|
lines=4 |
|
|
) |
|
|
train_btn = gr.Button("Start Training", variant="primary") |
|
|
with gr.Column(scale=2): |
|
|
train_status = gr.Textbox(label="Live Status / Logs", interactive=False) |
|
|
loss_plot = gr.Plot(label="Loss Curves") |
|
|
map_plot = gr.Plot(label="mAP Curves") |
|
|
final_model_file = gr.File(label="Download Trained Model (best.*)", interactive=False, visible=False) |
|
|
|
|
|
with gr.TabItem("4. Upload Model"): |
|
|
gr.Markdown("Upload your best checkpoint to Hugging Face or GitHub.") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Hugging Face") |
|
|
hf_token = gr.Textbox(label="Hugging Face API Token", type="password") |
|
|
hf_repo = gr.Textbox(label="Hugging Face Repo ID", placeholder="e.g., username/my-rtdetrv2-model") |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### GitHub") |
|
|
gh_token = gr.Textbox(label="GitHub Personal Access Token", type="password") |
|
|
gh_repo = gr.Textbox(label="GitHub Repo", placeholder="e.g., username/my-rtdetrv2-repo") |
|
|
upload_btn = gr.Button("Upload Model", variant="primary") |
|
|
with gr.Row(): |
|
|
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False) |
|
|
gh_status = gr.Textbox(label="GitHub Status", interactive=False) |
|
|
|
|
|
load_btn.click( |
|
|
fn=load_datasets_handler, |
|
|
inputs=[rf_api_key, rf_url_file], |
|
|
outputs=[dataset_status, dataset_info_state, class_df] |
|
|
) |
|
|
update_counts_btn.click( |
|
|
fn=update_class_counts_handler, |
|
|
inputs=[class_df, dataset_info_state], |
|
|
outputs=[class_count_summary_df] |
|
|
) |
|
|
finalize_btn.click( |
|
|
fn=finalize_handler, |
|
|
inputs=[dataset_info_state, class_df], |
|
|
outputs=[finalize_status, final_dataset_path_state] |
|
|
) |
|
|
train_btn.click( |
|
|
fn=training_handler_rtdetrv2, |
|
|
inputs=[ |
|
|
final_dataset_path_state, |
|
|
repo_dir_tb, |
|
|
model_choice_dd, |
|
|
run_name_tb, |
|
|
epochs_sl, |
|
|
batch_sl, |
|
|
imgsz_num, |
|
|
lr_num, |
|
|
opt_dd, |
|
|
cmd_template_tb |
|
|
], |
|
|
outputs=[train_status, loss_plot, map_plot, final_model_file] |
|
|
) |
|
|
upload_btn.click( |
|
|
fn=upload_handler, |
|
|
inputs=[final_model_file, hf_token, hf_repo, gh_token, gh_repo], |
|
|
outputs=[hf_status, gh_status] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
app.launch(debug=True) |
|
|
|