wuhp commited on
Commit
891078c
·
verified ·
1 Parent(s): 5c0661b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -28
app.py CHANGED
@@ -458,45 +458,37 @@ def _install_supervisely_logger_shim():
458
  """))
459
  return str(root)
460
 
461
- # ---- Runtime sitecustomize wrapper (SAFE) ------------------------------------
462
  def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
463
  """
464
  Creates a sitecustomize.py in the training cwd that monkeypatches
465
- rtdetrv2_pytorch.src.core.workspace.create to ensure cfg['_pymodule'] exists
466
- and is a *module object* before delegating to the original create().
467
  """
468
  sc_path = os.path.join(cwd_for_train, "sitecustomize.py")
469
  code = textwrap.dedent(f"""
470
- import os, importlib, types
471
  try:
 
472
  ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
473
  _orig_create = ws_mod.create
474
-
475
- def _ensure_module_object(name_or_none: str) -> types.ModuleType:
476
- target = name_or_none.strip() if isinstance(name_or_none, str) and name_or_none.strip() else os.environ.get("RTDETR_PYMODULE", "{module_default}")
477
- try:
478
- return importlib.import_module(target)
479
- except Exception:
480
- return importlib.import_module("{module_default}")
481
-
482
  def _safe_create(name, cfg, *args, **kwargs):
483
  pm = None
484
  try:
485
  pm = cfg.get("_pymodule", None)
486
  except Exception:
487
  pm = None
488
-
489
- if isinstance(pm, types.ModuleType):
 
 
 
 
 
 
 
490
  pass
491
- else:
492
- mod = _ensure_module_object(pm)
493
- try:
494
- cfg["_pymodule"] = mod
495
- except Exception:
496
- pass
497
-
498
  return _orig_create(name, cfg, *args, **kwargs)
499
-
500
  ws_mod.create = _safe_create
501
  except Exception:
502
  pass
@@ -505,14 +497,99 @@ def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "r
505
  f.write(code)
506
  return sc_path
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  def _unpatch_workspace_create(repo_root: str):
509
- """If a previous run modified workspace.py, restore it from its .bak (best-effort)."""
510
  ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
511
  bak_path = ws_path + ".bak"
512
  if os.path.exists(bak_path):
513
  try:
514
  shutil.copy2(bak_path, ws_path)
515
- logging.info("Restored original workspace.py from backup.")
516
  except Exception:
517
  pass
518
 
@@ -855,9 +932,12 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
855
  if not base_cfg:
856
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
857
 
858
- # Ensure we are NOT using any previous on-disk patch. Rely on runtime wrapper.
859
- _unpatch_workspace_create(REPO_DIR)
860
- logging.info("Using runtime sitecustomize monkey-patch for workspace.create; no on-disk edits.")
 
 
 
861
 
862
  data_yaml = os.path.join(dataset_path, "data.yaml")
863
  with open(data_yaml, "r", encoding="utf-8") as f:
@@ -889,7 +969,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
889
  q = Queue()
890
  def run_train():
891
  try:
892
- # Runtime helper that ensures cfg['_pymodule'] is a module object.
893
  train_cwd = os.path.dirname(train_script)
894
  _install_workspace_env_fallback(train_cwd)
895
 
 
458
  """))
459
  return str(root)
460
 
461
+ # ---- OPTIONAL: legacy sitecustomize fallback (kept for belt & suspenders) ----
462
  def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
463
  """
464
  Creates a sitecustomize.py in the training cwd that monkeypatches
465
+ rtdetrv2_pytorch.src.core.workspace.create to gracefully handle missing
466
+ cfg['_pymodule'] by falling back to $RTDETR_PYMODULE or module_default.
467
  """
468
  sc_path = os.path.join(cwd_for_train, "sitecustomize.py")
469
  code = textwrap.dedent(f"""
470
+ import os, importlib
471
  try:
472
+ mod_path = os.environ.get("RTDETR_PYMODULE", "{module_default}")
473
  ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
474
  _orig_create = ws_mod.create
 
 
 
 
 
 
 
 
475
  def _safe_create(name, cfg, *args, **kwargs):
476
  pm = None
477
  try:
478
  pm = cfg.get("_pymodule", None)
479
  except Exception:
480
  pm = None
481
+ if not pm:
482
+ pm = os.environ.get("RTDETR_PYMODULE", "{module_default}")
483
+ try:
484
+ importlib.import_module(pm)
485
+ except Exception:
486
+ pass
487
+ try:
488
+ cfg["_pymodule"] = pm
489
+ except Exception:
490
  pass
 
 
 
 
 
 
 
491
  return _orig_create(name, cfg, *args, **kwargs)
 
492
  ws_mod.create = _safe_create
