Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# app.py — Rolo: RT-DETRv2-only (Supervisely) trainer with auto COCO conversion & safe config patching
|
| 2 |
-
import os, sys, subprocess, shutil, stat, yaml, gradio as gr, re, random, logging, requests, json, base64, time
|
| 3 |
from urllib.parse import urlparse
|
| 4 |
from glob import glob
|
| 5 |
from threading import Thread
|
|
@@ -74,9 +74,32 @@ except Exception:
|
|
| 74 |
logging.exception("Bootstrap failed, UI will still load so you can see errors")
|
| 75 |
|
| 76 |
# === model choices (restricted to Supervisely RT-DETRv2) ======================
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
DEFAULT_MODEL_KEY = "rtdetrv2_s"
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# === utilities ================================================================
|
| 81 |
def handle_remove_readonly(func, path, exc_info):
|
| 82 |
try:
|
|
@@ -372,15 +395,12 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
|
|
| 372 |
|
| 373 |
# === entrypoint + config detection/generation =================================
|
| 374 |
def find_training_script(repo_root):
|
| 375 |
-
# Hard-prefer the canonical path widely used in the repo/issues
|
| 376 |
canonical = os.path.join(repo_root, "rtdetrv2_pytorch", "tools", "train.py")
|
| 377 |
if os.path.exists(canonical):
|
| 378 |
return canonical
|
| 379 |
-
|
| 380 |
candidates = []
|
| 381 |
for pat in ["**/tools/train.py", "**/train.py", "**/tools/train_net.py"]:
|
| 382 |
candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
|
| 383 |
-
# Prefer anything inside rtdetrv2_pytorch, then shorter paths
|
| 384 |
def _score(p):
|
| 385 |
pl = p.replace("\\", "/").lower()
|
| 386 |
return (0 if "rtdetrv2_pytorch" in pl else 1, len(p))
|
|
@@ -388,40 +408,13 @@ def find_training_script(repo_root):
|
|
| 388 |
return candidates[0] if candidates else None
|
| 389 |
|
| 390 |
def find_model_config_template(model_key):
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
yamls = glob(os.path.join(REPO_DIR, "**", "*.yml"), recursive=True) + \
|
| 398 |
-
glob(os.path.join(REPO_DIR, "**", "*.yaml"), recursive=True)
|
| 399 |
-
|
| 400 |
-
def score(p):
|
| 401 |
-
pl = p.lower()
|
| 402 |
-
s = 0
|
| 403 |
-
if "/rtdetrv2_pytorch/" in pl:
|
| 404 |
-
s += 4
|
| 405 |
-
if "/config" in pl:
|
| 406 |
-
s += 3
|
| 407 |
-
for token in want_tokens:
|
| 408 |
-
if token in os.path.basename(pl):
|
| 409 |
-
s += 3
|
| 410 |
-
if token in pl:
|
| 411 |
-
s += 2
|
| 412 |
-
if "coco" in pl:
|
| 413 |
-
s += 1
|
| 414 |
-
return -s, len(p)
|
| 415 |
-
|
| 416 |
-
yamls.sort(key=score)
|
| 417 |
-
return yamls[0] if yamls else None
|
| 418 |
|
| 419 |
def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None = None):
|
| 420 |
-
"""
|
| 421 |
-
If any key from `keys` exists in dict `d`, set the first one found to `value`.
|
| 422 |
-
Otherwise, if `fallback_key` is given, create it with `value`.
|
| 423 |
-
Returns the key that was set, or None.
|
| 424 |
-
"""
|
| 425 |
for k in keys:
|
| 426 |
if k in d:
|
| 427 |
d[k] = value
|
|
@@ -431,8 +424,74 @@ def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None
|
|
| 431 |
return fallback_key
|
| 432 |
return None
|
| 433 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 435 |
-
epochs, batch, imgsz, lr, optimizer):
|
| 436 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
| 437 |
raise gr.Error("Could not locate a model config inside the RT-DETRv2 repo.")
|
| 438 |
|
|
@@ -450,7 +509,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 450 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 451 |
}
|
| 452 |
|
| 453 |
-
# dataset block
|
| 454 |
for root_key in ["dataset", "data"]:
|
| 455 |
if root_key in cfg and isinstance(cfg[root_key], dict):
|
| 456 |
ds = cfg[root_key]
|
|
@@ -525,7 +584,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 525 |
break
|
| 526 |
else:
|
| 527 |
sol["base_lr"] = float(lr)
|
| 528 |
-
|
| 529 |
sol["optimizer"] = str(optimizer).lower()
|
| 530 |
if "train_dataloader" in cfg and isinstance(cfg["train_dataloader"], dict):
|
| 531 |
cfg["train_dataloader"]["batch_size"] = int(batch)
|
|
@@ -540,6 +598,14 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 540 |
else:
|
| 541 |
cfg["output_dir"] = paths["out_dir"]
|
| 542 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
cfg_out_dir = os.path.join("generated_configs")
|
| 544 |
os.makedirs(cfg_out_dir, exist_ok=True)
|
| 545 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
|
@@ -652,7 +718,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 652 |
|
| 653 |
base_cfg = find_model_config_template(model_key)
|
| 654 |
if not base_cfg:
|
| 655 |
-
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/L/X).")
|
| 656 |
|
| 657 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 658 |
with open(data_yaml, "r") as f:
|
|
@@ -660,6 +726,12 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 660 |
class_names = [str(x) for x in dy.get("names", [])]
|
| 661 |
make_coco_annotations(dataset_path, class_names)
|
| 662 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
cfg_path = patch_base_config(
|
| 664 |
base_cfg_path=base_cfg,
|
| 665 |
merged_dir=dataset_path,
|
|
@@ -670,11 +742,9 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 670 |
imgsz=imgsz,
|
| 671 |
lr=lr,
|
| 672 |
optimizer=opt,
|
|
|
|
| 673 |
)
|
| 674 |
|
| 675 |
-
out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
|
| 676 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 677 |
-
|
| 678 |
cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)]
|
| 679 |
logging.info(f"Training command: {' '.join(cmd)}")
|
| 680 |
|
|
@@ -685,6 +755,9 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 685 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
| 686 |
PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
|
| 687 |
]))
|
|
|
|
|
|
|
|
|
|
| 688 |
env.setdefault("WANDB_DISABLED", "true")
|
| 689 |
proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
|
| 690 |
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
|
|
|
| 1 |
# app.py — Rolo: RT-DETRv2-only (Supervisely) trainer with auto COCO conversion & safe config patching
|
| 2 |
+
import os, sys, subprocess, shutil, stat, yaml, gradio as gr, re, random, logging, requests, json, base64, time, pathlib, tempfile, textwrap
|
| 3 |
from urllib.parse import urlparse
|
| 4 |
from glob import glob
|
| 5 |
from threading import Thread
|
|
|
|
| 74 |
logging.exception("Bootstrap failed, UI will still load so you can see errors")
|
| 75 |
|
| 76 |
# === model choices (restricted to Supervisely RT-DETRv2) ======================
|
| 77 |
+
# Exact mapping to configs and reference COCO checkpoints you provided
|
| 78 |
+
MODEL_CHOICES = [
|
| 79 |
+
("rtdetrv2_s", "S (r18vd, 120e) — default"),
|
| 80 |
+
("rtdetrv2_m", "M (r34vd, 120e)"),
|
| 81 |
+
("rtdetrv2_msp", "M* (r50vd_m, 7x)"),
|
| 82 |
+
("rtdetrv2_l", "L (r50vd, 6x)"),
|
| 83 |
+
("rtdetrv2_x", "X (r101vd, 6x)"),
|
| 84 |
+
]
|
| 85 |
DEFAULT_MODEL_KEY = "rtdetrv2_s"
|
| 86 |
|
| 87 |
+
CONFIG_PATHS = {
|
| 88 |
+
"rtdetrv2_s": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_coco.yml",
|
| 89 |
+
"rtdetrv2_m": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r34vd_120e_coco.yml",
|
| 90 |
+
"rtdetrv2_msp": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r50vd_m_7x_coco.yml",
|
| 91 |
+
"rtdetrv2_l": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r50vd_6x_coco.yml",
|
| 92 |
+
"rtdetrv2_x": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r101vd_6x_coco.yml",
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
CKPT_URLS = {
|
| 96 |
+
"rtdetrv2_s": "https://github.com/lyuwenyu/storage/releases/download/v0.2/rtdetrv2_r18vd_120e_coco_rerun_48.1.pth",
|
| 97 |
+
"rtdetrv2_m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r34vd_120e_coco_ema.pth",
|
| 98 |
+
"rtdetrv2_msp": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_m_7x_coco_ema.pth",
|
| 99 |
+
"rtdetrv2_l": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_6x_coco_ema.pth",
|
| 100 |
+
"rtdetrv2_x": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r101vd_6x_coco_from_paddle.pth",
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
# === utilities ================================================================
|
| 104 |
def handle_remove_readonly(func, path, exc_info):
|
| 105 |
try:
|
|
|
|
| 395 |
|
| 396 |
# === entrypoint + config detection/generation =================================
|
| 397 |
def find_training_script(repo_root):
|
|
|
|
| 398 |
canonical = os.path.join(repo_root, "rtdetrv2_pytorch", "tools", "train.py")
|
| 399 |
if os.path.exists(canonical):
|
| 400 |
return canonical
|
|
|
|
| 401 |
candidates = []
|
| 402 |
for pat in ["**/tools/train.py", "**/train.py", "**/tools/train_net.py"]:
|
| 403 |
candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
|
|
|
|
| 404 |
def _score(p):
|
| 405 |
pl = p.replace("\\", "/").lower()
|
| 406 |
return (0 if "rtdetrv2_pytorch" in pl else 1, len(p))
|
|
|
|
| 408 |
return candidates[0] if candidates else None
|
| 409 |
|
| 410 |
def find_model_config_template(model_key):
|
| 411 |
+
rel = CONFIG_PATHS.get(model_key)
|
| 412 |
+
if not rel:
|
| 413 |
+
return None
|
| 414 |
+
path = os.path.join(REPO_DIR, rel)
|
| 415 |
+
return path if os.path.exists(path) else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
for k in keys:
|
| 419 |
if k in d:
|
| 420 |
d[k] = value
|
|
|
|
| 424 |
return fallback_key
|
| 425 |
return None
|
| 426 |
|
| 427 |
+
def _set_first_existing_key_deep(cfg: dict, keys: list, value):
|
| 428 |
+
"""
|
| 429 |
+
Try to set one of `keys` at top-level, under 'model', or under 'solver'.
|
| 430 |
+
"""
|
| 431 |
+
for scope in [cfg, cfg.get("model", {}), cfg.get("solver", {})]:
|
| 432 |
+
if isinstance(scope, dict):
|
| 433 |
+
for k in keys:
|
| 434 |
+
if k in scope:
|
| 435 |
+
scope[k] = value
|
| 436 |
+
return True
|
| 437 |
+
# If nowhere found, set on model
|
| 438 |
+
if "model" not in cfg or not isinstance(cfg["model"], dict):
|
| 439 |
+
cfg["model"] = {}
|
| 440 |
+
cfg["model"][keys[0]] = value
|
| 441 |
+
return True
|
| 442 |
+
|
| 443 |
+
def _install_supervisely_logger_shim():
|
| 444 |
+
"""
|
| 445 |
+
Creates a minimal shim so `from supervisely.nn.training import train_logger` works.
|
| 446 |
+
"""
|
| 447 |
+
base = pathlib.Path(tempfile.gettempdir()) / "sly_shim" / "supervisely" / "nn"
|
| 448 |
+
base.mkdir(parents=True, exist_ok=True)
|
| 449 |
+
for p in [base.parent.parent, base.parent, base]:
|
| 450 |
+
(p / "__init__.py").write_text("")
|
| 451 |
+
(base / "training.py").write_text(textwrap.dedent("""
|
| 452 |
+
# minimal shim for backward-compat with older Supervisely examples
|
| 453 |
+
class _TrainLogger:
|
| 454 |
+
def __init__(self): pass
|
| 455 |
+
def reset(self): pass
|
| 456 |
+
def log_metrics(self, metrics: dict, step: int | None = None): pass
|
| 457 |
+
def log_artifacts(self, *a, **k): pass
|
| 458 |
+
def log_image(self, *a, **k): pass
|
| 459 |
+
train_logger = _TrainLogger()
|
| 460 |
+
"""))
|
| 461 |
+
return str(base.parent.parent.parent) # .../sly_shim
|
| 462 |
+
|
| 463 |
+
def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
|
| 464 |
+
"""
|
| 465 |
+
Download the reference COCO checkpoint for the selected model if not present.
|
| 466 |
+
Returns local path (or None if not available).
|
| 467 |
+
"""
|
| 468 |
+
url = CKPT_URLS.get(model_key)
|
| 469 |
+
if not url:
|
| 470 |
+
return None
|
| 471 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 472 |
+
fname = os.path.join(out_dir, os.path.basename(url))
|
| 473 |
+
if os.path.exists(fname) and os.path.getsize(fname) > 0:
|
| 474 |
+
return fname
|
| 475 |
+
logging.info(f"Downloading pretrained checkpoint for {model_key} from {url}")
|
| 476 |
+
try:
|
| 477 |
+
with requests.get(url, stream=True, timeout=60) as r:
|
| 478 |
+
r.raise_for_status()
|
| 479 |
+
with open(fname, "wb") as f:
|
| 480 |
+
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
| 481 |
+
if chunk:
|
| 482 |
+
f.write(chunk)
|
| 483 |
+
return fname
|
| 484 |
+
except Exception as e:
|
| 485 |
+
logging.warning(f"Could not fetch checkpoint: {e}")
|
| 486 |
+
try:
|
| 487 |
+
if os.path.exists(fname):
|
| 488 |
+
os.remove(fname)
|
| 489 |
+
except Exception:
|
| 490 |
+
pass
|
| 491 |
+
return None
|
| 492 |
+
|
| 493 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 494 |
+
epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
|
| 495 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
| 496 |
raise gr.Error("Could not locate a model config inside the RT-DETRv2 repo.")
|
| 497 |
|
|
|
|
| 509 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 510 |
}
|
| 511 |
|
| 512 |
+
# dataset block
|
| 513 |
for root_key in ["dataset", "data"]:
|
| 514 |
if root_key in cfg and isinstance(cfg[root_key], dict):
|
| 515 |
ds = cfg[root_key]
|
|
|
|
| 584 |
break
|
| 585 |
else:
|
| 586 |
sol["base_lr"] = float(lr)
|
|
|
|
| 587 |
sol["optimizer"] = str(optimizer).lower()
|
| 588 |
if "train_dataloader" in cfg and isinstance(cfg["train_dataloader"], dict):
|
| 589 |
cfg["train_dataloader"]["batch_size"] = int(batch)
|
|
|
|
| 598 |
else:
|
| 599 |
cfg["output_dir"] = paths["out_dir"]
|
| 600 |
|
| 601 |
+
# Set pretrained weights if available; try common keys at top/model/solver
|
| 602 |
+
if pretrained_path:
|
| 603 |
+
_set_first_existing_key_deep(
|
| 604 |
+
cfg,
|
| 605 |
+
keys=["pretrain", "pretrained", "weight", "weights", "pretrained_path"],
|
| 606 |
+
value=os.path.abspath(pretrained_path),
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
cfg_out_dir = os.path.join("generated_configs")
|
| 610 |
os.makedirs(cfg_out_dir, exist_ok=True)
|
| 611 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
|
|
|
| 718 |
|
| 719 |
base_cfg = find_model_config_template(model_key)
|
| 720 |
if not base_cfg:
|
| 721 |
+
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
|
| 722 |
|
| 723 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 724 |
with open(data_yaml, "r") as f:
|
|
|
|
| 726 |
class_names = [str(x) for x in dy.get("names", [])]
|
| 727 |
make_coco_annotations(dataset_path, class_names)
|
| 728 |
|
| 729 |
+
out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
|
| 730 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 731 |
+
|
| 732 |
+
# Download matching COCO checkpoint for warm-start
|
| 733 |
+
pretrained_path = _ensure_checkpoint(model_key, out_dir)
|
| 734 |
+
|
| 735 |
cfg_path = patch_base_config(
|
| 736 |
base_cfg_path=base_cfg,
|
| 737 |
merged_dir=dataset_path,
|
|
|
|
| 742 |
imgsz=imgsz,
|
| 743 |
lr=lr,
|
| 744 |
optimizer=opt,
|
| 745 |
+
pretrained_path=pretrained_path,
|
| 746 |
)
|
| 747 |
|
|
|
|
|
|
|
|
|
|
| 748 |
cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)]
|
| 749 |
logging.info(f"Training command: {' '.join(cmd)}")
|
| 750 |
|
|
|
|
| 755 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
| 756 |
PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
|
| 757 |
]))
|
| 758 |
+
# put our shim at the very front so the import always resolves
|
| 759 |
+
shim_root = _install_supervisely_logger_shim()
|
| 760 |
+
env["PYTHONPATH"] = os.pathsep.join([shim_root, env["PYTHONPATH"]])
|
| 761 |
env.setdefault("WANDB_DISABLED", "true")
|
| 762 |
proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
|
| 763 |
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|