wuhp commited on
Commit
ae4cf01
·
verified ·
1 Parent(s): f735495

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -136
app.py CHANGED
@@ -12,6 +12,8 @@ import requests
12
  import json
13
  from PIL import Image
14
  import pandas as pd
 
 
15
  import matplotlib.pyplot as plt
16
  from threading import Thread
17
  from queue import Queue
@@ -22,27 +24,21 @@ import sys
22
  import time
23
  import glob
24
 
25
- # --- Configuration ---
26
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
27
 
28
- # Defaults for RT-DETRv2 (Supervisely ecosystem) integration
29
  RTDETRV2_REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
30
  DEFAULT_REPO_DIR = os.path.join("third_party", "rtdetrv2")
31
 
32
- # You can still offer "model size" choices to hint the user which config to use,
33
- # but the actual command is controlled by the template.
34
  RTDETRV2_MODELS = [
35
- "rtdetrv2-l-640", # label only; adapt your command template to use real config/weights
36
  "rtdetrv2-x-640"
37
  ]
38
  DEFAULT_MODEL = RTDETRV2_MODELS[0]
39
 
40
- # ------------------------------
41
- # Utilities
42
- # ------------------------------
43
-
44
  def handle_remove_readonly(func, path, exc_info):
45
- """Error handler for shutil.rmtree."""
46
  try:
47
  os.chmod(path, stat.S_IWRITE)
48
  except Exception:
@@ -67,14 +63,6 @@ _ROBO_URL_RX = re.compile(
67
  )
68
 
69
  def parse_roboflow_url(s: str):
70
- """
71
- Support:
72
- - https://universe.roboflow.com/<workspace>/<project>[/vN]
73
- - https://app.roboflow.com/<workspace>/<project>[/vN]
74
- - https://roboflow.com/<workspace>/<project>[/vN]
75
- - raw: <workspace>/<project>[/vN]
76
- Returns: (workspace, project, version_or_None)
77
- """
78
  s = s.strip()
79
  m = _ROBO_URL_RX.match(s)
80
  if m:
@@ -110,7 +98,6 @@ def parse_roboflow_url(s: str):
110
  return None, None, None
111
 
112
  def get_latest_version(api_key, workspace, project):
113
- """Gets the latest version number of a Roboflow project."""
114
  try:
115
  rf = Roboflow(api_key=api_key)
116
  proj = rf.workspace(workspace).project(project)
@@ -120,39 +107,24 @@ def get_latest_version(api_key, workspace, project):
120
  logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
121
  return None
122
 
123
- # --- Normalize class names from data.yaml ---
124
  def _extract_class_names(data_yaml):
125
- """
126
- Return list[str] of class names in index order.
127
- Supports:
128
- - list
129
- - dict with numeric keys {0:'cat',1:'dog'}
130
- - fallback to ['class_0', ...]
131
- """
132
  names = data_yaml.get('names', None)
133
-
134
  if isinstance(names, dict):
135
  def _k(x):
136
- try:
137
- return int(x)
138
- except Exception:
139
- return str(x)
140
  ordered = sorted(names.keys(), key=_k)
141
  names_list = [names[k] for k in ordered]
142
  elif isinstance(names, list):
143
  names_list = names
144
  else:
145
  nc = data_yaml.get('nc', 0)
146
- try:
147
- nc = int(nc)
148
- except Exception:
149
- nc = 0
150
  names_list = [f"class_{i}" for i in range(nc)]
151
-
152
  return [str(x) for x in names_list]
153
 
154
  def download_dataset(api_key, workspace, project, version):
155
- """Download Roboflow dataset in 'yolov8' layout (works fine for RT-DETR variants)."""
156
  try:
157
  rf = Roboflow(api_key=api_key)
158
  proj = rf.workspace(workspace).project(project)
@@ -180,25 +152,17 @@ def download_dataset(api_key, workspace, project, version):
180
  return None, [], [], None
181
 
182
  def label_path_for(img_path: str) -> str:
183
- """Convert .../split/images/file.jpg -> .../split/labels/file.txt."""
184
  split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
185
  base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
186
  return os.path.join(split_dir, 'labels', base)
187
 
188
  def gather_class_counts(dataset_info, class_mapping):
189
- """
190
- Count per final class how many images contain that class at least once (counted once per image).
191
- class_mapping: original_name -> final_name (or None if removed).
192
- """
193
  if not dataset_info:
