Update app.py
Browse files
app.py
CHANGED
|
@@ -458,38 +458,57 @@ def _install_supervisely_logger_shim():
|
|
| 458 |
"""))
|
| 459 |
return str(root)
|
| 460 |
|
| 461 |
-
# ----
|
| 462 |
-
def
|
| 463 |
"""
|
| 464 |
-
|
| 465 |
-
rtdetrv2_pytorch.src.core.workspace.create
|
| 466 |
-
|
|
|
|
| 467 |
"""
|
| 468 |
sc_path = os.path.join(cwd_for_train, "sitecustomize.py")
|
| 469 |
code = textwrap.dedent(f"""
|
| 470 |
-
import os, importlib
|
| 471 |
try:
|
| 472 |
-
|
| 473 |
ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
|
| 474 |
_orig_create = ws_mod.create
|
| 475 |
-
|
|
|
|
| 476 |
pm = None
|
| 477 |
try:
|
| 478 |
pm = cfg.get("_pymodule", None)
|
| 479 |
except Exception:
|
| 480 |
pm = None
|
| 481 |
-
if
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
try:
|
| 488 |
-
|
|
|
|
|
|
|
| 489 |
except Exception:
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
except Exception:
|
| 494 |
pass
|
| 495 |
""")
|
|
@@ -497,92 +516,13 @@ def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "r
|
|
| 497 |
f.write(code)
|
| 498 |
return sc_path
|
| 499 |
|
| 500 |
-
# ----
|
| 501 |
def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
|
| 502 |
"""
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
not a string. If missing or string, it imports the module and writes it back.
|
| 506 |
-
Returns the path of the file patched (or None).
|
| 507 |
"""
|
| 508 |
-
|
| 509 |
-
if not os.path.isfile(ws_path):
|
| 510 |
-
return None
|
| 511 |
-
|
| 512 |
-
# If previously patched with V2, nothing to do
|
| 513 |
-
with open(ws_path, "r", encoding="utf-8") as f:
|
| 514 |
-
current_src = f.read()
|
| 515 |
-
bak_path = ws_path + ".bak"
|
| 516 |
-
if "_SAFE_CREATE_PATCH_V2" in current_src:
|
| 517 |
-
return ws_path
|
| 518 |
-
|
| 519 |
-
# If there was an older patch, restore from .bak to re-apply cleanly
|
| 520 |
-
if "_SAFE_CREATE_PATCH" in current_src and os.path.exists(bak_path):
|
| 521 |
-
try:
|
| 522 |
-
shutil.copy2(bak_path, ws_path)
|
| 523 |
-
except Exception:
|
| 524 |
-
pass
|
| 525 |
-
with open(ws_path, "r", encoding="utf-8") as f:
|
| 526 |
-
current_src = f.read()
|
| 527 |
-
|
| 528 |
-
# Backup original once
|
| 529 |
-
if not os.path.exists(bak_path):
|
| 530 |
-
try:
|
| 531 |
-
shutil.copy2(ws_path, bak_path)
|
| 532 |
-
except Exception:
|
| 533 |
-
pass
|
| 534 |
-
|
| 535 |
-
marker = "_SAFE_CREATE_PATCH_V2"
|
| 536 |
-
helper_code = f"""
|
| 537 |
-
# --- {marker}: begin ---
|
| 538 |
-
import importlib as _importlib
|
| 539 |
-
import types as _types
|
| 540 |
-
|
| 541 |
-
def _ensure_pymodule_object(_cfg, _default="{module_default}"):
|
| 542 |
-
pm = None
|
| 543 |
-
try:
|
| 544 |
-
pm = _cfg.get("_pymodule", None)
|
| 545 |
-
except Exception:
|
| 546 |
-
pm = None
|
| 547 |
-
if isinstance(pm, str) or pm is None:
|
| 548 |
-
name = pm if isinstance(pm, str) and pm.strip() else _default
|
| 549 |
-
try:
|
| 550 |
-
mod = _importlib.import_module(name)
|
| 551 |
-
except Exception:
|
| 552 |
-
mod = _importlib.import_module(_default)
|
| 553 |
-
try:
|
| 554 |
-
_cfg["_pymodule"] = mod
|
| 555 |
-
except Exception:
|
| 556 |
-
pass
|
| 557 |
-
return mod
|
| 558 |
-
if isinstance(pm, _types.ModuleType):
|
| 559 |
-
return pm
|
| 560 |
-
try:
|
| 561 |
-
mod = _importlib.import_module(_default)
|
| 562 |
-
_cfg["_pymodule"] = mod
|
| 563 |
-
return mod
|
| 564 |
-
except Exception:
|
| 565 |
-
return pm
|
| 566 |
-
# --- {marker}: end ---
|
| 567 |
-
"""
|
| 568 |
-
|
| 569 |
-
# Insert helper at top if missing
|
| 570 |
-
src = current_src
|
| 571 |
-
if marker not in src:
|
| 572 |
-
src = helper_code + src
|
| 573 |
-
|
| 574 |
-
# Inject normalization line right after def create(
|
| 575 |
-
inject_line = " cfg['_pymodule'] = _ensure_pymodule_object(cfg)\n"
|
| 576 |
-
lines = src.splitlines(keepends=True)
|
| 577 |
-
for i, line in enumerate(lines):
|
| 578 |
-
if line.strip().startswith("def create(") and "_ensure_pymodule_object" not in "".join(lines[i:i+5]):
|
| 579 |
-
lines.insert(i + 1, inject_line)
|
| 580 |
-
break
|
| 581 |
-
new_src = "".join(lines)
|
| 582 |
-
|
| 583 |
-
with open(ws_path, "w", encoding="utf-8") as f:
|
| 584 |
-
f.write(new_src)
|
| 585 |
-
return ws_path
|
| 586 |
|
| 587 |
def _unpatch_workspace_create(repo_root: str):
|
| 588 |
ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
|
|
@@ -932,12 +872,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 932 |
if not base_cfg:
|
| 933 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
|
| 934 |
|
| 935 |
-
#
|
| 936 |
-
patched = _patch_workspace_create(REPO_DIR, module_default="rtdetrv2_pytorch.src")
|
| 937 |
-
if patched:
|
| 938 |
-
logging.info(f"Patched workspace.create at: {patched}")
|
| 939 |
-
else:
|
| 940 |
-
logging.warning("workspace.create patch not applied (file missing or atypical).")
|
| 941 |
|
| 942 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 943 |
with open(data_yaml, "r", encoding="utf-8") as f:
|
|
@@ -969,9 +904,10 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 969 |
q = Queue()
|
| 970 |
def run_train():
|
| 971 |
try:
|
| 972 |
-
# Optional: legacy sitecustomize helper (kept; patch above already handles it)
|
| 973 |
train_cwd = os.path.dirname(train_script)
|
| 974 |
-
|
|
|
|
|
|
|
| 975 |
|
| 976 |
env = os.environ.copy()
|
| 977 |
# Make sure repo code can be imported
|
|
|
|
| 458 |
"""))
|
| 459 |
return str(root)
|
| 460 |
|
| 461 |
+
# ---- NEW: kwargs-aware sitecustomize shim (safe, non-invasive) ---------------
|
| 462 |
+
def _install_workspace_shim_v3(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
|
| 463 |
"""
|
| 464 |
+
Writes a sitecustomize.py that monkeypatches
|
| 465 |
+
rtdetrv2_pytorch.src.core.workspace.create so it works with the real signature:
|
| 466 |
+
def create(name, **kwargs):
|
| 467 |
+
The shim ensures kwargs['cfg'] is a dict, then guarantees cfg['_pymodule'] is a *module object*.
|
| 468 |
"""
|
| 469 |
sc_path = os.path.join(cwd_for_train, "sitecustomize.py")
|
| 470 |
code = textwrap.dedent(f"""
|
| 471 |
+
import os, importlib, types
|
| 472 |
try:
|
| 473 |
+
mod_default = os.environ.get("RTDETR_PYMODULE", "{module_default}") or "{module_default}"
|
| 474 |
ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
|
| 475 |
_orig_create = ws_mod.create
|
| 476 |
+
|
| 477 |
+
def _ensure_pymodule_object(cfg):
|
| 478 |
pm = None
|
| 479 |
try:
|
| 480 |
pm = cfg.get("_pymodule", None)
|
| 481 |
except Exception:
|
| 482 |
pm = None
|
| 483 |
+
if isinstance(pm, str) or pm is None:
|
| 484 |
+
name = pm.strip() if isinstance(pm, str) and pm.strip() else mod_default
|
| 485 |
+
try:
|
| 486 |
+
mod = importlib.import_module(name)
|
| 487 |
+
except Exception:
|
| 488 |
+
mod = importlib.import_module(mod_default)
|
| 489 |
+
try:
|
| 490 |
+
cfg["_pymodule"] = mod
|
| 491 |
+
except Exception:
|
| 492 |
+
pass
|
| 493 |
+
return mod
|
| 494 |
+
if isinstance(pm, types.ModuleType):
|
| 495 |
+
return pm
|
| 496 |
try:
|
| 497 |
+
mod = importlib.import_module(mod_default)
|
| 498 |
+
cfg["_pymodule"] = mod
|
| 499 |
+
return mod
|
| 500 |
except Exception:
|
| 501 |
+
return pm
|
| 502 |
+
|
| 503 |
+
def create(name, **kwargs):
|
| 504 |
+
cfg = kwargs.get("cfg")
|
| 505 |
+
if not isinstance(cfg, dict):
|
| 506 |
+
cfg = {} if cfg is None else dict(cfg)
|
| 507 |
+
kwargs["cfg"] = cfg
|
| 508 |
+
_ensure_pymodule_object(cfg)
|
| 509 |
+
return _orig_create(name, **kwargs)
|
| 510 |
+
|
| 511 |
+
ws_mod.create = create
|
| 512 |
except Exception:
|
| 513 |
pass
|
| 514 |
""")
|
|
|
|
| 516 |
f.write(code)
|
| 517 |
return sc_path
|
| 518 |
|
| 519 |
+
# ---- Deprecated: on-disk workspace patch (no-op now) -------------------------
|
| 520 |
def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
|
| 521 |
"""
|
| 522 |
+
Deprecated: we no longer edit third-party files on disk.
|
| 523 |
+
The shim in sitecustomize.py handles cfg/_pymodule safely.
|
|
|
|
|
|
|
| 524 |
"""
|
| 525 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
def _unpatch_workspace_create(repo_root: str):
|
| 528 |
ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
|
|
|
|
| 872 |
if not base_cfg:
|
| 873 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
|
| 874 |
|
| 875 |
+
# No longer patch files on disk; we use a runtime shim instead.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 878 |
with open(data_yaml, "r", encoding="utf-8") as f:
|
|
|
|
| 904 |
q = Queue()
|
| 905 |
def run_train():
|
| 906 |
try:
|
|
|
|
| 907 |
train_cwd = os.path.dirname(train_script)
|
| 908 |
+
|
| 909 |
+
# Install kwargs-aware sitecustomize shim (safe, non-invasive)
|
| 910 |
+
_install_workspace_shim_v3(train_cwd, module_default="rtdetrv2_pytorch.src")
|
| 911 |
|
| 912 |
env = os.environ.copy()
|
| 913 |
# Make sure repo code can be imported
|