Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
|
|
| 10 |
from roboflow import Roboflow
|
| 11 |
from PIL import Image
|
| 12 |
import torch
|
| 13 |
-
from string import Template # <--
|
| 14 |
|
| 15 |
# Quiet some noisy libs on Spaces (harmless locally)
|
| 16 |
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
|
|
@@ -459,14 +459,14 @@ def _install_supervisely_logger_shim():
|
|
| 459 |
"""))
|
| 460 |
return str(root)
|
| 461 |
|
| 462 |
-
# ---- NEW:
|
| 463 |
def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
|
| 464 |
"""
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
"""
|
| 471 |
os.makedirs(dest_dir, exist_ok=True)
|
| 472 |
sc_path = os.path.join(dest_dir, "sitecustomize.py")
|
|
@@ -475,51 +475,51 @@ def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_py
|
|
| 475 |
tmpl = Template(r"""
|
| 476 |
import os, importlib, types
|
| 477 |
try:
|
| 478 |
-
|
| 479 |
ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
|
| 480 |
_orig_create = ws_mod.create
|
| 481 |
|
| 482 |
-
def _ensure_pymodule_object(cfg):
|
| 483 |
-
pm = None
|
| 484 |
-
try:
|
| 485 |
-
pm = cfg.get("_pymodule", None)
|
| 486 |
-
except Exception:
|
| 487 |
-
pm = None
|
| 488 |
-
if isinstance(pm, str) or pm is None:
|
| 489 |
-
name = pm.strip() if isinstance(pm, str) and pm.strip() else mod_default
|
| 490 |
-
try:
|
| 491 |
-
mod = importlib.import_module(name)
|
| 492 |
-
except Exception:
|
| 493 |
-
mod = importlib.import_module(mod_default)
|
| 494 |
-
try:
|
| 495 |
-
cfg["_pymodule"] = mod
|
| 496 |
-
except Exception:
|
| 497 |
-
pass
|
| 498 |
-
return mod
|
| 499 |
if isinstance(pm, types.ModuleType):
|
| 500 |
return pm
|
|
|
|
|
|
|
|
|
|
| 501 |
try:
|
| 502 |
-
mod = importlib.import_module(
|
| 503 |
-
cfg["_pymodule"] = mod
|
| 504 |
-
return mod
|
| 505 |
except Exception:
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
-
def create(name, **kwargs):
|
| 509 |
-
cfg = kwargs.get("cfg")
|
| 510 |
if not isinstance(cfg, dict):
|
| 511 |
cfg = {} if cfg is None else dict(cfg)
|
| 512 |
-
|
| 513 |
_ensure_pymodule_object(cfg)
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
ws_mod.create = create
|
| 517 |
except Exception:
|
| 518 |
-
#
|
| 519 |
pass
|
| 520 |
""")
|
| 521 |
code = tmpl.substitute(module_default=module_default)
|
| 522 |
-
|
| 523 |
with open(sc_path, "w", encoding="utf-8") as f:
|
| 524 |
f.write(code)
|
| 525 |
return sc_path
|
|
@@ -650,7 +650,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 650 |
|
| 651 |
# Ensure the runtime knows which Python module hosts builders
|
| 652 |
cfg["task"] = cfg.get("task", "detection")
|
| 653 |
-
cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <=
|
| 654 |
|
| 655 |
# Disable SyncBN for single GPU/CPU runs; guard DDP flags
|
| 656 |
cfg["sync_bn"] = False
|
|
@@ -696,7 +696,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 696 |
"shuffle": bool(default_shuffle),
|
| 697 |
"num_workers": 2,
|
| 698 |
"drop_last": bool(dl_key == "train_dataloader"),
|
| 699 |
-
"collate_fn": {"type": "BatchImageCollateFunction"}, #
|
| 700 |
"total_batch_size": int(batch),
|
| 701 |
}
|
| 702 |
cfg[dl_key] = block
|
|
@@ -716,6 +716,16 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 716 |
block.setdefault("shuffle", bool(default_shuffle))
|
| 717 |
block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
|
| 718 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
|
| 720 |
ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
|
| 721 |
|
|
|
|
| 10 |
from roboflow import Roboflow
|
| 11 |
from PIL import Image
|
| 12 |
import torch
|
| 13 |
+
from string import Template # <-- used by the shim
|
| 14 |
|
| 15 |
# Quiet some noisy libs on Spaces (harmless locally)
|
| 16 |
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
|
|
|
|
| 459 |
"""))
|
| 460 |
return str(root)
|
| 461 |
|
| 462 |
+
# ---- NEW: signature-agnostic sitecustomize shim (safe, non-invasive) ---------
|
| 463 |
def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
|
| 464 |
"""
|
| 465 |
+
sitecustomize shim that patches rtdetrv2_pytorch.src.core.workspace.create
|
| 466 |
+
to be robust for BOTH call styles:
|
| 467 |
+
- create(name, cfg, **extras) (positional cfg)
|
| 468 |
+
- create(name, cfg=<dict>, **extras) (keyword cfg)
|
| 469 |
+
It guarantees cfg is a dict and cfg['_pymodule'] is a *module object*.
|
| 470 |
"""
|
| 471 |
os.makedirs(dest_dir, exist_ok=True)
|
| 472 |
sc_path = os.path.join(dest_dir, "sitecustomize.py")
|
|
|
|
| 475 |
tmpl = Template(r"""
|
| 476 |
import os, importlib, types
|
| 477 |
try:
|
| 478 |
+
MOD_DEFAULT = os.environ.get("RTDETR_PYMODULE", "$module_default") or "$module_default"
|
| 479 |
ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
|
| 480 |
_orig_create = ws_mod.create
|
| 481 |
|
| 482 |
+
def _ensure_pymodule_object(cfg: dict):
|
| 483 |
+
pm = cfg.get("_pymodule", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
if isinstance(pm, types.ModuleType):
|
| 485 |
return pm
|
| 486 |
+
name = (pm or "").strip() if isinstance(pm, str) else MOD_DEFAULT
|
| 487 |
+
if not name:
|
| 488 |
+
name = MOD_DEFAULT
|
| 489 |
try:
|
| 490 |
+
mod = importlib.import_module(name)
|
|
|
|
|
|
|
| 491 |
except Exception:
|
| 492 |
+
mod = importlib.import_module(MOD_DEFAULT)
|
| 493 |
+
cfg["_pymodule"] = mod
|
| 494 |
+
return mod
|
| 495 |
+
|
| 496 |
+
def create(name, *args, **kwargs):
|
| 497 |
+
# Accept both positional and keyword cfg
|
| 498 |
+
if args:
|
| 499 |
+
args = list(args)
|
| 500 |
+
cfg = args[0]
|
| 501 |
+
else:
|
| 502 |
+
cfg = kwargs.get("cfg", None)
|
| 503 |
|
|
|
|
|
|
|
| 504 |
if not isinstance(cfg, dict):
|
| 505 |
cfg = {} if cfg is None else dict(cfg)
|
| 506 |
+
|
| 507 |
_ensure_pymodule_object(cfg)
|
| 508 |
+
|
| 509 |
+
if args:
|
| 510 |
+
args[0] = cfg
|
| 511 |
+
args = tuple(args)
|
| 512 |
+
else:
|
| 513 |
+
kwargs["cfg"] = cfg
|
| 514 |
+
|
| 515 |
+
return _orig_create(name, *args, **kwargs)
|
| 516 |
|
| 517 |
ws_mod.create = create
|
| 518 |
except Exception:
|
| 519 |
+
# Never block training on shim failure
|
| 520 |
pass
|
| 521 |
""")
|
| 522 |
code = tmpl.substitute(module_default=module_default)
|
|
|
|
| 523 |
with open(sc_path, "w", encoding="utf-8") as f:
|
| 524 |
f.write(code)
|
| 525 |
return sc_path
|
|
|
|
| 650 |
|
| 651 |
# Ensure the runtime knows which Python module hosts builders
|
| 652 |
cfg["task"] = cfg.get("task", "detection")
|
| 653 |
+
cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <= hint for loader
|
| 654 |
|
| 655 |
# Disable SyncBN for single GPU/CPU runs; guard DDP flags
|
| 656 |
cfg["sync_bn"] = False
|
|
|
|
| 696 |
"shuffle": bool(default_shuffle),
|
| 697 |
"num_workers": 2,
|
| 698 |
"drop_last": bool(dl_key == "train_dataloader"),
|
| 699 |
+
"collate_fn": {"type": "BatchImageCollateFunction"}, # correct spelling
|
| 700 |
"total_batch_size": int(batch),
|
| 701 |
}
|
| 702 |
cfg[dl_key] = block
|
|
|
|
| 716 |
block.setdefault("shuffle", bool(default_shuffle))
|
| 717 |
block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
|
| 718 |
|
| 719 |
+
# ---- FORCE-FIX collate name even if it existed already
|
| 720 |
+
cf = block.get("collate_fn", {})
|
| 721 |
+
if isinstance(cf, dict):
|
| 722 |
+
t = str(cf.get("type", ""))
|
| 723 |
+
if t.lower() == "batchimagecollatefuncion" or "Funcion" in t:
|
| 724 |
+
cf["type"] = "BatchImageCollateFunction"
|
| 725 |
+
block["collate_fn"] = cf
|
| 726 |
+
else:
|
| 727 |
+
block["collate_fn"] = {"type": "BatchImageCollateFunction"}
|
| 728 |
+
|
| 729 |
ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
|
| 730 |
ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
|
| 731 |
|