Update app.py
Browse files
app.py
CHANGED
|
@@ -573,6 +573,7 @@ def _maybe_set_model_field(cfg: dict, key: str, value):
|
|
| 573 |
return
|
| 574 |
cfg[key] = value # fallback
|
| 575 |
|
|
|
|
| 576 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 577 |
epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
|
| 578 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
|
@@ -585,6 +586,9 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 585 |
cfg = yaml.safe_load(f)
|
| 586 |
_absify_any_paths_deep(cfg, template_dir)
|
| 587 |
|
|
|
|
|
|
|
|
|
|
| 588 |
ann_dir = os.path.join(merged_dir, "annotations")
|
| 589 |
paths = {
|
| 590 |
"train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
|
|
@@ -596,83 +600,107 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 596 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 597 |
}
|
| 598 |
|
| 599 |
-
#
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
_set_num_classes_safely(cfg, int(class_count))
|
| 622 |
|
| 623 |
-
#
|
| 624 |
applied_epoch = False
|
| 625 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 626 |
if key in cfg:
|
| 627 |
-
cfg[key] = int(epochs)
|
| 628 |
-
applied_epoch = True
|
| 629 |
-
break
|
| 630 |
if "solver" in cfg and isinstance(cfg["solver"], dict):
|
| 631 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 632 |
if key in cfg["solver"]:
|
| 633 |
-
cfg["solver"][key] = int(epochs)
|
| 634 |
-
applied_epoch = True
|
| 635 |
-
break
|
| 636 |
if not applied_epoch:
|
| 637 |
-
cfg["epoches"] = int(epochs)
|
| 638 |
-
|
| 639 |
-
# image size knobs: unify on top-level input_size (respected by templates)
|
| 640 |
cfg["input_size"] = int(imgsz)
|
| 641 |
|
| 642 |
-
#
|
| 643 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
| 644 |
cfg["solver"] = {}
|
| 645 |
sol = cfg["solver"]
|
| 646 |
for key in ("base_lr", "lr", "learning_rate"):
|
| 647 |
if key in sol:
|
| 648 |
-
sol[key] = float(lr)
|
| 649 |
-
break
|
| 650 |
else:
|
| 651 |
sol["base_lr"] = float(lr)
|
| 652 |
sol["optimizer"] = str(optimizer).lower()
|
| 653 |
if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
|
| 654 |
sol["batch_size"] = int(batch)
|
| 655 |
|
| 656 |
-
#
|
| 657 |
if "output_dir" in cfg:
|
| 658 |
cfg["output_dir"] = paths["out_dir"]
|
| 659 |
else:
|
| 660 |
sol["output_dir"] = paths["out_dir"]
|
| 661 |
|
| 662 |
-
#
|
| 663 |
if pretrained_path:
|
| 664 |
p = os.path.abspath(pretrained_path)
|
| 665 |
_maybe_set_model_field(cfg, "pretrain", p)
|
| 666 |
_maybe_set_model_field(cfg, "pretrained", p)
|
| 667 |
|
| 668 |
-
# Save near the template so
|
| 669 |
cfg_out_dir = os.path.join(template_dir, "generated")
|
| 670 |
os.makedirs(cfg_out_dir, exist_ok=True)
|
| 671 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 672 |
|
| 673 |
# Force block style for lists (no inline [a, b, c])
|
| 674 |
-
class _NoFlowDumper(yaml.SafeDumper):
|
| 675 |
-
pass
|
| 676 |
def _repr_list_block(dumper, data):
|
| 677 |
return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
|
| 678 |
_NoFlowDumper.add_representer(list, _repr_list_block)
|
|
@@ -970,66 +998,4 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 970 |
with gr.TabItem("2. Manage & Merge"):
|
| 971 |
gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.")
|
| 972 |
with gr.Row():
|
| 973 |
-
class_df = gr
|
| 974 |
-
datatype=["str","str","number","bool"], label="Class Config", interactive=True, scale=3)
|
| 975 |
-
with gr.Column(scale=1):
|
| 976 |
-
class_count_summary_df = gr.DataFrame(label="Merged Class Counts Preview",
|
| 977 |
-
headers=["Final Class Name","Est. Total Images"], interactive=False)
|
| 978 |
-
update_counts_btn = gr.Button("Update Counts")
|
| 979 |
-
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
|
| 980 |
-
finalize_status = gr.Textbox(label="Status", interactive=False)
|
| 981 |
-
|
| 982 |
-
with gr.TabItem("3. Configure & Train"):
|
| 983 |
-
gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
|
| 984 |
-
with gr.Row():
|
| 985 |
-
with gr.Column(scale=1):
|
| 986 |
-
model_dd = gr.Dropdown(choices=[k for k,_ in MODEL_CHOICES], value=DEFAULT_MODEL_KEY,
|
| 987 |
-
label="Model (RT-DETRv2)")
|
| 988 |
-
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
|
| 989 |
-
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
|
| 990 |
-
batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
|
| 991 |
-
imgsz_num = gr.Number(label="Image Size", value=640)
|
| 992 |
-
lr_num = gr.Number(label="Learning Rate", value=0.001)
|
| 993 |
-
opt_dd = gr.Dropdown(["Adam","AdamW","SGD"], value="Adam", label="Optimizer")
|
| 994 |
-
train_btn = gr.Button("Start Training", variant="primary")
|
| 995 |
-
with gr.Column(scale=2):
|
| 996 |
-
train_status = gr.Textbox(label="Live Logs (tail)", interactive=False, lines=12)
|
| 997 |
-
loss_plot = gr.Plot(label="Loss")
|
| 998 |
-
map_plot = gr.Plot(label="mAP")
|
| 999 |
-
final_model_file = gr.File(label="Download Trained Checkpoint", interactive=False, visible=False)
|
| 1000 |
-
|
| 1001 |
-
with gr.TabItem("4. Upload Model"):
|
| 1002 |
-
gr.Markdown("Optionally push your checkpoint to Hugging Face / GitHub.")
|
| 1003 |
-
with gr.Row():
|
| 1004 |
-
with gr.Column():
|
| 1005 |
-
gr.Markdown("**Hugging Face**")
|
| 1006 |
-
hf_token = gr.Textbox(label="HF Token", type="password")
|
| 1007 |
-
hf_repo = gr.Textbox(label="HF Repo (user/repo)")
|
| 1008 |
-
with gr.Column():
|
| 1009 |
-
gr.Markdown("**GitHub**")
|
| 1010 |
-
gh_token = gr.Textbox(label="GitHub PAT", type="password")
|
| 1011 |
-
gh_repo = gr.Textbox(label="GitHub Repo (user/repo)")
|
| 1012 |
-
upload_btn = gr.Button("Upload", variant="primary")
|
| 1013 |
-
with gr.Row():
|
| 1014 |
-
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
|
| 1015 |
-
gh_status = gr.Textbox(label="GitHub Status", interactive=False)
|
| 1016 |
-
|
| 1017 |
-
load_btn.click(load_datasets_handler, [rf_api_key, rf_url_file],
|
| 1018 |
-
[dataset_status, dataset_info_state, class_df])
|
| 1019 |
-
update_counts_btn.click(update_class_counts_handler, [class_df, dataset_info_state],
|
| 1020 |
-
[class_count_summary_df])
|
| 1021 |
-
finalize_btn.click(finalize_handler, [dataset_info_state, class_df],
|
| 1022 |
-
[finalize_status, final_dataset_path_state])
|
| 1023 |
-
train_btn.click(training_handler,
|
| 1024 |
-
[final_dataset_path_state, model_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
|
| 1025 |
-
[train_status, loss_plot, map_plot, final_model_file])
|
| 1026 |
-
upload_btn.click(upload_handler, [final_model_file, hf_token, hf_repo, gh_token, gh_repo],
|
| 1027 |
-
[hf_status, gh_status])
|
| 1028 |
-
|
| 1029 |
-
if __name__ == "__main__":
|
| 1030 |
-
try:
|
| 1031 |
-
ts = find_training_script(REPO_DIR)
|
| 1032 |
-
logging.info(f"Startup check — training script at: {ts}")
|
| 1033 |
-
except Exception as e:
|
| 1034 |
-
logging.warning(f"Startup training-script check failed: {e}")
|
| 1035 |
-
app.launch(debug=True)
|
|
|
|
| 573 |
return
|
| 574 |
cfg[key] = value # fallback
|
| 575 |
|
| 576 |
+
# --- CRITICAL FIX: force custom dataloaders & disable sync_bn -----------------
|
| 577 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 578 |
epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
|
| 579 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
|
|
|
| 586 |
cfg = yaml.safe_load(f)
|
| 587 |
_absify_any_paths_deep(cfg, template_dir)
|
| 588 |
|
| 589 |
+
# Safer on single GPU/CPU
|
| 590 |
+
cfg["sync_bn"] = False
|
| 591 |
+
|
| 592 |
ann_dir = os.path.join(merged_dir, "annotations")
|
| 593 |
paths = {
|
| 594 |
"train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
|
|
|
|
| 600 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 601 |
}
|
| 602 |
|
| 603 |
+
# Remove COCO dataset include so it can't override our dataloaders
|
| 604 |
+
inc_key = "__include__"
|
| 605 |
+
if inc_key in cfg and isinstance(cfg[inc_key], list):
|
| 606 |
+
cfg[inc_key] = [
|
| 607 |
+
p for p in cfg[inc_key]
|
| 608 |
+
if not (isinstance(p, str) and "configs/dataset/coco" in p.replace("\\", "/"))
|
| 609 |
+
]
|
| 610 |
+
|
| 611 |
+
# Helper to ensure & patch dataloaders
|
| 612 |
+
def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
|
| 613 |
+
block = cfg.get(dl_key)
|
| 614 |
+
if not isinstance(block, dict):
|
| 615 |
+
block = {
|
| 616 |
+
"type": "DataLoader",
|
| 617 |
+
"dataset": {
|
| 618 |
+
"type": "CocoDetection",
|
| 619 |
+
"img_folder": paths[img_key],
|
| 620 |
+
"ann_file": paths[json_key],
|
| 621 |
+
"return_masks": False,
|
| 622 |
+
"transforms": {
|
| 623 |
+
"type": "Compose",
|
| 624 |
+
"ops": [
|
| 625 |
+
{"type": "Resize", "size": [int(imgsz), int(imgsz)]},
|
| 626 |
+
{"type": "ConvertPILImage", "dtype": "float32", "scale": True},
|
| 627 |
+
],
|
| 628 |
+
},
|
| 629 |
+
},
|
| 630 |
+
"shuffle": bool(default_shuffle),
|
| 631 |
+
"num_workers": 2,
|
| 632 |
+
"drop_last": bool(dl_key == "train_dataloader"),
|
| 633 |
+
"collate_fn": {"type": "BatchImageCollateFuncion"},
|
| 634 |
+
"total_batch_size": int(batch),
|
| 635 |
+
}
|
| 636 |
+
cfg[dl_key] = block
|
| 637 |
+
ds = block.get("dataset", {})
|
| 638 |
+
if isinstance(ds, dict):
|
| 639 |
+
ds["img_folder"] = paths[img_key]
|
| 640 |
+
ds["ann_file"] = paths[json_key]
|
| 641 |
+
for k in ("img_dir", "image_root", "data_root"):
|
| 642 |
+
if k in ds: ds[k] = paths[img_key]
|
| 643 |
+
for k in ("ann_path", "annotation", "annotations"):
|
| 644 |
+
if k in ds: ds[k] = paths[json_key]
|
| 645 |
+
block["dataset"] = ds
|
| 646 |
+
block["total_batch_size"] = int(batch)
|
| 647 |
+
block.setdefault("num_workers", 2)
|
| 648 |
+
block.setdefault("shuffle", bool(default_shuffle))
|
| 649 |
+
block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
|
| 650 |
+
|
| 651 |
+
ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
|
| 652 |
+
ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
|
| 653 |
+
# Optional test loader if needed:
|
| 654 |
+
# ensure_and_patch_dl("test_dataloader", "test_img", "test_json", default_shuffle=False)
|
| 655 |
+
|
| 656 |
+
# Classes
|
| 657 |
_set_num_classes_safely(cfg, int(class_count))
|
| 658 |
|
| 659 |
+
# Epochs / imgsz
|
| 660 |
applied_epoch = False
|
| 661 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 662 |
if key in cfg:
|
| 663 |
+
cfg[key] = int(epochs); applied_epoch = True; break
|
|
|
|
|
|
|
| 664 |
if "solver" in cfg and isinstance(cfg["solver"], dict):
|
| 665 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 666 |
if key in cfg["solver"]:
|
| 667 |
+
cfg["solver"][key] = int(epochs); applied_epoch = True; break
|
|
|
|
|
|
|
| 668 |
if not applied_epoch:
|
| 669 |
+
cfg["epoches"] = int(epochs)
|
|
|
|
|
|
|
| 670 |
cfg["input_size"] = int(imgsz)
|
| 671 |
|
| 672 |
+
# LR / optimizer / batch fallbacks
|
| 673 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
| 674 |
cfg["solver"] = {}
|
| 675 |
sol = cfg["solver"]
|
| 676 |
for key in ("base_lr", "lr", "learning_rate"):
|
| 677 |
if key in sol:
|
| 678 |
+
sol[key] = float(lr); break
|
|
|
|
| 679 |
else:
|
| 680 |
sol["base_lr"] = float(lr)
|
| 681 |
sol["optimizer"] = str(optimizer).lower()
|
| 682 |
if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
|
| 683 |
sol["batch_size"] = int(batch)
|
| 684 |
|
| 685 |
+
# Output dir
|
| 686 |
if "output_dir" in cfg:
|
| 687 |
cfg["output_dir"] = paths["out_dir"]
|
| 688 |
else:
|
| 689 |
sol["output_dir"] = paths["out_dir"]
|
| 690 |
|
| 691 |
+
# Pretrained weights in correct block
|
| 692 |
if pretrained_path:
|
| 693 |
p = os.path.abspath(pretrained_path)
|
| 694 |
_maybe_set_model_field(cfg, "pretrain", p)
|
| 695 |
_maybe_set_model_field(cfg, "pretrained", p)
|
| 696 |
|
| 697 |
+
# Save near the template so any remaining relative includes still resolve
|
| 698 |
cfg_out_dir = os.path.join(template_dir, "generated")
|
| 699 |
os.makedirs(cfg_out_dir, exist_ok=True)
|
| 700 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 701 |
|
| 702 |
# Force block style for lists (no inline [a, b, c])
|
| 703 |
+
class _NoFlowDumper(yaml.SafeDumper): pass
|
|
|
|
| 704 |
def _repr_list_block(dumper, data):
|
| 705 |
return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
|
| 706 |
_NoFlowDumper.add_representer(list, _repr_list_block)
|
|
|
|
| 998 |
with gr.TabItem("2. Manage & Merge"):
|
| 999 |
gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.")
|
| 1000 |
with gr.Row():
|
| 1001 |
+
class_df = gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|