wuhp commited on
Commit
2cc3e82
·
verified ·
1 Parent(s): 0d7a9e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -44
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # app.py — Rolo: RT-DETRv2-only (Supervisely) trainer with auto COCO conversion & safe config patching
2
- import os, sys, subprocess, shutil, stat, yaml, gradio as gr, re, random, logging, requests, json, base64, time
3
  from urllib.parse import urlparse
4
  from glob import glob
5
  from threading import Thread
@@ -74,9 +74,32 @@ except Exception:
74
  logging.exception("Bootstrap failed, UI will still load so you can see errors")
75
 
76
  # === model choices (restricted to Supervisely RT-DETRv2) ======================
77
- MODEL_CHOICES = [("rtdetrv2_s", "Small (default)"), ("rtdetrv2_l", "Large"), ("rtdetrv2_x", "X-Large")]
 
 
 
 
 
 
 
78
  DEFAULT_MODEL_KEY = "rtdetrv2_s"
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # === utilities ================================================================
81
  def handle_remove_readonly(func, path, exc_info):
82
  try:
@@ -372,15 +395,12 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
372
 
373
  # === entrypoint + config detection/generation =================================
374
  def find_training_script(repo_root):
375
- # Hard-prefer the canonical path widely used in the repo/issues
376
  canonical = os.path.join(repo_root, "rtdetrv2_pytorch", "tools", "train.py")
377
  if os.path.exists(canonical):
378
  return canonical
379
-
380
  candidates = []
381
  for pat in ["**/tools/train.py", "**/train.py", "**/tools/train_net.py"]:
382
  candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
383
- # Prefer anything inside rtdetrv2_pytorch, then shorter paths
384
  def _score(p):
385
  pl = p.replace("\\", "/").lower()
386
  return (0 if "rtdetrv2_pytorch" in pl else 1, len(p))
@@ -388,40 +408,13 @@ def find_training_script(repo_root):
388
  return candidates[0] if candidates else None
389
 
390
  def find_model_config_template(model_key):
391
- want_tokens = {
392
- "rtdetrv2_s": ["rtdetrv2", "r18", "coco"],
393
- "rtdetrv2_l": ["rtdetrv2", "r50", "coco"],
394
- "rtdetrv2_x": ["rtdetrv2", "r101", "coco"],
395
- }.get(model_key, ["rtdetrv2", "r18", "coco"])
396
-
397
- yamls = glob(os.path.join(REPO_DIR, "**", "*.yml"), recursive=True) + \
398
- glob(os.path.join(REPO_DIR, "**", "*.yaml"), recursive=True)
399
-
400
- def score(p):
401
- pl = p.lower()
402
- s = 0
403
- if "/rtdetrv2_pytorch/" in pl:
404
- s += 4
405
- if "/config" in pl:
406
- s += 3
407
- for token in want_tokens:
408
- if token in os.path.basename(pl):
409
- s += 3
410
- if token in pl:
411
- s += 2
412
- if "coco" in pl:
413
- s += 1
414
- return -s, len(p)
415
-
416
- yamls.sort(key=score)
417
- return yamls[0] if yamls else None
418
 
419
  def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None = None):
420
- """
421
- If any key from `keys` exists in dict `d`, set the first one found to `value`.
422
- Otherwise, if `fallback_key` is given, create it with `value`.
423
- Returns the key that was set, or None.
424
- """
425
  for k in keys:
426
  if k in d:
427
  d[k] = value
@@ -431,8 +424,74 @@ def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None
431
  return fallback_key
432
  return None
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
435
- epochs, batch, imgsz, lr, optimizer):
436
  if not base_cfg_path or not os.path.exists(base_cfg_path):
437
  raise gr.Error("Could not locate a model config inside the RT-DETRv2 repo.")
438
 
@@ -450,7 +509,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
450
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
451
  }
452
 
453
- # dataset block: set an existing alias if present, otherwise add a common key
454
  for root_key in ["dataset", "data"]:
455
  if root_key in cfg and isinstance(cfg[root_key], dict):
456
  ds = cfg[root_key]
@@ -525,7 +584,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
525
  break
526
  else:
527
  sol["base_lr"] = float(lr)
528
-
529
  sol["optimizer"] = str(optimizer).lower()
530
  if "train_dataloader" in cfg and isinstance(cfg["train_dataloader"], dict):
531
  cfg["train_dataloader"]["batch_size"] = int(batch)
@@ -540,6 +598,14 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
540
  else:
541
  cfg["output_dir"] = paths["out_dir"]
542
 
 
 
 
 
 
 
 
 
543
  cfg_out_dir = os.path.join("generated_configs")
544
  os.makedirs(cfg_out_dir, exist_ok=True)
545
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
@@ -652,7 +718,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
652
 
653
  base_cfg = find_model_config_template(model_key)
654
  if not base_cfg:
655
- raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/L/X).")
656
 
657
  data_yaml = os.path.join(dataset_path, "data.yaml")
