Update app.py
Browse files
app.py
CHANGED
|
@@ -16,9 +16,23 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
|
|
| 16 |
REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
|
| 17 |
REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
|
| 18 |
PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
|
|
|
|
|
|
|
| 19 |
COMMON_REQUIREMENTS = [
|
| 20 |
-
"gradio>=4.36.1",
|
| 21 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
]
|
| 23 |
|
| 24 |
# === bootstrap (clone + pip) ===================================================
|
|
@@ -37,11 +51,21 @@ def ensure_repo_and_requirements():
|
|
| 37 |
except Exception:
|
| 38 |
logging.warning("git pull failed; continuing with current checkout")
|
| 39 |
|
|
|
|
| 40 |
pip_install(COMMON_REQUIREMENTS)
|
|
|
|
|
|
|
| 41 |
req_file = os.path.join(PY_IMPL_DIR, "requirements.txt")
|
| 42 |
if os.path.exists(req_file):
|
| 43 |
pip_install(["-r", req_file])
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
try:
|
| 46 |
ensure_repo_and_requirements()
|
| 47 |
except Exception:
|
|
@@ -141,15 +165,10 @@ def label_path_for(img_path: str) -> str:
|
|
| 141 |
|
| 142 |
# === YOLOv8 -> COCO converter =================================================
|
| 143 |
def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
|
| 144 |
-
"""
|
| 145 |
-
Convert YOLO txt labels to a COCO annotations json.
|
| 146 |
-
"""
|
| 147 |
images, annotations = [], []
|
| 148 |
categories = [{"id": i, "name": n} for i, n in enumerate(class_names)]
|
| 149 |
ann_id = 1
|
| 150 |
img_id = 1
|
| 151 |
-
|
| 152 |
-
# Simple image size read (PIL); in Spaces this is fine.
|
| 153 |
for fname in sorted(os.listdir(split_dir_images)):
|
| 154 |
if not fname.lower().endswith((".jpg",".jpeg",".png")): continue
|
| 155 |
img_path = os.path.join(split_dir_images, fname)
|
|
@@ -157,10 +176,8 @@ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
|
|
| 157 |
with Image.open(img_path) as im:
|
| 158 |
w, h = im.size
|
| 159 |
except Exception:
|
| 160 |
-
# skip unreadable images
|
| 161 |
continue
|
| 162 |
images.append({"id": img_id, "file_name": fname, "width": w, "height": h})
|
| 163 |
-
|
| 164 |
label_file = os.path.join(split_dir_labels, os.path.splitext(fname)[0] + ".txt")
|
| 165 |
if os.path.exists(label_file):
|
| 166 |
with open(label_file, "r") as f:
|
|
@@ -169,7 +186,6 @@ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
|
|
| 169 |
if len(parts) < 5: continue
|
| 170 |
cls = int(float(parts[0]))
|
| 171 |
cx, cy, bw, bh = map(float, parts[1:5])
|
| 172 |
-
# convert normalized (cx,cy,bw,bh) to x,y,w,h in pixels
|
| 173 |
x = (cx - bw/2.0) * w
|
| 174 |
y = (cy - bh/2.0) * h
|
| 175 |
ww = bw * w
|
|
@@ -185,16 +201,11 @@ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
|
|
| 185 |
})
|
| 186 |
ann_id += 1
|
| 187 |
img_id += 1
|
| 188 |
-
|
| 189 |
coco = {"images": images, "annotations": annotations, "categories": categories}
|
| 190 |
os.makedirs(os.path.dirname(out_json), exist_ok=True)
|
| 191 |
with open(out_json, "w") as f: json.dump(coco, f)
|
| 192 |
|
| 193 |
def make_coco_annotations(merged_dir, class_names):
|
| 194 |
-
"""
|
| 195 |
-
Build COCO jsons under merged_dir/annotations:
|
| 196 |
-
instances_train.json, instances_val.json, instances_test.json
|
| 197 |
-
"""
|
| 198 |
ann_dir = os.path.join(merged_dir, "annotations")
|
| 199 |
os.makedirs(ann_dir, exist_ok=True)
|
| 200 |
mapping = {"train": "instances_train.json", "valid": "instances_val.json", "test": "instances_test.json"}
|
|
@@ -317,33 +328,19 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
|
|
| 317 |
'names': active_classes
|
| 318 |
}, f)
|
| 319 |
|
| 320 |
-
# also create COCO jsons for RT-DETRv2 training
|
| 321 |
ann_dir = make_coco_annotations(merged_dir, active_classes)
|
| 322 |
progress(0.98, desc="Finalizing...")
|
| 323 |
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
|
| 324 |
|
| 325 |
# === entrypoint + config detection/generation =================================
|
| 326 |
def find_training_script(repo_root):
|
| 327 |
-
"""
|
| 328 |
-
Recursively search for a tools/train.py (or train.py) suitable for RT-DETRv2.
|
| 329 |
-
"""
|
| 330 |
candidates = []
|
| 331 |
for pat in ["**/tools/train.py", "**/train.py"]:
|
| 332 |
candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
|
| 333 |
-
# Prefer ones inside rtdetrv2_pytorch
|
| 334 |
candidates.sort(key=lambda p: (0 if "rtdetrv2_pytorch" in p else 1, len(p)))
|
| 335 |
return candidates[0] if candidates else None
|
| 336 |
|
| 337 |
def find_model_config_template(model_key):
|
| 338 |
-
"""
|
| 339 |
-
Choose a native RT-DETRv2 config YAML from the Supervisely repo.
|
| 340 |
-
|
| 341 |
-
Heuristics:
|
| 342 |
-
- rtdetrv2_s -> r18 (Small)
|
| 343 |
-
- rtdetrv2_l -> r50 (Large)
|
| 344 |
-
- rtdetrv2_x -> r101 (X-Large)
|
| 345 |
-
Prefer files under rtdetrv2_pytorch/**/config(s) and with 'coco' in name.
|
| 346 |
-
"""
|
| 347 |
want_tokens = {
|
| 348 |
"rtdetrv2_s": ["rtdetrv2", "r18", "coco"],
|
| 349 |
"rtdetrv2_l": ["rtdetrv2", "r50", "coco"],
|
|
@@ -369,10 +366,6 @@ def find_model_config_template(model_key):
|
|
| 369 |
|
| 370 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 371 |
epochs, batch, imgsz, lr, optimizer):
|
| 372 |
-
"""
|
| 373 |
-
Load the chosen repo config and patch only the keys that already exist.
|
| 374 |
-
This avoids schema mismatches between forks.
|
| 375 |
-
"""
|
| 376 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
| 377 |
raise gr.Error("Could not locate a model config inside the RT-DETRv2 repo.")
|
| 378 |
|
|
@@ -385,12 +378,12 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 385 |
"val_json": os.path.abspath(os.path.join(ann_dir, "instances_val.json")),
|
| 386 |
"test_json": os.path.abspath(os.path.join(ann_dir, "instances_test.json")),
|
| 387 |
"train_img": os.path.abspath(os.path.join(merged_dir, "train", "images")),
|
| 388 |
-
"val_img": os.path.abspath(os.path.join(merged_dir, "valid", "images")),
|
| 389 |
"test_img": os.path.abspath(os.path.join(merged_dir, "test", "images")),
|
| 390 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 391 |
}
|
| 392 |
|
| 393 |
-
#
|
| 394 |
for root_key in ["dataset", "data"]:
|
| 395 |
if root_key in cfg and isinstance(cfg[root_key], dict):
|
| 396 |
ds = cfg[root_key]
|
|
@@ -401,17 +394,14 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 401 |
]:
|
| 402 |
if split in ds and isinstance(ds[split], dict):
|
| 403 |
ds[split]["name"] = ds[split].get("name", "coco")
|
| 404 |
-
# Common key variants across forks:
|
| 405 |
for k in ["ann_file", "ann_path", "annotation", "annotations"]:
|
| 406 |
if k in ds[split] or k in ["ann_file", "ann_path"]:
|
| 407 |
-
ds[split][k] = paths[jf]
|
| 408 |
-
break
|
| 409 |
for k in ["img_prefix", "img_dir", "image_root", "data_root"]:
|
| 410 |
if k in ds[split] or k in ["img_prefix", "img_dir"]:
|
| 411 |
-
ds[split][k] = paths[ip]
|
| 412 |
-
break
|
| 413 |
|
| 414 |
-
#
|
| 415 |
def set_num_classes(node, n):
|
| 416 |
if not isinstance(node, dict): return False
|
| 417 |
if "num_classes" in node:
|
|
@@ -426,7 +416,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 426 |
else:
|
| 427 |
cfg["model"] = {"num_classes": int(class_count)}
|
| 428 |
|
| 429 |
-
#
|
| 430 |
updated_epoch = False
|
| 431 |
for key in ["max_epoch", "epochs", "num_epochs"]:
|
| 432 |
if key in cfg:
|
|
@@ -442,7 +432,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 442 |
if key in cfg: cfg[key] = int(imgsz)
|
| 443 |
if "input_size" not in cfg: cfg["input_size"] = int(imgsz)
|
| 444 |
|
| 445 |
-
#
|
| 446 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
| 447 |
cfg["solver"] = {}
|
| 448 |
sol = cfg["solver"]
|
|
@@ -453,13 +443,12 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 453 |
sol["base_lr"] = float(lr)
|
| 454 |
|
| 455 |
sol["optimizer"] = str(optimizer).lower()
|
| 456 |
-
|
| 457 |
if "train_dataloader" in cfg and isinstance(cfg["train_dataloader"], dict):
|
| 458 |
cfg["train_dataloader"]["batch_size"] = int(batch)
|
| 459 |
else:
|
| 460 |
sol["batch_size"] = int(batch)
|
| 461 |
|
| 462 |
-
#
|
| 463 |
if "output_dir" in cfg:
|
| 464 |
cfg["output_dir"] = paths["out_dir"]
|
| 465 |
elif "solver" in cfg:
|
|
@@ -467,7 +456,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 467 |
else:
|
| 468 |
cfg["output_dir"] = paths["out_dir"]
|
| 469 |
|
| 470 |
-
# --- write patched config -------------------------------------------------
|
| 471 |
cfg_out_dir = os.path.join("generated_configs"); os.makedirs(cfg_out_dir, exist_ok=True)
|
| 472 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 473 |
with open(out_path, "w") as f: yaml.safe_dump(cfg, f, sort_keys=False)
|
|
@@ -570,23 +558,19 @@ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
|
|
| 570 |
def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
|
| 571 |
if not dataset_path: raise gr.Error("Finalize a dataset in Tab 2 before training.")
|
| 572 |
|
| 573 |
-
# 1) training script (nested-safe)
|
| 574 |
train_script = find_training_script(REPO_DIR)
|
| 575 |
if not train_script:
|
| 576 |
raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).")
|
| 577 |
|
| 578 |
-
# 2) base config = a real model template from the repo
|
| 579 |
base_cfg = find_model_config_template(model_key)
|
| 580 |
if not base_cfg:
|
| 581 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/L/X).")
|
| 582 |
|
| 583 |
-
# 3) read classes + ensure COCO JSONs up to date
|
| 584 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 585 |
with open(data_yaml, "r") as f: dy = yaml.safe_load(f)
|
| 586 |
class_names = [str(x) for x in dy.get("names", [])]
|
| 587 |
make_coco_annotations(dataset_path, class_names)
|
| 588 |
|
| 589 |
-
# 4) patch the base config safely (no custom schema assumptions)
|
| 590 |
cfg_path = patch_base_config(
|
| 591 |
base_cfg_path=base_cfg,
|
| 592 |
merged_dir=dataset_path,
|
|
@@ -602,7 +586,6 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 602 |
out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
|
| 603 |
os.makedirs(out_dir, exist_ok=True)
|
| 604 |
|
| 605 |
-
# 5) build & run command (no extra flags that might not exist)
|
| 606 |
cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)]
|
| 607 |
logging.info(f"Training command: {' '.join(cmd)}")
|
| 608 |
|
|
@@ -610,11 +593,9 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 610 |
def run_train():
|
| 611 |
try:
|
| 612 |
env = os.environ.copy()
|
| 613 |
-
# Ensure both repo root and pytorch impl are on PYTHONPATH
|
| 614 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
| 615 |
PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
|
| 616 |
]))
|
| 617 |
-
# Disable wandb in Spaces by default
|
| 618 |
env.setdefault("WANDB_DISABLED", "true")
|
| 619 |
proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
|
| 620 |
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
|
@@ -628,7 +609,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 628 |
Thread(target=run_train, daemon=True).start()
|
| 629 |
|
| 630 |
log_tail, last_epoch, total_epochs = [], 0, int(epochs)
|
| 631 |
-
first_lines = []
|
| 632 |
while True:
|
| 633 |
line = q.get()
|
| 634 |
if line.startswith("__EXITCODE__"):
|
|
|
|
| 16 |
REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
|
| 17 |
REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
|
| 18 |
PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
|
| 19 |
+
|
| 20 |
+
# Core deps + your requested packages; pinned as lower-bounds to avoid downgrades
|
| 21 |
COMMON_REQUIREMENTS = [
|
| 22 |
+
"gradio>=4.36.1",
|
| 23 |
+
"ultralytics>=8.2.0",
|
| 24 |
+
"roboflow>=1.1.28",
|
| 25 |
+
"requests>=2.31.0",
|
| 26 |
+
"huggingface_hub>=0.22.0",
|
| 27 |
+
"pandas>=2.0.0",
|
| 28 |
+
"matplotlib>=3.7.0",
|
| 29 |
+
"torch>=2.0.1",
|
| 30 |
+
"torchvision>=0.15.2",
|
| 31 |
+
"pyyaml>=6.0.1",
|
| 32 |
+
"Pillow>=10.0.0",
|
| 33 |
+
"supervisely>=6.0.0", # <- fixes ModuleNotFoundError from repo trainer
|
| 34 |
+
"tensorboard>=2.13.0", # convenience: sometimes used by forks
|
| 35 |
+
"pycocotools>=2.0.7", # convenience: ensure wheels are present
|
| 36 |
]
|
| 37 |
|
| 38 |
# === bootstrap (clone + pip) ===================================================
|
|
|
|
| 51 |
except Exception:
|
| 52 |
logging.warning("git pull failed; continuing with current checkout")
|
| 53 |
|
| 54 |
+
# Make sure all our app/runtime deps (incl. supervisely & ultralytics) are present
|
| 55 |
pip_install(COMMON_REQUIREMENTS)
|
| 56 |
+
|
| 57 |
+
# Then install repo-specific extras (pycocotools/tensorboard etc. if required)
|
| 58 |
req_file = os.path.join(PY_IMPL_DIR, "requirements.txt")
|
| 59 |
if os.path.exists(req_file):
|
| 60 |
pip_install(["-r", req_file])
|
| 61 |
|
| 62 |
+
# Double-check supervisely importability; if not, try again explicitly.
|
| 63 |
+
try:
|
| 64 |
+
import supervisely # noqa: F401
|
| 65 |
+
except Exception:
|
| 66 |
+
logging.warning("supervisely not importable after first pass; retrying install…")
|
| 67 |
+
pip_install(["supervisely>=6.0.0"])
|
| 68 |
+
|
| 69 |
try:
|
| 70 |
ensure_repo_and_requirements()
|
| 71 |
except Exception:
|
|
|
|
| 165 |
|
| 166 |
# === YOLOv8 -> COCO converter =================================================
|
| 167 |
def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
|
|
|
|
|
|
|
|
|
|
| 168 |
images, annotations = [], []
|
| 169 |
categories = [{"id": i, "name": n} for i, n in enumerate(class_names)]
|
| 170 |
ann_id = 1
|
| 171 |
img_id = 1
|
|
|
|
|
|
|
| 172 |
for fname in sorted(os.listdir(split_dir_images)):
|
| 173 |
if not fname.lower().endswith((".jpg",".jpeg",".png")): continue
|
| 174 |
img_path = os.path.join(split_dir_images, fname)
|
|
|
|
| 176 |
with Image.open(img_path) as im:
|
| 177 |
w, h = im.size
|
| 178 |
except Exception:
|
|
|
|
| 179 |
continue
|
| 180 |
images.append({"id": img_id, "file_name": fname, "width": w, "height": h})
|
|
|
|
| 181 |
label_file = os.path.join(split_dir_labels, os.path.splitext(fname)[0] + ".txt")
|
| 182 |
if os.path.exists(label_file):
|
| 183 |
with open(label_file, "r") as f:
|
|
|
|
| 186 |
if len(parts) < 5: continue
|
| 187 |
cls = int(float(parts[0]))
|
| 188 |
cx, cy, bw, bh = map(float, parts[1:5])
|
|
|
|
| 189 |
x = (cx - bw/2.0) * w
|
| 190 |
y = (cy - bh/2.0) * h
|
| 191 |
ww = bw * w
|
|
|
|
| 201 |
})
|
| 202 |
ann_id += 1
|
| 203 |
img_id += 1
|
|
|
|
| 204 |
coco = {"images": images, "annotations": annotations, "categories": categories}
|
| 205 |
os.makedirs(os.path.dirname(out_json), exist_ok=True)
|
| 206 |
with open(out_json, "w") as f: json.dump(coco, f)
|
| 207 |
|
| 208 |
def make_coco_annotations(merged_dir, class_names):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
ann_dir = os.path.join(merged_dir, "annotations")
|
| 210 |
os.makedirs(ann_dir, exist_ok=True)
|
| 211 |
mapping = {"train": "instances_train.json", "valid": "instances_val.json", "test": "instances_test.json"}
|
|
|
|
| 328 |
'names': active_classes
|
| 329 |
}, f)
|
| 330 |
|
|
|
|
| 331 |
ann_dir = make_coco_annotations(merged_dir, active_classes)
|
| 332 |
progress(0.98, desc="Finalizing...")
|
| 333 |
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
|
| 334 |
|
| 335 |
# === entrypoint + config detection/generation =================================
|
| 336 |
def find_training_script(repo_root):
|
|
|
|
|
|
|
|
|
|
| 337 |
candidates = []
|
| 338 |
for pat in ["**/tools/train.py", "**/train.py"]:
|
| 339 |
candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
|
|
|
|
| 340 |
candidates.sort(key=lambda p: (0 if "rtdetrv2_pytorch" in p else 1, len(p)))
|
| 341 |
return candidates[0] if candidates else None
|
| 342 |
|
| 343 |
def find_model_config_template(model_key):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
want_tokens = {
|
| 345 |
"rtdetrv2_s": ["rtdetrv2", "r18", "coco"],
|
| 346 |
"rtdetrv2_l": ["rtdetrv2", "r50", "coco"],
|
|
|
|
| 366 |
|
| 367 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 368 |
epochs, batch, imgsz, lr, optimizer):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
| 370 |
raise gr.Error("Could not locate a model config inside the RT-DETRv2 repo.")
|
| 371 |
|
|
|
|
| 378 |
"val_json": os.path.abspath(os.path.join(ann_dir, "instances_val.json")),
|
| 379 |
"test_json": os.path.abspath(os.path.join(ann_dir, "instances_test.json")),
|
| 380 |
"train_img": os.path.abspath(os.path.join(merged_dir, "train", "images")),
|
| 381 |
+
"val_img": os.path.abspath(os.path.join(merged_dir, "valid", "images")),
|
| 382 |
"test_img": os.path.abspath(os.path.join(merged_dir, "test", "images")),
|
| 383 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 384 |
}
|
| 385 |
|
| 386 |
+
# dataset block
|
| 387 |
for root_key in ["dataset", "data"]:
|
| 388 |
if root_key in cfg and isinstance(cfg[root_key], dict):
|
| 389 |
ds = cfg[root_key]
|
|
|
|
| 394 |
]:
|
| 395 |
if split in ds and isinstance(ds[split], dict):
|
| 396 |
ds[split]["name"] = ds[split].get("name", "coco")
|
|
|
|
| 397 |
for k in ["ann_file", "ann_path", "annotation", "annotations"]:
|
| 398 |
if k in ds[split] or k in ["ann_file", "ann_path"]:
|
| 399 |
+
ds[split][k] = paths[jf]; break
|
|
|
|
| 400 |
for k in ["img_prefix", "img_dir", "image_root", "data_root"]:
|
| 401 |
if k in ds[split] or k in ["img_prefix", "img_dir"]:
|
| 402 |
+
ds[split][k] = paths[ip]; break
|
|
|
|
| 403 |
|
| 404 |
+
# num_classes
|
| 405 |
def set_num_classes(node, n):
|
| 406 |
if not isinstance(node, dict): return False
|
| 407 |
if "num_classes" in node:
|
|
|
|
| 416 |
else:
|
| 417 |
cfg["model"] = {"num_classes": int(class_count)}
|
| 418 |
|
| 419 |
+
# epochs / imgsz
|
| 420 |
updated_epoch = False
|
| 421 |
for key in ["max_epoch", "epochs", "num_epochs"]:
|
| 422 |
if key in cfg:
|
|
|
|
| 432 |
if key in cfg: cfg[key] = int(imgsz)
|
| 433 |
if "input_size" not in cfg: cfg["input_size"] = int(imgsz)
|
| 434 |
|
| 435 |
+
# lr / optimizer / batch
|
| 436 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
| 437 |
cfg["solver"] = {}
|
| 438 |
sol = cfg["solver"]
|
|
|
|
| 443 |
sol["base_lr"] = float(lr)
|
| 444 |
|
| 445 |
sol["optimizer"] = str(optimizer).lower()
|
|
|
|
| 446 |
if "train_dataloader" in cfg and isinstance(cfg["train_dataloader"], dict):
|
| 447 |
cfg["train_dataloader"]["batch_size"] = int(batch)
|
| 448 |
else:
|
| 449 |
sol["batch_size"] = int(batch)
|
| 450 |
|
| 451 |
+
# output dir
|
| 452 |
if "output_dir" in cfg:
|
| 453 |
cfg["output_dir"] = paths["out_dir"]
|
| 454 |
elif "solver" in cfg:
|
|
|
|
| 456 |
else:
|
| 457 |
cfg["output_dir"] = paths["out_dir"]
|
| 458 |
|
|
|
|
| 459 |
cfg_out_dir = os.path.join("generated_configs"); os.makedirs(cfg_out_dir, exist_ok=True)
|
| 460 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 461 |
with open(out_path, "w") as f: yaml.safe_dump(cfg, f, sort_keys=False)
|
|
|
|
| 558 |
def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
|
| 559 |
if not dataset_path: raise gr.Error("Finalize a dataset in Tab 2 before training.")
|
| 560 |
|
|
|
|
| 561 |
train_script = find_training_script(REPO_DIR)
|
| 562 |
if not train_script:
|
| 563 |
raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).")
|
| 564 |
|
|
|
|
| 565 |
base_cfg = find_model_config_template(model_key)
|
| 566 |
if not base_cfg:
|
| 567 |
raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/L/X).")
|
| 568 |
|
|
|
|
| 569 |
data_yaml = os.path.join(dataset_path, "data.yaml")
|
| 570 |
with open(data_yaml, "r") as f: dy = yaml.safe_load(f)
|
| 571 |
class_names = [str(x) for x in dy.get("names", [])]
|
| 572 |
make_coco_annotations(dataset_path, class_names)
|
| 573 |
|
|
|
|
| 574 |
cfg_path = patch_base_config(
|
| 575 |
base_cfg_path=base_cfg,
|
| 576 |
merged_dir=dataset_path,
|
|
|
|
| 586 |
out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
|
| 587 |
os.makedirs(out_dir, exist_ok=True)
|
| 588 |
|
|
|
|
| 589 |
cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)]
|
| 590 |
logging.info(f"Training command: {' '.join(cmd)}")
|
| 591 |
|
|
|
|
| 593 |
def run_train():
|
| 594 |
try:
|
| 595 |
env = os.environ.copy()
|
|
|
|
| 596 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
| 597 |
PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
|
| 598 |
]))
|
|
|
|
| 599 |
env.setdefault("WANDB_DISABLED", "true")
|
| 600 |
proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
|
| 601 |
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
|
|
|
| 609 |
Thread(target=run_train, daemon=True).start()
|
| 610 |
|
| 611 |
log_tail, last_epoch, total_epochs = [], 0, int(epochs)
|
| 612 |
+
first_lines = []
|
| 613 |
while True:
|
| 614 |
line = q.get()
|
| 615 |
if line.startswith("__EXITCODE__"):
|