Update app.py
Browse files
app.py
CHANGED
|
@@ -145,7 +145,7 @@ def parse_roboflow_url(s: str):
|
|
| 145 |
version = None
|
| 146 |
if len(p) >= 3:
|
| 147 |
v = p[2]
|
| 148 |
-
if v.lower().startswith('v') and v[1:].isdigit():
|
| 149 |
version = int(v[1:])
|
| 150 |
elif v.isdigit():
|
| 151 |
version = int(v)
|
|
@@ -459,7 +459,7 @@ def _install_supervisely_logger_shim():
|
|
| 459 |
"""))
|
| 460 |
return str(root)
|
| 461 |
|
| 462 |
-
# ---- [
|
| 463 |
def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
|
| 464 |
"""
|
| 465 |
sitecustomize shim that:
|
|
@@ -566,23 +566,6 @@ if not _try_patch_now():
|
|
| 566 |
f.write(code)
|
| 567 |
return sc_path
|
| 568 |
|
| 569 |
-
# ---- Deprecated: on-disk workspace patch (no-op now) -------------------------
|
| 570 |
-
def _patch_workspace_create(repo_root: str, module_default: str = "rtdetrv2_pytorch.src") -> str | None:
|
| 571 |
-
"""
|
| 572 |
-
Deprecated: we no longer edit third-party files on disk.
|
| 573 |
-
The shim in sitecustomize.py handles cfg/_pymodule safely.
|
| 574 |
-
"""
|
| 575 |
-
return None
|
| 576 |
-
|
| 577 |
-
def _unpatch_workspace_create(repo_root: str):
|
| 578 |
-
ws_path = os.path.join(repo_root, "rtdetrv2_pytorch", "src", "core", "workspace.py")
|
| 579 |
-
bak_path = ws_path + ".bak"
|
| 580 |
-
if os.path.exists(bak_path):
|
| 581 |
-
try:
|
| 582 |
-
shutil.copy2(bak_path, ws_path)
|
| 583 |
-
except Exception:
|
| 584 |
-
pass
|
| 585 |
-
|
| 586 |
def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
|
| 587 |
url = CKPT_URLS.get(model_key)
|
| 588 |
if not url:
|
|
@@ -731,7 +714,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 731 |
"shuffle": bool(default_shuffle),
|
| 732 |
"num_workers": 2,
|
| 733 |
"drop_last": bool(dl_key == "train_dataloader"),
|
| 734 |
-
"collate_fn": {"type": "BatchImageCollateFunction"},
|
| 735 |
"total_batch_size": int(batch),
|
| 736 |
}
|
| 737 |
cfg[dl_key] = block
|
|
@@ -751,7 +734,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 751 |
block.setdefault("shuffle", bool(default_shuffle))
|
| 752 |
block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
|
| 753 |
|
| 754 |
-
# ---- FORCE-FIX collate name even if it existed already
|
| 755 |
cf = block.get("collate_fn", {})
|
| 756 |
if isinstance(cf, dict):
|
| 757 |
t = str(cf.get("type", ""))
|
|
@@ -805,7 +788,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 805 |
_maybe_set_model_field(cfg, "pretrain", p)
|
| 806 |
_maybe_set_model_field(cfg, "pretrained", p)
|
| 807 |
|
| 808 |
-
# Defensive: if after keeping includes we still don't have a model block, add a stub
|
| 809 |
if not cfg.get("model"):
|
| 810 |
cfg["model"] = {"type": "RTDETR", "num_classes": int(class_count)}
|
| 811 |
|
|
@@ -961,21 +943,13 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 961 |
try:
|
| 962 |
train_cwd = os.path.dirname(train_script)
|
| 963 |
|
| 964 |
-
# --- NEW: create a temp dir for sitecustomize and put it FIRST on PYTHONPATH
|
| 965 |
shim_dir = tempfile.mkdtemp(prefix="rtdetr_site_")
|
| 966 |
_install_workspace_shim_v3(shim_dir, module_default="rtdetrv2_pytorch.src")
|
| 967 |
|
| 968 |
env = os.environ.copy()
|
| 969 |
|
| 970 |
-
# Supervisely logger shim (can be later in path)
|
| 971 |
sly_shim_root = _install_supervisely_logger_shim()
|
| 972 |
|
| 973 |
-
# Build PYTHONPATH — order matters!
|
| 974 |
-
# 1) shim_dir (so sitecustomize auto-imports)
|
| 975 |
-
# 2) train_cwd (belt & suspenders; makes local imports easy)
|
| 976 |
-
# 3) PY_IMPL_DIR + REPO_DIR (RT-DETRv2 code)
|
| 977 |
-
# 4) sly_shim_root (optional)
|
| 978 |
-
# 5) existing PYTHONPATH
|
| 979 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
| 980 |
shim_dir,
|
| 981 |
train_cwd,
|
|
@@ -987,8 +961,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 987 |
|
| 988 |
env.setdefault("WANDB_DISABLED", "true")
|
| 989 |
env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
|
| 990 |
-
env.setdefault("PYTHONUNBUFFERED", "1")
|
| 991 |
-
# Optional tiny guard: pick a single visible GPU if available
|
| 992 |
if torch.cuda.is_available():
|
| 993 |
env.setdefault("CUDA_VISIBLE_DEVICES", "0")
|
| 994 |
|
|
@@ -1147,7 +1120,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 1147 |
gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
|
| 1148 |
with gr.Row():
|
| 1149 |
with gr.Column(scale=1):
|
| 1150 |
-
|
|
|
|
|
|
|
| 1151 |
label="Model (RT-DETRv2)")
|
| 1152 |
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
|
| 1153 |
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
|
|
|
|
| 145 |
version = None
|
| 146 |
if len(p) >= 3:
|
| 147 |
v = p[2]
|
| 148 |
+
if v.lower().startswith('v') and v[1:].isdigit():
|
| 149 |
version = int(v[1:])
|
| 150 |
elif v.isdigit():
|
| 151 |
version = int(v)
|
|
|
|
| 459 |
"""))
|
| 460 |
return str(root)
|
| 461 |
|
| 462 |
+
# ---- [!! CORRECTED !!] robust sitecustomize shim with lazy import hook --------------------
|
| 463 |
def _install_workspace_shim_v3(dest_dir: str, module_default: str = "rtdetrv2_pytorch.src"):
|
| 464 |
"""
|
| 465 |
sitecustomize shim that:
|
|
|
|
| 566 |
f.write(code)
|
| 567 |
return sc_path
|
| 568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
|
| 570 |
url = CKPT_URLS.get(model_key)
|
| 571 |
if not url:
|
|
|
|
| 714 |
"shuffle": bool(default_shuffle),
|
| 715 |
"num_workers": 2,
|
| 716 |
"drop_last": bool(dl_key == "train_dataloader"),
|
| 717 |
+
"collate_fn": {"type": "BatchImageCollateFunction"},
|
| 718 |
"total_batch_size": int(batch),
|
| 719 |
}
|
| 720 |
cfg[dl_key] = block
|
|
|
|
| 734 |
block.setdefault("shuffle", bool(default_shuffle))
|
| 735 |
block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
|
| 736 |
|
| 737 |
+
# ---- FORCE-FIX collate name typo even if it existed already
|
| 738 |
cf = block.get("collate_fn", {})
|
| 739 |
if isinstance(cf, dict):
|
| 740 |
t = str(cf.get("type", ""))
|
|
|
|
| 788 |
_maybe_set_model_field(cfg, "pretrain", p)
|
| 789 |
_maybe_set_model_field(cfg, "pretrained", p)
|
| 790 |
|
|
|
|
| 791 |
if not cfg.get("model"):
|
| 792 |
cfg["model"] = {"type": "RTDETR", "num_classes": int(class_count)}
|
| 793 |
|
|
|
|
| 943 |
try:
|
| 944 |
train_cwd = os.path.dirname(train_script)
|
| 945 |
|
|
|
|
| 946 |
shim_dir = tempfile.mkdtemp(prefix="rtdetr_site_")
|
| 947 |
_install_workspace_shim_v3(shim_dir, module_default="rtdetrv2_pytorch.src")
|
| 948 |
|
| 949 |
env = os.environ.copy()
|
| 950 |
|
|
|
|
| 951 |
sly_shim_root = _install_supervisely_logger_shim()
|
| 952 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
| 954 |
shim_dir,
|
| 955 |
train_cwd,
|
|
|
|
| 961 |
|
| 962 |
env.setdefault("WANDB_DISABLED", "true")
|
| 963 |
env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
|
| 964 |
+
env.setdefault("PYTHONUNBUFFERED", "1")
|
|
|
|
| 965 |
if torch.cuda.is_available():
|
| 966 |
env.setdefault("CUDA_VISIBLE_DEVICES", "0")
|
| 967 |
|
|
|
|
| 1120 |
gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
|
| 1121 |
with gr.Row():
|
| 1122 |
with gr.Column(scale=1):
|
| 1123 |
+
# [UI IMPROVEMENT] Using (label, value) format for a better user experience
|
| 1124 |
+
model_dd = gr.Dropdown(choices=[(label, value) for value, label in MODEL_CHOICES],
|
| 1125 |
+
value=DEFAULT_MODEL_KEY,
|
| 1126 |
label="Model (RT-DETRv2)")
|
| 1127 |
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
|
| 1128 |
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
|