Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,8 @@ import requests
|
|
| 12 |
import json
|
| 13 |
from PIL import Image
|
| 14 |
import pandas as pd
|
|
|
|
|
|
|
| 15 |
import matplotlib.pyplot as plt
|
| 16 |
from threading import Thread
|
| 17 |
from queue import Queue
|
|
@@ -22,27 +24,21 @@ import sys
|
|
| 22 |
import time
|
| 23 |
import glob
|
| 24 |
|
| 25 |
-
# ---
|
| 26 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 27 |
|
| 28 |
-
#
|
| 29 |
RTDETRV2_REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
|
| 30 |
DEFAULT_REPO_DIR = os.path.join("third_party", "rtdetrv2")
|
| 31 |
|
| 32 |
-
# You can still offer "model size" choices to hint the user which config to use,
|
| 33 |
-
# but the actual command is controlled by the template.
|
| 34 |
RTDETRV2_MODELS = [
|
| 35 |
-
"rtdetrv2-l-640", #
|
| 36 |
"rtdetrv2-x-640"
|
| 37 |
]
|
| 38 |
DEFAULT_MODEL = RTDETRV2_MODELS[0]
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
# Utilities
|
| 42 |
-
# ------------------------------
|
| 43 |
-
|
| 44 |
def handle_remove_readonly(func, path, exc_info):
|
| 45 |
-
"""Error handler for shutil.rmtree."""
|
| 46 |
try:
|
| 47 |
os.chmod(path, stat.S_IWRITE)
|
| 48 |
except Exception:
|
|
@@ -67,14 +63,6 @@ _ROBO_URL_RX = re.compile(
|
|
| 67 |
)
|
| 68 |
|
| 69 |
def parse_roboflow_url(s: str):
|
| 70 |
-
"""
|
| 71 |
-
Support:
|
| 72 |
-
- https://universe.roboflow.com/<workspace>/<project>[/vN]
|
| 73 |
-
- https://app.roboflow.com/<workspace>/<project>[/vN]
|
| 74 |
-
- https://roboflow.com/<workspace>/<project>[/vN]
|
| 75 |
-
- raw: <workspace>/<project>[/vN]
|
| 76 |
-
Returns: (workspace, project, version_or_None)
|
| 77 |
-
"""
|
| 78 |
s = s.strip()
|
| 79 |
m = _ROBO_URL_RX.match(s)
|
| 80 |
if m:
|
|
@@ -110,7 +98,6 @@ def parse_roboflow_url(s: str):
|
|
| 110 |
return None, None, None
|
| 111 |
|
| 112 |
def get_latest_version(api_key, workspace, project):
|
| 113 |
-
"""Gets the latest version number of a Roboflow project."""
|
| 114 |
try:
|
| 115 |
rf = Roboflow(api_key=api_key)
|
| 116 |
proj = rf.workspace(workspace).project(project)
|
|
@@ -120,39 +107,24 @@ def get_latest_version(api_key, workspace, project):
|
|
| 120 |
logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
|
| 121 |
return None
|
| 122 |
|
| 123 |
-
# --- Normalize class names from data.yaml ---
|
| 124 |
def _extract_class_names(data_yaml):
|
| 125 |
-
"""
|
| 126 |
-
Return list[str] of class names in index order.
|
| 127 |
-
Supports:
|
| 128 |
-
- list
|
| 129 |
-
- dict with numeric keys {0:'cat',1:'dog'}
|
| 130 |
-
- fallback to ['class_0', ...]
|
| 131 |
-
"""
|
| 132 |
names = data_yaml.get('names', None)
|
| 133 |
-
|
| 134 |
if isinstance(names, dict):
|
| 135 |
def _k(x):
|
| 136 |
-
try:
|
| 137 |
-
|
| 138 |
-
except Exception:
|
| 139 |
-
return str(x)
|
| 140 |
ordered = sorted(names.keys(), key=_k)
|
| 141 |
names_list = [names[k] for k in ordered]
|
| 142 |
elif isinstance(names, list):
|
| 143 |
names_list = names
|
| 144 |
else:
|
| 145 |
nc = data_yaml.get('nc', 0)
|
| 146 |
-
try:
|
| 147 |
-
|
| 148 |
-
except Exception:
|
| 149 |
-
nc = 0
|
| 150 |
names_list = [f"class_{i}" for i in range(nc)]
|
| 151 |
-
|
| 152 |
return [str(x) for x in names_list]
|
| 153 |
|
| 154 |
def download_dataset(api_key, workspace, project, version):
|
| 155 |
-
"""Download Roboflow dataset in 'yolov8' layout (works fine for RT-DETR variants)."""
|
| 156 |
try:
|
| 157 |
rf = Roboflow(api_key=api_key)
|
| 158 |
proj = rf.workspace(workspace).project(project)
|
|
@@ -180,25 +152,17 @@ def download_dataset(api_key, workspace, project, version):
|
|
| 180 |
return None, [], [], None
|
| 181 |
|
| 182 |
def label_path_for(img_path: str) -> str:
|
| 183 |
-
"""Convert .../split/images/file.jpg -> .../split/labels/file.txt."""
|
| 184 |
split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
|
| 185 |
base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
|
| 186 |
return os.path.join(split_dir, 'labels', base)
|
| 187 |
|
| 188 |
def gather_class_counts(dataset_info, class_mapping):
|
| 189 |
-
"""
|
| 190 |
-
Count per final class how many images contain that class at least once (counted once per image).
|
| 191 |
-
class_mapping: original_name -> final_name (or None if removed).
|
| 192 |
-
"""
|
| 193 |
if not dataset_info:
|
| 194 |
return {}
|
| 195 |
-
|
| 196 |
final_names = set(v for v in class_mapping.values() if v is not None)
|
| 197 |
counts = {name: 0 for name in final_names}
|
| 198 |
-
|
| 199 |
for loc, names, splits, _ in dataset_info:
|
| 200 |
id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
|
| 201 |
-
|
| 202 |
for split in splits:
|
| 203 |
labels_dir = os.path.join(loc, split, 'labels')
|
| 204 |
if not os.path.exists(labels_dir):
|
|
@@ -221,11 +185,9 @@ def gather_class_counts(dataset_info, class_mapping):
|
|
| 221 |
continue
|
| 222 |
for m in found:
|
| 223 |
counts[m] += 1
|
| 224 |
-
|
| 225 |
return counts
|
| 226 |
|
| 227 |
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
|
| 228 |
-
"""Merge datasets following mapping and per-class image limits."""
|
| 229 |
merged_dir = 'rolo_merged_dataset'
|
| 230 |
if os.path.exists(merged_dir):
|
| 231 |
shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
|
|
@@ -238,7 +200,6 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
|
|
| 238 |
active_classes = sorted(set([cls for cls, limit in class_limits.items() if limit > 0]))
|
| 239 |
final_class_map = {name: i for i, name in enumerate(active_classes)}
|
| 240 |
|
| 241 |
-
# Collect candidates
|
| 242 |
all_images = []
|
| 243 |
for loc, _, splits, _ in dataset_info:
|
| 244 |
for split in splits:
|
|
@@ -327,18 +288,81 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
|
|
| 327 |
|
| 328 |
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
|
| 329 |
|
| 330 |
-
#
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
def ensure_repo(repo_dir: str, repo_url: str = RTDETRV2_REPO_URL):
|
| 335 |
-
"""Clone the repo into repo_dir if not present."""
|
| 336 |
if os.path.isdir(repo_dir) and os.path.isdir(os.path.join(repo_dir, ".git")):
|
| 337 |
return
|
| 338 |
os.makedirs(os.path.dirname(repo_dir), exist_ok=True)
|
| 339 |
logging.info(f"Cloning RT-DETRv2 repo into {repo_dir} ...")
|
| 340 |
-
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, imgsz: int,
|
| 344 |
lr: float, optimizer: str, run_name: str, output_dir: str) -> str:
|
|
@@ -354,7 +378,6 @@ def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, i
|
|
| 354 |
)
|
| 355 |
|
| 356 |
_METRIC_PATTERNS = [
|
| 357 |
-
# add more patterns if your repo prints differently
|
| 358 |
(re.compile(r"mAP@0\.5[:/]?0\.95[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50_95"),
|
| 359 |
(re.compile(r"mAP50[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50"),
|
| 360 |
(re.compile(r"\bval[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "val_loss"),
|
|
@@ -375,10 +398,6 @@ def parse_metrics_from_line(line: str):
|
|
| 375 |
return result
|
| 376 |
|
| 377 |
def guess_final_weights(output_dir: str):
|
| 378 |
-
"""
|
| 379 |
-
Try to locate a 'best' checkpoint in output_dir.
|
| 380 |
-
Supports .pt/.pth/.pdparams etc. Return first match or None.
|
| 381 |
-
"""
|
| 382 |
patterns = [
|
| 383 |
os.path.join(output_dir, "**", "best.*"),
|
| 384 |
os.path.join(output_dir, "**", "best_model.*"),
|
|
@@ -390,12 +409,8 @@ def guess_final_weights(output_dir: str):
|
|
| 390 |
return hits[0]
|
| 391 |
return None
|
| 392 |
|
| 393 |
-
#
|
| 394 |
-
# Gradio UI Event Handlers
|
| 395 |
-
# ------------------------------
|
| 396 |
-
|
| 397 |
def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
|
| 398 |
-
"""Handles the 'Load Datasets' button click."""
|
| 399 |
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
|
| 400 |
if not api_key:
|
| 401 |
raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
|
|
@@ -407,14 +422,12 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
|
|
| 407 |
|
| 408 |
dataset_info = []
|
| 409 |
failures = []
|
| 410 |
-
|
| 411 |
for i, raw in enumerate(urls):
|
| 412 |
progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}")
|
| 413 |
ws, proj, ver = parse_roboflow_url(raw)
|
| 414 |
if not (ws and proj):
|
| 415 |
failures.append((raw, "ParseError: could not resolve workspace/project"))
|
| 416 |
continue
|
| 417 |
-
|
| 418 |
if ver is None:
|
| 419 |
ver = get_latest_version(api_key, ws, proj)
|
| 420 |
if ver is None:
|
|
@@ -431,26 +444,21 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
|
|
| 431 |
msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
|
| 432 |
raise gr.Error(msg)
|
| 433 |
|
| 434 |
-
# ensure names are strings before sorting
|
| 435 |
all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
|
| 436 |
class_map = {name: name for name in all_names}
|
| 437 |
-
|
| 438 |
initial_counts = gather_class_counts(dataset_info, class_map)
|
| 439 |
df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
|
| 440 |
status_text = "Datasets loaded successfully."
|
| 441 |
if failures:
|
| 442 |
status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
|
| 443 |
|
| 444 |
-
# FIX: gr.update(...) (not gr.DataFrame.update)
|
| 445 |
return status_text, dataset_info, gr.update(
|
| 446 |
value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
|
| 447 |
)
|
| 448 |
|
| 449 |
def update_class_counts_handler(class_df, dataset_info):
|
| 450 |
-
"""Live preview of merged class counts given the current mapping/removals."""
|
| 451 |
if class_df is None or not dataset_info:
|
| 452 |
return None
|
| 453 |
-
|
| 454 |
class_df = pd.DataFrame(class_df)
|
| 455 |
mapping = {}
|
| 456 |
for _, row in class_df.iterrows():
|
|
@@ -462,10 +470,8 @@ def update_class_counts_handler(class_df, dataset_info):
|
|
| 462 |
|
| 463 |
final_names = sorted(set(v for v in mapping.values() if v))
|
| 464 |
counts = {k: 0 for k in final_names}
|
| 465 |
-
|
| 466 |
for loc, names, splits, _ in dataset_info:
|
| 467 |
id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
|
| 468 |
-
|
| 469 |
for split in splits:
|
| 470 |
labels_dir = os.path.join(loc, split, 'labels')
|
| 471 |
if not os.path.exists(labels_dir):
|
|
@@ -493,7 +499,6 @@ def update_class_counts_handler(class_df, dataset_info):
|
|
| 493 |
return summary_df
|
| 494 |
|
| 495 |
def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
|
| 496 |
-
"""Create the merged dataset directory with relabeled .txts and data.yaml."""
|
| 497 |
if not dataset_info:
|
| 498 |
raise gr.Error("Load datasets first in Tab 1.")
|
| 499 |
if class_df is None:
|
|
@@ -515,20 +520,19 @@ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
|
|
| 515 |
|
| 516 |
def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, epochs, batch, imgsz, lr, opt,
|
| 517 |
cmd_template, progress=gr.Progress()):
|
| 518 |
-
"""
|
| 519 |
-
Train using RT-DETRv2 repo via a configurable command template.
|
| 520 |
-
We stream logs, parse simple metrics when patterns match, and try to locate a best checkpoint on completion.
|
| 521 |
-
"""
|
| 522 |
if not dataset_path:
|
| 523 |
raise gr.Error("Finalize a dataset in Tab 2 before training.")
|
| 524 |
|
| 525 |
-
#
|
| 526 |
try:
|
| 527 |
ensure_repo(repo_dir)
|
|
|
|
| 528 |
except subprocess.CalledProcessError as e:
|
| 529 |
-
raise gr.Error(f"Failed to clone
|
|
|
|
|
|
|
| 530 |
|
| 531 |
-
#
|
| 532 |
output_dir = os.path.join("runs", "train", str(run_name))
|
| 533 |
os.makedirs(output_dir, exist_ok=True)
|
| 534 |
|
|
@@ -536,7 +540,7 @@ def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, ep
|
|
| 536 |
if not os.path.isfile(data_yaml):
|
| 537 |
raise gr.Error(f"'data.yaml' was not found in: {dataset_path}")
|
| 538 |
|
| 539 |
-
# Build
|
| 540 |
cmd = make_train_command(
|
| 541 |
template=cmd_template,
|
| 542 |
data_yaml=data_yaml,
|
|
@@ -549,54 +553,36 @@ def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, ep
|
|
| 549 |
output_dir=output_dir
|
| 550 |
)
|
| 551 |
|
| 552 |
-
# Launch training subprocess in repo_dir
|
| 553 |
logging.info(f"Running training command in {repo_dir}: {cmd}")
|
| 554 |
proc = subprocess.Popen(
|
| 555 |
-
cmd,
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
stdout=subprocess.PIPE,
|
| 559 |
-
stderr=subprocess.STDOUT,
|
| 560 |
-
bufsize=1,
|
| 561 |
-
universal_newlines=True,
|
| 562 |
-
env={**os.environ} # inherit env (CUDA, etc.)
|
| 563 |
)
|
| 564 |
|
| 565 |
-
# Live metrics
|
| 566 |
history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']}
|
| 567 |
-
last_epoch = 0
|
| 568 |
-
|
| 569 |
-
# Stream logs and parse
|
| 570 |
for line in iter(proc.stdout.readline, ''):
|
| 571 |
line = line.rstrip()
|
| 572 |
-
|
| 573 |
-
if "epoch" in line.lower():
|
| 574 |
-
progress(0.0, desc=line[-120:]) # show last part of the line
|
| 575 |
-
else:
|
| 576 |
-
progress(0.0, desc=line[-120:])
|
| 577 |
-
|
| 578 |
metrics = parse_metrics_from_line(line)
|
| 579 |
if metrics:
|
| 580 |
for k, v in metrics.items():
|
| 581 |
history[k].append(v)
|
| 582 |
-
|
| 583 |
-
#
|
| 584 |
fig_loss = plt.figure()
|
| 585 |
ax_loss = fig_loss.add_subplot(111)
|
| 586 |
ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
|
| 587 |
ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
|
| 588 |
-
ax_loss.legend()
|
| 589 |
-
ax_loss.set_title("Loss")
|
| 590 |
|
| 591 |
-
#
|
| 592 |
fig_map = plt.figure()
|
| 593 |
ax_map = fig_map.add_subplot(111)
|
| 594 |
ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
|
| 595 |
ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
|
| 596 |
-
ax_map.legend()
|
| 597 |
-
ax_map.set_title("mAP")
|
| 598 |
|
| 599 |
-
# Emit an update to the UI (status text is the last log line)
|
| 600 |
yield line[-200:], fig_loss, fig_map, None
|
| 601 |
|
| 602 |
proc.stdout.close()
|
|
@@ -604,16 +590,14 @@ def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, ep
|
|
| 604 |
if ret != 0:
|
| 605 |
raise gr.Error(f"Training process exited with code {ret}. Check console/logs for details.")
|
| 606 |
|
| 607 |
-
# Try to locate a best checkpoint
|
| 608 |
final_ckpt = guess_final_weights(output_dir)
|
| 609 |
if final_ckpt and os.path.isfile(final_ckpt):
|
| 610 |
yield "Training complete!", None, None, gr.File.update(value=final_ckpt, visible=True)
|
| 611 |
else:
|
| 612 |
-
|
| 613 |
-
|
| 614 |
|
| 615 |
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
|
| 616 |
-
"""Handles model upload to Hugging Face and GitHub."""
|
| 617 |
if not model_file:
|
| 618 |
raise gr.Error("No trained model file available to upload. Train a model first.")
|
| 619 |
|
|
@@ -640,7 +624,6 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
|
|
| 640 |
try:
|
| 641 |
if '/' not in gh_repo:
|
| 642 |
raise ValueError("GitHub repo must be in the form 'username/repo'.")
|
| 643 |
-
|
| 644 |
username, repo_name = gh_repo.split('/')
|
| 645 |
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
|
| 646 |
headers = {"Authorization": f"token {gh_token}"}
|
|
@@ -652,11 +635,9 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
|
|
| 652 |
sha = get_resp.json().get('sha') if get_resp.ok else None
|
| 653 |
|
| 654 |
data = {"message": "Upload trained model from Rolo app", "content": content}
|
| 655 |
-
if sha:
|
| 656 |
-
data["sha"] = sha
|
| 657 |
|
| 658 |
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
|
| 659 |
-
|
| 660 |
if put_resp.ok:
|
| 661 |
gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
|
| 662 |
else:
|
|
@@ -668,27 +649,24 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
|
|
| 668 |
progress(1)
|
| 669 |
return hf_status, gh_status
|
| 670 |
|
| 671 |
-
#
|
| 672 |
-
# Gradio UI
|
| 673 |
-
# ------------------------------
|
| 674 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
| 675 |
-
gr.Markdown("# Rolo: RT-DETRv2 Training Dashboard (
|
| 676 |
|
| 677 |
-
# State variables
|
| 678 |
dataset_info_state = gr.State([])
|
| 679 |
final_dataset_path_state = gr.State(None)
|
| 680 |
|
| 681 |
with gr.Tabs():
|
| 682 |
with gr.TabItem("1. Prepare Datasets"):
|
| 683 |
-
gr.Markdown("
|
| 684 |
with gr.Row():
|
| 685 |
-
rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY
|
| 686 |
rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
|
| 687 |
load_btn = gr.Button("Load Datasets", variant="primary")
|
| 688 |
dataset_status = gr.Textbox(label="Status", interactive=False)
|
| 689 |
|
| 690 |
with gr.TabItem("2. Manage & Merge"):
|
| 691 |
-
gr.Markdown("
|
| 692 |
with gr.Row():
|
| 693 |
class_df = gr.DataFrame(
|
| 694 |
headers=["Original Name", "Rename To", "Max Images", "Remove"],
|
|
@@ -702,18 +680,16 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 702 |
interactive=False
|
| 703 |
)
|
| 704 |
update_counts_btn = gr.Button("Update Counts")
|
| 705 |
-
|
| 706 |
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
|
| 707 |
finalize_status = gr.Textbox(label="Status", interactive=False)
|
| 708 |
|
| 709 |
with gr.TabItem("3. Configure & Train"):
|
| 710 |
-
gr.Markdown("
|
| 711 |
with gr.Row():
|
| 712 |
with gr.Column(scale=1):
|
| 713 |
model_choice_dd = gr.Dropdown(
|
| 714 |
-
label="Model Choice (label only
|
| 715 |
-
choices=RTDETRV2_MODELS,
|
| 716 |
-
value=DEFAULT_MODEL
|
| 717 |
)
|
| 718 |
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
|
| 719 |
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
|
|
@@ -721,7 +697,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 721 |
imgsz_num = gr.Number(label="Image Size", value=640)
|
| 722 |
lr_num = gr.Number(label="Learning Rate", value=0.001)
|
| 723 |
opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="AdamW", label="Optimizer")
|
| 724 |
-
|
| 725 |
repo_dir_tb = gr.Textbox(label="RT-DETRv2 repo directory", value=DEFAULT_REPO_DIR)
|
| 726 |
cmd_template_tb = gr.Textbox(
|
| 727 |
label="Train command template",
|
|
@@ -745,7 +720,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 745 |
final_model_file = gr.File(label="Download Trained Model (best.*)", interactive=False, visible=False)
|
| 746 |
|
| 747 |
with gr.TabItem("4. Upload Model"):
|
| 748 |
-
gr.Markdown("
|
| 749 |
with gr.Row():
|
| 750 |
with gr.Column():
|
| 751 |
gr.Markdown("#### Hugging Face")
|
|
@@ -760,7 +735,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 760 |
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
|
| 761 |
gh_status = gr.Textbox(label="GitHub Status", interactive=False)
|
| 762 |
|
| 763 |
-
# Wire UI handlers
|
| 764 |
load_btn.click(
|
| 765 |
fn=load_datasets_handler,
|
| 766 |
inputs=[rf_api_key, rf_url_file],
|
|
@@ -780,7 +754,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 780 |
fn=training_handler_rtdetrv2,
|
| 781 |
inputs=[
|
| 782 |
final_dataset_path_state, # dataset_path
|
| 783 |
-
repo_dir_tb, # repo_dir
|
| 784 |
model_choice_dd, # model_choice (label only)
|
| 785 |
run_name_tb,
|
| 786 |
epochs_sl,
|
|
@@ -799,5 +773,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 799 |
)
|
| 800 |
|
| 801 |
if __name__ == "__main__":
|
| 802 |
-
#
|
|
|
|
| 803 |
app.launch(debug=True)
|
|
|
|
| 12 |
import json
|
| 13 |
from PIL import Image
|
| 14 |
import pandas as pd
|
| 15 |
+
import matplotlib
|
| 16 |
+
matplotlib.use("Agg") # headless (HF Spaces)
|
| 17 |
import matplotlib.pyplot as plt
|
| 18 |
from threading import Thread
|
| 19 |
from queue import Queue
|
|
|
|
| 24 |
import time
|
| 25 |
import glob
|
| 26 |
|
| 27 |
+
# --- Logging ---
|
| 28 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 29 |
|
| 30 |
+
# --- RT-DETRv2 backend defaults (Supervisely ecosystem) ---
|
| 31 |
RTDETRV2_REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
|
| 32 |
DEFAULT_REPO_DIR = os.path.join("third_party", "rtdetrv2")
|
| 33 |
|
|
|
|
|
|
|
| 34 |
RTDETRV2_MODELS = [
|
| 35 |
+
"rtdetrv2-l-640", # labels only; match your config via the command template
|
| 36 |
"rtdetrv2-x-640"
|
| 37 |
]
|
| 38 |
DEFAULT_MODEL = RTDETRV2_MODELS[0]
|
| 39 |
|
| 40 |
+
# --- Utilities ---
|
|
|
|
|
|
|
|
|
|
| 41 |
def handle_remove_readonly(func, path, exc_info):
|
|
|
|
| 42 |
try:
|
| 43 |
os.chmod(path, stat.S_IWRITE)
|
| 44 |
except Exception:
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
def parse_roboflow_url(s: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
s = s.strip()
|
| 67 |
m = _ROBO_URL_RX.match(s)
|
| 68 |
if m:
|
|
|
|
| 98 |
return None, None, None
|
| 99 |
|
| 100 |
def get_latest_version(api_key, workspace, project):
|
|
|
|
| 101 |
try:
|
| 102 |
rf = Roboflow(api_key=api_key)
|
| 103 |
proj = rf.workspace(workspace).project(project)
|
|
|
|
| 107 |
logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
|
| 108 |
return None
|
| 109 |
|
|
|
|
| 110 |
def _extract_class_names(data_yaml):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
names = data_yaml.get('names', None)
|
|
|
|
| 112 |
if isinstance(names, dict):
|
| 113 |
def _k(x):
|
| 114 |
+
try: return int(x)
|
| 115 |
+
except Exception: return str(x)
|
|
|
|
|
|
|
| 116 |
ordered = sorted(names.keys(), key=_k)
|
| 117 |
names_list = [names[k] for k in ordered]
|
| 118 |
elif isinstance(names, list):
|
| 119 |
names_list = names
|
| 120 |
else:
|
| 121 |
nc = data_yaml.get('nc', 0)
|
| 122 |
+
try: nc = int(nc)
|
| 123 |
+
except Exception: nc = 0
|
|
|
|
|
|
|
| 124 |
names_list = [f"class_{i}" for i in range(nc)]
|
|
|
|
| 125 |
return [str(x) for x in names_list]
|
| 126 |
|
| 127 |
def download_dataset(api_key, workspace, project, version):
|
|
|
|
| 128 |
try:
|
| 129 |
rf = Roboflow(api_key=api_key)
|
| 130 |
proj = rf.workspace(workspace).project(project)
|
|
|
|
| 152 |
return None, [], [], None
|
| 153 |
|
| 154 |
def label_path_for(img_path: str) -> str:
|
|
|
|
| 155 |
split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
|
| 156 |
base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
|
| 157 |
return os.path.join(split_dir, 'labels', base)
|
| 158 |
|
| 159 |
def gather_class_counts(dataset_info, class_mapping):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
if not dataset_info:
|
| 161 |
return {}
|
|
|
|
| 162 |
final_names = set(v for v in class_mapping.values() if v is not None)
|
| 163 |
counts = {name: 0 for name in final_names}
|
|
|
|
| 164 |
for loc, names, splits, _ in dataset_info:
|
| 165 |
id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
|
|
|
|
| 166 |
for split in splits:
|
| 167 |
labels_dir = os.path.join(loc, split, 'labels')
|
| 168 |
if not os.path.exists(labels_dir):
|
|
|
|
| 185 |
continue
|
| 186 |
for m in found:
|
| 187 |
counts[m] += 1
|
|
|
|
| 188 |
return counts
|
| 189 |
|
| 190 |
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
|
|
|
|
| 191 |
merged_dir = 'rolo_merged_dataset'
|
| 192 |
if os.path.exists(merged_dir):
|
| 193 |
shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
|
|
|
|
| 200 |
active_classes = sorted(set([cls for cls, limit in class_limits.items() if limit > 0]))
|
| 201 |
final_class_map = {name: i for i, name in enumerate(active_classes)}
|
| 202 |
|
|
|
|
| 203 |
all_images = []
|
| 204 |
for loc, _, splits, _ in dataset_info:
|
| 205 |
for split in splits:
|
|
|
|
| 288 |
|
| 289 |
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
|
| 290 |
|
| 291 |
+
# --- Repo + deps helpers (auto-install for HF Spaces) ---
|
| 292 |
+
|
| 293 |
+
def run_pip_install(args, desc="pip install"):
|
| 294 |
+
logging.info(f"{desc}: {args}")
|
| 295 |
+
cmd = [sys.executable, "-m", "pip", "install"] + args
|
| 296 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
|
| 297 |
+
logging.info(proc.stdout)
|
| 298 |
+
if proc.returncode != 0:
|
| 299 |
+
raise RuntimeError(f"{desc} failed with code {proc.returncode}")
|
| 300 |
|
| 301 |
def ensure_repo(repo_dir: str, repo_url: str = RTDETRV2_REPO_URL):
|
|
|
|
| 302 |
if os.path.isdir(repo_dir) and os.path.isdir(os.path.join(repo_dir, ".git")):
|
| 303 |
return
|
| 304 |
os.makedirs(os.path.dirname(repo_dir), exist_ok=True)
|
| 305 |
logging.info(f"Cloning RT-DETRv2 repo into {repo_dir} ...")
|
| 306 |
+
subprocess.run(["git", "clone", "--depth", "1", repo_url, repo_dir], check=True)
|
| 307 |
+
|
| 308 |
+
def ensure_python_deps(repo_dir: str):
|
| 309 |
+
"""
|
| 310 |
+
Auto-install dependencies (idempotent).
|
| 311 |
+
- Tries to install pinned basics that are often needed.
|
| 312 |
+
- If repo has requirements*.txt, install them.
|
| 313 |
+
- Creates a .deps_installed marker to skip on next run.
|
| 314 |
+
"""
|
| 315 |
+
marker = os.path.join(repo_dir, ".deps_installed")
|
| 316 |
+
if os.path.exists(marker):
|
| 317 |
+
logging.info("Dependencies already installed; skipping.")
|
| 318 |
+
return
|
| 319 |
+
|
| 320 |
+
# 1) Common essentials for vision training environments on HF Spaces
|
| 321 |
+
basics = [
|
| 322 |
+
"numpy<2", # safer with many libs
|
| 323 |
+
"pillow",
|
| 324 |
+
"tqdm",
|
| 325 |
+
"pyyaml",
|
| 326 |
+
"matplotlib",
|
| 327 |
+
"pandas",
|
| 328 |
+
"scipy",
|
| 329 |
+
"opencv-python-headless",
|
| 330 |
+
"packaging",
|
| 331 |
+
"requests",
|
| 332 |
+
"pycocotools-windows; platform_system=='Windows'",
|
| 333 |
+
"pycocotools; platform_system!='Windows'",
|
| 334 |
+
]
|
| 335 |
+
try:
|
| 336 |
+
run_pip_install(basics, desc="Installing common basics")
|
| 337 |
+
except Exception as e:
|
| 338 |
+
logging.warning(f"Basic installs had issues: {e}")
|
| 339 |
+
|
| 340 |
+
# 2) Repo requirements
|
| 341 |
+
req_files = []
|
| 342 |
+
for name in ["requirements.txt", "requirements-dev.txt", "requirements.in"]:
|
| 343 |
+
p = os.path.join(repo_dir, name)
|
| 344 |
+
if os.path.isfile(p):
|
| 345 |
+
req_files.append(p)
|
| 346 |
+
|
| 347 |
+
for rf in req_files:
|
| 348 |
+
try:
|
| 349 |
+
run_pip_install(["-r", rf], desc=f"Installing repo requirements from {rf}")
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logging.warning(f"Installing {rf} failed: {e}")
|
| 352 |
+
|
| 353 |
+
# 3) Optional: torch if not present (CPU-only by default on Spaces)
|
| 354 |
+
try:
|
| 355 |
+
import torch # noqa: F401
|
| 356 |
+
except Exception:
|
| 357 |
+
# Try a CPU-friendly torch; change version/cuda wheels if needed
|
| 358 |
+
try:
|
| 359 |
+
run_pip_install(["torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cpu"], desc="Installing PyTorch (CPU)")
|
| 360 |
+
except Exception as e:
|
| 361 |
+
logging.warning(f"PyTorch installation failed/skipped: {e}")
|
| 362 |
+
|
| 363 |
+
# Mark done
|
| 364 |
+
with open(marker, "w") as f:
|
| 365 |
+
f.write("ok\n")
|
| 366 |
|
| 367 |
def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, imgsz: int,
|
| 368 |
lr: float, optimizer: str, run_name: str, output_dir: str) -> str:
|
|
|
|
| 378 |
)
|
| 379 |
|
| 380 |
_METRIC_PATTERNS = [
|
|
|
|
| 381 |
(re.compile(r"mAP@0\.5[:/]?0\.95[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50_95"),
|
| 382 |
(re.compile(r"mAP50[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50"),
|
| 383 |
(re.compile(r"\bval[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "val_loss"),
|
|
|
|
| 398 |
return result
|
| 399 |
|
| 400 |
def guess_final_weights(output_dir: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
patterns = [
|
| 402 |
os.path.join(output_dir, "**", "best.*"),
|
| 403 |
os.path.join(output_dir, "**", "best_model.*"),
|
|
|
|
| 409 |
return hits[0]
|
| 410 |
return None
|
| 411 |
|
| 412 |
+
# --- Gradio handlers ---
|
|
|
|
|
|
|
|
|
|
| 413 |
def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
|
|
|
|
| 414 |
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
|
| 415 |
if not api_key:
|
| 416 |
raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
|
|
|
|
| 422 |
|
| 423 |
dataset_info = []
|
| 424 |
failures = []
|
|
|
|
| 425 |
for i, raw in enumerate(urls):
|
| 426 |
progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}")
|
| 427 |
ws, proj, ver = parse_roboflow_url(raw)
|
| 428 |
if not (ws and proj):
|
| 429 |
failures.append((raw, "ParseError: could not resolve workspace/project"))
|
| 430 |
continue
|
|
|
|
| 431 |
if ver is None:
|
| 432 |
ver = get_latest_version(api_key, ws, proj)
|
| 433 |
if ver is None:
|
|
|
|
| 444 |
msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
|
| 445 |
raise gr.Error(msg)
|
| 446 |
|
|
|
|
| 447 |
all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
|
| 448 |
class_map = {name: name for name in all_names}
|
|
|
|
| 449 |
initial_counts = gather_class_counts(dataset_info, class_map)
|
| 450 |
df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
|
| 451 |
status_text = "Datasets loaded successfully."
|
| 452 |
if failures:
|
| 453 |
status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
|
| 454 |
|
|
|
|
| 455 |
return status_text, dataset_info, gr.update(
|
| 456 |
value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
|
| 457 |
)
|
| 458 |
|
| 459 |
def update_class_counts_handler(class_df, dataset_info):
|
|
|
|
| 460 |
if class_df is None or not dataset_info:
|
| 461 |
return None
|
|
|
|
| 462 |
class_df = pd.DataFrame(class_df)
|
| 463 |
mapping = {}
|
| 464 |
for _, row in class_df.iterrows():
|
|
|
|
| 470 |
|
| 471 |
final_names = sorted(set(v for v in mapping.values() if v))
|
| 472 |
counts = {k: 0 for k in final_names}
|
|
|
|
| 473 |
for loc, names, splits, _ in dataset_info:
|
| 474 |
id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
|
|
|
|
| 475 |
for split in splits:
|
| 476 |
labels_dir = os.path.join(loc, split, 'labels')
|
| 477 |
if not os.path.exists(labels_dir):
|
|
|
|
| 499 |
return summary_df
|
| 500 |
|
| 501 |
def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
|
|
|
|
| 502 |
if not dataset_info:
|
| 503 |
raise gr.Error("Load datasets first in Tab 1.")
|
| 504 |
if class_df is None:
|
|
|
|
| 520 |
|
| 521 |
def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, epochs, batch, imgsz, lr, opt,
|
| 522 |
cmd_template, progress=gr.Progress()):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
if not dataset_path:
|
| 524 |
raise gr.Error("Finalize a dataset in Tab 2 before training.")
|
| 525 |
|
| 526 |
+
# Clone + deps (idempotent)
|
| 527 |
try:
|
| 528 |
ensure_repo(repo_dir)
|
| 529 |
+
ensure_python_deps(repo_dir)
|
| 530 |
except subprocess.CalledProcessError as e:
|
| 531 |
+
raise gr.Error(f"Failed to clone repo: {e}")
|
| 532 |
+
except Exception as e:
|
| 533 |
+
raise gr.Error(f"Dependency setup failed: {e}")
|
| 534 |
|
| 535 |
+
# Output dir
|
| 536 |
output_dir = os.path.join("runs", "train", str(run_name))
|
| 537 |
os.makedirs(output_dir, exist_ok=True)
|
| 538 |
|
|
|
|
| 540 |
if not os.path.isfile(data_yaml):
|
| 541 |
raise gr.Error(f"'data.yaml' was not found in: {dataset_path}")
|
| 542 |
|
| 543 |
+
# Build command from template
|
| 544 |
cmd = make_train_command(
|
| 545 |
template=cmd_template,
|
| 546 |
data_yaml=data_yaml,
|
|
|
|
| 553 |
output_dir=output_dir
|
| 554 |
)
|
| 555 |
|
|
|
|
| 556 |
logging.info(f"Running training command in {repo_dir}: {cmd}")
|
| 557 |
proc = subprocess.Popen(
|
| 558 |
+
cmd, cwd=repo_dir, shell=True,
|
| 559 |
+
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
| 560 |
+
bufsize=1, universal_newlines=True, env={**os.environ}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
)
|
| 562 |
|
|
|
|
| 563 |
history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']}
|
|
|
|
|
|
|
|
|
|
| 564 |
for line in iter(proc.stdout.readline, ''):
|
| 565 |
line = line.rstrip()
|
| 566 |
+
progress(0.0, desc=line[-120:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
metrics = parse_metrics_from_line(line)
|
| 568 |
if metrics:
|
| 569 |
for k, v in metrics.items():
|
| 570 |
history[k].append(v)
|
| 571 |
+
|
| 572 |
+
# plot loss
|
| 573 |
fig_loss = plt.figure()
|
| 574 |
ax_loss = fig_loss.add_subplot(111)
|
| 575 |
ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
|
| 576 |
ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
|
| 577 |
+
ax_loss.legend(); ax_loss.set_title("Loss")
|
|
|
|
| 578 |
|
| 579 |
+
# plot mAP
|
| 580 |
fig_map = plt.figure()
|
| 581 |
ax_map = fig_map.add_subplot(111)
|
| 582 |
ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
|
| 583 |
ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
|
| 584 |
+
ax_map.legend(); ax_map.set_title("mAP")
|
|
|
|
| 585 |
|
|
|
|
| 586 |
yield line[-200:], fig_loss, fig_map, None
|
| 587 |
|
| 588 |
proc.stdout.close()
|
|
|
|
| 590 |
if ret != 0:
|
| 591 |
raise gr.Error(f"Training process exited with code {ret}. Check console/logs for details.")
|
| 592 |
|
|
|
|
| 593 |
final_ckpt = guess_final_weights(output_dir)
|
| 594 |
if final_ckpt and os.path.isfile(final_ckpt):
|
| 595 |
yield "Training complete!", None, None, gr.File.update(value=final_ckpt, visible=True)
|
| 596 |
else:
|
| 597 |
+
yield ("Training finished. Could not auto-detect a 'best' checkpoint; "
|
| 598 |
+
"please check the output directory."), None, None, gr.update(visible=False)
|
| 599 |
|
| 600 |
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
|
|
|
|
| 601 |
if not model_file:
|
| 602 |
raise gr.Error("No trained model file available to upload. Train a model first.")
|
| 603 |
|
|
|
|
| 624 |
try:
|
| 625 |
if '/' not in gh_repo:
|
| 626 |
raise ValueError("GitHub repo must be in the form 'username/repo'.")
|
|
|
|
| 627 |
username, repo_name = gh_repo.split('/')
|
| 628 |
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
|
| 629 |
headers = {"Authorization": f"token {gh_token}"}
|
|
|
|
| 635 |
sha = get_resp.json().get('sha') if get_resp.ok else None
|
| 636 |
|
| 637 |
data = {"message": "Upload trained model from Rolo app", "content": content}
|
| 638 |
+
if sha: data["sha"] = sha
|
|
|
|
| 639 |
|
| 640 |
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
|
|
|
|
| 641 |
if put_resp.ok:
|
| 642 |
gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
|
| 643 |
else:
|
|
|
|
| 649 |
progress(1)
|
| 650 |
return hf_status, gh_status
|
| 651 |
|
| 652 |
+
# --- Gradio UI ---
|
|
|
|
|
|
|
| 653 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
| 654 |
+
gr.Markdown("# Rolo: RT-DETRv2 Training Dashboard (Auto-setup for Hugging Face)")
|
| 655 |
|
|
|
|
| 656 |
dataset_info_state = gr.State([])
|
| 657 |
final_dataset_path_state = gr.State(None)
|
| 658 |
|
| 659 |
with gr.Tabs():
|
| 660 |
with gr.TabItem("1. Prepare Datasets"):
|
| 661 |
+
gr.Markdown("Upload a `.txt` with Roboflow URLs or `workspace/project[/vN]` lines.")
|
| 662 |
with gr.Row():
|
| 663 |
+
rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY)", type="password", scale=2)
|
| 664 |
rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
|
| 665 |
load_btn = gr.Button("Load Datasets", variant="primary")
|
| 666 |
dataset_status = gr.Textbox(label="Status", interactive=False)
|
| 667 |
|
| 668 |
with gr.TabItem("2. Manage & Merge"):
|
| 669 |
+
gr.Markdown("Rename classes, set image limits, or remove them. Preview, then finalize.")
|
| 670 |
with gr.Row():
|
| 671 |
class_df = gr.DataFrame(
|
| 672 |
headers=["Original Name", "Rename To", "Max Images", "Remove"],
|
|
|
|
| 680 |
interactive=False
|
| 681 |
)
|
| 682 |
update_counts_btn = gr.Button("Update Counts")
|
|
|
|
| 683 |
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
|
| 684 |
finalize_status = gr.Textbox(label="Status", interactive=False)
|
| 685 |
|
| 686 |
with gr.TabItem("3. Configure & Train"):
|
| 687 |
+
gr.Markdown("Set hyperparameters and the training command template.")
|
| 688 |
with gr.Row():
|
| 689 |
with gr.Column(scale=1):
|
| 690 |
model_choice_dd = gr.Dropdown(
|
| 691 |
+
label="Model Choice (label only; use your config in the template)",
|
| 692 |
+
choices=RTDETRV2_MODELS, value=DEFAULT_MODEL
|
|
|
|
| 693 |
)
|
| 694 |
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
|
| 695 |
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
|
|
|
|
| 697 |
imgsz_num = gr.Number(label="Image Size", value=640)
|
| 698 |
lr_num = gr.Number(label="Learning Rate", value=0.001)
|
| 699 |
opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="AdamW", label="Optimizer")
|
|
|
|
| 700 |
repo_dir_tb = gr.Textbox(label="RT-DETRv2 repo directory", value=DEFAULT_REPO_DIR)
|
| 701 |
cmd_template_tb = gr.Textbox(
|
| 702 |
label="Train command template",
|
|
|
|
| 720 |
final_model_file = gr.File(label="Download Trained Model (best.*)", interactive=False, visible=False)
|
| 721 |
|
| 722 |
with gr.TabItem("4. Upload Model"):
|
| 723 |
+
gr.Markdown("Upload your best checkpoint to Hugging Face or GitHub.")
|
| 724 |
with gr.Row():
|
| 725 |
with gr.Column():
|
| 726 |
gr.Markdown("#### Hugging Face")
|
|
|
|
| 735 |
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
|
| 736 |
gh_status = gr.Textbox(label="GitHub Status", interactive=False)
|
| 737 |
|
|
|
|
| 738 |
load_btn.click(
|
| 739 |
fn=load_datasets_handler,
|
| 740 |
inputs=[rf_api_key, rf_url_file],
|
|
|
|
| 754 |
fn=training_handler_rtdetrv2,
|
| 755 |
inputs=[
|
| 756 |
final_dataset_path_state, # dataset_path
|
| 757 |
+
repo_dir_tb, # repo_dir (auto clone + pip install)
|
| 758 |
model_choice_dd, # model_choice (label only)
|
| 759 |
run_name_tb,
|
| 760 |
epochs_sl,
|
|
|
|
| 773 |
)
|
| 774 |
|
| 775 |
if __name__ == "__main__":
|
| 776 |
+
# Hugging Face Spaces: set server name/port via env if needed.
|
| 777 |
+
# Example: app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), debug=True)
|
| 778 |
app.launch(debug=True)
|