test-detr / app.py
wuhp's picture
Update app.py
ae4cf01 verified
raw
history blame
31 kB
import os
import shutil
import stat
import yaml
import gradio as gr
from roboflow import Roboflow
import re
from urllib.parse import urlparse
import random
import logging
import requests
import json
from PIL import Image
import pandas as pd
import matplotlib
matplotlib.use("Agg") # headless (HF Spaces)
import matplotlib.pyplot as plt
from threading import Thread
from queue import Queue
from huggingface_hub import HfApi, HfFolder
import base64
import subprocess
import sys
import time
import glob
# --- Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- RT-DETRv2 backend defaults (Supervisely ecosystem) ---
RTDETRV2_REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
DEFAULT_REPO_DIR = os.path.join("third_party", "rtdetrv2")
RTDETRV2_MODELS = [
"rtdetrv2-l-640", # labels only; match your config via the command template
"rtdetrv2-x-640"
]
DEFAULT_MODEL = RTDETRV2_MODELS[0]
# --- Utilities ---
def handle_remove_readonly(func, path, exc_info):
try:
os.chmod(path, stat.S_IWRITE)
except Exception:
pass
func(path)
_ROBO_URL_RX = re.compile(
r"""
^(?:
(?:https?://)?(?:universe|app|www)?\.?roboflow\.com/
(?P<ws>[A-Za-z0-9\-_]+)/
(?P<proj>[A-Za-z0-9\-_]+)/?
(?:
(?:dataset/[^/]+/)?
(?:v?(?P<ver>\d+))?
)?
|
(?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))?
)$
""",
re.VERBOSE | re.IGNORECASE
)
def parse_roboflow_url(s: str):
s = s.strip()
m = _ROBO_URL_RX.match(s)
if m:
ws = m.group('ws') or m.group('ws2')
proj = m.group('proj') or m.group('proj2')
ver = m.group('ver') or m.group('ver2')
return ws, proj, (int(ver) if ver else None)
parsed = urlparse(s)
parts = [p for p in parsed.path.strip('/').split('/') if p]
if len(parts) >= 2:
version = None
if len(parts) >= 3:
vpart = parts[2]
if vpart.lower().startswith('v') and vpart[1:].isdigit():
version = int(vpart[1:])
elif vpart.isdigit():
version = int(vpart)
return parts[0], parts[1], version
if '/' in s and 'roboflow' not in s:
p = s.split('/')
if len(p) >= 2:
version = None
if len(p) >= 3:
v = p[2]
if v.lower().startswith('v') and v[1:].isdigit():
version = int(v[1:])
elif v.isdigit():
version = int(v)
return p[0], p[1], version
return None, None, None
def get_latest_version(api_key, workspace, project):
try:
rf = Roboflow(api_key=api_key)
proj = rf.workspace(workspace).project(project)
versions = sorted([int(v.version) for v in proj.versions()], reverse=True)
return versions[0] if versions else None
except Exception as e:
logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
return None
def _extract_class_names(data_yaml):
names = data_yaml.get('names', None)
if isinstance(names, dict):
def _k(x):
try: return int(x)
except Exception: return str(x)
ordered = sorted(names.keys(), key=_k)
names_list = [names[k] for k in ordered]
elif isinstance(names, list):
names_list = names
else:
nc = data_yaml.get('nc', 0)
try: nc = int(nc)
except Exception: nc = 0
names_list = [f"class_{i}" for i in range(nc)]
return [str(x) for x in names_list]
def download_dataset(api_key, workspace, project, version):
try:
rf = Roboflow(api_key=api_key)
proj = rf.workspace(workspace).project(project)
ver = proj.version(int(version))
dataset = ver.download("yolov8")
data_yaml_path = os.path.join(dataset.location, 'data.yaml')
with open(data_yaml_path, 'r') as f:
data_yaml = yaml.safe_load(f)
class_names = _extract_class_names(data_yaml)
try:
nc = int(data_yaml.get('nc', len(class_names)))
except Exception:
nc = len(class_names)
if len(class_names) != nc:
logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.")
splits = [s for s in ['train', 'valid', 'test']
if os.path.exists(os.path.join(dataset.location, s))]
return dataset.location, class_names, splits, f"{project}-v{version}"
except Exception as e:
logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}")
return None, [], [], None
def label_path_for(img_path: str) -> str:
split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
return os.path.join(split_dir, 'labels', base)
def gather_class_counts(dataset_info, class_mapping):
if not dataset_info:
return {}
final_names = set(v for v in class_mapping.values() if v is not None)
counts = {name: 0 for name in final_names}
for loc, names, splits, _ in dataset_info:
id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
for split in splits:
labels_dir = os.path.join(loc, split, 'labels')
if not os.path.exists(labels_dir):
continue
for label_file in os.listdir(labels_dir):
if not label_file.endswith('.txt'):
continue
found = set()
with open(os.path.join(labels_dir, label_file), 'r') as f:
for line in f:
parts = line.strip().split()
if not parts:
continue
try:
cls_id = int(parts[0])
mapped = id_to_name.get(cls_id, None)
if mapped in final_names:
found.add(mapped)
except Exception:
continue
for m in found:
counts[m] += 1
return counts
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
merged_dir = 'rolo_merged_dataset'
if os.path.exists(merged_dir):
shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
progress(0, desc="Creating directories...")
for split in ['train', 'valid', 'test']:
os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True)
os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True)
active_classes = sorted(set([cls for cls, limit in class_limits.items() if limit > 0]))
final_class_map = {name: i for i, name in enumerate(active_classes)}
all_images = []
for loc, _, splits, _ in dataset_info:
for split in splits:
img_dir = os.path.join(loc, split, 'images')
if not os.path.exists(img_dir):
continue
for img_file in os.listdir(img_dir):
if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
all_images.append((os.path.join(img_dir, img_file), split, loc))
random.shuffle(all_images)
progress(0.2, desc="Selecting images based on limits...")
selected_images = []
current_counts = {cls: 0 for cls in active_classes}
loc_to_names = {info[0]: info[1] for info in dataset_info}
for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
lbl_path = label_path_for(img_path)
if not os.path.exists(lbl_path):
continue
source_names = loc_to_names.get(source_loc, [])
image_classes = set()
with open(lbl_path, 'r') as f:
for line in f:
parts = line.strip().split()
if not parts:
continue
try:
cls_id = int(parts[0])
orig = source_names[cls_id]
mapped = class_mapping.get(orig, orig)
if mapped in active_classes:
image_classes.add(mapped)
except Exception:
continue
if not image_classes:
continue
if any(current_counts[c] >= class_limits[c] for c in image_classes):
continue
selected_images.append((img_path, split))
for c in image_classes:
current_counts[c] += 1
progress(0.6, desc=f"Copying {len(selected_images)} files...")
for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
lbl_path = label_path_for(img_path)
out_img = os.path.join(merged_dir, split, 'images', os.path.basename(img_path))
out_lbl = os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path))
shutil.copy(img_path, out_img)
source_loc = None
for info in dataset_info:
if img_path.startswith(info[0]):
source_loc = info[0]
break
source_names = loc_to_names.get(source_loc, [])
with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out:
for line in f_in:
parts = line.strip().split()
if not parts:
continue
try:
old_id = int(parts[0])
original_name = source_names[old_id]
mapped_name = class_mapping.get(original_name, original_name)
if mapped_name in final_class_map:
new_id = final_class_map[mapped_name]
f_out.write(f"{new_id} {' '.join(parts[1:])}\n")
except Exception:
continue
progress(0.95, desc="Creating data.yaml...")
with open(os.path.join(merged_dir, 'data.yaml'), 'w') as f:
yaml.dump({
'path': os.path.abspath(merged_dir),
'train': 'train/images',
'val': 'valid/images',
'test': 'test/images',
'nc': len(active_classes),
'names': active_classes
}, f)
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
# --- Repo + deps helpers (auto-install for HF Spaces) ---
def run_pip_install(args, desc="pip install"):
logging.info(f"{desc}: {args}")
cmd = [sys.executable, "-m", "pip", "install"] + args
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
logging.info(proc.stdout)
if proc.returncode != 0:
raise RuntimeError(f"{desc} failed with code {proc.returncode}")
def ensure_repo(repo_dir: str, repo_url: str = RTDETRV2_REPO_URL):
if os.path.isdir(repo_dir) and os.path.isdir(os.path.join(repo_dir, ".git")):
return
os.makedirs(os.path.dirname(repo_dir), exist_ok=True)
logging.info(f"Cloning RT-DETRv2 repo into {repo_dir} ...")
subprocess.run(["git", "clone", "--depth", "1", repo_url, repo_dir], check=True)
def ensure_python_deps(repo_dir: str):
"""
Auto-install dependencies (idempotent).
- Tries to install pinned basics that are often needed.
- If repo has requirements*.txt, install them.
- Creates a .deps_installed marker to skip on next run.
"""
marker = os.path.join(repo_dir, ".deps_installed")
if os.path.exists(marker):
logging.info("Dependencies already installed; skipping.")
return
# 1) Common essentials for vision training environments on HF Spaces
basics = [
"numpy<2", # safer with many libs
"pillow",
"tqdm",
"pyyaml",
"matplotlib",
"pandas",
"scipy",
"opencv-python-headless",
"packaging",
"requests",
"pycocotools-windows; platform_system=='Windows'",
"pycocotools; platform_system!='Windows'",
]
try:
run_pip_install(basics, desc="Installing common basics")
except Exception as e:
logging.warning(f"Basic installs had issues: {e}")
# 2) Repo requirements
req_files = []
for name in ["requirements.txt", "requirements-dev.txt", "requirements.in"]:
p = os.path.join(repo_dir, name)
if os.path.isfile(p):
req_files.append(p)
for rf in req_files:
try:
run_pip_install(["-r", rf], desc=f"Installing repo requirements from {rf}")
except Exception as e:
logging.warning(f"Installing {rf} failed: {e}")
# 3) Optional: torch if not present (CPU-only by default on Spaces)
try:
import torch # noqa: F401
except Exception:
# Try a CPU-friendly torch; change version/cuda wheels if needed
try:
run_pip_install(["torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cpu"], desc="Installing PyTorch (CPU)")
except Exception as e:
logging.warning(f"PyTorch installation failed/skipped: {e}")
# Mark done
with open(marker, "w") as f:
f.write("ok\n")
def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, imgsz: int,
lr: float, optimizer: str, run_name: str, output_dir: str) -> str:
return template.format(
data_yaml=data_yaml,
epochs=int(epochs),
batch=int(batch),
imgsz=int(imgsz),
lr=float(lr),
optimizer=str(optimizer),
run_name=str(run_name),
output_dir=output_dir
)
_METRIC_PATTERNS = [
(re.compile(r"mAP@0\.5[:/]?0\.95[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50_95"),
(re.compile(r"mAP50[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50"),
(re.compile(r"\bval[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "val_loss"),
(re.compile(r"\btrain[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "train_loss"),
(re.compile(r"\bepoch[^0-9]*([0-9]+)"), "epoch"),
]
def parse_metrics_from_line(line: str):
result = {}
for pat, key in _METRIC_PATTERNS:
m = pat.search(line)
if m:
val = m.group(1)
try:
result[key] = int(val) if key == "epoch" else float(val)
except Exception:
pass
return result
def guess_final_weights(output_dir: str):
patterns = [
os.path.join(output_dir, "**", "best.*"),
os.path.join(output_dir, "**", "best_model.*"),
os.path.join(output_dir, "**", "checkpoint_best.*"),
]
for p in patterns:
hits = glob.glob(p, recursive=True)
if hits:
return hits[0]
return None
# --- Gradio handlers ---
def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
if not api_key:
raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
if not url_file:
raise gr.Error("Please upload a .txt file with Roboflow URLs or lines like 'workspace/project[/vN]'.")
with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
urls = [line.strip() for line in f if line.strip()]
dataset_info = []
failures = []
for i, raw in enumerate(urls):
progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}")
ws, proj, ver = parse_roboflow_url(raw)
if not (ws and proj):
failures.append((raw, "ParseError: could not resolve workspace/project"))
continue
if ver is None:
ver = get_latest_version(api_key, ws, proj)
if ver is None:
failures.append((raw, f"Could not resolve latest version for {ws}/{proj}"))
continue
loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver))
if loc:
dataset_info.append((loc, names, splits, name_str))
else:
failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
if not dataset_info:
msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
raise gr.Error(msg)
all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
class_map = {name: name for name in all_names}
initial_counts = gather_class_counts(dataset_info, class_map)
df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
status_text = "Datasets loaded successfully."
if failures:
status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
return status_text, dataset_info, gr.update(
value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
)
def update_class_counts_handler(class_df, dataset_info):
if class_df is None or not dataset_info:
return None
class_df = pd.DataFrame(class_df)
mapping = {}
for _, row in class_df.iterrows():
orig = row["Original Name"]
if bool(row["Remove"]):
mapping[orig] = None
else:
mapping[orig] = row["Rename To"]
final_names = sorted(set(v for v in mapping.values() if v))
counts = {k: 0 for k in final_names}
for loc, names, splits, _ in dataset_info:
id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
for split in splits:
labels_dir = os.path.join(loc, split, 'labels')
if not os.path.exists(labels_dir):
continue
for label_file in os.listdir(labels_dir):
if not label_file.endswith('.txt'):
continue
found = set()
with open(os.path.join(labels_dir, label_file), 'r') as f:
for line in f:
parts = line.strip().split()
if not parts:
continue
try:
cls_id = int(parts[0])
mapped = id_to_final.get(cls_id, None)
if mapped:
found.add(mapped)
except Exception:
continue
for m in found:
counts[m] += 1
summary_df = pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
return summary_df
def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
if not dataset_info:
raise gr.Error("Load datasets first in Tab 1.")
if class_df is None:
raise gr.Error("Class data is missing.")
class_df = pd.DataFrame(class_df)
class_mapping = {}
class_limits = {}
for _, row in class_df.iterrows():
orig = row["Original Name"]
if bool(row["Remove"]):
continue
final_name = row["Rename To"]
class_mapping[orig] = final_name
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
return status, path
def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, epochs, batch, imgsz, lr, opt,
cmd_template, progress=gr.Progress()):
if not dataset_path:
raise gr.Error("Finalize a dataset in Tab 2 before training.")
# Clone + deps (idempotent)
try:
ensure_repo(repo_dir)
ensure_python_deps(repo_dir)
except subprocess.CalledProcessError as e:
raise gr.Error(f"Failed to clone repo: {e}")
except Exception as e:
raise gr.Error(f"Dependency setup failed: {e}")
# Output dir
output_dir = os.path.join("runs", "train", str(run_name))
os.makedirs(output_dir, exist_ok=True)
data_yaml = os.path.join(dataset_path, "data.yaml")
if not os.path.isfile(data_yaml):
raise gr.Error(f"'data.yaml' was not found in: {dataset_path}")
# Build command from template
cmd = make_train_command(
template=cmd_template,
data_yaml=data_yaml,
epochs=int(epochs),
batch=int(batch),
imgsz=int(imgsz),
lr=float(lr),
optimizer=str(opt),
run_name=str(run_name),
output_dir=output_dir
)
logging.info(f"Running training command in {repo_dir}: {cmd}")
proc = subprocess.Popen(
cmd, cwd=repo_dir, shell=True,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
bufsize=1, universal_newlines=True, env={**os.environ}
)
history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']}
for line in iter(proc.stdout.readline, ''):
line = line.rstrip()
progress(0.0, desc=line[-120:])
metrics = parse_metrics_from_line(line)
if metrics:
for k, v in metrics.items():
history[k].append(v)
# plot loss
fig_loss = plt.figure()
ax_loss = fig_loss.add_subplot(111)
ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
ax_loss.legend(); ax_loss.set_title("Loss")
# plot mAP
fig_map = plt.figure()
ax_map = fig_map.add_subplot(111)
ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
ax_map.legend(); ax_map.set_title("mAP")
yield line[-200:], fig_loss, fig_map, None
proc.stdout.close()
ret = proc.wait()
if ret != 0:
raise gr.Error(f"Training process exited with code {ret}. Check console/logs for details.")
final_ckpt = guess_final_weights(output_dir)
if final_ckpt and os.path.isfile(final_ckpt):
yield "Training complete!", None, None, gr.File.update(value=final_ckpt, visible=True)
else:
yield ("Training finished. Could not auto-detect a 'best' checkpoint; "
"please check the output directory."), None, None, gr.update(visible=False)
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
if not model_file:
raise gr.Error("No trained model file available to upload. Train a model first.")
hf_status = "Skipped Hugging Face (credentials not provided)."
if hf_token and hf_repo:
progress(0, desc="Uploading to Hugging Face...")
try:
api = HfApi()
HfFolder.save_token(hf_token)
repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token)
api.upload_file(
path_or_fileobj=model_file.name,
path_in_repo=os.path.basename(model_file.name),
repo_id=hf_repo,
token=hf_token
)
hf_status = f"Success! Model at: {repo_url}"
except Exception as e:
hf_status = f"Hugging Face Error: {e}"
gh_status = "Skipped GitHub (credentials not provided)."
if gh_token and gh_repo:
progress(0.5, desc="Uploading to GitHub...")
try:
if '/' not in gh_repo:
raise ValueError("GitHub repo must be in the form 'username/repo'.")
username, repo_name = gh_repo.split('/')
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
headers = {"Authorization": f"token {gh_token}"}
with open(model_file.name, "rb") as f:
content = base64.b64encode(f.read()).decode()
get_resp = requests.get(api_url, headers=headers, timeout=30)
sha = get_resp.json().get('sha') if get_resp.ok else None
data = {"message": "Upload trained model from Rolo app", "content": content}
if sha: data["sha"] = sha
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
if put_resp.ok:
gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
else:
msg = put_resp.json().get('message', 'Unknown')
gh_status = f"GitHub Error: {msg}"
except Exception as e:
gh_status = f"GitHub Error: {e}"
progress(1)
return hf_status, gh_status
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
gr.Markdown("# Rolo: RT-DETRv2 Training Dashboard (Auto-setup for Hugging Face)")
dataset_info_state = gr.State([])
final_dataset_path_state = gr.State(None)
with gr.Tabs():
with gr.TabItem("1. Prepare Datasets"):
gr.Markdown("Upload a `.txt` with Roboflow URLs or `workspace/project[/vN]` lines.")
with gr.Row():
rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY)", type="password", scale=2)
rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
load_btn = gr.Button("Load Datasets", variant="primary")
dataset_status = gr.Textbox(label="Status", interactive=False)
with gr.TabItem("2. Manage & Merge"):
gr.Markdown("Rename classes, set image limits, or remove them. Preview, then finalize.")
with gr.Row():
class_df = gr.DataFrame(
headers=["Original Name", "Rename To", "Max Images", "Remove"],
datatype=["str", "str", "number", "bool"],
label="Class Configuration", interactive=True, scale=3
)
with gr.Column(scale=1):
class_count_summary_df = gr.DataFrame(
label="Merged Class Counts Preview",
headers=["Final Class Name", "Est. Total Images"],
interactive=False
)
update_counts_btn = gr.Button("Update Counts")
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
finalize_status = gr.Textbox(label="Status", interactive=False)
with gr.TabItem("3. Configure & Train"):
gr.Markdown("Set hyperparameters and the training command template.")
with gr.Row():
with gr.Column(scale=1):
model_choice_dd = gr.Dropdown(
label="Model Choice (label only; use your config in the template)",
choices=RTDETRV2_MODELS, value=DEFAULT_MODEL
)
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
imgsz_num = gr.Number(label="Image Size", value=640)
lr_num = gr.Number(label="Learning Rate", value=0.001)
opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="AdamW", label="Optimizer")
repo_dir_tb = gr.Textbox(label="RT-DETRv2 repo directory", value=DEFAULT_REPO_DIR)
cmd_template_tb = gr.Textbox(
label="Train command template",
value=(
"python tools/train.py "
"--data {data_yaml} "
"--epochs {epochs} "
"--batch {batch} "
"--imgsz {imgsz} "
"--lr {lr} "
"--optimizer {optimizer} "
"--output {output_dir}"
),
lines=4
)
train_btn = gr.Button("Start Training", variant="primary")
with gr.Column(scale=2):
train_status = gr.Textbox(label="Live Status / Logs", interactive=False)
loss_plot = gr.Plot(label="Loss Curves")
map_plot = gr.Plot(label="mAP Curves")
final_model_file = gr.File(label="Download Trained Model (best.*)", interactive=False, visible=False)
with gr.TabItem("4. Upload Model"):
gr.Markdown("Upload your best checkpoint to Hugging Face or GitHub.")
with gr.Row():
with gr.Column():
gr.Markdown("#### Hugging Face")
hf_token = gr.Textbox(label="Hugging Face API Token", type="password")
hf_repo = gr.Textbox(label="Hugging Face Repo ID", placeholder="e.g., username/my-rtdetrv2-model")
with gr.Column():
gr.Markdown("#### GitHub")
gh_token = gr.Textbox(label="GitHub Personal Access Token", type="password")
gh_repo = gr.Textbox(label="GitHub Repo", placeholder="e.g., username/my-rtdetrv2-repo")
upload_btn = gr.Button("Upload Model", variant="primary")
with gr.Row():
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
gh_status = gr.Textbox(label="GitHub Status", interactive=False)
load_btn.click(
fn=load_datasets_handler,
inputs=[rf_api_key, rf_url_file],
outputs=[dataset_status, dataset_info_state, class_df]
)
update_counts_btn.click(
fn=update_class_counts_handler,
inputs=[class_df, dataset_info_state],
outputs=[class_count_summary_df]
)
finalize_btn.click(
fn=finalize_handler,
inputs=[dataset_info_state, class_df],
outputs=[finalize_status, final_dataset_path_state]
)
train_btn.click(
fn=training_handler_rtdetrv2,
inputs=[
final_dataset_path_state, # dataset_path
repo_dir_tb, # repo_dir (auto clone + pip install)
model_choice_dd, # model_choice (label only)
run_name_tb,
epochs_sl,
batch_sl,
imgsz_num,
lr_num,
opt_dd,
cmd_template_tb
],
outputs=[train_status, loss_plot, map_plot, final_model_file]
)
upload_btn.click(
fn=upload_handler,
inputs=[final_model_file, hf_token, hf_repo, gh_token, gh_repo],
outputs=[hf_status, gh_status]
)
if __name__ == "__main__":
# Hugging Face Spaces: set server name/port via env if needed.
# Example: app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), debug=True)
app.launch(debug=True)