Update app.py
Browse files
app.py
CHANGED
|
@@ -458,7 +458,7 @@ def _install_supervisely_logger_shim():
|
|
| 458 |
"""))
|
| 459 |
return str(root)
|
| 460 |
|
| 461 |
-
# ----
|
| 462 |
def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
|
| 463 |
"""
|
| 464 |
Creates a sitecustomize.py in the training cwd that monkeypatches
|
|
@@ -497,6 +497,78 @@ def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "r
|
|
| 497 |
f.write(code)
|
| 498 |
return sc_path
|
| 499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
|
| 501 |
url = CKPT_URLS.get(model_key)
|
| 502 |
if not url:
|
|
@@ -608,9 +680,8 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 608 |
cfg["task"] = cfg.get("task", "detection")
|
| 609 |
cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <= HINT for loader
|
| 610 |
|
| 611 |
-
# Disable SyncBN for single GPU/CPU runs
|
| 612 |
cfg["sync_bn"] = False
|
| 613 |
-
# Guardrails for single-process runs
|
| 614 |
cfg.setdefault("device", "")
|
| 615 |
cfg["find_unused_parameters"] = False
|
| 616 |
|
|
@@ -837,6 +908,13 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 837 |
if not base_cfg:
|
| 838 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
|
| 839 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 840 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 841 |
with open(data_yaml, "r", encoding="utf-8") as f:
|
| 842 |
dy = yaml.safe_load(f)
|
|
@@ -867,7 +945,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 867 |
q = Queue()
|
| 868 |
def run_train():
|
| 869 |
try:
|
| 870 |
-
#
|
| 871 |
train_cwd = os.path.dirname(train_script)
|
| 872 |
_install_workspace_env_fallback(train_cwd)
|
| 873 |
|
|
|
|
| 458 |
"""))
|
| 459 |
return str(root)
|
| 460 |
|
| 461 |
+
# ---- OPTIONAL: legacy sitecustomize fallback (kept for belt & suspenders) ----
|
| 462 |
def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
|
| 463 |
"""
|
| 464 |
Creates a sitecustomize.py in the training cwd that monkeypatches
|
|
|
|
| 497 |
f.write(code)
|
| 498 |
return sc_path
|
| 499 |
|
| 500 |
+
# ---- NEW: direct, deterministic patch of workspace.create --------------------
|
| 501 |
+
def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
|
| 502 |
+
"""
|
| 503 |
+
Idempotently patches third_party/RT-DETRv2/rtdetrv2_pytorch/src/core/workspace.py
|
| 504 |
+
so that workspace.create() tolerates missing cfg['_pymodule'] by falling back
|
| 505 |
+
to $RTDETR_PYMODULE or module_default. Returns the path of the file patched (or None).
|
| 506 |
+
"""
|
| 507 |
+
ws_path = os.path.join(
|
| 508 |
+
repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py"
|
| 509 |
+
)
|
| 510 |
+
if not os.path.isfile(ws_path):
|
| 511 |
+
return None
|
| 512 |
+
|
| 513 |
+
with open(ws_path, "r", encoding="utf-8") as f:
|
| 514 |
+
src = f.read()
|
| 515 |
+
|
| 516 |
+
if "RTDETR_PYMODULE" in src and "_SAFE_CREATE_PATCH" in src:
|
| 517 |
+
return ws_path # already patched
|
| 518 |
+
|
| 519 |
+
# Backup once
|
| 520 |
+
bak_path = ws_path + ".bak"
|
| 521 |
+
if not os.path.exists(bak_path):
|
| 522 |
+
try:
|
| 523 |
+
shutil.copy2(ws_path, bak_path)
|
| 524 |
+
except Exception:
|
| 525 |
+
pass
|
| 526 |
+
|
| 527 |
+
marker = "_SAFE_CREATE_PATCH"
|
| 528 |
+
inject_code = f"""
|
| 529 |
+
# --- {marker}: begin ---
|
| 530 |
+
import os as _os
|
| 531 |
+
import importlib as _importlib
|
| 532 |
+
|
| 533 |
+
def _resolve_pymodule_from_cfg_or_env(_cfg):
|
| 534 |
+
try:
|
| 535 |
+
pm = _cfg.get("_pymodule", None)
|
| 536 |
+
except Exception:
|
| 537 |
+
pm = None
|
| 538 |
+
if not pm:
|
| 539 |
+
pm = _os.environ.get("RTDETR_PYMODULE", "{module_default}")
|
| 540 |
+
try:
|
| 541 |
+
_importlib.import_module(pm)
|
| 542 |
+
except Exception:
|
| 543 |
+
pass
|
| 544 |
+
try:
|
| 545 |
+
_cfg["_pymodule"] = pm
|
| 546 |
+
except Exception:
|
| 547 |
+
pass
|
| 548 |
+
return pm
|
| 549 |
+
# --- {marker}: end ---
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
if "def create(" in src:
|
| 553 |
+
if "_resolve_pymodule_from_cfg_or_env" not in src:
|
| 554 |
+
src = inject_code + src
|
| 555 |
+
src = src.replace("cfg['_pymodule']", "_resolve_pymodule_from_cfg_or_env(cfg)")
|
| 556 |
+
else:
|
| 557 |
+
return None
|
| 558 |
+
|
| 559 |
+
with open(ws_path, "w", encoding="utf-8") as f:
|
| 560 |
+
f.write(src)
|
| 561 |
+
return ws_path
|
| 562 |
+
|
| 563 |
+
def _unpatch_workspace_create(repo_root: str):
|
| 564 |
+
ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
|
| 565 |
+
bak_path = ws_path + ".bak"
|
| 566 |
+
if os.path.exists(bak_path):
|
| 567 |
+
try:
|
| 568 |
+
shutil.copy2(bak_path, ws_path)
|
| 569 |
+
except Exception:
|
| 570 |
+
pass
|
| 571 |
+
|
| 572 |
def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
|
| 573 |
url = CKPT_URLS.get(model_key)
|
| 574 |
if not url:
|
|
|
|
| 680 |
cfg["task"] = cfg.get("task", "detection")
|
| 681 |
cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <= HINT for loader
|
| 682 |
|
| 683 |
+
# Disable SyncBN for single GPU/CPU runs; guard DDP flags
|
| 684 |
cfg["sync_bn"] = False
|
|
|
|
| 685 |
cfg.setdefault("device", "")
|
| 686 |
cfg["find_unused_parameters"] = False
|
| 687 |
|
|
|
|
| 908 |
if not base_cfg:
|
| 909 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
|
| 910 |
|
| 911 |
+
# Apply deterministic patch so workspace.create never KeyErrors on '_pymodule'
|
| 912 |
+
patched = _patch_workspace_create(REPO_DIR, module_default="rtdetrv2_pytorch.src")
|
| 913 |
+
if patched:
|
| 914 |
+
logging.info(f"Patched workspace.create at: {patched}")
|
| 915 |
+
else:
|
| 916 |
+
logging.warning("workspace.create patch not applied (file missing or atypical).")
|
| 917 |
+
|
| 918 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 919 |
with open(data_yaml, "r", encoding="utf-8") as f:
|
| 920 |
dy = yaml.safe_load(f)
|
|
|
|
| 945 |
q = Queue()
|
| 946 |
def run_train():
|
| 947 |
try:
|
| 948 |
+
# Optional: legacy sitecustomize helper (kept; patch above already handles it)
|
| 949 |
train_cwd = os.path.dirname(train_script)
|
| 950 |
_install_workspace_env_fallback(train_cwd)
|
| 951 |
|