wuhp commited on
Commit
0d25d2a
·
verified ·
1 Parent(s): a71c515

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -40
app.py CHANGED
@@ -459,65 +459,91 @@ def _install_supervisely_logger_shim():
459
  """))
460
  return str(root)
461
 
462
- # ---- NEW: signature-agnostic sitecustomize shim (safe, non-invasive) ---------
463
  def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
464
  """
465
- sitecustomize shim that patches rtdetrv2_pytorch.src.core.workspace.create
466
- to be robust for BOTH call styles:
467
- - create(name, cfg, **extras) (positional cfg)
468
- - create(name, cfg=<dict>, **extras) (keyword cfg)
469
- It guarantees cfg is a dict and cfg['_pymodule'] is a *module object*.
470
  """
471
  os.makedirs(dest_dir, exist_ok=True)
472
  sc_path = os.path.join(dest_dir, "sitecustomize.py")
473
 
474
- # Use Template so braces in code remain literal; only $module_default is substituted.
475
  tmpl = Template(r"""
476
- import os, importlib, types
477
- try:
478
- MOD_DEFAULT = os.environ.get("RTDETR_PYMODULE", "$module_default") or "$module_default"
479
- ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
480
- _orig_create = ws_mod.create
481
-
482
- def _ensure_pymodule_object(cfg: dict):
483
- pm = cfg.get("_pymodule", None)
484
- if isinstance(pm, types.ModuleType):
485
- return pm
486
- name = (pm or "").strip() if isinstance(pm, str) else MOD_DEFAULT
487
- if not name:
488
- name = MOD_DEFAULT
489
- try:
490
- mod = importlib.import_module(name)
491
- except Exception:
492
- mod = importlib.import_module(MOD_DEFAULT)
493
- cfg["_pymodule"] = mod
494
- return mod
495
 
 
 
 
 
496
  def create(name, *args, **kwargs):
497
- # Accept both positional and keyword cfg
498
  if args:
499
  args = list(args)
500
  cfg = args[0]
501
  else:
502
  cfg = kwargs.get("cfg", None)
503
-
504
  if not isinstance(cfg, dict):
505
  cfg = {} if cfg is None else dict(cfg)
506
-
507
  _ensure_pymodule_object(cfg)
508
-
509
  if args:
510
  args[0] = cfg
511
  args = tuple(args)
512
  else:
513
  kwargs["cfg"] = cfg
514
-
515
  return _orig_create(name, *args, **kwargs)
516
-
517
  ws_mod.create = create
518
- except Exception:
519
- # Never block training on shim failure
520
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  """)
522
  code = tmpl.substitute(module_default=module_default)
523
  with open(sc_path, "w", encoding="utf-8") as f:
@@ -657,12 +683,16 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
657
  cfg.setdefault("device", "")
658
  cfg["find_unused_parameters"] = False
659
 
660
- # Remove COCO dataset include so it can't override our dataset paths later
661
  if "__include__" in cfg and isinstance(cfg["__include__"], list):
662
- cfg["__include__"] = [
663
- p for p in cfg["__include__"]
664
- if not (isinstance(p, str) and "configs/dataset/coco" in p.replace("\\", "/"))
665
- ]
 
 
 
 
666
 
667
  ann_dir = os.path.join(merged_dir, "annotations")
668
  paths = {
 
459
  """))
460
  return str(root)
461
 
462
+ # ---- NEW: robust sitecustomize shim with lazy import hook --------------------
463
  def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
464
  """
465
+ sitecustomize shim that:
466
+ - patches workspace.create for positional/keyword cfg,
467
+ - ensures cfg is a dict,
468
+ - injects cfg['_pymodule'] as a *module object*,
469
+ even if the target module is imported after sitecustomize runs.
470
  """
471
  os.makedirs(dest_dir, exist_ok=True)
472
  sc_path = os.path.join(dest_dir, "sitecustomize.py")