194
  return {}
195
-
196
  final_names = set(v for v in class_mapping.values() if v is not None)
197
  counts = {name: 0 for name in final_names}
198
-
199
  for loc, names, splits, _ in dataset_info:
200
  id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
201
-
202
  for split in splits:
203
  labels_dir = os.path.join(loc, split, 'labels')
204
  if not os.path.exists(labels_dir):
@@ -221,11 +185,9 @@ def gather_class_counts(dataset_info, class_mapping):
221
  continue
222
  for m in found:
223
  counts[m] += 1
224
-
225
  return counts
226
 
227
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
228
- """Merge datasets following mapping and per-class image limits."""
229
  merged_dir = 'rolo_merged_dataset'
230
  if os.path.exists(merged_dir):
231
  shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
@@ -238,7 +200,6 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
238
  active_classes = sorted(set([cls for cls, limit in class_limits.items() if limit > 0]))
239
  final_class_map = {name: i for i, name in enumerate(active_classes)}
240
 
241
- # Collect candidates
242
  all_images = []
243
  for loc, _, splits, _ in dataset_info:
244
  for split in splits:
@@ -327,18 +288,81 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
327
 
328
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
329
 
330
- # ------------------------------
331
- # RT-DETRv2 backend helpers
332
- # ------------------------------
 
 
 
 
 
 
333
 
334
  def ensure_repo(repo_dir: str, repo_url: str = RTDETRV2_REPO_URL):
335
- """Clone the repo into repo_dir if not present."""
336
  if os.path.isdir(repo_dir) and os.path.isdir(os.path.join(repo_dir, ".git")):
337
  return
338
  os.makedirs(os.path.dirname(repo_dir), exist_ok=True)
339
  logging.info(f"Cloning RT-DETRv2 repo into {repo_dir} ...")
340
- cmd = ["git", "clone", "--depth", "1", repo_url, repo_dir]
341
- subprocess.run(cmd, check=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, imgsz: int,
344
  lr: float, optimizer: str, run_name: str, output_dir: str) -> str:
@@ -354,7 +378,6 @@ def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, i
354
  )
355
 
