wuhp commited on
Commit
d193c16
·
verified ·
1 Parent(s): ff8714f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -103
app.py CHANGED
@@ -573,6 +573,7 @@ def _maybe_set_model_field(cfg: dict, key: str, value):
573
  return
574
  cfg[key] = value # fallback
575
 
 
576
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
577
  epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
578
  if not base_cfg_path or not os.path.exists(base_cfg_path):
@@ -585,6 +586,9 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
585
  cfg = yaml.safe_load(f)
586
  _absify_any_paths_deep(cfg, template_dir)
587
 
 
 
 
588
  ann_dir = os.path.join(merged_dir, "annotations")
589
  paths = {
590
  "train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
@@ -596,83 +600,107 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
596
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
597
  }
598
 
599
- # --- Rewrite dataloaders to use your dataset ---
600
- def _patch_dl(dl_key, img_key, json_key):
601
- if dl_key in cfg and isinstance(cfg[dl_key], dict):
602
- ds = cfg[dl_key].get("dataset", {})
603
- if isinstance(ds, dict):
604
- if "img_folder" in ds: ds["img_folder"] = paths[img_key]
605
- if "ann_file" in ds: ds["ann_file"] = paths[json_key]
606
- # alternative key names occasionally used
607
- for k in ("img_dir", "image_root", "data_root"):
608
- if k in ds: ds[k] = paths[img_key]
609
- for k in ("ann_path", "annotation", "annotations"):
610
- if k in ds: ds[k] = paths[json_key]
611
- cfg[dl_key]["dataset"] = ds
612
- # batch size here if present
613
- if "batch_size" in cfg[dl_key]:
614
- cfg[dl_key]["batch_size"] = int(batch)
615
-
616
- _patch_dl("train_dataloader", "train_img", "train_json")
617
- _patch_dl("val_dataloader", "val_img", "val_json")
618
- _patch_dl("test_dataloader", "test_img", "test_json")
619
-
620
- # --- classes ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  _set_num_classes_safely(cfg, int(class_count))
622
 
623
- # --- epochs / imgsz ---
624
  applied_epoch = False
625
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
626
  if key in cfg:
627
- cfg[key] = int(epochs)
628
- applied_epoch = True
629
- break
630
  if "solver" in cfg and isinstance(cfg["solver"], dict):
631
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
632
  if key in cfg["solver"]:
633
- cfg["solver"][key] = int(epochs)
634
- applied_epoch = True
635
- break
636
  if not applied_epoch:
637
- cfg["epoches"] = int(epochs) # common in this repo
638
-
639
- # image size knobs: unify on top-level input_size (respected by templates)
640
  cfg["input_size"] = int(imgsz)
641
 
642
- # --- lr / optimizer / batch fallbacks ---
643
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
644
  cfg["solver"] = {}
645
  sol = cfg["solver"]
646
  for key in ("base_lr", "lr", "learning_rate"):
647
  if key in sol:
648
- sol[key] = float(lr)
649
- break
650
  else:
651
  sol["base_lr"] = float(lr)
652
  sol["optimizer"] = str(optimizer).lower()
653
  if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
654
  sol["batch_size"] = int(batch)
655
 
656
- # output dir
657
  if "output_dir" in cfg:
658
  cfg["output_dir"] = paths["out_dir"]
659
  else:
660
  sol["output_dir"] = paths["out_dir"]
661
 
662
- # pretrained weights in the right model block
663
  if pretrained_path:
664
  p = os.path.abspath(pretrained_path)
665
  _maybe_set_model_field(cfg, "pretrain", p)
666
  _maybe_set_model_field(cfg, "pretrained", p)
667
 
668
- # Save near the template so internal relative references still make sense
669
  cfg_out_dir = os.path.join(template_dir, "generated")
670
  os.makedirs(cfg_out_dir, exist_ok=True)
671
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
672
 
673
  # Force block style for lists (no inline [a, b, c])
674
- class _NoFlowDumper(yaml.SafeDumper):
675
- pass
676
  def _repr_list_block(dumper, data):
677
  return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
678
  _NoFlowDumper.add_representer(list, _repr_list_block)
@@ -970,66 +998,4 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
970
  with gr.TabItem("2. Manage & Merge"):
971
  gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.")
972
  with gr.Row():
