wuhp commited on
Commit
ff8714f
·
verified ·
1 Parent(s): ade6b4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -78
app.py CHANGED
@@ -495,12 +495,12 @@ def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
495
  pass
496
  return None
497
 
498
- # --- NEW: robust include absolutizer (no raw-text rewriting) ------------------
499
  def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE", "BASE_YAML",
500
  "includes", "include", "BASES", "__include__")):
501
  """
502
- Walk dict/list; if a string looks like a relative YAML include (../*.yml/.yaml) or
503
- appears under any of the known include keys, make it absolute against base_dir.
504
  """
505
  def _absify(s: str) -> str:
506
  if os.path.isabs(s):
@@ -510,7 +510,6 @@ def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE
510
  return s
511
 
512
  if isinstance(node, dict):
513
- # First, handle explicit include keys
514
  for k in list(node.keys()):
515
  v = node[k]
516
  if k in include_keys:
@@ -518,13 +517,11 @@ def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE
518
  node[k] = _absify(v)
519
  elif isinstance(v, list):
520
  node[k] = [_absify(x) if isinstance(x, str) else x for x in v]
521
- # Recurse and also absify stray string values that look like includes
522
  for k, v in list(node.items()):
523
  if isinstance(v, (dict, list)):
524
  _absify_any_paths_deep(v, base_dir, include_keys)
525
  elif isinstance(v, str):
526
  node[k] = _absify(v)
527
-
528
  elif isinstance(node, list):
529
  for i, v in enumerate(list(node)):
530
  if isinstance(v, (dict, list)):
@@ -532,6 +529,50 @@ def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE
532
  elif isinstance(v, str):
533
  node[i] = _absify(v)
534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
536
  epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
537
  if not base_cfg_path or not os.path.exists(base_cfg_path):
@@ -539,7 +580,7 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
539
 
540
  template_dir = os.path.dirname(base_cfg_path)
541
 
542
- # Load YAML directly (no raw-text editing), then absolutize known include keys and any '../*.yml'
543
  with open(base_cfg_path, "r", encoding="utf-8") as f:
544
  cfg = yaml.safe_load(f)
545
  _absify_any_paths_deep(cfg, template_dir)
@@ -555,109 +596,81 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
555
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
556
  }
557
 
558
- # dataset block
559
- for root_key in ["dataset", "data"]:
560
- if root_key in cfg and isinstance(cfg[root_key], dict):
561
- ds = cfg[root_key]
562
- for split, jf, ip in [
563
- ("train", "train_json", "train_img"),
564
- ("val", "val_json", "val_img"),
565
- ("test", "test_json", "test_img"),
566
- ]:
567
- if split in ds and isinstance(ds[split], dict):
568
- node = ds[split]
569
- node["name"] = node.get("name", "coco")
570
- _set_first_existing_key(
571
- node,
572
- keys=["ann_file", "ann_path", "annotation", "annotations"],
573
- value=paths[jf],
574
- fallback_key="ann_file",
575
- )
576
- _set_first_existing_key(
577
- node,
578
- keys=["img_prefix", "img_dir", "image_root", "data_root"],
579
- value=paths[ip],
580
- fallback_key="img_prefix",
581
- )
582
-
583
- # num_classes
584
- def set_num_classes(node, n):
585
- if not isinstance(node, dict):
586
- return False
587
- if "num_classes" in node:
588
- node["num_classes"] = int(n)
589
- return True
590
- for k, v in node.items():
591
- if isinstance(v, dict) and set_num_classes(v, n):
592
- return True
593
- return False
594
-
595
- if "model" in cfg and isinstance(cfg["model"], dict):
596
- if not set_num_classes(cfg["model"], class_count):
597
- cfg["model"]["num_classes"] = int(class_count)
598
- else:
599
- cfg["model"] = {"num_classes": int(class_count)}
600
-
601
- # epochs / imgsz
602
- updated_epoch = False
603
- for key in ["max_epoch", "epochs", "num_epochs"]:
604
  if key in cfg:
605
  cfg[key] = int(epochs)
606
- updated_epoch = True
607
  break
608
  if "solver" in cfg and isinstance(cfg["solver"], dict):
609
- for key in ["max_epoch", "epochs", "num_epochs"]:
610
  if key in cfg["solver"]:
611
  cfg["solver"][key] = int(epochs)
612
- updated_epoch = True
613
  break
