Update app.py
Browse files
app.py
CHANGED
|
@@ -573,7 +573,7 @@ def _maybe_set_model_field(cfg: dict, key: str, value):
|
|
| 573 |
return
|
| 574 |
cfg[key] = value # fallback
|
| 575 |
|
| 576 |
-
# --- CRITICAL
|
| 577 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 578 |
epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
|
| 579 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
|
@@ -586,9 +586,16 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 586 |
cfg = yaml.safe_load(f)
|
| 587 |
_absify_any_paths_deep(cfg, template_dir)
|
| 588 |
|
| 589 |
-
#
|
| 590 |
cfg["sync_bn"] = False
|
| 591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
ann_dir = os.path.join(merged_dir, "annotations")
|
| 593 |
paths = {
|
| 594 |
"train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
|
|
@@ -600,15 +607,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 600 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 601 |
}
|
| 602 |
|
| 603 |
-
#
|
| 604 |
-
inc_key = "__include__"
|
| 605 |
-
if inc_key in cfg and isinstance(cfg[inc_key], list):
|
| 606 |
-
cfg[inc_key] = [
|
| 607 |
-
p for p in cfg[inc_key]
|
| 608 |
-
if not (isinstance(p, str) and "configs/dataset/coco" in p.replace("\\", "/"))
|
| 609 |
-
]
|
| 610 |
-
|
| 611 |
-
# Helper to ensure & patch dataloaders
|
| 612 |
def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
|
| 613 |
block = cfg.get(dl_key)
|
| 614 |
if not isinstance(block, dict):
|
|
@@ -634,15 +633,18 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 634 |
"total_batch_size": int(batch),
|
| 635 |
}
|
| 636 |
cfg[dl_key] = block
|
|
|
|
|
|
|
| 637 |
ds = block.get("dataset", {})
|
| 638 |
if isinstance(ds, dict):
|
| 639 |
ds["img_folder"] = paths[img_key]
|
| 640 |
-
ds["ann_file"]
|
| 641 |
for k in ("img_dir", "image_root", "data_root"):
|
| 642 |
if k in ds: ds[k] = paths[img_key]
|
| 643 |
for k in ("ann_path", "annotation", "annotations"):
|
| 644 |
if k in ds: ds[k] = paths[json_key]
|
| 645 |
block["dataset"] = ds
|
|
|
|
| 646 |
block["total_batch_size"] = int(batch)
|
| 647 |
block.setdefault("num_workers", 2)
|
| 648 |
block.setdefault("shuffle", bool(default_shuffle))
|
|
@@ -650,57 +652,62 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 650 |
|
| 651 |
ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
|
| 652 |
ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
|
| 653 |
-
# Optional test loader
|
| 654 |
# ensure_and_patch_dl("test_dataloader", "test_img", "test_json", default_shuffle=False)
|
| 655 |
|
| 656 |
-
#
|
| 657 |
_set_num_classes_safely(cfg, int(class_count))
|
| 658 |
|
| 659 |
-
#
|
| 660 |
applied_epoch = False
|
| 661 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 662 |
if key in cfg:
|
| 663 |
-
cfg[key] = int(epochs)
|
|
|
|
|
|
|
| 664 |
if "solver" in cfg and isinstance(cfg["solver"], dict):
|
| 665 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 666 |
if key in cfg["solver"]:
|
| 667 |
-
cfg["solver"][key] = int(epochs)
|
|
|
|
|
|
|
| 668 |
if not applied_epoch:
|
| 669 |
cfg["epoches"] = int(epochs)
|
| 670 |
cfg["input_size"] = int(imgsz)
|
| 671 |
|
| 672 |
-
#
|
| 673 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
| 674 |
cfg["solver"] = {}
|
| 675 |
sol = cfg["solver"]
|
| 676 |
for key in ("base_lr", "lr", "learning_rate"):
|
| 677 |
if key in sol:
|
| 678 |
-
sol[key] = float(lr)
|
|
|
|
| 679 |
else:
|
| 680 |
sol["base_lr"] = float(lr)
|
| 681 |
sol["optimizer"] = str(optimizer).lower()
|
| 682 |
if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
|
| 683 |
sol["batch_size"] = int(batch)
|
| 684 |
|
| 685 |
-
#
|
| 686 |
if "output_dir" in cfg:
|
| 687 |
cfg["output_dir"] = paths["out_dir"]
|
| 688 |
else:
|
| 689 |
sol["output_dir"] = paths["out_dir"]
|
| 690 |
|
| 691 |
-
#
|
| 692 |
if pretrained_path:
|
| 693 |
p = os.path.abspath(pretrained_path)
|
| 694 |
_maybe_set_model_field(cfg, "pretrain", p)
|
| 695 |
_maybe_set_model_field(cfg, "pretrained", p)
|
| 696 |
|
| 697 |
-
# Save near the template so
|
| 698 |
cfg_out_dir = os.path.join(template_dir, "generated")
|
| 699 |
os.makedirs(cfg_out_dir, exist_ok=True)
|
| 700 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 701 |
|
| 702 |
# Force block style for lists (no inline [a, b, c])
|
| 703 |
-
class _NoFlowDumper(yaml.SafeDumper):
|
| 704 |
def _repr_list_block(dumper, data):
|
| 705 |
return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
|
| 706 |
_NoFlowDumper.add_representer(list, _repr_list_block)
|
|
@@ -998,4 +1005,66 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
|
|
| 998 |
with gr.TabItem("2. Manage & Merge"):
|
| 999 |
gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.")
|
| 1000 |
with gr.Row():
|
| 1001 |
-
class_df = gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
return
|
| 574 |
cfg[key] = value # fallback
|
| 575 |
|
| 576 |
+
# --- CRITICAL: dataset override + include cleanup + sync_bn off ---------------
|
| 577 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
| 578 |
epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
|
| 579 |
if not base_cfg_path or not os.path.exists(base_cfg_path):
|
|
|
|
| 586 |
cfg = yaml.safe_load(f)
|
| 587 |
_absify_any_paths_deep(cfg, template_dir)
|
| 588 |
|
| 589 |
+
# Disable SyncBN for single GPU/CPU runs
|
| 590 |
cfg["sync_bn"] = False
|
| 591 |
|
| 592 |
+
# Remove COCO dataset include so it can't override our dataset paths later
|
| 593 |
+
if "__include__" in cfg and isinstance(cfg["__include__"], list):
|
| 594 |
+
cfg["__include__"] = [
|
| 595 |
+
p for p in cfg["__include__"]
|
| 596 |
+
if not (isinstance(p, str) and "configs/dataset/coco" in p.replace("\\", "/"))
|
| 597 |
+
]
|
| 598 |
+
|
| 599 |
ann_dir = os.path.join(merged_dir, "annotations")
|
| 600 |
paths = {
|
| 601 |
"train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
|
|
|
|
| 607 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 608 |
}
|
| 609 |
|
| 610 |
+
# Ensure/patch dataloaders to point to our dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
|
| 612 |
block = cfg.get(dl_key)
|
| 613 |
if not isinstance(block, dict):
|
|
|
|
| 633 |
"total_batch_size": int(batch),
|
| 634 |
}
|
| 635 |
cfg[dl_key] = block
|
| 636 |
+
|
| 637 |
+
# Patch existing block
|
| 638 |
ds = block.get("dataset", {})
|
| 639 |
if isinstance(ds, dict):
|
| 640 |
ds["img_folder"] = paths[img_key]
|
| 641 |
+
ds["ann_file"] = paths[json_key]
|
| 642 |
for k in ("img_dir", "image_root", "data_root"):
|
| 643 |
if k in ds: ds[k] = paths[img_key]
|
| 644 |
for k in ("ann_path", "annotation", "annotations"):
|
| 645 |
if k in ds: ds[k] = paths[json_key]
|
| 646 |
block["dataset"] = ds
|
| 647 |
+
|
| 648 |
block["total_batch_size"] = int(batch)
|
| 649 |
block.setdefault("num_workers", 2)
|
| 650 |
block.setdefault("shuffle", bool(default_shuffle))
|
|
|
|
| 652 |
|
| 653 |
ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
|
| 654 |
ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
|
| 655 |
+
# Optional test loader
|
| 656 |
# ensure_and_patch_dl("test_dataloader", "test_img", "test_json", default_shuffle=False)
|
| 657 |
|
| 658 |
+
# num classes (handles model: "RTDETR")
|
| 659 |
_set_num_classes_safely(cfg, int(class_count))
|
| 660 |
|
| 661 |
+
# epochs / imgsz
|
| 662 |
applied_epoch = False
|
| 663 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 664 |
if key in cfg:
|
| 665 |
+
cfg[key] = int(epochs)
|
| 666 |
+
applied_epoch = True
|
| 667 |
+
break
|
| 668 |
if "solver" in cfg and isinstance(cfg["solver"], dict):
|
| 669 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 670 |
if key in cfg["solver"]:
|
| 671 |
+
cfg["solver"][key] = int(epochs)
|
| 672 |
+
applied_epoch = True
|
| 673 |
+
break
|
| 674 |
if not applied_epoch:
|
| 675 |
cfg["epoches"] = int(epochs)
|
| 676 |
cfg["input_size"] = int(imgsz)
|
| 677 |
|
| 678 |
+
# lr / optimizer / batch
|
| 679 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
| 680 |
cfg["solver"] = {}
|
| 681 |
sol = cfg["solver"]
|
| 682 |
for key in ("base_lr", "lr", "learning_rate"):
|
| 683 |
if key in sol:
|
| 684 |
+
sol[key] = float(lr)
|
| 685 |
+
break
|
| 686 |
else:
|
| 687 |
sol["base_lr"] = float(lr)
|
| 688 |
sol["optimizer"] = str(optimizer).lower()
|
| 689 |
if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
|
| 690 |
sol["batch_size"] = int(batch)
|
| 691 |
|
| 692 |
+
# output dir
|
| 693 |
if "output_dir" in cfg:
|
| 694 |
cfg["output_dir"] = paths["out_dir"]
|
| 695 |
else:
|
| 696 |
sol["output_dir"] = paths["out_dir"]
|
| 697 |
|
| 698 |
+
# pretrained weights in the right model block
|
| 699 |
if pretrained_path:
|
| 700 |
p = os.path.abspath(pretrained_path)
|
| 701 |
_maybe_set_model_field(cfg, "pretrain", p)
|
| 702 |
_maybe_set_model_field(cfg, "pretrained", p)
|
| 703 |
|
| 704 |
+
# Save near the template so internal relative references still make sense
|
| 705 |
cfg_out_dir = os.path.join(template_dir, "generated")
|
| 706 |
os.makedirs(cfg_out_dir, exist_ok=True)
|
| 707 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 708 |
|
| 709 |
# Force block style for lists (no inline [a, b, c])
|
| 710 |
+
class _NoFlowDumper(yaml.SafeDumper): ...
|
| 711 |
def _repr_list_block(dumper, data):
|
| 712 |
return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
|
| 713 |
_NoFlowDumper.add_representer(list, _repr_list_block)
|
|
|
|
| 1005 |
with gr.TabItem("2. Manage & Merge"):
|
| 1006 |
gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.")
|
| 1007 |
with gr.Row():
|
| 1008 |
+
class_df = gr.DataFrame(headers=["Original Name","Rename To","Max Images","Remove"],
|
| 1009 |
+
datatype=["str","str","number","bool"], label="Class Config", interactive=True, scale=3)
|
| 1010 |
+
with gr.Column(scale=1):
|
| 1011 |
+
class_count_summary_df = gr.DataFrame(label="Merged Class Counts Preview",
|
| 1012 |
+
headers=["Final Class Name","Est. Total Images"], interactive=False)
|
| 1013 |
+
update_counts_btn = gr.Button("Update Counts")
|
| 1014 |
+
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
|
| 1015 |
+
finalize_status = gr.Textbox(label="Status", interactive=False)
|
| 1016 |
+
|
| 1017 |
+
with gr.TabItem("3. Configure & Train"):
|
| 1018 |
+
gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
|
| 1019 |
+
with gr.Row():
|
| 1020 |
+
with gr.Column(scale=1):
|
| 1021 |
+
model_dd = gr.Dropdown(choices=[k for k,_ in MODEL_CHOICES], value=DEFAULT_MODEL_KEY,
|
| 1022 |
+
label="Model (RT-DETRv2)")
|
| 1023 |
+
run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
|
| 1024 |
+
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
|
| 1025 |
+
batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
|
| 1026 |
+
imgsz_num = gr.Number(label="Image Size", value=640)
|
| 1027 |
+
lr_num = gr.Number(label="Learning Rate", value=0.001)
|
| 1028 |
+
opt_dd = gr.Dropdown(["Adam","AdamW","SGD"], value="Adam", label="Optimizer")
|
| 1029 |
+
train_btn = gr.Button("Start Training", variant="primary")
|
| 1030 |
+
with gr.Column(scale=2):
|
| 1031 |
+
train_status = gr.Textbox(label="Live Logs (tail)", interactive=False, lines=12)
|
| 1032 |
+
loss_plot = gr.Plot(label="Loss")
|
| 1033 |
+
map_plot = gr.Plot(label="mAP")
|
| 1034 |
+
final_model_file = gr.File(label="Download Trained Checkpoint", interactive=False, visible=False)
|
| 1035 |
+
|
| 1036 |
+
with gr.TabItem("4. Upload Model"):
|
| 1037 |
+
gr.Markdown("Optionally push your checkpoint to Hugging Face / GitHub.")
|
| 1038 |
+
with gr.Row():
|
| 1039 |
+
with gr.Column():
|
| 1040 |
+
gr.Markdown("**Hugging Face**")
|
| 1041 |
+
hf_token = gr.Textbox(label="HF Token", type="password")
|
| 1042 |
+
hf_repo = gr.Textbox(label="HF Repo (user/repo)")
|
| 1043 |
+
with gr.Column():
|
| 1044 |
+
gr.Markdown("**GitHub**")
|
| 1045 |
+
gh_token = gr.Textbox(label="GitHub PAT", type="password")
|
| 1046 |
+
gh_repo = gr.Textbox(label="GitHub Repo (user/repo)")
|
| 1047 |
+
upload_btn = gr.Button("Upload", variant="primary")
|
| 1048 |
+
with gr.Row():
|
| 1049 |
+
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
|
| 1050 |
+
gh_status = gr.Textbox(label="GitHub Status", interactive=False)
|
| 1051 |
+
|
| 1052 |
+
load_btn.click(load_datasets_handler, [rf_api_key, rf_url_file],
|
| 1053 |
+
[dataset_status, dataset_info_state, class_df])
|
| 1054 |
+
update_counts_btn.click(update_class_counts_handler, [class_df, dataset_info_state],
|
| 1055 |
+
[class_count_summary_df])
|
| 1056 |
+
finalize_btn.click(finalize_handler, [dataset_info_state, class_df],
|
| 1057 |
+
[finalize_status, final_dataset_path_state])
|
| 1058 |
+
train_btn.click(training_handler,
|
| 1059 |
+
[final_dataset_path_state, model_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
|
| 1060 |
+
[train_status, loss_plot, map_plot, final_model_file])
|
| 1061 |
+
upload_btn.click(upload_handler, [final_model_file, hf_token, hf_repo, gh_token, gh_repo],
|
| 1062 |
+
[hf_status, gh_status])
|
| 1063 |
+
|
| 1064 |
+
if __name__ == "__main__":
|
| 1065 |
+
try:
|
| 1066 |
+
ts = find_training_script(REPO_DIR)
|
| 1067 |
+
logging.info(f"Startup check — training script at: {ts}")
|
| 1068 |
+
except Exception as e:
|
| 1069 |
+
logging.warning(f"Startup training-script check failed: {e}")
|
| 1070 |
+
app.launch(debug=True)
|