wuhp commited on
Commit
04a0a1c
·
verified ·
1 Parent(s): d193c16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -24
app.py CHANGED
@@ -573,7 +573,7 @@ def _maybe_set_model_field(cfg: dict, key: str, value):
573
  return
574
  cfg[key] = value # fallback
575
 
576
- # --- CRITICAL FIX: force custom dataloaders & disable sync_bn -----------------
577
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
578
  epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
579
  if not base_cfg_path or not os.path.exists(base_cfg_path):
@@ -586,9 +586,16 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
586
  cfg = yaml.safe_load(f)
587
  _absify_any_paths_deep(cfg, template_dir)
588
 
589
- # Safer on single GPU/CPU
590
  cfg["sync_bn"] = False
591
 
 
 
 
 
 
 
 
592
  ann_dir = os.path.join(merged_dir, "annotations")
593
  paths = {
594
  "train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
@@ -600,15 +607,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
600
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
601
  }
602
 
603
- # Remove COCO dataset include so it can't override our dataloaders
604
- inc_key = "__include__"
605
- if inc_key in cfg and isinstance(cfg[inc_key], list):
606
- cfg[inc_key] = [
607
- p for p in cfg[inc_key]
608
- if not (isinstance(p, str) and "configs/dataset/coco" in p.replace("\\", "/"))
609
- ]
610
-
611
- # Helper to ensure & patch dataloaders
612
  def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
613
  block = cfg.get(dl_key)
614
  if not isinstance(block, dict):
@@ -634,15 +633,18 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
634
  "total_batch_size": int(batch),
635
  }
636
  cfg[dl_key] = block
 
 
637
  ds = block.get("dataset", {})
638
  if isinstance(ds, dict):
639
  ds["img_folder"] = paths[img_key]
640
- ds["ann_file"] = paths[json_key]
641
  for k in ("img_dir", "image_root", "data_root"):
642
  if k in ds: ds[k] = paths[img_key]
643
  for k in ("ann_path", "annotation", "annotations"):
644
  if k in ds: ds[k] = paths[json_key]
645
  block["dataset"] = ds
 
646
  block["total_batch_size"] = int(batch)
647
  block.setdefault("num_workers", 2)
648
  block.setdefault("shuffle", bool(default_shuffle))
@@ -650,57 +652,62 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
650
 
651
  ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
652
  ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
653
- # Optional test loader if needed:
654
  # ensure_and_patch_dl("test_dataloader", "test_img", "test_json", default_shuffle=False)
655
 
656
- # Classes
657
  _set_num_classes_safely(cfg, int(class_count))
658
 
659
- # Epochs / imgsz
660
  applied_epoch = False
661
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
662
  if key in cfg:
663
- cfg[key] = int(epochs); applied_epoch = True; break
 
 
664
  if "solver" in cfg and isinstance(cfg["solver"], dict):
665
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
666
  if key in cfg["solver"]:
667
- cfg["solver"][key] = int(epochs); applied_epoch = True; break
 
 
668
  if not applied_epoch:
669
  cfg["epoches"] = int(epochs)
670
  cfg["input_size"] = int(imgsz)
671
 
672
- # LR / optimizer / batch fallbacks
673
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
674
  cfg["solver"] = {}
675
  sol = cfg["solver"]
676
  for key in ("base_lr", "lr", "learning_rate"):
677
  if key in sol:
678
- sol[key] = float(lr); break
 
679
  else:
680
  sol["base_lr"] = float(lr)
681
  sol["optimizer"] = str(optimizer).lower()
682
  if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
683
  sol["batch_size"] = int(batch)
684
 
685
- # Output dir
686
  if "output_dir" in cfg:
687
  cfg["output_dir"] = paths["out_dir"]
688
  else:
689
  sol["output_dir"] = paths["out_dir"]
690
 
691
- # Pretrained weights in correct block
692
  if pretrained_path:
693
  p = os.path.abspath(pretrained_path)
694
  _maybe_set_model_field(cfg, "pretrain", p)
695
  _maybe_set_model_field(cfg, "pretrained", p)
696
 
697
- # Save near the template so any remaining relative includes still resolve
698
  cfg_out_dir = os.path.join(template_dir, "generated")
699
  os.makedirs(cfg_out_dir, exist_ok=True)
700
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
701
 
702
  # Force block style for lists (no inline [a, b, c])