473
 
 
474
  tmpl = Template(r"""
475
+ import os, sys, importlib, importlib.abc, importlib.util, importlib.machinery, types
476
+ MOD_DEFAULT = os.environ.get("RTDETR_PYMODULE", "$module_default") or "$module_default"
477
+ TARGET = "rtdetrv2_pytorch.src.core.workspace"
478
+
479
+ def _ensure_pymodule_object(cfg: dict):
480
+ pm = cfg.get("_pymodule", None)
481
+ if isinstance(pm, types.ModuleType):
482
+ return pm
483
+ name = (pm or "").strip() if isinstance(pm, str) else MOD_DEFAULT
484
+ if not name:
485
+ name = MOD_DEFAULT
486
+ try:
487
+ mod = importlib.import_module(name)
488
+ except Exception:
489
+ mod = importlib.import_module(MOD_DEFAULT)
490
+ cfg["_pymodule"] = mod
491
+ return mod
 
 
492
 
493
+ def _patch_ws(ws_mod):
494
+ if getattr(ws_mod, "__rolo_patched__", False):
495
+ return
496
+ _orig_create = ws_mod.create
497
  def create(name, *args, **kwargs):
 
498
  if args:
499
  args = list(args)
500
  cfg = args[0]
501
  else:
502
  cfg = kwargs.get("cfg", None)
 
503
  if not isinstance(cfg, dict):
504
  cfg = {} if cfg is None else dict(cfg)
 
505
  _ensure_pymodule_object(cfg)
 
506
  if args:
507
  args[0] = cfg
508
  args = tuple(args)
509
  else:
510
  kwargs["cfg"] = cfg
 
511
  return _orig_create(name, *args, **kwargs)
 
512
  ws_mod.create = create
513
+ ws_mod.__rolo_patched__ = True
514
+
515
+ def _try_patch_now():
516
+ try:
517
+ ws_mod = importlib.import_module(TARGET)
518
+ _patch_ws(ws_mod)
519
+ return True
520
+ except Exception:
521
+ return False
522
+
523
+ if not _try_patch_now():
524
+ class _RoloFinder(importlib.abc.MetaPathFinder):
525
+ def find_spec(self, fullname, path, target=None):
526
+ if fullname != TARGET:
527
+ return None
528
+ origin_spec = importlib.util.find_spec(fullname)
529
+ if origin_spec is None or origin_spec.loader is None:
530
+ return None
531
+ loader = origin_spec.loader
532
+ class _RoloLoader(importlib.abc.Loader):
533
+ def create_module(self, spec):
534
+ if hasattr(loader, "create_module"):
535
+ return loader.create_module(spec)
536
+ return None
537
+ def exec_module(self, module):
538
+ loader.exec_module(module)
539
+ try:
540
+ _patch_ws(module)
541
+ except Exception:
542
+ pass
543
+ spec = importlib.machinery.ModuleSpec(fullname, _RoloLoader(), origin=origin_spec.origin)
544
+ spec.submodule_search_locations = origin_spec.submodule_search_locations
545
+ return spec
546
+ sys.meta_path.insert(0, _RoloFinder())
547
  """)
548
  code = tmpl.substitute(module_default=module_default)
549
  with open(sc_path, "w", encoding="utf-8") as f:
 
683
  cfg.setdefault("device", "")
684
  cfg["find_unused_parameters"] = False
685
 
686
+ # STRONG include pruning: keep ONLY runtime.yml to avoid overrides
687
  if "__include__" in cfg and isinstance(cfg["__include__"], list):
688
+ kept = []
689
+ for p in cfg["__include__"]:
690
+ if not isinstance(p, str):
691
+ continue
692
+ pp = p.replace("\\", "/")
693
+ if pp.endswith("/configs/runtime.yml"):
694
+ kept.append(p)
695
+ cfg["__include__"] = kept
696
 
697
  ann_dir = os.path.join(merged_dir, "annotations")
698
  paths = {