wuhp commited on
Commit
99ee02b
·
verified ·
1 Parent(s): 891078c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -109
app.py CHANGED
@@ -458,38 +458,57 @@ def _install_supervisely_logger_shim():
458
  """))
459
  return str(root)
460
 
461
- # ---- OPTIONAL: legacy sitecustomize fallback (kept for belt & suspenders) ----
462
- def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
463
  """
464
- Creates a sitecustomize.py in the training cwd that monkeypatches
465
- rtdetrv2_pytorch.src.core.workspace.create to gracefully handle missing
466
- cfg['_pymodule'] by falling back to $RTDETR_PYMODULE or module_default.
 
467
  """
468
  sc_path = os.path.join(cwd_for_train, "sitecustomize.py")
469
  code = textwrap.dedent(f"""
470
- import os, importlib
471
  try:
472
- mod_path = os.environ.get("RTDETR_PYMODULE", "{module_default}")
473
  ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
474
  _orig_create = ws_mod.create
475
- def _safe_create(name, cfg, *args, **kwargs):
 
476
  pm = None
477
  try:
478
  pm = cfg.get("_pymodule", None)
479
  except Exception:
480
  pm = None
481
- if not pm:
482
- pm = os.environ.get("RTDETR_PYMODULE", "{module_default}")
483
- try:
484
- importlib.import_module(pm)
485
- except Exception:
486
- pass
 
 
 
 
 
 
 
487
  try:
488
- cfg["_pymodule"] = pm
 
 
489
  except Exception:
490
- pass
491
- return _orig_create(name, cfg, *args, **kwargs)
492
- ws_mod.create = _safe_create
 
 
 
 
 
 
 
 
493
  except Exception:
494
  pass
495
  """)
@@ -497,92 +516,13 @@ def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "r
497
  f.write(code)
498
  return sc_path
499
 
500
- # ---- NEW: direct, deterministic patch of workspace.create (V2) ---------------
501
  def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
502
  """
503
- Idempotently patches third_party/RT-DETRv2/rtdetrv2_pytorch/src/core/workspace.py
504
- so that workspace.create() guarantees cfg['_pymodule'] is a *module object*,
505
- not a string. If missing or string, it imports the module and writes it back.
506
- Returns the path of the file patched (or None).
507
  """
508
- ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
509
- if not os.path.isfile(ws_path):
510
- return None
511
-
512
- # If previously patched with V2, nothing to do
513
- with open(ws_path, "r", encoding="utf-8") as f:
514
- current_src = f.read()
515
- bak_path = ws_path + ".bak"
516
- if "_SAFE_CREATE_PATCH_V2" in current_src:
517
- return ws_path
518
-
519
- # If there was an older patch, restore from .bak to re-apply cleanly
520
- if "_SAFE_CREATE_PATCH" in current_src and os.path.exists(bak_path):
521
- try:
522
- shutil.copy2(bak_path, ws_path)
523
- except Exception:
524
- pass
525
- with open(ws_path, "r", encoding="utf-8") as f:
526
- current_src = f.read()
527
-
528
- # Backup original once
529
- if not os.path.exists(bak_path):
530
- try:
531
- shutil.copy2(ws_path, bak_path)
532
- except Exception:
533
- pass
534
-
535
- marker = "_SAFE_CREATE_PATCH_V2"
536
- helper_code = f"""
537
- # --- {marker}: begin ---
538
- import importlib as _importlib
539
- import types as _types
540
-
541
- def _ensure_pymodule_object(_cfg, _default="{module_default}"):
542
- pm = None
543
- try:
544
- pm = _cfg.get("_pymodule", None)
545
- except Exception:
546
- pm = None
547
- if isinstance(pm, str) or pm is None:
548
- name = pm if isinstance(pm, str) and pm.strip() else _default
549
- try:
550
- mod = _importlib.import_module(name)
551
- except Exception:
552
- mod = _importlib.import_module(_default)
553
- try:
554
- _cfg["_pymodule"] = mod
555
- except Exception:
556
- pass
557
- return mod
558
- if isinstance(pm, _types.ModuleType):
559
- return pm
560
- try:
561
- mod = _importlib.import_module(_default)
562
- _cfg["_pymodule"] = mod
563
- return mod
564
- except Exception:
565
- return pm
566
- # --- {marker}: end ---
567
- """
568
-
569
- # Insert helper at top if missing
570
- src = current_src
571
- if marker not in src:
572
- src = helper_code + src
573
-
574
- # Inject normalization line right after def create(
575
- inject_line = " cfg['_pymodule'] = _ensure_pymodule_object(cfg)\n"
576
- lines = src.splitlines(keepends=True)
577
- for i, line in enumerate(lines):
578
- if line.strip().startswith("def create(") and "_ensure_pymodule_object" not in "".join(lines[i:i+5]):
579
- lines.insert(i + 1, inject_line)
580
- break
581
- new_src = "".join(lines)
582
-
583
- with open(ws_path, "w", encoding="utf-8") as f:
584
- f.write(new_src)
585
- return ws_path
586
 
587
  def _unpatch_workspace_create(repo_root: str):
588
  ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
@@ -932,12 +872,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
932
  if not base_cfg:
933
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
934
 
935
- # Apply deterministic patch so workspace.create never fails on '_pymodule'
936
- patched = _patch_workspace_create(REPO_DIR, module_default="rtdetrv2_pytorch.src")
937
- if patched:
938
- logging.info(f"Patched workspace.create at: {patched}")
939
- else:
940
- logging.warning("workspace.create patch not applied (file missing or atypical).")
941
 
942
  data_yaml = os.path.join(dataset_path, "data.yaml")
943
  with open(data_yaml, "r", encoding="utf-8") as f:
@@ -969,9 +904,10 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
969
  q = Queue()
970
  def run_train():
971
  try:
972
- # Optional: legacy sitecustomize helper (kept; patch above already handles it)
973
  train_cwd = os.path.dirname(train_script)
974
- _install_workspace_env_fallback(train_cwd)
 
 
975
 
976
  env = os.environ.copy()
977
  # Make sure repo code can be imported
 
458
  """))