658
  with open(data_yaml, "r") as f:
@@ -660,6 +726,12 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
660
  class_names = [str(x) for x in dy.get("names", [])]
661
  make_coco_annotations(dataset_path, class_names)
662
 
 
 
 
 
 
 
663
  cfg_path = patch_base_config(
664
  base_cfg_path=base_cfg,
665
  merged_dir=dataset_path,
@@ -670,11 +742,9 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
670
  imgsz=imgsz,
671
  lr=lr,
672
  optimizer=opt,
 
673
  )
674
 
675
- out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
676
- os.makedirs(out_dir, exist_ok=True)
677
-
678
  cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)]
679
  logging.info(f"Training command: {' '.join(cmd)}")
680
 
@@ -685,6 +755,9 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
685
  env["PYTHONPATH"] = os.pathsep.join(filter(None, [
686
  PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
687
  ]))
 
 
 
688
  env.setdefault("WANDB_DISABLED", "true")
689
  proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
690
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
 
1
  # app.py — Rolo: RT-DETRv2-only (Supervisely) trainer with auto COCO conversion & safe config patching
2
+ import os, sys, subprocess, shutil, stat, yaml, gradio as gr, re, random, logging, requests, json, base64, time, pathlib, tempfile, textwrap
3
  from urllib.parse import urlparse
4
  from glob import glob
5
  from threading import Thread
 
74
  logging.exception("Bootstrap failed, UI will still load so you can see errors")
75
 
76
  # === model choices (restricted to Supervisely RT-DETRv2) ======================
77
+ # Exact mapping to configs and reference COCO checkpoints you provided
78
+ MODEL_CHOICES = [
79
+ ("rtdetrv2_s", "S (r18vd, 120e) — default"),
80
+ ("rtdetrv2_m", "M (r34vd, 120e)"),
81
+ ("rtdetrv2_msp", "M* (r50vd_m, 7x)"),
82
+ ("rtdetrv2_l", "L (r50vd, 6x)"),
83
+ ("rtdetrv2_x", "X (r101vd, 6x)"),
84
+ ]
85
  DEFAULT_MODEL_KEY = "rtdetrv2_s"
86
 
87
+ CONFIG_PATHS = {
88
+ "rtdetrv2_s": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_coco.yml",
89
+ "rtdetrv2_m": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r34vd_120e_coco.yml",
90
+ "rtdetrv2_msp": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r50vd_m_7x_coco.yml",
91
+ "rtdetrv2_l": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r50vd_6x_coco.yml",
92
+ "rtdetrv2_x": "rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r101vd_6x_coco.yml",
93
+ }
94
+
95
+ CKPT_URLS = {
96
+ "rtdetrv2_s": "https://github.com/lyuwenyu/storage/releases/download/v0.2/rtdetrv2_r18vd_120e_coco_rerun_48.1.pth",
97
+ "rtdetrv2_m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r34vd_120e_coco_ema.pth",
98
+ "rtdetrv2_msp": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_m_7x_coco_ema.pth",
99
+ "rtdetrv2_l": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_6x_coco_ema.pth",
100
+ "rtdetrv2_x": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r101vd_6x_coco_from_paddle.pth",
101
+ }
102
+
103
  # === utilities ================================================================
104
  def handle_remove_readonly(func, path, exc_info):
105
  try:
 
395
 
396
  # === entrypoint + config detection/generation =================================
397
  def find_training_script(repo_root):
 
398
  canonical = os.path.join(repo_root, "rtdetrv2_pytorch", "tools", "train.py")
399
  if os.path.exists(canonical):
400
  return canonical
 
401
  candidates = []
402
  for pat in ["**/tools/train.py", "**/train.py", "**/tools/train_net.py"]:
403
  candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
 
404
  def _score(p):
405
  pl = p.replace("\\", "/").lower()
406
  return (0 if "rtdetrv2_pytorch" in pl else 1, len(p))
 
408
  return candidates[0] if candidates else None
409
 
410
  def find_model_config_template(model_key):
411
+ rel = CONFIG_PATHS.get(model_key)
412
+ if not rel:
413
+ return None
414
+ path = os.path.join(REPO_DIR, rel)
415
+ return path if os.path.exists(path) else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
  def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None = None):
 
 
 
 
 
418
  for k in keys:
419
  if k in d:
420
  d[k] = value
 
424
  return fallback_key
425
  return None
426
 
