Update app.py
Browse files
app.py
CHANGED
|
@@ -458,45 +458,37 @@ def _install_supervisely_logger_shim():
|
|
| 458 |
"""))
|
| 459 |
return str(root)
|
| 460 |
|
| 461 |
-
# ----
|
| 462 |
def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
|
| 463 |
"""
|
| 464 |
Creates a sitecustomize.py in the training cwd that monkeypatches
|
| 465 |
-
rtdetrv2_pytorch.src.core.workspace.create to
|
| 466 |
-
|
| 467 |
"""
|
| 468 |
sc_path = os.path.join(cwd_for_train, "sitecustomize.py")
|
| 469 |
code = textwrap.dedent(f"""
|
| 470 |
-
import os, importlib
|
| 471 |
try:
|
|
|
|
| 472 |
ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
|
| 473 |
_orig_create = ws_mod.create
|
| 474 |
-
|
| 475 |
-
def _ensure_module_object(name_or_none: str) -> types.ModuleType:
|
| 476 |
-
target = name_or_none.strip() if isinstance(name_or_none, str) and name_or_none.strip() else os.environ.get("RTDETR_PYMODULE", "{module_default}")
|
| 477 |
-
try:
|
| 478 |
-
return importlib.import_module(target)
|
| 479 |
-
except Exception:
|
| 480 |
-
return importlib.import_module("{module_default}")
|
| 481 |
-
|
| 482 |
def _safe_create(name, cfg, *args, **kwargs):
|
| 483 |
pm = None
|
| 484 |
try:
|
| 485 |
pm = cfg.get("_pymodule", None)
|
| 486 |
except Exception:
|
| 487 |
pm = None
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
pass
|
| 491 |
-
else:
|
| 492 |
-
mod = _ensure_module_object(pm)
|
| 493 |
-
try:
|
| 494 |
-
cfg["_pymodule"] = mod
|
| 495 |
-
except Exception:
|
| 496 |
-
pass
|
| 497 |
-
|
| 498 |
return _orig_create(name, cfg, *args, **kwargs)
|
| 499 |
-
|
| 500 |
ws_mod.create = _safe_create
|
| 501 |
except Exception:
|
| 502 |
pass
|
|
@@ -505,14 +497,99 @@ def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "r
|
|
| 505 |
f.write(code)
|
| 506 |
return sc_path
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
def _unpatch_workspace_create(repo_root: str):
|
| 509 |
-
"""If a previous run modified workspace.py, restore it from its .bak (best-effort)."""
|
| 510 |
ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
|
| 511 |
bak_path = ws_path + ".bak"
|
| 512 |
if os.path.exists(bak_path):
|
| 513 |
try:
|
| 514 |
shutil.copy2(bak_path, ws_path)
|
| 515 |
-
logging.info("Restored original workspace.py from backup.")
|
| 516 |
except Exception:
|
| 517 |
pass
|
| 518 |
|
|
@@ -855,9 +932,12 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 855 |
if not base_cfg:
|
| 856 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
|
| 857 |
|
| 858 |
-
#
|
| 859 |
-
|
| 860 |
-
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 863 |
with open(data_yaml, "r", encoding="utf-8") as f:
|
|
@@ -889,7 +969,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 889 |
q = Queue()
|
| 890 |
def run_train():
|
| 891 |
try:
|
| 892 |
-
#
|
| 893 |
train_cwd = os.path.dirname(train_script)
|
| 894 |
_install_workspace_env_fallback(train_cwd)
|
| 895 |
|
|
|
|
| 458 |
"""))
|
| 459 |
return str(root)
|
| 460 |
|
| 461 |
+
# ---- OPTIONAL: legacy sitecustomize fallback (kept for belt & suspenders) ----
|
| 462 |
def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
|
| 463 |
"""
|
| 464 |
Creates a sitecustomize.py in the training cwd that monkeypatches
|
| 465 |
+
rtdetrv2_pytorch.src.core.workspace.create to gracefully handle missing
|
| 466 |
+
cfg['_pymodule'] by falling back to $RTDETR_PYMODULE or module_default.
|
| 467 |
"""
|
| 468 |
sc_path = os.path.join(cwd_for_train, "sitecustomize.py")
|
| 469 |
code = textwrap.dedent(f"""
|
| 470 |
+
import os, importlib
|
| 471 |
try:
|
| 472 |
+
mod_path = os.environ.get("RTDETR_PYMODULE", "{module_default}")
|
| 473 |
ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
|
| 474 |
_orig_create = ws_mod.create
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
def _safe_create(name, cfg, *args, **kwargs):
|
| 476 |
pm = None
|
| 477 |
try:
|
| 478 |
pm = cfg.get("_pymodule", None)
|
| 479 |
except Exception:
|
| 480 |
pm = None
|
| 481 |
+
if not pm:
|
| 482 |
+
pm = os.environ.get("RTDETR_PYMODULE", "{module_default}")
|
| 483 |
+
try:
|
| 484 |
+
importlib.import_module(pm)
|
| 485 |
+
except Exception:
|
| 486 |
+
pass
|
| 487 |
+
try:
|
| 488 |
+
cfg["_pymodule"] = pm
|
| 489 |
+
except Exception:
|
| 490 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
return _orig_create(name, cfg, *args, **kwargs)
|
|
|
|
| 492 |
ws_mod.create = _safe_create
|
| 493 |
except Exception:
|
| 494 |
pass
|
|
|
|
| 497 |
f.write(code)
|
| 498 |
return sc_path
|
| 499 |
|
| 500 |
+
# ---- NEW: direct, deterministic patch of workspace.create (V2) ---------------
|
| 501 |
+
def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
|
| 502 |
+
"""
|
| 503 |
+
Idempotently patches third_party/RT-DETRv2/rtdetrv2_pytorch/src/core/workspace.py
|
| 504 |
+
so that workspace.create() guarantees cfg['_pymodule'] is a *module object*,
|
| 505 |
+
not a string. If missing or string, it imports the module and writes it back.
|
| 506 |
+
Returns the path of the file patched (or None).
|
| 507 |
+
"""
|
| 508 |
+
ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
|
| 509 |
+
if not os.path.isfile(ws_path):
|
| 510 |
+
return None
|
| 511 |
+
|
| 512 |
+
# If previously patched with V2, nothing to do
|
| 513 |
+
with open(ws_path, "r", encoding="utf-8") as f:
|
| 514 |
+
current_src = f.read()
|
| 515 |
+
bak_path = ws_path + ".bak"
|
| 516 |
+
if "_SAFE_CREATE_PATCH_V2" in current_src:
|
| 517 |
+
return ws_path
|
| 518 |
+
|
| 519 |
+
# If there was an older patch, restore from .bak to re-apply cleanly
|
| 520 |
+
if "_SAFE_CREATE_PATCH" in current_src and os.path.exists(bak_path):
|
| 521 |
+
try:
|
| 522 |
+
shutil.copy2(bak_path, ws_path)
|
| 523 |
+
except Exception:
|
| 524 |
+
pass
|
| 525 |
+
with open(ws_path, "r", encoding="utf-8") as f:
|
| 526 |
+
current_src = f.read()
|
| 527 |
+
|
| 528 |
+
# Backup original once
|
| 529 |
+
if not os.path.exists(bak_path):
|
| 530 |
+
try:
|
| 531 |
+
shutil.copy2(ws_path, bak_path)
|
| 532 |
+
except Exception:
|
| 533 |
+
pass
|
| 534 |
+
|
| 535 |
+
marker = "_SAFE_CREATE_PATCH_V2"
|
| 536 |
+
helper_code = f"""
|
| 537 |
+
# --- {marker}: begin ---
|
| 538 |
+
import importlib as _importlib
|
| 539 |
+
import types as _types
|
| 540 |
+
|
| 541 |
+
def _ensure_pymodule_object(_cfg, _default="{module_default}"):
|
| 542 |
+
pm = None
|
| 543 |
+
try:
|
| 544 |
+
pm = _cfg.get("_pymodule", None)
|
| 545 |
+
except Exception:
|
| 546 |
+
pm = None
|
| 547 |
+
if isinstance(pm, str) or pm is None:
|
| 548 |
+
name = pm if isinstance(pm, str) and pm.strip() else _default
|
| 549 |
+
try:
|
| 550 |
+
mod = _importlib.import_module(name)
|
| 551 |
+
except Exception:
|
| 552 |
+
mod = _importlib.import_module(_default)
|
| 553 |
+
try:
|
| 554 |
+
_cfg["_pymodule"] = mod
|
| 555 |
+
except Exception:
|
| 556 |
+
pass
|
| 557 |
+
return mod
|
| 558 |
+
if isinstance(pm, _types.ModuleType):
|
| 559 |
+
return pm
|
| 560 |
+
try:
|
| 561 |
+
mod = _importlib.import_module(_default)
|
| 562 |
+
_cfg["_pymodule"] = mod
|
| 563 |
+
return mod
|
| 564 |
+
except Exception:
|
| 565 |
+
return pm
|
| 566 |
+
# --- {marker}: end ---
|
| 567 |
+
"""
|
| 568 |
+
|
| 569 |
+
# Insert helper at top if missing
|
| 570 |
+
src = current_src
|
| 571 |
+
if marker not in src:
|
| 572 |
+
src = helper_code + src
|
| 573 |
+
|
| 574 |
+
# Inject normalization line right after def create(
|
| 575 |
+
inject_line = " cfg['_pymodule'] = _ensure_pymodule_object(cfg)\n"
|
| 576 |
+
lines = src.splitlines(keepends=True)
|
| 577 |
+
for i, line in enumerate(lines):
|
| 578 |
+
if line.strip().startswith("def create(") and "_ensure_pymodule_object" not in "".join(lines[i:i+5]):
|
| 579 |
+
lines.insert(i + 1, inject_line)
|
| 580 |
+
break
|
| 581 |
+
new_src = "".join(lines)
|
| 582 |
+
|
| 583 |
+
with open(ws_path, "w", encoding="utf-8") as f:
|
| 584 |
+
f.write(new_src)
|
| 585 |
+
return ws_path
|
| 586 |
+
|
| 587 |
def _unpatch_workspace_create(repo_root: str):
|
|
|
|
| 588 |
ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
|
| 589 |
bak_path = ws_path + ".bak"
|
| 590 |
if os.path.exists(bak_path):
|
| 591 |
try:
|
| 592 |
shutil.copy2(bak_path, ws_path)
|
|
|
|
| 593 |
except Exception:
|
| 594 |
pass
|
| 595 |
|
|
|
|
| 932 |
if not base_cfg:
|
| 933 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
|
| 934 |
|
| 935 |
+
# Apply deterministic patch so workspace.create never fails on '_pymodule'
|
| 936 |
+
patched = _patch_workspace_create(REPO_DIR, module_default="rtdetrv2_pytorch.src")
|
| 937 |
+
if patched:
|
| 938 |
+
logging.info(f"Patched workspace.create at: {patched}")
|
| 939 |
+
else:
|
| 940 |
+
logging.warning("workspace.create patch not applied (file missing or atypical).")
|
| 941 |
|
| 942 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 943 |
with open(data_yaml, "r", encoding="utf-8") as f:
|
|
|
|
| 969 |
q = Queue()
|
| 970 |
def run_train():
|
| 971 |
try:
|
| 972 |
+
# Optional: legacy sitecustomize helper (kept; patch above already handles it)
|
| 973 |
train_cwd = os.path.dirname(train_script)
|
| 974 |
_install_workspace_env_fallback(train_cwd)
|
| 975 |
|