703
- class _NoFlowDumper(yaml.SafeDumper): pass
704
  def _repr_list_block(dumper, data):
705
  return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
706
  _NoFlowDumper.add_representer(list, _repr_list_block)
@@ -998,4 +1005,66 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
998
  with gr.TabItem("2. Manage & Merge"):
999
  gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.")
1000
  with gr.Row():
1001
- class_df = gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  return
574
  cfg[key] = value # fallback
575
 
576
+ # --- CRITICAL: dataset override + include cleanup + sync_bn off ---------------
577
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
578
  epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
579
  if not base_cfg_path or not os.path.exists(base_cfg_path):
 
586
  cfg = yaml.safe_load(f)
587
  _absify_any_paths_deep(cfg, template_dir)
588
 
589
+ # Disable SyncBN for single GPU/CPU runs
590
  cfg["sync_bn"] = False
591
 
592
+ # Remove COCO dataset include so it can't override our dataset paths later
593
+ if "__include__" in cfg and isinstance(cfg["__include__"], list):
594
+ cfg["__include__"] = [
595
+ p for p in cfg["__include__"]
596
+ if not (isinstance(p, str) and "configs/dataset/coco" in p.replace("\\", "/"))
597
+ ]
598
+
599
  ann_dir = os.path.join(merged_dir, "annotations")
600
  paths = {
601
  "train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
 
607
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
608
  }
609
 
610
+ # Ensure/patch dataloaders to point to our dataset
 
 
 
 
 
 
 
 
611
  def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
612
  block = cfg.get(dl_key)
613
  if not isinstance(block, dict):
 
633
  "total_batch_size": int(batch),
634
  }
635
  cfg[dl_key] = block
636
+
637
+ # Patch existing block
638
  ds = block.get("dataset", {})
639
  if isinstance(ds, dict):
640
  ds["img_folder"] = paths[img_key]
641
+ ds["ann_file"] = paths[json_key]
642
  for k in ("img_dir", "image_root", "data_root"):
643
  if k in ds: ds[k] = paths[img_key]
644
  for k in ("ann_path", "annotation", "annotations"):
645
  if k in ds: ds[k] = paths[json_key]
646
  block["dataset"] = ds
647
+
648
  block["total_batch_size"] = int(batch)
649
  block.setdefault("num_workers", 2)
650
  block.setdefault("shuffle", bool(default_shuffle))
 
652
 
653
  ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
654
  ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
655
+ # Optional test loader
656
  # ensure_and_patch_dl("test_dataloader", "test_img", "test_json", default_shuffle=False)
657
 
658
+ # num classes (handles model: "RTDETR")
659
  _set_num_classes_safely(cfg, int(class_count))
660
 
661
+ # epochs / imgsz
662
  applied_epoch = False
663
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
664
  if key in cfg:
665
+ cfg[key] = int(epochs)
666
+ applied_epoch = True
667
+ break
668
  if "solver" in cfg and isinstance(cfg["solver"], dict):
669
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
670
  if key in cfg["solver"]:
671
+ cfg["solver"][key] = int(epochs)
672
+ applied_epoch = True
673
+ break
674
  if not applied_epoch:
675
  cfg["epoches"] = int(epochs)
676
  cfg["input_size"] = int(imgsz)
677
 
678
+ # lr / optimizer / batch
679
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
680
  cfg["solver"] = {}
681
  sol = cfg["solver"]
682
  for key in ("base_lr", "lr", "learning_rate"):
683
  if key in sol:
684
+ sol[key] = float(lr)
685
+ break
686
  else:
687
  sol["base_lr"] = float(lr)
688
  sol["optimizer"] = str(optimizer).lower()
689
  if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
690
  sol["batch_size"] = int(batch)
691
 
692
+ # output dir
693
  if "output_dir" in cfg:
694
  cfg["output_dir"] = paths["out_dir"]
695
  else:
696
  sol["output_dir"] = paths["out_dir"]
697
 
698
+ # pretrained weights in the right model block
699
  if pretrained_path:
700
  p = os.path.abspath(pretrained_path)
701
  _maybe_set_model_field(cfg, "pretrain", p)
702
  _maybe_set_model_field(cfg, "pretrained", p)
703
 
704
+ # Save near the template so internal relative references still make sense
705
  cfg_out_dir = os.path.join(template_dir, "generated")
