wuhp commited on
Commit
ddcd05d
·
verified ·
1 Parent(s): 7411ec8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -4
app.py CHANGED
@@ -458,7 +458,7 @@ def _install_supervisely_logger_shim():
458
  """))
459
  return str(root)
460
 
461
- # ---- NEW: robust fallback for cfg['_pymodule'] via sitecustomize -------------
462
  def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
463
  """
464
  Creates a sitecustomize.py in the training cwd that monkeypatches
@@ -497,6 +497,78 @@ def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "r
497
  f.write(code)
498
  return sc_path
499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
501
  url = CKPT_URLS.get(model_key)
502
  if not url:
@@ -608,9 +680,8 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
608
  cfg["task"] = cfg.get("task", "detection")
609
  cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <= HINT for loader
610
 
611
- # Disable SyncBN for single GPU/CPU runs
612
  cfg["sync_bn"] = False
613
- # Guardrails for single-process runs
614
  cfg.setdefault("device", "")
615
  cfg["find_unused_parameters"] = False
616
 
@@ -837,6 +908,13 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
837
  if not base_cfg:
838
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
839
 
 
 
 
 
 
 
 
840
  data_yaml = os.path.join(dataset_path, "data.yaml")
841
  with open(data_yaml, "r", encoding="utf-8") as f:
842
  dy = yaml.safe_load(f)
@@ -867,7 +945,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
867
  q = Queue()
868
  def run_train():
869
  try:
870
- # Ensure our fallback hook is available in the train process (CWD on sys.path)
871
  train_cwd = os.path.dirname(train_script)
872
  _install_workspace_env_fallback(train_cwd)
873
 
 
458
  """))
459
  return str(root)
460
 
461
+ # ---- OPTIONAL: legacy sitecustomize fallback (kept for belt & suspenders) ----
462
  def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
463
  """
464
  Creates a sitecustomize.py in the training cwd that monkeypatches
 
497
  f.write(code)
498
  return sc_path
499
 
500
+ # ---- NEW: direct, deterministic patch of workspace.create --------------------
501
+ def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
502
+ """
503
+ Idempotently patches third_party/RT-DETRv2/rtdetrv2_pytorch/src/core/workspace.py
504
+ so that workspace.create() tolerates missing cfg['_pymodule'] by falling back
505
+ to $RTDETR_PYMODULE or module_default. Returns the path of the file patched (or None).
506
+ """
507
+ ws_path = os.path.join(
508
+ repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py"
509
+ )
510
+ if not os.path.isfile(ws_path):
511
+ return None
512
+
513
+ with open(ws_path, "r", encoding="utf-8") as f:
514
+ src = f.read()
515
+
516
+ if "RTDETR_PYMODULE" in src and "_SAFE_CREATE_PATCH" in src:
517
+ return ws_path # already patched
518
+
519
+ # Backup once
520
+ bak_path = ws_path + ".bak"
521
+ if not os.path.exists(bak_path):
522
+ try:
523
+ shutil.copy2(ws_path, bak_path)
524
+ except Exception:
525
+ pass
526
+
527
+ marker = "_SAFE_CREATE_PATCH"
528
+ inject_code = f"""
529
+ # --- {marker}: begin ---
530
+ import os as _os
531
+ import importlib as _importlib
532
+
533
+ def _resolve_pymodule_from_cfg_or_env(_cfg):
534
+ try:
535
+ pm = _cfg.get("_pymodule", None)
536
+ except Exception:
537
+ pm = None
538
+ if not pm:
539
+ pm = _os.environ.get("RTDETR_PYMODULE", "{module_default}")
540
+ try:
541
+ _importlib.import_module(pm)
542
+ except Exception:
543
+ pass
544
+ try:
545
+ _cfg["_pymodule"] = pm
546
+ except Exception:
547
+ pass
548
+ return pm
549
+ # --- {marker}: end ---
550
+ """
551
+
552
+ if "def create(" in src:
553
+ if "_resolve_pymodule_from_cfg_or_env" not in src:
554
+ src = inject_code + src
555
+ src = src.replace("cfg['_pymodule']", "_resolve_pymodule_from_cfg_or_env(cfg)")
556
+ else:
557
+ return None
558
+
559
+ with open(ws_path, "w", encoding="utf-8") as f:
560
+ f.write(src)
561
+ return ws_path
562
+
563
+ def _unpatch_workspace_create(repo_root: str):
564
+ ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
565
+ bak_path = ws_path + ".bak"
566
+ if os.path.exists(bak_path):
567
+ try:
568
+ shutil.copy2(bak_path, ws_path)
569
+ except Exception:
570
+ pass
571
+
572
  def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
573
  url = CKPT_URLS.get(model_key)
574
  if not url:
 
680
  cfg["task"] = cfg.get("task", "detection")
681
  cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <= HINT for loader
682
 
683
+ # Disable SyncBN for single GPU/CPU runs; guard DDP flags
684
  cfg["sync_bn"] = False
 
685
  cfg.setdefault("device", "")
686
  cfg["find_unused_parameters"] = False
687
 
 
908
  if not base_cfg:
909
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
910
 
911
+ # Apply deterministic patch so workspace.create never KeyErrors on '_pymodule'
912
+ patched = _patch_workspace_create(REPO_DIR, module_default="rtdetrv2_pytorch.src")
913
+ if patched:
914
+ logging.info(f"Patched workspace.create at: {patched}")
915
+ else:
916
+ logging.warning("workspace.create patch not applied (file missing or atypical).")
917
+
918
  data_yaml = os.path.join(dataset_path, "data.yaml")
919
  with open(data_yaml, "r", encoding="utf-8") as f:
920
  dy = yaml.safe_load(f)
 
945
  q = Queue()
946
  def run_train():
947
  try:
948
+ # Optional: legacy sitecustomize helper (kept; patch above already handles it)
949
  train_cwd = os.path.dirname(train_script)
950
  _install_workspace_env_fallback(train_cwd)
951