614
- if not updated_epoch:
615
- cfg["max_epoch"] = int(epochs)
616
 
617
- for key in ["input_size", "img_size", "imgsz"]:
618
- if key in cfg:
619
- cfg[key] = int(imgsz)
620
- if "input_size" not in cfg:
621
- cfg["input_size"] = int(imgsz)
622
 
623
- # lr / optimizer / batch
624
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
625
  cfg["solver"] = {}
626
  sol = cfg["solver"]
627
- for key in ["base_lr", "lr", "learning_rate"]:
628
  if key in sol:
629
  sol[key] = float(lr)
630
  break
631
  else:
632
  sol["base_lr"] = float(lr)
633
  sol["optimizer"] = str(optimizer).lower()
634
- if "train_dataloader" in cfg and isinstance(cfg["train_dataloader"], dict):
635
- cfg["train_dataloader"]["batch_size"] = int(batch)
636
- else:
637
  sol["batch_size"] = int(batch)
638
 
639
  # output dir
640
  if "output_dir" in cfg:
641
  cfg["output_dir"] = paths["out_dir"]
642
- elif "solver" in cfg:
643
- sol["output_dir"] = paths["out_dir"]
644
  else:
645
- cfg["output_dir"] = paths["out_dir"]
646
 
647
- # Set pretrained weights if available; try common keys at top/model/solver
648
  if pretrained_path:
649
- _set_first_existing_key_deep(
650
- cfg,
651
- keys=["pretrain", "pretrained", "weight", "weights", "pretrained_path"],
652
- value=os.path.abspath(pretrained_path),
653
- )
654
 
655
  # Save near the template so internal relative references still make sense
656
  cfg_out_dir = os.path.join(template_dir, "generated")
657
  os.makedirs(cfg_out_dir, exist_ok=True)
658
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
659
 
660
- # Force block style so lists don’t emit as inline [a, b, c] flow sequences
661
  class _NoFlowDumper(yaml.SafeDumper):
662
  pass
663
  def _repr_list_block(dumper, data):
 
495
  pass
496
  return None
497
 
498
+ # --- include absolutizer ------------------------------------------------------
499
  def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE", "BASE_YAML",
500
  "includes", "include", "BASES", "__include__")):
501
  """
502
+ Walk dict/list; for known include keys or strings that look like ../*.yml/.yaml,
503
+ make them absolute against base_dir.
504
  """
505
  def _absify(s: str) -> str:
506
  if os.path.isabs(s):
 
510
  return s
511
 
512
  if isinstance(node, dict):
 
513
  for k in list(node.keys()):
514
  v = node[k]
515
  if k in include_keys:
 
517
  node[k] = _absify(v)
518
  elif isinstance(v, list):
519
  node[k] = [_absify(x) if isinstance(x, str) else x for x in v]
 
520
  for k, v in list(node.items()):
521
  if isinstance(v, (dict, list)):
522
  _absify_any_paths_deep(v, base_dir, include_keys)
523
  elif isinstance(v, str):
524
  node[k] = _absify(v)
 
525
  elif isinstance(node, list):
526
  for i, v in enumerate(list(node)):
527
  if isinstance(v, (dict, list)):
 
529
  elif isinstance(v, str):
530
  node[i] = _absify(v)
531
 
532
+ # --- NEW: safe model field setters --------------------------------------------
533
+ def _set_num_classes_safely(cfg: dict, n: int):
534
+ """
535
+ Set class count without breaking templates that use `model: "RTDETR"` indirection.
536
+ """
537
+ def set_num_classes(node):
538
+ if not isinstance(node, dict):
539
+ return False
540
+ if "num_classes" in node:
541
+ node["num_classes"] = int(n)
542
+ return True
543
+ for k, v in node.items():
544
+ if isinstance(v, dict) and set_num_classes(v):
545
+ return True
546
+ return False
547
+
548
+ m = cfg.get("model", None)
549
+ if isinstance(m, dict):
550
+ if not set_num_classes(m):
551
+ m["num_classes"] = int(n)
552
+ return
553
+
554
+ if isinstance(m, str):
555
+ block = cfg.get(m, None)
556
+ if isinstance(block, dict):
557
+ if not set_num_classes(block):
558
+ block["num_classes"] = int(n)
559
+ return
560
+
561
+ cfg["num_classes"] = int(n) # last resort
562
+
563
+ def _maybe_set_model_field(cfg: dict, key: str, value):
564
+ """
565
+ Place fields like 'pretrain' under the proper model dict, respecting string indirection.
566
+ """
567
+ m = cfg.get("model", None)
568
+ if isinstance(m, dict):
569
+ m[key] = value
570
+ return
571
+ if isinstance(m, str) and isinstance(cfg.get(m), dict):
572
+ cfg[m][key] = value
573
+ return
574
+ cfg[key] = value # fallback
575
+
576
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
577
  epochs, batch, imgsz, lr, optimizer, pretrained_path: str | None):
