wuhp commited on
Commit
255f7e6
·
verified ·
1 Parent(s): db5e783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -46
app.py CHANGED
@@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
10
  from roboflow import Roboflow
11
  from PIL import Image
12
  import torch
 
13
 
14
  # Quiet some noisy libs on Spaces (harmless locally)
15
  os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
@@ -144,7 +145,7 @@ def parse_roboflow_url(s: str):
144
  version = None
145
  if len(p) >= 3:
146
  v = p[2]
147
- if v.lower().startsWith('v') and v[1:].isdigit():
148
  version = int(v[1:])
149
  elif v.isdigit():
150
  version = int(v)
@@ -470,53 +471,54 @@ def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_py
470
  os.makedirs(dest_dir, exist_ok=True)
471
  sc_path = os.path.join(dest_dir, "sitecustomize.py")
472
 
473
- # NOTE: Not an f-string. Escape literal braces with {{ }} and only format {module_default}.
474
- code = textwrap.dedent("""\
475
- import os, importlib, types
476
- try:
477
- mod_default = os.environ.get("RTDETR_PYMODULE", "{module_default}") or "{module_default}"
478
- ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
479
- _orig_create = ws_mod.create
480
 
481
- def _ensure_pymodule_object(cfg):
482
- pm = None
483
- try:
484
- pm = cfg.get("_pymodule", None)
485
- except Exception:
486
- pm = None
487
- if isinstance(pm, str) or pm is None:
488
- name = pm.strip() if isinstance(pm, str) and pm.strip() else mod_default
489
- try:
490
- mod = importlib.import_module(name)
491
- except Exception:
492
- mod = importlib.import_module(mod_default)
493
- try:
494
- cfg["_pymodule"] = mod
495
- except Exception:
496
- pass
497
- return mod
498
- if isinstance(pm, types.ModuleType):
499
- return pm
500
- try:
501
- mod = importlib.import_module(mod_default)
502
- cfg["_pymodule"] = mod
503
- return mod
504
- except Exception:
505
- return pm
506
 
507
- def create(name, **kwargs):
508
- cfg = kwargs.get("cfg")
509
- if not isinstance(cfg, dict):
510
- cfg = {} if cfg is None else dict(cfg)
511
- kwargs["cfg"] = cfg
512
- _ensure_pymodule_object(cfg)
513
- return _orig_create(name, **kwargs)
514
 
515
- ws_mod.create = create
516
- except Exception:
517
- # never block training if shim fails
518
- pass
519
- """).format(module_default=module_default)
 
520
 
521
  with open(sc_path, "w", encoding="utf-8") as f:
522
  f.write(code)
@@ -694,7 +696,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
694
  "shuffle": bool(default_shuffle),
695
  "num_workers": 2,
696
  "drop_last": bool(dl_key == "train_dataloader"),
697
- "collate_fn": {"type": "BatchImageCollateFuncion"},
698
  "total_batch_size": int(batch),
699
  }
700
  cfg[dl_key] = block
 
10
  from roboflow import Roboflow
11
  from PIL import Image
12
  import torch
13
+ from string import Template # <-- NEW
14
 
15
  # Quiet some noisy libs on Spaces (harmless locally)
16
  os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
 
145
  version = None
146
  if len(p) >= 3:
147
  v = p[2]
148
+ if v.lower().startswith('v') and v[1:].isdigit(): # <-- FIXED
149
  version = int(v[1:])
150
  elif v.isdigit():
151
  version = int(v)
 
471
  os.makedirs(dest_dir, exist_ok=True)
472
  sc_path = os.path.join(dest_dir, "sitecustomize.py")
473
 
474
+ # Use Template so braces in code remain literal; only $module_default is substituted.
475
+ tmpl = Template(r"""
476
+ import os, importlib, types
477
+ try:
478
+ mod_default = os.environ.get("RTDETR_PYMODULE", "$module_default") or "$module_default"
479
+ ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
480
+ _orig_create = ws_mod.create
481
 
482
+ def _ensure_pymodule_object(cfg):
483
+ pm = None
484
+ try:
485
+ pm = cfg.get("_pymodule", None)
486
+ except Exception:
487
+ pm = None
488
+ if isinstance(pm, str) or pm is None:
489
+ name = pm.strip() if isinstance(pm, str) and pm.strip() else mod_default
490
+ try:
491
+ mod = importlib.import_module(name)
492
+ except Exception:
493
+ mod = importlib.import_module(mod_default)
494
+ try:
495
+ cfg["_pymodule"] = mod
496
+ except Exception:
497
+ pass
498
+ return mod
499
+ if isinstance(pm, types.ModuleType):
500
+ return pm
501
+ try:
502
+ mod = importlib.import_module(mod_default)
503
+ cfg["_pymodule"] = mod
504
+ return mod
505
+ except Exception:
506
+ return pm
507
 
508
+ def create(name, **kwargs):
509
+ cfg = kwargs.get("cfg")
510
+ if not isinstance(cfg, dict):
511
+ cfg = {} if cfg is None else dict(cfg)
512
+ kwargs["cfg"] = cfg
513
+ _ensure_pymodule_object(cfg)
514
+ return _orig_create(name, **kwargs)
515
 
516
+ ws_mod.create = create
517
+ except Exception:
518
+ # never block training if shim fails
519
+ pass
520
+ """)
521
+ code = tmpl.substitute(module_default=module_default)
522
 
523
  with open(sc_path, "w", encoding="utf-8") as f:
524
  f.write(code)
 
696
  "shuffle": bool(default_shuffle),
697
  "num_workers": 2,
698
  "drop_last": bool(dl_key == "train_dataloader"),
699
+ "collate_fn": {"type": "BatchImageCollateFunction"}, # <-- FIXED name
700
  "total_batch_size": int(batch),
701
  }
702
  cfg[dl_key] = block