973
- class_df = gr.DataFrame(headers=["Original Name","Rename To","Max Images","Remove"],
974
- datatype=["str","str","number","bool"], label="Class Config", interactive=True, scale=3)
975
- with gr.Column(scale=1):
976
- class_count_summary_df = gr.DataFrame(label="Merged Class Counts Preview",
977
- headers=["Final Class Name","Est. Total Images"], interactive=False)
978
- update_counts_btn = gr.Button("Update Counts")
979
- finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
980
- finalize_status = gr.Textbox(label="Status", interactive=False)
981
-
982
- with gr.TabItem("3. Configure & Train"):
983
- gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
984
- with gr.Row():
985
- with gr.Column(scale=1):
986
- model_dd = gr.Dropdown(choices=[k for k,_ in MODEL_CHOICES], value=DEFAULT_MODEL_KEY,
987
- label="Model (RT-DETRv2)")
988
- run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
989
- epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
990
- batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
991
- imgsz_num = gr.Number(label="Image Size", value=640)
992
- lr_num = gr.Number(label="Learning Rate", value=0.001)
993
- opt_dd = gr.Dropdown(["Adam","AdamW","SGD"], value="Adam", label="Optimizer")
994
- train_btn = gr.Button("Start Training", variant="primary")
995
- with gr.Column(scale=2):
996
- train_status = gr.Textbox(label="Live Logs (tail)", interactive=False, lines=12)
997
- loss_plot = gr.Plot(label="Loss")
998
- map_plot = gr.Plot(label="mAP")
999
- final_model_file = gr.File(label="Download Trained Checkpoint", interactive=False, visible=False)
1000
-
1001
- with gr.TabItem("4. Upload Model"):
1002
- gr.Markdown("Optionally push your checkpoint to Hugging Face / GitHub.")
1003
- with gr.Row():
1004
- with gr.Column():
1005
- gr.Markdown("**Hugging Face**")
1006
- hf_token = gr.Textbox(label="HF Token", type="password")
1007
- hf_repo = gr.Textbox(label="HF Repo (user/repo)")
1008
- with gr.Column():
1009
- gr.Markdown("**GitHub**")
1010
- gh_token = gr.Textbox(label="GitHub PAT", type="password")
1011
- gh_repo = gr.Textbox(label="GitHub Repo (user/repo)")
1012
- upload_btn = gr.Button("Upload", variant="primary")
1013
- with gr.Row():
1014
- hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
1015
- gh_status = gr.Textbox(label="GitHub Status", interactive=False)
1016
-
1017
- load_btn.click(load_datasets_handler, [rf_api_key, rf_url_file],
1018
- [dataset_status, dataset_info_state, class_df])
1019
- update_counts_btn.click(update_class_counts_handler, [class_df, dataset_info_state],
1020
- [class_count_summary_df])
1021
- finalize_btn.click(finalize_handler, [dataset_info_state, class_df],
1022
- [finalize_status, final_dataset_path_state])
1023
- train_btn.click(training_handler,
1024
- [final_dataset_path_state, model_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
1025
- [train_status, loss_plot, map_plot, final_model_file])
1026
- upload_btn.click(upload_handler, [final_model_file, hf_token, hf_repo, gh_token, gh_repo],
1027
- [hf_status, gh_status])
1028
-
1029
- if __name__ == "__main__":
1030
- try:
1031
- ts = find_training_script(REPO_DIR)
1032
- logging.info(f"Startup check — training script at: {ts}")
1033
- except Exception as e:
1034
- logging.warning(f"Startup training-script check failed: {e}")
1035
- app.launch(debug=True)
 
573
  return
574
  cfg[key] = value # fallback
575
 
576
+ # --- CRITICAL FIX: force custom dataloaders & disable sync_bn -----------------
577
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
578
  epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
579
  if not base_cfg_path or not os.path.exists(base_cfg_path):
 
586
  cfg = yaml.safe_load(f)
587
  _absify_any_paths_deep(cfg, template_dir)
588
 
589
+ # Safer on single GPU/CPU
590
+ cfg["sync_bn"] = False
591
+
592
  ann_dir = os.path.join(merged_dir, "annotations")
593
  paths = {
594
  "train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
 
600
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
601
  }
602
 
603
+ # Remove COCO dataset include so it can't override our dataloaders
604
+ inc_key = "__include__"
605
+ if inc_key in cfg and isinstance(cfg[inc_key], list):
606
+ cfg[inc_key] = [
607
+ p for p in cfg[inc_key]
608
+ if not (isinstance(p, str) and "configs/dataset/coco" in p.replace("\\", "/"))
609
+ ]
610
+
611
+ # Helper to ensure & patch dataloaders
612
+ def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
613
+ block = cfg.get(dl_key)
614
+ if not isinstance(block, dict):
615
+ block = {
616
+ "type": "DataLoader",
617
+ "dataset": {
618
+ "type": "CocoDetection",
619
+ "img_folder": paths[img_key],
620
+ "ann_file": paths[json_key],
621
+ "return_masks": False,
622
+ "transforms": {
623
+ "type": "Compose",
624
+ "ops": [
625
+ {"type": "Resize", "size": [int(imgsz), int(imgsz)]},
626
+ {"type": "ConvertPILImage", "dtype": "float32", "scale": True},
627
+ ],
628
+ },
629
+ },
630
+ "shuffle": bool(default_shuffle),
631
+ "num_workers": 2,
632
+ "drop_last": bool(dl_key == "train_dataloader"),
633
+ "collate_fn": {"type": "BatchImageCollateFuncion"},
634
+ "total_batch_size": int(batch),
635
+ }
636
+ cfg[dl_key] = block
637
+ ds = block.get("dataset", {})
638
+ if isinstance(ds, dict):
639
+ ds["img_folder"] = paths[img_key]
640
+ ds["ann_file"] = paths[json_key]
641
+ for k in ("img_dir", "image_root", "data_root"):
642
+ if k in ds: ds[k] = paths[img_key]
643
+ for k in ("ann_path", "annotation", "annotations"):
644
+ if k in ds: ds[k] = paths[json_key]
645
+ block["dataset"] = ds
646
+ block["total_batch_size"] = int(batch)
647
+ block.setdefault("num_workers", 2)
648
+ block.setdefault("shuffle", bool(default_shuffle))
649
+ block.setdefault("drop_last", bool(dl_key == "train_dataloader"))
650
+
651
+ ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
652
+ ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
653
+ # Optional test loader if needed:
654
+ # ensure_and_patch_dl("test_dataloader", "test_img", "test_json", default_shuffle=False)
655
+
656
+ # Classes
657
  _set_num_classes_safely(cfg, int(class_count))
658
 
659
+ # Epochs / imgsz
660
  applied_epoch = False
661
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
662
  if key in cfg:
663
+ cfg[key] = int(epochs); applied_epoch = True; break
 
 
664
  if "solver" in cfg and isinstance(cfg["solver"], dict):
665
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
666
  if key in cfg["solver"]:
667
+ cfg["solver"][key] = int(epochs); applied_epoch = True; break
 
 
668
  if not applied_epoch:
669
+ cfg["epoches"] = int(epochs)
 
 
670
  cfg["input_size"] = int(imgsz)
671
 
672
+ # LR / optimizer / batch fallbacks
673
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
674
  cfg["solver"] = {}
675
  sol = cfg["solver"]
676
  for key in ("base_lr", "lr", "learning_rate"):
677
  if key in sol:
678
+ sol[key] = float(lr); break
 
679
  else:
680
  sol["base_lr"] = float(lr)
681
  sol["optimizer"] = str(optimizer).lower()
682
  if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
683
  sol["batch_size"] = int(batch)
684
 
685
+ # Output dir
686
  if "output_dir" in cfg:
687
  cfg["output_dir"] = paths["out_dir"]
688
  else:
689
  sol["output_dir"] = paths["out_dir"]
690
 
691
+ # Pretrained weights in correct block
692
  if pretrained_path:
693
  p = os.path.abspath(pretrained_path)
694
  _maybe_set_model_field(cfg, "pretrain", p)
695
  _maybe_set_model_field(cfg, "pretrained", p)
696
 
697
+ # Save near the template so any remaining relative includes still resolve
698
  cfg_out_dir = os.path.join(template_dir, "generated")
699
  os.makedirs(cfg_out_dir, exist_ok=True)
700
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
701
 
702
  # Force block style for lists (no inline [a, b, c])
703
+ class _NoFlowDumper(yaml.SafeDumper): pass
 
704
  def _repr_list_block(dumper, data):
705
  return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
706
  _NoFlowDumper.add_representer(list, _repr_list_block)
 
998
  with gr.TabItem("2. Manage & Merge"):
999
  gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.")
1000
  with gr.Row():
1001
+ class_df = gr