wuhp commited on
Commit
a71c515
·
verified ·
1 Parent(s): 255f7e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -37
app.py CHANGED
@@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
10
  from roboflow import Roboflow
11
  from PIL import Image
12
  import torch
13
- from string import Template # <-- NEW
14
 
15
  # Quiet some noisy libs on Spaces (harmless locally)
16
  os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
@@ -459,14 +459,14 @@ def _install_supervisely_logger_shim():
459
  """))
460
  return str(root)
461
 
462
- # ---- NEW: kwargs-aware sitecustomize shim (safe, non-invasive) ---------------
463
  def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
464
  """
465
- Writes a sitecustomize.py that monkeypatches
466
- rtdetrv2_pytorch.src.core.workspace.create so it works with the real signature:
467
- def create(name, **kwargs):
468
- The shim ensures kwargs['cfg'] is a dict, then guarantees cfg['_pymodule'] is a *module object*.
469
- `dest_dir` MUST be on sys.path at interpreter startup (we'll prepend it to PYTHONPATH).
470
  """
471
  os.makedirs(dest_dir, exist_ok=True)
472
  sc_path = os.path.join(dest_dir, "sitecustomize.py")
@@ -475,51 +475,51 @@ def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_py
475
  tmpl = Template(r"""
476
  import os, importlib, types
477
  try:
478
- mod_default = os.environ.get("RTDETR_PYMODULE", "$module_default") or "$module_default"
479
  ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
480
  _orig_create = ws_mod.create
481
 
482
- def _ensure_pymodule_object(cfg):
483
- pm = None
484
- try:
485
- pm = cfg.get("_pymodule", None)
486
- except Exception:
487
- pm = None
488
- if isinstance(pm, str) or pm is None:
489
- name = pm.strip() if isinstance(pm, str) and pm.strip() else mod_default
490
- try:
491
- mod = importlib.import_module(name)
492
- except Exception:
493
- mod = importlib.import_module(mod_default)
494
- try:
495
- cfg["_pymodule"] = mod
496
- except Exception:
497
- pass
498
- return mod
499
  if isinstance(pm, types.ModuleType):
500
  return pm
 
 
 
501
  try:
502
- mod = importlib.import_module(mod_default)
503
- cfg["_pymodule"] = mod
504
- return mod
505
  except Exception:
506
- return pm
 
 
 
 
 
 
 
 
 
 
507
 
508
- def create(name, **kwargs):
509
- cfg = kwargs.get("cfg")
510
  if not isinstance(cfg, dict):
511
  cfg = {} if cfg is None else dict(cfg)
512
- kwargs["cfg"] = cfg
513
  _ensure_pymodule_object(cfg)
514
- return _orig_create(name, **kwargs)
 
 
 
 
 
 
 
515
 
516
  ws_mod.create = create
517
  except Exception:
518
- # never block training if shim fails
519
  pass
520
  """)
521
  code = tmpl.substitute(module_default=module_default)
522
-
523
  with open(sc_path, "w", encoding="utf-8") as f:
524
  f.write(code)
525
  return sc_path
@@ -650,7 +650,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
650
 
651
  # Ensure the runtime knows which Python module hosts builders
652
  cfg["task"] = cfg.get("task", "detection")
653
- cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <= HINT for loader
654
 
655
  # Disable SyncBN for single GPU/CPU runs; guard DDP flags
656
  cfg["sync_bn"] = False
@@ -696,7 +696,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
696
  "shuffle": bool(default_shuffle),
697
  "num_workers": 2,
698
  "drop_last": bool(dl_key == "train_dataloader"),
699
- "collate_fn": {"type": "BatchImageCollateFunction"}, # <-- FIXED name
700
  "total_batch_size": int(batch),
701
  }
702
  cfg[dl_key] = block
@@ -716,6 +716,16 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
716
  block.setdefault("shuffle", bool(default_shuffle))
717
  block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
718
 
 
 
 
 
 
 
 
 
 
 
719
  ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
720
  ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
721
 
 
10
  from roboflow import Roboflow
11
  from PIL import Image
12
  import torch
13
+ from string import Template # <-- used by the shim
14
 
15
  # Quiet some noisy libs on Spaces (harmless locally)
16
  os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
 
459
  """))