578
  if not base_cfg_path or not os.path.exists(base_cfg_path):
 
580
 
581
  template_dir = os.path.dirname(base_cfg_path)
582
 
583
+ # Load YAML then absolutize include-like paths
584
  with open(base_cfg_path, "r", encoding="utf-8") as f:
585
  cfg = yaml.safe_load(f)
586
  _absify_any_paths_deep(cfg, template_dir)
 
596
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
597
  }
598
 
599
+ # --- Rewrite dataloaders to use your dataset ---
600
+ def _patch_dl(dl_key, img_key, json_key):
601
+ if dl_key in cfg and isinstance(cfg[dl_key], dict):
602
+ ds = cfg[dl_key].get("dataset", {})
603
+ if isinstance(ds, dict):
604
+ if "img_folder" in ds: ds["img_folder"] = paths[img_key]
605
+ if "ann_file" in ds: ds["ann_file"] = paths[json_key]
606
+ # alternative key names occasionally used
607
+ for k in ("img_dir", "image_root", "data_root"):
608
+ if k in ds: ds[k] = paths[img_key]
609
+ for k in ("ann_path", "annotation", "annotations"):
610
+ if k in ds: ds[k] = paths[json_key]
611
+ cfg[dl_key]["dataset"] = ds
612
+ # batch size here if present
613
+ if "batch_size" in cfg[dl_key]:
614
+ cfg[dl_key]["batch_size"] = int(batch)
615
+
616
+ _patch_dl("train_dataloader", "train_img", "train_json")
617
+ _patch_dl("val_dataloader", "val_img", "val_json")
618
+ _patch_dl("test_dataloader", "test_img", "test_json")
619
+
620
+ # --- classes ---
621
+ _set_num_classes_safely(cfg, int(class_count))
622
+
623
+ # --- epochs / imgsz ---
624
+ applied_epoch = False
625
+ for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
  if key in cfg:
627
  cfg[key] = int(epochs)
628
+ applied_epoch = True
629
  break
630
  if "solver" in cfg and isinstance(cfg["solver"], dict):
631
+ for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
632
  if key in cfg["solver"]:
633
  cfg["solver"][key] = int(epochs)
634
+ applied_epoch = True
635
  break
636
+ if not applied_epoch:
637
+ cfg["epoches"] = int(epochs) # common in this repo
638
 
639
+ # image size knobs: unify on top-level input_size (respected by templates)
640
+ cfg["input_size"] = int(imgsz)
 
 
 
641
 
642
+ # --- lr / optimizer / batch fallbacks ---
643
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
644
  cfg["solver"] = {}
645
  sol = cfg["solver"]
646
+ for key in ("base_lr", "lr", "learning_rate"):
647
  if key in sol:
648
  sol[key] = float(lr)
649
  break
650
  else:
651
  sol["base_lr"] = float(lr)
652
  sol["optimizer"] = str(optimizer).lower()
653
+ if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
 
 
654
  sol["batch_size"] = int(batch)
655
 
656
  # output dir
657
  if "output_dir" in cfg:
658
  cfg["output_dir"] = paths["out_dir"]
 
 
659
  else:
660
+ sol["output_dir"] = paths["out_dir"]
661
 
662
+ # pretrained weights in the right model block
663
  if pretrained_path:
664
+ p = os.path.abspath(pretrained_path)
665
+ _maybe_set_model_field(cfg, "pretrain", p)
666
+ _maybe_set_model_field(cfg, "pretrained", p)
 
 
667
 
668
  # Save near the template so internal relative references still make sense
669
  cfg_out_dir = os.path.join(template_dir, "generated")
670
  os.makedirs(cfg_out_dir, exist_ok=True)
671
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
672
 
673
+ # Force block style for lists (no inline [a, b, c])
674
  class _NoFlowDumper(yaml.SafeDumper):
675
  pass
676
  def _repr_list_block(dumper, data):