459
  return str(root)
460
 
461
+ # ---- NEW: kwargs-aware sitecustomize shim (safe, non-invasive) ---------------
462
+ def _install_workspace_shim_v3(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
463
  """
464
+ Writes a sitecustomize.py that monkeypatches
465
+ rtdetrv2_pytorch.src.core.workspace.create so it works with the real signature:
466
+ def create(name, **kwargs):
467
+ The shim ensures kwargs['cfg'] is a dict, then guarantees cfg['_pymodule'] is a *module object*.
468
  """
469
  sc_path = os.path.join(cwd_for_train, "sitecustomize.py")
470
  code = textwrap.dedent(f"""
471
+ import os, importlib, types
472
  try:
473
+ mod_default = os.environ.get("RTDETR_PYMODULE", "{module_default}") or "{module_default}"
474
  ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
475
  _orig_create = ws_mod.create
476
+
477
+ def _ensure_pymodule_object(cfg):
478
  pm = None
479
  try:
480
  pm = cfg.get("_pymodule", None)
481
  except Exception:
482
  pm = None
483
+ if isinstance(pm, str) or pm is None:
484
+ name = pm.strip() if isinstance(pm, str) and pm.strip() else mod_default
485
+ try:
486
+ mod = importlib.import_module(name)
487
+ except Exception:
488
+ mod = importlib.import_module(mod_default)
489
+ try:
490
+ cfg["_pymodule"] = mod
491
+ except Exception:
492
+ pass
493
+ return mod
494
+ if isinstance(pm, types.ModuleType):
495
+ return pm
496
  try:
497
+ mod = importlib.import_module(mod_default)
498
+ cfg["_pymodule"] = mod
499
+ return mod
500
  except Exception:
501
+ return pm
502
+
503
+ def create(name, **kwargs):
504
+ cfg = kwargs.get("cfg")
505
+ if not isinstance(cfg, dict):
506
+ cfg = {} if cfg is None else dict(cfg)
507
+ kwargs["cfg"] = cfg
508
+ _ensure_pymodule_object(cfg)
509
+ return _orig_create(name, **kwargs)
510
+
511
+ ws_mod.create = create
512
  except Exception:
513
  pass
514
  """)
 
516
  f.write(code)
517
  return sc_path
518
 
519
+ # ---- Deprecated: on-disk workspace patch (no-op now) -------------------------
520
  def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
521
  """
522
+ Deprecated: we no longer edit third-party files on disk.
523
+ The shim in sitecustomize.py handles cfg/_pymodule safely.
 
 
524
  """
525
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
  def _unpatch_workspace_create(repo_root: str):
528
  ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
 
872
  if not base_cfg:
873
  raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/M/M*/L/X).")
874
 
875
+ # No longer patch files on disk; we use a runtime shim instead.
 
 
 
 
 
876
 
877
  data_yaml = os.path.join(dataset_path, "data.yaml")
878
  with open(data_yaml, "r", encoding="utf-8") as f:
 
904
  q = Queue()
905
  def run_train():
906
  try:
 
907
  train_cwd = os.path.dirname(train_script)
908
+
909
+ # Install kwargs-aware sitecustomize shim (safe, non-invasive)
910
+ _install_workspace_shim_v3(train_cwd, module_default="rtdetrv2_pytorch.src")
911
 
912
  env = os.environ.copy()
913
  # Make sure repo code can be imported