460
  return str(root)
461
 
462
+ # ---- NEW: signature-agnostic sitecustomize shim (safe, non-invasive) ---------
463
  def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
464
  """
465
+ sitecustomize shim that patches rtdetrv2_pytorch.src.core.workspace.create
466
+ to be robust for BOTH call styles:
467
+ - create(name, cfg, **extras) (positional cfg)
468
+ - create(name, cfg=<dict>, **extras) (keyword cfg)
469
+ It guarantees cfg is a dict and cfg['_pymodule'] is a *module object*.
470
  """
471
  os.makedirs(dest_dir, exist_ok=True)
472
  sc_path = os.path.join(dest_dir, "sitecustomize.py")
 
475
  tmpl = Template(r"""
476
  import os, importlib, types
477
  try:
478
+ MOD_DEFAULT = os.environ.get("RTDETR_PYMODULE", "$module_default") or "$module_default"
479
  ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
480
  _orig_create = ws_mod.create
481
 
482
+ def _ensure_pymodule_object(cfg: dict):
483
+ pm = cfg.get("_pymodule", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  if isinstance(pm, types.ModuleType):
485
  return pm
486
+ name = (pm or "").strip() if isinstance(pm, str) else MOD_DEFAULT
487
+ if not name:
488
+ name = MOD_DEFAULT
489
  try:
490
+ mod = importlib.import_module(name)
 
 
491
  except Exception:
492
+ mod = importlib.import_module(MOD_DEFAULT)
493
+ cfg["_pymodule"] = mod
494
+ return mod
495
+
496
+ def create(name, *args, **kwargs):
497
+ # Accept both positional and keyword cfg
498
+ if args:
499
+ args = list(args)
500
+ cfg = args[0]
501
+ else:
502
+ cfg = kwargs.get("cfg", None)
503
 
 
 
504
  if not isinstance(cfg, dict):
505
  cfg = {} if cfg is None else dict(cfg)
506
+
507
  _ensure_pymodule_object(cfg)
508
+
509
+ if args:
510
+ args[0] = cfg
511
+ args = tuple(args)
512
+ else:
513
+ kwargs["cfg"] = cfg
514
+
515
+ return _orig_create(name, *args, **kwargs)
516
 
517
  ws_mod.create = create
518
  except Exception:
519
+ # Never block training on shim failure
520
  pass
521
  """)
522
  code = tmpl.substitute(module_default=module_default)
 
523
  with open(sc_path, "w", encoding="utf-8") as f:
524
  f.write(code)
525
  return sc_path
 
650
 
651
  # Ensure the runtime knows which Python module hosts builders
652
  cfg["task"] = cfg.get("task", "detection")
653
+ cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <= hint for loader
654
 
655
  # Disable SyncBN for single GPU/CPU runs; guard DDP flags
656
  cfg["sync_bn"] = False
 
696
  "shuffle": bool(default_shuffle),
697
  "num_workers": 2,
698
  "drop_last": bool(dl_key == "train_dataloader"),
699
+ "collate_fn": {"type": "BatchImageCollateFunction"}, # correct spelling
700
  "total_batch_size": int(batch),
701
  }
702
  cfg[dl_key] = block
 
716
  block.setdefault("shuffle", bool(default_shuffle))
717
  block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
718
 
719
+ # ---- FORCE-FIX collate name even if it existed already
720
+ cf = block.get("collate_fn", {})
721
+ if isinstance(cf, dict):
722
+ t = str(cf.get("type", ""))
723
+ if t.lower() == "batchimagecollatefuncion" or "Funcion" in t:
724
+ cf["type"] = "BatchImageCollateFunction"
725
+ block["collate_fn"] = cf
726
+ else:
727
+ block["collate_fn"] = {"type": "BatchImageCollateFunction"}
728
+
729
  ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
730
  ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
731