493
  except Exception:
494
  pass
 
497
  f.write(code)
498
  return sc_path
499
 
500
+ # ---- NEW: direct, deterministic patch of workspace.create (V2) ---------------
501
+ def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
502
+ """
503
+ Idempotently patches third_party/RT-DETRv2/rtdetrv2_pytorch/src/core/workspace.py
504
+ so that workspace.create() guarantees cfg['_pymodule'] is a *module object*,
505
+ not a string. If missing or string, it imports the module and writes it back.
506
+ Returns the path of the file patched (or None).
507
+ """
508
+ ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
509
+ if not os.path.isfile(ws_path):
510
+ return None
511
+
512
+ # If previously patched with V2, nothing to do
513
+ with open(ws_path, "r", encoding="utf-8") as f:
514
+ current_src = f.read()
515
+ bak_path = ws_path + ".bak"
516
+ if "_SAFE_CREATE_PATCH_V2" in current_src:
517
+ return ws_path
518
+
519
+ # If there was an older patch, restore from .bak to re-apply cleanly
520
+ if "_SAFE_CREATE_PATCH" in current_src and os.path.exists(bak_path):
521
+ try:
522
+ shutil.copy2(bak_path, ws_path)
523
+ except Exception:
524
+ pass
525
+ with open(ws_path, "r", encoding="utf-8") as f:
526
+ current_src = f.read()
527
+
528
+ # Backup original once
529
+ if not os.path.exists(bak_path):
530
+ try:
531
+ shutil.copy2(ws_path, bak_path)
532
+ except Exception:
533
+ pass
534
+
535
+ marker = "_SAFE_CREATE_PATCH_V2"
536
+ helper_code = f"""
537
+ # --- {marker}: begin ---
538
+ import importlib as _importlib
539
+ import types as _types
540
+
541
+ def _ensure_pymodule_object(_cfg, _default="{module_default}"):
542
+ pm = None
543
+ try:
544
+ pm = _cfg.get("_pymodule", None)
545
+ except Exception:
546
+ pm = None
547
+ if isinstance(pm, str) or pm is None:
548
+ name = pm if isinstance(pm, str) and pm.strip() else _default
549
+ try:
550
+ mod = _importlib.import_module(name)
551
+ except Exception:
552
+ mod = _importlib.import_module(_default)
553
+ try:
554
+ _cfg["_pymodule"] = mod
555
+ except Exception:
556
+ pass
557
+ return mod
558
+ if isinstance(pm, _types.ModuleType):
559
+ return pm
560
+ try:
561
+ mod = _importlib.import_module(_default)
562
+ _cfg["_pymodule"] = mod
563
+ return mod
564
+ except Exception:
565
+ return pm
566
+ # --- {marker}: end ---
567
+ """
568
+
569
+ # Insert helper at top if missing
570
+ src = current_src
571
+ if marker not in src:
572
+ src = helper_code + src
573
+
574
+ # Inject normalization line right after def create(
575
+ inject_line = " cfg['_pymodule'] = _ensure_pymodule_object(cfg)\n"
576
+ lines = src.splitlines(keepends=True)
577
+ for i, line in enumerate(lines):
578
+ if line.strip().startswith("def create(") and "_ensure_pymodule_object" not in "".join(lines[i:i+5]):
579
+ lines.insert(i + 1, inject_line)
580
+ break
581
+ new_src = "".join(lines)
582
+
583
+ with open(ws_path, "w", encoding="utf-8") as f:
584
+ f.write(new_src)
585
+ return ws_path
586
+
587
  def _unpatch_workspace_create(repo_root: str):
 
588
  ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
589
  bak_path = ws_path + ".bak"
590
  if os.path.exists(bak_path):
591
  try:
592
  shutil.copy2(bak_path, ws_path)
 
593
  except Exception:
594
  pass
595
 
 
932
  if not base_cfg:
933
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
934
 
935
+ # Apply deterministic patch so workspace.create never fails on '_pymodule'
936
+ patched = _patch_workspace_create(REPO_DIR, module_default="rtdetrv2_pytorch.src")
937
+ if patched:
938
+ logging.info(f"Patched workspace.create at: {patched}")
939
+ else:
940
+ logging.warning("workspace.create patch not applied (file missing or atypical).")
941
 
942
  data_yaml = os.path.join(dataset_path, "data.yaml")
943
  with open(data_yaml, "r", encoding="utf-8") as f:
 
969
  q = Queue()
970
  def run_train():
971
  try:
972
+ # Optional: legacy sitecustomize helper (kept; patch above already handles it)
973
  train_cwd = os.path.dirname(train_script)
974
  _install_workspace_env_fallback(train_cwd)
975