427
+ def _set_first_existing_key_deep(cfg: dict, keys: list, value):
428
+ """
429
+ Try to set one of `keys` at top-level, under 'model', or under 'solver'.
430
+ """
431
+ for scope in [cfg, cfg.get("model", {}), cfg.get("solver", {})]:
432
+ if isinstance(scope, dict):
433
+ for k in keys:
434
+ if k in scope:
435
+ scope[k] = value
436
+ return True
437
+ # If nowhere found, set on model
438
+ if "model" not in cfg or not isinstance(cfg["model"], dict):
439
+ cfg["model"] = {}
440
+ cfg["model"][keys[0]] = value
441
+ return True
442
+
443
+ def _install_supervisely_logger_shim():
444
+ """
445
+ Creates a minimal shim so `from supervisely.nn.training import train_logger` works.
446
+ """
447
+ base = pathlib.Path(tempfile.gettempdir()) / "sly_shim" / "supervisely" / "nn"
448
+ base.mkdir(parents=True, exist_ok=True)
449
+ for p in [base.parent.parent, base.parent, base]:
450
+ (p / "__init__.py").write_text("")
451
+ (base / "training.py").write_text(textwrap.dedent("""
452
+ # minimal shim for backward-compat with older Supervisely examples
453
+ class _TrainLogger:
454
+ def __init__(self): pass
455
+ def reset(self): pass
456
+ def log_metrics(self, metrics: dict, step: int | None = None): pass
457
+ def log_artifacts(self, *a, **k): pass
458
+ def log_image(self, *a, **k): pass
459
+ train_logger = _TrainLogger()
460
+ """))
461
+ return str(base.parent.parent.parent) # .../sly_shim
462
+
463
+ def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
464
+ """
465
+ Download the reference COCO checkpoint for the selected model if not present.
466
+ Returns local path (or None if not available).
467
+ """
468
+ url = CKPT_URLS.get(model_key)
469
+ if not url:
470
+ return None
471
+ os.makedirs(out_dir, exist_ok=True)
472
+ fname = os.path.join(out_dir, os.path.basename(url))
473
+ if os.path.exists(fname) and os.path.getsize(fname) > 0:
474
+ return fname
475
+ logging.info(f"Downloading pretrained checkpoint for {model_key} from {url}")
476
+ try:
477
+ with requests.get(url, stream=True, timeout=60) as r:
478
+ r.raise_for_status()
479
+ with open(fname, "wb") as f:
480
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
481
+ if chunk:
482
+ f.write(chunk)
483
+ return fname
484
+ except Exception as e:
485
+ logging.warning(f"Could not fetch checkpoint: {e}")
486
+ try:
487
+ if os.path.exists(fname):
488
+ os.remove(fname)
489
+ except Exception:
490
+ pass
491
+ return None
492
+
493
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
494
+ epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
495
  if not base_cfg_path or not os.path.exists(base_cfg_path):
496
  raise gr.Error("Could not locate a model config inside the RT-DETRv2 repo.")
497
 
 
509
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
510
  }
511
 
512
+ # dataset block
513
  for root_key in ["dataset", "data"]:
514
  if root_key in cfg and isinstance(cfg[root_key], dict):
515
  ds = cfg[root_key]
 
584
  break
585
  else:
586
  sol["base_lr"] = float(lr)
 
587
  sol["optimizer"] = str(optimizer).lower()
588
  if "train_dataloader" in cfg and isinstance(cfg["train_dataloader"], dict):
589
  cfg["train_dataloader"]["batch_size"] = int(batch)
 
598
  else:
599
  cfg["output_dir"] = paths["out_dir"]
600
 
601
+ # Set pretrained weights if available; try common keys at top/model/solver
602
+ if pretrained_path:
603
+ _set_first_existing_key_deep(
604
+ cfg,
605
+ keys=["pretrain", "pretrained", "weight", "weights", "pretrained_path"],
606
+ value=os.path.abspath(pretrained_path),
607
+ )
608
+
609
  cfg_out_dir = os.path.join("generated_configs")
610
  os.makedirs(cfg_out_dir, exist_ok=True)
611
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
 
718
 
719
  base_cfg = find_model_config_template(model_key)
720
  if not base_cfg:
721
+ raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
722
 
723
  data_yaml = os.path.join(dataset_path, "data.yaml")
724
  with open(data_yaml, "r") as f:
 
726
  class_names = [str(x) for x in dy.get("names", [])]
727
  make_coco_annotations(dataset_path, class_names)
728
 
729
+ out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
730
+ os.makedirs(out_dir, exist_ok=True)
731
+
732
+ # Download matching COCO checkpoint for warm-start
733
+ pretrained_path = _ensure_checkpoint(model_key, out_dir)
734
+
735
  cfg_path = patch_base_config(
736
  base_cfg_path=base_cfg,
737
  merged_dir=dataset_path,
 
742
  imgsz=imgsz,
743
  lr=lr,
744
  optimizer=opt,
745
+ pretrained_path=pretrained_path,
746
  )
747
 
 
 
 
748
  cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)]
749
  logging.info(f"Training command: {' '.join(cmd)}")
750
 
 
755
  env["PYTHONPATH"] = os.pathsep.join(filter(None, [
756
  PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
757
  ]))
758
+ # put our shim at the very front so the import always resolves
759
+ shim_root = _install_supervisely_logger_shim()
760
+ env["PYTHONPATH"] = os.pathsep.join([shim_root, env["PYTHONPATH"]])
761
  env.setdefault("WANDB_DISABLED", "true")
762
  proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
763
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,