356
  _METRIC_PATTERNS = [
357
- # add more patterns if your repo prints differently
358
  (re.compile(r"mAP@0\.5[:/]?0\.95[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50_95"),
359
  (re.compile(r"mAP50[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50"),
360
  (re.compile(r"\bval[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "val_loss"),
@@ -375,10 +398,6 @@ def parse_metrics_from_line(line: str):
375
  return result
376
 
377
  def guess_final_weights(output_dir: str):
378
- """
379
- Try to locate a 'best' checkpoint in output_dir.
380
- Supports .pt/.pth/.pdparams etc. Return first match or None.
381
- """
382
  patterns = [
383
  os.path.join(output_dir, "**", "best.*"),
384
  os.path.join(output_dir, "**", "best_model.*"),
@@ -390,12 +409,8 @@ def guess_final_weights(output_dir: str):
390
  return hits[0]
391
  return None
392
 
393
- # ------------------------------
394
- # Gradio UI Event Handlers
395
- # ------------------------------
396
-
397
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
398
- """Handles the 'Load Datasets' button click."""
399
  api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
400
  if not api_key:
401
  raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
@@ -407,14 +422,12 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
407
 
408
  dataset_info = []
409
  failures = []
410
-
411
  for i, raw in enumerate(urls):
412
  progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}")
413
  ws, proj, ver = parse_roboflow_url(raw)
414
  if not (ws and proj):
415
  failures.append((raw, "ParseError: could not resolve workspace/project"))
416
  continue
417
-
418
  if ver is None:
419
  ver = get_latest_version(api_key, ws, proj)
420
  if ver is None:
@@ -431,26 +444,21 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
431
  msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
432
  raise gr.Error(msg)
433
 
434
- # ensure names are strings before sorting
435
  all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
436
  class_map = {name: name for name in all_names}
437
-
438
  initial_counts = gather_class_counts(dataset_info, class_map)
439
  df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
440
  status_text = "Datasets loaded successfully."
441
  if failures:
442
  status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
443
 
444
- # FIX: gr.update(...) (not gr.DataFrame.update)
445
  return status_text, dataset_info, gr.update(
446
  value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
447
  )
448
 
449
  def update_class_counts_handler(class_df, dataset_info):
450
- """Live preview of merged class counts given the current mapping/removals."""
451
  if class_df is None or not dataset_info:
452
  return None
453
-
454
  class_df = pd.DataFrame(class_df)
455
  mapping = {}
456
  for _, row in class_df.iterrows():
@@ -462,10 +470,8 @@ def update_class_counts_handler(class_df, dataset_info):
462
 
463
  final_names = sorted(set(v for v in mapping.values() if v))
464
  counts = {k: 0 for k in final_names}
465
-
466
  for loc, names, splits, _ in dataset_info:
467
  id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
468
-
469
  for split in splits:
470
  labels_dir = os.path.join(loc, split, 'labels')
471
  if not os.path.exists(labels_dir):
@@ -493,7 +499,6 @@ def update_class_counts_handler(class_df, dataset_info):
493
  return summary_df
494
 
495
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
496
- """Create the merged dataset directory with relabeled .txts and data.yaml."""
497
  if not dataset_info:
498
  raise gr.Error("Load datasets first in Tab 1.")
499
  if class_df is None:
@@ -515,20 +520,19 @@ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
515
 
516
  def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, epochs, batch, imgsz, lr, opt,
517
  cmd_template, progress=gr.Progress()):
518
- """
519
- Train using RT-DETRv2 repo via a configurable command template.
520
- We stream logs, parse simple metrics when patterns match, and try to locate a best checkpoint on completion.
521
- """
522
  if not dataset_path:
523
  raise gr.Error("Finalize a dataset in Tab 2 before training.")
524
 
525
- # Make sure repo exists
526
  try:
527
  ensure_repo(repo_dir)
 
528
  except subprocess.CalledProcessError as e:
529
- raise gr.Error(f"Failed to clone RT-DETRv2 repo: {e}")
 
 
530
 
531
- # Prepare output directory
532
  output_dir = os.path.join("runs", "train", str(run_name))
533
  os.makedirs(output_dir, exist_ok=True)
534
 
@@ -536,7 +540,7 @@ def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, ep
536
  if not os.path.isfile(data_yaml):
537
  raise gr.Error(f"'data.yaml' was not found in: {dataset_path}")
538
 
539
- # Build the command
540
  cmd = make_train_command(
541
  template=cmd_template,
542
  data_yaml=data_yaml,
@@ -549,54 +553,36 @@ def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, ep
549
  output_dir=output_dir
550
  )
551
 
552
- # Launch training subprocess in repo_dir
553
  logging.info(f"Running training command in {repo_dir}: {cmd}")
554
  proc = subprocess.Popen(
555
- cmd,
556
- cwd=repo_dir,
557
- shell=True,
558
- stdout=subprocess.PIPE,
559
- stderr=subprocess.STDOUT,
560
- bufsize=1,
561
- universal_newlines=True,
562
- env={**os.environ} # inherit env (CUDA, etc.)
563
  )
564
 
565
- # Live metrics
566
  history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']}
567
- last_epoch = 0
568
-
569
- # Stream logs and parse
570
  for line in iter(proc.stdout.readline, ''):
571
  line = line.rstrip()
572
- # Update progress indeterminately (we don't know total epochs from logs generically)
573
- if "epoch" in line.lower():
574
- progress(0.0, desc=line[-120:]) # show last part of the line
575
- else:
576
- progress(0.0, desc=line[-120:])
577
-
578
  metrics = parse_metrics_from_line(line)
579
  if metrics:
580
  for k, v in metrics.items():
581
  history[k].append(v)
582
- # Plot when we detect an epoch number or mAP/loss update
583
- # Plot Loss
584
  fig_loss = plt.figure()
585
  ax_loss = fig_loss.add_subplot(111)
586
  ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
587
  ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
588
- ax_loss.legend()
589
- ax_loss.set_title("Loss")
590
 
591
- # Plot mAP
592
  fig_map = plt.figure()
593
  ax_map = fig_map.add_subplot(111)
594
  ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
595
  ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
596
- ax_map.legend()
597
- ax_map.set_title("mAP")
598
 
599
- # Emit an update to the UI (status text is the last log line)
600
  yield line[-200:], fig_loss, fig_map, None
601
 
602
  proc.stdout.close()
@@ -604,16 +590,14 @@ def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, ep
604
  if ret != 0:
605
  raise gr.Error(f"Training process exited with code {ret}. Check console/logs for details.")
606
 
607
- # Try to locate a best checkpoint
608
  final_ckpt = guess_final_weights(output_dir)
609
  if final_ckpt and os.path.isfile(final_ckpt):
610
  yield "Training complete!", None, None, gr.File.update(value=final_ckpt, visible=True)
611
  else:
612
- # Still complete, but we couldn't find a checkpoint automatically
613
- yield "Training finished. Could not auto-detect 'best' checkpoint; please check the output directory.", None, None, gr.update(visible=False)
614
 
615
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
616
- """Handles model upload to Hugging Face and GitHub."""
617
  if not model_file:
618
  raise gr.Error("No trained model file available to upload. Train a model first.")
619
 
@@ -640,7 +624,6 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
640
  try:
641
  if '/' not in gh_repo:
642
  raise ValueError("GitHub repo must be in the form 'username/repo'.")
643
-
644
  username, repo_name = gh_repo.split('/')
645
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
646
  headers = {"Authorization": f"token {gh_token}"}
@@ -652,11 +635,9 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
652
  sha = get_resp.json().get('sha') if get_resp.ok else None
653
 
654
  data = {"message": "Upload trained model from Rolo app", "content": content}
655
- if sha:
656
- data["sha"] = sha
657
 
658
  put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
659
-
660
  if put_resp.ok:
661
  gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
662
  else:
@@ -668,27 +649,24 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
668
  progress(1)
669
  return hf_status, gh_status
670
 
671
- # ------------------------------
672
- # Gradio UI
673
- # ------------------------------
674
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
675
- gr.Markdown("# Rolo: RT-DETRv2 Training Dashboard (Supervisely Ecosystem Backend)")
676
 
677
- # State variables
678
  dataset_info_state = gr.State([])
679
  final_dataset_path_state = gr.State(None)
680
 
681
  with gr.Tabs():
682
  with gr.TabItem("1. Prepare Datasets"):
683
- gr.Markdown("### Load Roboflow Datasets\nProvide your Roboflow API key and upload a `.txt` file containing one Roboflow dataset URL or `workspace/project[/vN]` per line.")
684
  with gr.Row():
685
- rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY env)", type="password", scale=2)
686
  rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
687
  load_btn = gr.Button("Load Datasets", variant="primary")
688
  dataset_status = gr.Textbox(label="Status", interactive=False)
689
 
690
  with gr.TabItem("2. Manage & Merge"):
691
- gr.Markdown("### Configure Classes and Finalize Dataset\nRename classes to merge them, set image limits, or remove them. Click **Update Counts** to preview, then **Finalize** to create the dataset.")
692
  with gr.Row():
693
  class_df = gr.DataFrame(
694
  headers=["Original Name", "Rename To", "Max Images", "Remove"],
@@ -702,18 +680,16 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
702
  interactive=False
703
  )
704
  update_counts_btn = gr.Button("Update Counts")
705
-
706
  finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
707
  finalize_status = gr.Textbox(label="Status", interactive=False)
708
 
709
  with gr.TabItem("3. Configure & Train"):
710
- gr.Markdown("### Set Hyperparameters and Train with RT-DETRv2")
711
  with gr.Row():
712
  with gr.Column(scale=1):
713
  model_choice_dd = gr.Dropdown(
714
- label="Model Choice (label only adjust your command template to use the right config)",
715
- choices=RTDETRV2_MODELS,
716
- value=DEFAULT_MODEL
717
  )
718
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
719
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
@@ -721,7 +697,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
721
  imgsz_num = gr.Number(label="Image Size", value=640)
722
  lr_num = gr.Number(label="Learning Rate", value=0.001)
723
  opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="AdamW", label="Optimizer")
724
-
725
  repo_dir_tb = gr.Textbox(label="RT-DETRv2 repo directory", value=DEFAULT_REPO_DIR)
726
  cmd_template_tb = gr.Textbox(
727
  label="Train command template",
@@ -745,7 +720,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
745
  final_model_file = gr.File(label="Download Trained Model (best.*)", interactive=False, visible=False)
746
 
747
  with gr.TabItem("4. Upload Model"):
748
- gr.Markdown("### Upload Your Trained Model\nAfter training, you can upload the best checkpoint to Hugging Face and/or GitHub.")
749
  with gr.Row():
750
  with gr.Column():
751
  gr.Markdown("#### Hugging Face")
@@ -760,7 +735,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
760
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
761
  gh_status = gr.Textbox(label="GitHub Status", interactive=False)
762
 
763
- # Wire UI handlers
764
  load_btn.click(
765
  fn=load_datasets_handler,
766
  inputs=[rf_api_key, rf_url_file],
@@ -780,7 +754,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
780
  fn=training_handler_rtdetrv2,
781
  inputs=[
782
  final_dataset_path_state, # dataset_path
783
- repo_dir_tb, # repo_dir
784
  model_choice_dd, # model_choice (label only)
785
  run_name_tb,
786
  epochs_sl,
@@ -799,5 +773,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
799
  )
800
 
801
  if __name__ == "__main__":
802
- # If Ultralytics warnings annoy you, set: export YOLO_CONFIG_DIR=/tmp/Ultralytics
 
803
  app.launch(debug=True)
 
12
  import json
13
  from PIL import Image
14
  import pandas as pd
15
+ import matplotlib
16
+ matplotlib.use("Agg") # headless (HF Spaces)
17
  import matplotlib.pyplot as plt
18
  from threading import Thread
19
  from queue import Queue
 
24
  import time
25
  import glob
26
 
27
+ # --- Logging ---
28
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
29
 
30
+ # --- RT-DETRv2 backend defaults (Supervisely ecosystem) ---
31
  RTDETRV2_REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
32
  DEFAULT_REPO_DIR = os.path.join("third_party", "rtdetrv2")
33
 
 
 
34
  RTDETRV2_MODELS = [
35
+ "rtdetrv2-l-640", # labels only; match your config via the command template
36
  "rtdetrv2-x-640"
37
  ]
38
  DEFAULT_MODEL = RTDETRV2_MODELS[0]
39
 
40
+ # --- Utilities ---
 
 
 
41
  def handle_remove_readonly(func, path, exc_info):
 
42
  try:
43
  os.chmod(path, stat.S_IWRITE)
44
  except Exception:
 
63
  )
64
 
65
  def parse_roboflow_url(s: str):
 
 
 
 
 
 
 
 
66
  s = s.strip()
67
  m = _ROBO_URL_RX.match(s)
68
  if m:
 
98
  return None, None, None
99
 
100
  def get_latest_version(api_key, workspace, project):
 
101
  try:
102
  rf = Roboflow(api_key=api_key)
103
  proj = rf.workspace(workspace).project(project)
 
107
  logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
108
  return None
109
 
 
110
  def _extract_class_names(data_yaml):
 
 
 
 
 
 
 
111
  names = data_yaml.get('names', None)
 
112
  if isinstance(names, dict):
113
  def _k(x):
114
+ try: return int(x)
115
+ except Exception: return str(x)
 
 
116
  ordered = sorted(names.keys(), key=_k)
117
  names_list = [names[k] for k in ordered]
118
  elif isinstance(names, list):
119
  names_list = names
120
  else:
121
  nc = data_yaml.get('nc', 0)
122
+ try: nc = int(nc)
123
+ except Exception: nc = 0
 
 
124
  names_list = [f"class_{i}" for i in range(nc)]
 
125
  return [str(x) for x in names_list]
126
 
127
  def download_dataset(api_key, workspace, project, version):
 
128
  try:
129
  rf = Roboflow(api_key=api_key)
130
  proj = rf.workspace(workspace).project(project)
 
152
  return None, [], [], None
153
 
154
  def label_path_for(img_path: str) -> str:
 
155
  split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
156
  base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
157
  return os.path.join(split_dir, 'labels', base)
158
 
159
  def gather_class_counts(dataset_info, class_mapping):
 
 
 
 
160
  if not dataset_info:
161
  return {}
 
162
  final_names = set(v for v in class_mapping.values() if v is not None)
163
  counts = {name: 0 for name in final_names}
 
164
  for loc, names, splits, _ in dataset_info:
165
  id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
 
166
  for split in splits:
167
  labels_dir = os.path.join(loc, split, 'labels')
168
  if not os.path.exists(labels_dir):
 
185
  continue
186
  for m in found:
187
  counts[m] += 1
 
188
  return counts
189
 
190
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
 
191
  merged_dir = 'rolo_merged_dataset'
192
  if os.path.exists(merged_dir):
193
  shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
 
200
  active_classes = sorted(set([cls for cls, limit in class_limits.items() if limit > 0]))
201
  final_class_map = {name: i for i, name in enumerate(active_classes)}
202
 
 
203
  all_images = []
204
  for loc, _, splits, _ in dataset_info:
205
  for split in splits:
 
288
 
289
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
290
 
291
+ # --- Repo + deps helpers (auto-install for HF Spaces) ---
292
+
293
+ def run_pip_install(args, desc="pip install"):
294
+ logging.info(f"{desc}: {args}")
295
+ cmd = [sys.executable, "-m", "pip", "install"] + args
296
+ proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
297
+ logging.info(proc.stdout)
298
+ if proc.returncode != 0:
299
+ raise RuntimeError(f"{desc} failed with code {proc.returncode}")
300
 
301
  def ensure_repo(repo_dir: str, repo_url: str = RTDETRV2_REPO_URL):
 
302
  if os.path.isdir(repo_dir) and os.path.isdir(os.path.join(repo_dir, ".git")):
303
  return
304
  os.makedirs(os.path.dirname(repo_dir), exist_ok=True)
305
  logging.info(f"Cloning RT-DETRv2 repo into {repo_dir} ...")
306
+ subprocess.run(["git", "clone", "--depth", "1", repo_url, repo_dir], check=True)
307
+
308
+ def ensure_python_deps(repo_dir: str):
309
+ """
310
+ Auto-install dependencies (idempotent).
311
+ - Tries to install pinned basics that are often needed.
312
+ - If repo has requirements*.txt, install them.
313
+ - Creates a .deps_installed marker to skip on next run.
314
+ """
315
+ marker = os.path.join(repo_dir, ".deps_installed")
316
+ if os.path.exists(marker):
317
+ logging.info("Dependencies already installed; skipping.")
318
+ return
319
+
320
+ # 1) Common essentials for vision training environments on HF Spaces
321
+ basics = [
322
+ "numpy<2", # safer with many libs
323
+ "pillow",
324
+ "tqdm",
325
+ "pyyaml",
326
+ "matplotlib",
327
+ "pandas",
328
+ "scipy",
329
+ "opencv-python-headless",
330
+ "packaging",
331
+ "requests",
332
+ "pycocotools-windows; platform_system=='Windows'",
333
+ "pycocotools; platform_system!='Windows'",
334
+ ]
335
+ try:
336
+ run_pip_install(basics, desc="Installing common basics")
337
+ except Exception as e:
338
+ logging.warning(f"Basic installs had issues: {e}")
339
+
340
+ # 2) Repo requirements
341
+ req_files = []
342
+ for name in ["requirements.txt", "requirements-dev.txt", "requirements.in"]:
343
+ p = os.path.join(repo_dir, name)
344
+ if os.path.isfile(p):
345
+ req_files.append(p)
346
+
347
+ for rf in req_files:
348
+ try:
349
+ run_pip_install(["-r", rf], desc=f"Installing repo requirements from {rf}")
350
+ except Exception as e:
351
+ logging.warning(f"Installing {rf} failed: {e}")
352
+
353
+ # 3) Optional: torch if not present (CPU-only by default on Spaces)
354
+ try:
355
+ import torch # noqa: F401
356
+ except Exception:
357
+ # Try a CPU-friendly torch; change version/cuda wheels if needed
358
+ try:
359
+ run_pip_install(["torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cpu"], desc="Installing PyTorch (CPU)")
360
+ except Exception as e:
361
+ logging.warning(f"PyTorch installation failed/skipped: {e}")
362
+
363
+ # Mark done
364
+ with open(marker, "w") as f:
365
+ f.write("ok\n")
366
 
367
  def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, imgsz: int,
368
  lr: float, optimizer: str, run_name: str, output_dir: str) -> str:
 
378
  )
379
 
380
  _METRIC_PATTERNS = [
 
381
  (re.compile(r"mAP@0\.5[:/]?0\.95[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50_95"),
382
  (re.compile(r"mAP50[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50"),
383
  (re.compile(r"\bval[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "val_loss"),
 
398
  return result
399
 
400
  def guess_final_weights(output_dir: str):
 
 
 
 
401
  patterns = [
402
  os.path.join(output_dir, "**", "best.*"),
403
  os.path.join(output_dir, "**", "best_model.*"),
 
409
  return hits[0]
410
  return None
411
 
412
+ # --- Gradio handlers ---
 
 
 
413
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
 
414
  api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
415
  if not api_key:
416
  raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
 
422
 
423
  dataset_info = []
424
  failures = []
 
425
  for i, raw in enumerate(urls):
426
  progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}")
427
  ws, proj, ver = parse_roboflow_url(raw)
428
  if not (ws and proj):
429
  failures.append((raw, "ParseError: could not resolve workspace/project"))
430
  continue
 
431
  if ver is None:
432
  ver = get_latest_version(api_key, ws, proj)
433
  if ver is None:
 
444
  msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
445
  raise gr.Error(msg)
446
 
 
447
  all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
448
  class_map = {name: name for name in all_names}
 
449
  initial_counts = gather_class_counts(dataset_info, class_map)
450
  df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
451
  status_text = "Datasets loaded successfully."
452
  if failures:
453
  status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
454
 
 
455
  return status_text, dataset_info, gr.update(
456
  value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
457
  )
458
 
459
  def update_class_counts_handler(class_df, dataset_info):
 
460
  if class_df is None or not dataset_info:
461
  return None
 
462
  class_df = pd.DataFrame(class_df)
463
  mapping = {}
464
  for _, row in class_df.iterrows():
 
470
 
471
  final_names = sorted(set(v for v in mapping.values() if v))
472
  counts = {k: 0 for k in final_names}
 
473
  for loc, names, splits, _ in dataset_info:
474
  id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
 
475
  for split in splits:
476
  labels_dir = os.path.join(loc, split, 'labels')
477
  if not os.path.exists(labels_dir):
 
499
  return summary_df
500
 
501
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
 
502
  if not dataset_info:
503
  raise gr.Error("Load datasets first in Tab 1.")
504
  if class_df is None:
 
520
 
521
  def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, epochs, batch, imgsz, lr, opt,
522
  cmd_template, progress=gr.Progress()):
 
 
 
 
523
  if not dataset_path:
524
  raise gr.Error("Finalize a dataset in Tab 2 before training.")
525
 
526
+ # Clone + deps (idempotent)
527
  try:
528
  ensure_repo(repo_dir)
529
+ ensure_python_deps(repo_dir)
530
  except subprocess.CalledProcessError as e:
531
+ raise gr.Error(f"Failed to clone repo: {e}")
532
+ except Exception as e:
533
+ raise gr.Error(f"Dependency setup failed: {e}")
534
 
535
+ # Output dir
536
  output_dir = os.path.join("runs", "train", str(run_name))
537
  os.makedirs(output_dir, exist_ok=True)
538
 
 
540
  if not os.path.isfile(data_yaml):
541
  raise gr.Error(f"'data.yaml' was not found in: {dataset_path}")
542
 
543
+ # Build command from template
544
  cmd = make_train_command(
545
  template=cmd_template,
546
  data_yaml=data_yaml,
 
553
  output_dir=output_dir
554
  )
555
 
 
556
  logging.info(f"Running training command in {repo_dir}: {cmd}")
557
  proc = subprocess.Popen(
558
+ cmd, cwd=repo_dir, shell=True,
559
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
560
+ bufsize=1, universal_newlines=True, env={**os.environ}
 
 
 
 
 
561
  )
562
 
 
563
  history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']}
 
 
 
564
  for line in iter(proc.stdout.readline, ''):
565
  line = line.rstrip()
566
+ progress(0.0, desc=line[-120:])
 
 
 
 
 
567
  metrics = parse_metrics_from_line(line)
568
  if metrics:
569
  for k, v in metrics.items():
570
  history[k].append(v)
571
+
572
+ # plot loss
573
  fig_loss = plt.figure()
574
  ax_loss = fig_loss.add_subplot(111)
575
  ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
576
  ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
577
+ ax_loss.legend(); ax_loss.set_title("Loss")
 
578
 
579
+ # plot mAP
580
  fig_map = plt.figure()
581
  ax_map = fig_map.add_subplot(111)
582
  ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
583
  ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
584
+ ax_map.legend(); ax_map.set_title("mAP")
 
585
 
 
586
  yield line[-200:], fig_loss, fig_map, None
587
 
588
  proc.stdout.close()
 
590
  if ret != 0:
591
  raise gr.Error(f"Training process exited with code {ret}. Check console/logs for details.")
592
 
 
593
  final_ckpt = guess_final_weights(output_dir)
594
  if final_ckpt and os.path.isfile(final_ckpt):
595
  yield "Training complete!", None, None, gr.File.update(value=final_ckpt, visible=True)
596
  else:
597
+ yield ("Training finished. Could not auto-detect a 'best' checkpoint; "
598
+ "please check the output directory."), None, None, gr.update(visible=False)
599
 
600
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
 
601
  if not model_file:
602
  raise gr.Error("No trained model file available to upload. Train a model first.")
603
 
 
624
  try:
625
  if '/' not in gh_repo:
626
  raise ValueError("GitHub repo must be in the form 'username/repo'.")
 
627
  username, repo_name = gh_repo.split('/')
628
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
629
  headers = {"Authorization": f"token {gh_token}"}
 
635
  sha = get_resp.json().get('sha') if get_resp.ok else None
636
 
637
  data = {"message": "Upload trained model from Rolo app", "content": content}
638
+ if sha: data["sha"] = sha
 
639
 
640
  put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
 
641
  if put_resp.ok:
642
  gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
643
  else:
 
649
  progress(1)
650
  return hf_status, gh_status
651
 
652
+ # --- Gradio UI ---
 
 
653
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
654
+ gr.Markdown("# Rolo: RT-DETRv2 Training Dashboard (Auto-setup for Hugging Face)")
655
 
 
656
  dataset_info_state = gr.State([])
657
  final_dataset_path_state = gr.State(None)
658
 
659
  with gr.Tabs():
660
  with gr.TabItem("1. Prepare Datasets"):
661
+ gr.Markdown("Upload a `.txt` with Roboflow URLs or `workspace/project[/vN]` lines.")
662
  with gr.Row():
663
+ rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY)", type="password", scale=2)
664
  rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
665
  load_btn = gr.Button("Load Datasets", variant="primary")
666
  dataset_status = gr.Textbox(label="Status", interactive=False)
667
 
668
  with gr.TabItem("2. Manage & Merge"):
669
+ gr.Markdown("Rename classes, set image limits, or remove them. Preview, then finalize.")
670
  with gr.Row():
671
  class_df = gr.DataFrame(
672
  headers=["Original Name", "Rename To", "Max Images", "Remove"],
 
680
  interactive=False
681
  )
682
  update_counts_btn = gr.Button("Update Counts")
 
683
  finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
684
  finalize_status = gr.Textbox(label="Status", interactive=False)
685
 
686
  with gr.TabItem("3. Configure & Train"):
687
+ gr.Markdown("Set hyperparameters and the training command template.")
688
  with gr.Row():
689
  with gr.Column(scale=1):
690
  model_choice_dd = gr.Dropdown(
691
+ label="Model Choice (label only; use your config in the template)",
692
+ choices=RTDETRV2_MODELS, value=DEFAULT_MODEL
 
693
  )
694
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
695
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
 
697
  imgsz_num = gr.Number(label="Image Size", value=640)
698
  lr_num = gr.Number(label="Learning Rate", value=0.001)
699
  opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="AdamW", label="Optimizer")
 
700
  repo_dir_tb = gr.Textbox(label="RT-DETRv2 repo directory", value=DEFAULT_REPO_DIR)
701
  cmd_template_tb = gr.Textbox(
702
  label="Train command template",
 
720
  final_model_file = gr.File(label="Download Trained Model (best.*)", interactive=False, visible=False)
721
 
722
  with gr.TabItem("4. Upload Model"):
723
+ gr.Markdown("Upload your best checkpoint to Hugging Face or GitHub.")
724
  with gr.Row():
725
  with gr.Column():
726
  gr.Markdown("#### Hugging Face")
 
735
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
736
  gh_status = gr.Textbox(label="GitHub Status", interactive=False)
737
 
 
738
  load_btn.click(
739
  fn=load_datasets_handler,
740
  inputs=[rf_api_key, rf_url_file],
 
754
  fn=training_handler_rtdetrv2,
755
  inputs=[
756
  final_dataset_path_state, # dataset_path
757
+ repo_dir_tb, # repo_dir (auto clone + pip install)
758
  model_choice_dd, # model_choice (label only)
759
  run_name_tb,
760
  epochs_sl,
 
773
  )
774
 
775
  if __name__ == "__main__":
776
+ # Hugging Face Spaces: set server name/port via env if needed.
777
+ # Example: app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), debug=True)
778
  app.launch(debug=True)