706
  os.makedirs(cfg_out_dir, exist_ok=True)
707
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
708
 
709
  # Force block style for lists (no inline [a, b, c])
710
+ class _NoFlowDumper(yaml.SafeDumper): ...
711
  def _repr_list_block(dumper, data):
712
  return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
713
  _NoFlowDumper.add_representer(list, _repr_list_block)
 
1005
  with gr.TabItem("2. Manage & Merge"):
1006
  gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.")
1007
  with gr.Row():
1008
+ class_df = gr.DataFrame(headers=["Original Name","Rename To","Max Images","Remove"],
1009
+ datatype=["str","str","number","bool"], label="Class Config", interactive=True, scale=3)
1010
+ with gr.Column(scale=1):
1011
+ class_count_summary_df = gr.DataFrame(label="Merged Class Counts Preview",
1012
+ headers=["Final Class Name","Est. Total Images"], interactive=False)
1013
+ update_counts_btn = gr.Button("Update Counts")
1014
+ finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
1015
+ finalize_status = gr.Textbox(label="Status", interactive=False)
1016
+
1017
+ with gr.TabItem("3. Configure & Train"):
1018
+ gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
1019
+ with gr.Row():
1020
+ with gr.Column(scale=1):
1021
+ model_dd = gr.Dropdown(choices=[k for k,_ in MODEL_CHOICES], value=DEFAULT_MODEL_KEY,
1022
+ label="Model (RT-DETRv2)")
1023
+ run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
1024
+ epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
1025
+ batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
1026
+ imgsz_num = gr.Number(label="Image Size", value=640)
1027
+ lr_num = gr.Number(label="Learning Rate", value=0.001)
1028
+ opt_dd = gr.Dropdown(["Adam","AdamW","SGD"], value="Adam", label="Optimizer")
1029
+ train_btn = gr.Button("Start Training", variant="primary")
1030
+ with gr.Column(scale=2):
1031
+ train_status = gr.Textbox(label="Live Logs (tail)", interactive=False, lines=12)
1032
+ loss_plot = gr.Plot(label="Loss")
1033
+ map_plot = gr.Plot(label="mAP")
1034
+ final_model_file = gr.File(label="Download Trained Checkpoint", interactive=False, visible=False)
1035
+
1036
+ with gr.TabItem("4. Upload Model"):
1037
+ gr.Markdown("Optionally push your checkpoint to Hugging Face / GitHub.")
1038
+ with gr.Row():
1039
+ with gr.Column():
1040
+ gr.Markdown("**Hugging Face**")
1041
+ hf_token = gr.Textbox(label="HF Token", type="password")
1042
+ hf_repo = gr.Textbox(label="HF Repo (user/repo)")
1043
+ with gr.Column():
1044
+ gr.Markdown("**GitHub**")
1045
+ gh_token = gr.Textbox(label="GitHub PAT", type="password")
1046
+ gh_repo = gr.Textbox(label="GitHub Repo (user/repo)")
1047
+ upload_btn = gr.Button("Upload", variant="primary")
1048
+ with gr.Row():
1049
+ hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
1050
+ gh_status = gr.Textbox(label="GitHub Status", interactive=False)
1051
+
1052
+ load_btn.click(load_datasets_handler, [rf_api_key, rf_url_file],
1053
+ [dataset_status, dataset_info_state, class_df])
1054
+ update_counts_btn.click(update_class_counts_handler, [class_df, dataset_info_state],
1055
+ [class_count_summary_df])
1056
+ finalize_btn.click(finalize_handler, [dataset_info_state, class_df],
1057
+ [finalize_status, final_dataset_path_state])
1058
+ train_btn.click(training_handler,
1059
+ [final_dataset_path_state, model_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
1060
+ [train_status, loss_plot, map_plot, final_model_file])
1061
+ upload_btn.click(upload_handler, [final_model_file, hf_token, hf_repo, gh_token, gh_repo],
1062
+ [hf_status, gh_status])
1063
+
1064
+ if __name__ == "__main__":
1065
+ try:
1066
+ ts = find_training_script(REPO_DIR)
1067
+ logging.info(f"Startup check — training script at: {ts}")
1068
+ except Exception as e:
1069
+ logging.warning(f"Startup training-script check failed: {e}")
1070
+ app.launch(debug=True)