wuhp commited on
Commit
b0cabfb
·
verified ·
1 Parent(s): 2a0a7c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -457
app.py CHANGED
@@ -1,56 +1,27 @@
1
- # app.py
2
- # Rolo: RT-DETRv2-only Training Dashboard (Supervisely ecosystem)
3
- # - No Ultralytics import or usage
4
- # - Auto-installs deps in HF Spaces
5
- # - Only supports models that ship with https://github.com/supervisely-ecosystem/RT-DETRv2
6
-
7
- import os
8
- import sys
9
- import subprocess
10
- import shutil
11
- import stat
12
- import yaml
13
- import gradio as gr
14
- from roboflow import Roboflow
15
- import re
16
  from urllib.parse import urlparse
17
- import random
18
- import logging
19
- import requests
20
- import json
21
- from PIL import Image
22
- import torch
23
- import pandas as pd
24
- import matplotlib.pyplot as plt
25
  from threading import Thread
26
  from queue import Queue
27
- from glob import glob
28
- import time
29
- import base64
30
 
31
- # --- Logging ---
32
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
33
-
34
- REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2" # :contentReference[oaicite:1]{index=1}
35
- REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
36
- PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # contains the pytorch impl (models, training)
37
- WEIGHTS_DIR = os.path.join(PY_IMPL_DIR, "weights")
38
 
39
- # ------------------------------
40
- # Environment bootstrap (HF Spaces)
41
- # ------------------------------
42
 
 
 
 
43
  COMMON_REQUIREMENTS = [
44
- "gradio>=4.36.1",
45
- "roboflow>=1.1.28",
46
- "pandas>=2.0.0",
47
- "matplotlib>=3.7.0",
48
- "pyyaml>=6.0.1",
49
- "Pillow>=10.0.0",
50
- "requests>=2.31.0",
51
- "huggingface_hub>=0.22.0",
52
  ]
53
 
 
54
  def pip_install(args):
55
  logging.info(f"pip install {' '.join(args)}")
56
  subprocess.check_call([sys.executable, "-m", "pip", "install"] + args)
@@ -61,65 +32,39 @@ def ensure_repo_and_requirements():
61
  logging.info(f"Cloning RT-DETRv2 repo to {REPO_DIR} ...")
62
  subprocess.check_call(["git", "clone", "--depth", "1", REPO_URL, REPO_DIR])
63
  else:
64
- logging.info("RT-DETRv2 repo already present, pulling latest...")
65
  try:
66
  subprocess.check_call(["git", "-C", REPO_DIR, "pull", "--ff-only"])
67
  except Exception:
68
- logging.warning("Could not pull latest; continuing with current checkout.")
69
 
70
- # Install common libs
71
  pip_install(COMMON_REQUIREMENTS)
72
-
73
- # Install rtdetrv2_pytorch requirements if present
74
  req_file = os.path.join(PY_IMPL_DIR, "requirements.txt")
75
  if os.path.exists(req_file):
76
  pip_install(["-r", req_file])
77
- else:
78
- logging.info("No rtdetrv2_pytorch/requirements.txt found; relying on common reqs.")
79
 
80
- # Do the bootstrap once at import time (HF Spaces-friendly).
81
  try:
82
  ensure_repo_and_requirements()
83
- except Exception as e:
84
- logging.exception("Bootstrap failed")
85
- # Still allow UI to load so user can see the error
86
- pass
87
-
88
- # ------------------------------
89
- # Model options (strictly from RT-DETRv2 repo)
90
- # ------------------------------
91
- # We expose only the canonical small/large/xlarge variants that ship with the repo.
92
- # If the repo adds/removes variants, you can read from weights dir dynamically.
93
- MODEL_CHOICES = [
94
- ("rtdetrv2_s", "Small (default)"),
95
- ("rtdetrv2_l", "Large"),
96
- ("rtdetrv2_x", "X-Large")
97
- ]
98
- DEFAULT_MODEL_KEY = "rtdetrv2_s" # Small as default
99
 
100
- # ------------------------------
101
- # Utilities
102
- # ------------------------------
103
 
 
104
  def handle_remove_readonly(func, path, exc_info):
105
- try:
106
- os.chmod(path, stat.S_IWRITE)
107
- except Exception:
108
- pass
109
  func(path)
110
 
111
- _ROBO_URL_RX = re.compile(
112
- r"""
113
- ^(?:
114
- (?:https?://)?(?:universe|app|www)?\.?roboflow\.com/
115
- (?P<ws>[A-Za-z0-9\-_]+)/
116
- (?P<proj>[A-Za-z0-9\-_]+)/?
117
- (?:(?:dataset/[^/]+/)?(?:v?(?P<ver>\d+))?)?
118
- |
119
- (?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))?
120
- )$
121
- """, re.VERBOSE | re.IGNORECASE
122
- )
123
 
124
  def parse_roboflow_url(s: str):
125
  s = s.strip()
@@ -129,31 +74,24 @@ def parse_roboflow_url(s: str):
129
  proj = m.group('proj') or m.group('proj2')
130
  ver = m.group('ver') or m.group('ver2')
131
  return ws, proj, (int(ver) if ver else None)
132
-
133
  parsed = urlparse(s)
134
  parts = [p for p in parsed.path.strip('/').split('/') if p]
135
  if len(parts) >= 2:
136
  version = None
137
  if len(parts) >= 3:
138
- vpart = parts[2]
139
- if vpart.lower().startswith('v') and vpart[1:].isdigit():
140
- version = int(vpart[1:])
141
- elif vpart.isdigit():
142
- version = int(vpart)
143
  return parts[0], parts[1], version
144
-
145
  if '/' in s and 'roboflow' not in s:
146
  p = s.split('/')
147
  if len(p) >= 2:
148
  version = None
149
  if len(p) >= 3:
150
  v = p[2]
151
- if v.lower().startswith('v') and v[1:].isdigit():
152
- version = int(v[1:])
153
- elif v.isdigit():
154
- version = int(v)
155
  return p[0], p[1], version
156
-
157
  return None, None, None
158
 
159
  def get_latest_version(api_key, workspace, project):
@@ -170,43 +108,26 @@ def _extract_class_names(data_yaml):
170
  names = data_yaml.get('names', None)
171
  if isinstance(names, dict):
172
  def _k(x):
173
- try:
174
- return int(x)
175
- except Exception:
176
- return str(x)
177
- ordered_keys = sorted(names.keys(), key=_k)
178
- names_list = [names[k] for k in ordered_keys]
179
  elif isinstance(names, list):
180
  names_list = names
181
  else:
182
- nc = data_yaml.get('nc', 0)
183
- try:
184
- nc = int(nc)
185
- except Exception:
186
- nc = 0
187
  names_list = [f"class_{i}" for i in range(nc)]
188
  return [str(x) for x in names_list]
189
 
190
  def download_dataset(api_key, workspace, project, version):
191
- """Download a Roboflow dataset in YOLOv8 format (labels are compatible with our merger)."""
192
  try:
193
  rf = Roboflow(api_key=api_key)
194
  proj = rf.workspace(workspace).project(project)
195
  ver = proj.version(int(version))
196
- dataset = ver.download("yolov8")
197
-
198
  data_yaml_path = os.path.join(dataset.location, 'data.yaml')
199
- with open(data_yaml_path, 'r') as f:
200
- data_yaml = yaml.safe_load(f)
201
-
202
  class_names = _extract_class_names(data_yaml)
203
- try:
204
- nc = int(data_yaml.get('nc', len(class_names)))
205
- except Exception:
206
- nc = len(class_names)
207
- if len(class_names) != nc:
208
- logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.")
209
-
210
  splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))]
211
  return dataset.location, class_names, splits, f"{project}-v{version}"
212
  except Exception as e:
@@ -218,36 +139,97 @@ def label_path_for(img_path: str) -> str:
218
  base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
219
  return os.path.join(split_dir, 'labels', base)
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def gather_class_counts(dataset_info, class_mapping):
222
- if not dataset_info:
223
- return {}
224
  final_names = set(v for v in class_mapping.values() if v is not None)
225
  counts = {name: 0 for name in final_names}
226
-
227
  for loc, names, splits, _ in dataset_info:
228
  id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
229
  for split in splits:
230
  labels_dir = os.path.join(loc, split, 'labels')
231
- if not os.path.exists(labels_dir):
232
- continue
233
  for label_file in os.listdir(labels_dir):
234
- if not label_file.endswith('.txt'):
235
- continue
236
  found = set()
237
  with open(os.path.join(labels_dir, label_file), 'r') as f:
238
  for line in f:
239
  parts = line.strip().split()
240
- if not parts:
241
- continue
242
  try:
243
  cls_id = int(parts[0])
244
  mapped = id_to_name.get(cls_id, None)
245
- if mapped:
246
- found.add(mapped)
247
  except Exception:
248
  continue
249
- for m in found:
250
- counts[m] += 1
251
  return counts
252
 
253
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
@@ -267,49 +249,36 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
267
  for loc, _, splits, _ in dataset_info:
268
  for split in splits:
269
  img_dir = os.path.join(loc, split, 'images')
270
- if not os.path.exists(img_dir):
271
- continue
272
  for img_file in os.listdir(img_dir):
273
  if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
274
  all_images.append((os.path.join(img_dir, img_file), split, loc))
275
  random.shuffle(all_images)
276
 
277
  progress(0.2, desc="Selecting images based on limits...")
278
- selected_images = []
279
- current_counts = {cls: 0 for cls in active_classes}
280
  loc_to_names = {info[0]: info[1] for info in dataset_info}
281
 
282
- # progress.tqdm is available on Gradio Progress objects
283
  for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
284
  lbl_path = label_path_for(img_path)
285
- if not os.path.exists(lbl_path):
286
- continue
287
-
288
  source_names = loc_to_names.get(source_loc, [])
289
  image_classes = set()
290
  with open(lbl_path, 'r') as f:
291
  for line in f:
292
  parts = line.strip().split()
293
- if not parts:
294
- continue
295
  try:
296
  cls_id = int(parts[0])
297
  orig = source_names[cls_id]
298
  mapped = class_mapping.get(orig, orig)
299
- if mapped in active_classes:
300
- image_classes.add(mapped)
301
  except Exception:
302
  continue
303
-
304
- if not image_classes:
305
- continue
306
-
307
- if any(current_counts[c] >= class_limits[c] for c in image_classes):
308
- continue
309
-
310
  selected_images.append((img_path, split))
311
- for c in image_classes:
312
- current_counts[c] += 1
313
 
314
  progress(0.6, desc=f"Copying {len(selected_images)} files...")
315
  for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
@@ -320,16 +289,13 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
320
 
321
  source_loc = None
322
  for info in dataset_info:
323
- if img_path.startswith(info[0]):
324
- source_loc = info[0]
325
- break
326
  source_names = loc_to_names.get(source_loc, [])
327
 
328
  with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out:
329
  for line in f_in:
330
  parts = line.strip().split()
331
- if not parts:
332
- continue
333
  try:
334
  old_id = int(parts[0])
335
  original_name = source_names[old_id]
@@ -340,7 +306,7 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
340
  except Exception:
341
  continue
342
 
343
- progress(0.95, desc="Creating data.yaml...")
344
  with open(os.path.join(merged_dir, 'data.yaml'), 'w') as f:
345
  yaml.dump({
346
  'path': os.path.abspath(merged_dir),
@@ -351,124 +317,111 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
351
  'names': active_classes
352
  }, f)
353
 
 
 
 
354
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
355
 
356
- # ------------------------------
357
- # Training integration (RT-DETRv2 repo)
358
- # ------------------------------
359
-
360
- def detect_training_entrypoint():
361
  """
362
- We try a couple of common patterns inside the Supervisely repo:
363
- 1) rtdetrv2_pytorch/train.py
364
- 2) tools/train.py
365
- Returns (python_file, style) where style hints how to build args.
366
  """
367
- cand1 = os.path.join(PY_IMPL_DIR, "train.py")
368
- cand2 = os.path.join(REPO_DIR, "tools", "train.py")
369
- if os.path.exists(cand1):
370
- return cand1, "pytorch_train"
371
- if os.path.exists(cand2):
372
- return cand2, "tools_train"
373
- # Fallback: just try main.py if present
374
- cand3 = os.path.join(REPO_DIR, "src", "main.py")
375
- if os.path.exists(cand3):
376
- return cand3, "app_main"
377
- return None, None
378
-
379
- def build_command(entrypoint, style, dataset_path, model_key, run_name, epochs, batch, imgsz, lr, optimizer):
380
  """
381
- Build a best-guess command for the detected style.
382
- Users never have to edit CLI; we do it for them.
383
- We keep args conservative and standard (data, epochs, batch, img size).
384
  """
385
- data_yaml = os.path.join(dataset_path, "data.yaml")
386
- out_dir = os.path.join("runs", "train", str(run_name))
387
- os.makedirs(out_dir, exist_ok=True)
388
-
389
- # Some repos expect weight/model name; we pass model_key (e.g., rtdetrv2_s) and let their script resolve it.
390
- # Learning rate / optimizer flags may differ; include only when style suggests they're supported.
391
- if style == "pytorch_train":
392
- # Hypothetical common args for a train.py inside rtdetrv2_pytorch
393
- cmd = [
394
- sys.executable, entrypoint,
395
- "--data", data_yaml,
396
- "--model", model_key,
397
- "--epochs", str(int(epochs)),
398
- "--batch", str(int(batch)),
399
- "--imgsz", str(int(imgsz)),
400
- "--project", os.path.abspath(out_dir)
401
- ]
402
- if lr is not None:
403
- cmd += ["--lr", str(float(lr))]
404
- if optimizer:
405
- cmd += ["--optimizer", str(optimizer)]
406
- return cmd, out_dir
407
-
408
- if style == "tools_train":
409
- # Alternate style (tools/train.py). We keep flags generic.
410
- cmd = [
411
- sys.executable, entrypoint,
412
- "--data", data_yaml,
413
- "--model", model_key,
414
- "--epochs", str(int(epochs)),
415
- "--batch-size", str(int(batch)),
416
- "--imgsz", str(int(imgsz)),
417
- "--project", os.path.abspath(out_dir),
418
- "--name", "exp"
419
- ]
420
- if lr is not None:
421
- cmd += ["--lr0", str(float(lr))]
422
- if optimizer:
423
- cmd += ["--optimizer", str(optimizer)]
424
- return cmd, out_dir
425
-
426
- if style == "app_main":
427
- # If app_main exists, it may require an options file; we still try a generic mapping.
428
- cmd = [
429
- sys.executable, entrypoint,
430
- "--data", data_yaml,
431
- "--model", model_key,
432
- "--epochs", str(int(epochs)),
433
- "--batch", str(int(batch)),
434
- "--imgsz", str(int(imgsz)),
435
- "--output", os.path.abspath(out_dir)
436
- ]
437
- if lr is not None:
438
- cmd += ["--lr", str(float(lr))]
439
- if optimizer:
440
- cmd += ["--optimizer", str(optimizer)]
441
- return cmd, out_dir
442
-
443
- raise gr.Error("Could not locate a training script inside RT-DETRv2 repo. Please check the repo layout.")
 
 
 
 
444
 
445
  def find_best_checkpoint(out_dir):
446
- # Look for common patterns
447
- patterns = [
448
  os.path.join(out_dir, "**", "best*.pt"),
449
  os.path.join(out_dir, "**", "best*.pth"),
450
  os.path.join(out_dir, "**", "model_best*.pt"),
451
  os.path.join(out_dir, "**", "model_best*.pth"),
452
  ]
453
- for p in patterns:
454
- files = sorted(glob(p, recursive=True))
455
- if files:
456
- return files[0]
457
- # Fall back to latest .pt/.pth
458
  any_ckpt = sorted(glob(os.path.join(out_dir, "**", "*.pt"), recursive=True) +
459
  glob(os.path.join(out_dir, "**", "*.pth"), recursive=True))
460
  return any_ckpt[-1] if any_ckpt else None
461
 
462
- # ------------------------------
463
- # Gradio Handlers
464
- # ------------------------------
465
-
466
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
467
  api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
468
- if not api_key:
469
- raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
470
- if not url_file:
471
- raise gr.Error("Please upload a .txt file with Roboflow URLs or lines like 'workspace/project[/vN]'.")
472
 
473
  with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
474
  urls = [line.strip() for line in f if line.strip()]
@@ -483,126 +436,117 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
483
  if ver is None:
484
  ver = get_latest_version(api_key, ws, proj)
485
  if ver is None:
486
- failures.append((raw, f"Could not resolve latest version for {ws}/{proj}"))
487
  continue
488
-
489
  loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver))
490
- if loc:
491
- dataset_info.append((loc, names, splits, name_str))
492
- else:
493
- failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
494
 
495
  if not dataset_info:
496
- msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
497
  raise gr.Error(msg)
498
 
499
- # Make sure names are strings before sorting to avoid mixed-type comparison
500
  all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
501
  class_map = {name: name for name in all_names}
502
-
503
- initial_counts = gather_class_counts(dataset_info, class_map)
504
- df = pd.DataFrame([[name, name, initial_counts.get(name, 0), False] for name in all_names],
505
  columns=["Original Name", "Rename To", "Max Images", "Remove"])
506
- status_text = "Datasets loaded successfully."
507
- if failures:
508
- status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
509
-
510
- # Return the DataFrame value directly (works across Gradio versions)
511
- return status_text, dataset_info, df
512
 
513
  def update_class_counts_handler(class_df, dataset_info):
514
- if class_df is None or not dataset_info:
515
- return None
516
-
517
  class_df = pd.DataFrame(class_df)
518
- mapping = {}
519
- for _, row in class_df.iterrows():
520
- orig = row["Original Name"]
521
- mapping[orig] = None if bool(row["Remove"]) else row["Rename To"]
522
-
523
  final_names = sorted(set(v for v in mapping.values() if v))
524
  counts = {k: 0 for k in final_names}
525
-
526
  for loc, names, splits, _ in dataset_info:
527
  id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
528
  for split in splits:
529
  labels_dir = os.path.join(loc, split, 'labels')
530
- if not os.path.exists(labels_dir):
531
- continue
532
  for label_file in os.listdir(labels_dir):
533
- if not label_file.endswith('.txt'):
534
- continue
535
  found = set()
536
  with open(os.path.join(labels_dir, label_file), 'r') as f:
537
  for line in f:
538
  parts = line.strip().split()
539
- if not parts:
540
- continue
541
  try:
542
  cls_id = int(parts[0])
543
  mapped = id_to_final.get(cls_id, None)
544
- if mapped:
545
- found.add(mapped)
546
  except Exception:
547
  continue
548
- for m in found:
549
- counts[m] += 1
550
-
551
  return pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
552
 
553
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
554
- if not dataset_info:
555
- raise gr.Error("Load datasets first in Tab 1.")
556
- if class_df is None:
557
- raise gr.Error("Class data is missing.")
558
-
559
  class_df = pd.DataFrame(class_df)
560
  class_mapping, class_limits = {}, {}
561
  for _, row in class_df.iterrows():
562
  orig = row["Original Name"]
563
- if bool(row["Remove"]):
564
- continue
565
  final_name = row["Rename To"]
566
  class_mapping[orig] = final_name
567
  class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
568
-
569
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
570
  return status, path
571
 
572
- def training_handler(dataset_path, model_choice_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
573
- if not dataset_path:
574
- raise gr.Error("Finalize a dataset in Tab 2 before training.")
575
-
576
- # Verify repo entrypoint
577
- entrypoint, style = detect_training_entrypoint()
578
- if not entrypoint:
579
- raise gr.Error("RT-DETRv2 training script not found in the repo. Please check repo contents.")
580
-
581
- # Build and run command (users never touch CLI)
582
- cmd, out_dir = build_command(
583
- entrypoint=entrypoint,
584
- style=style,
585
- dataset_path=dataset_path,
586
- model_key=model_choice_key,
 
 
 
 
 
 
 
 
587
  run_name=run_name,
588
  epochs=epochs,
589
  batch=batch,
590
  imgsz=imgsz,
591
  lr=lr,
592
- optimizer=opt
593
  )
 
 
 
 
 
 
 
 
 
594
  logging.info(f"Training command: {' '.join(cmd)}")
595
 
596
- # Live-run in a thread and stream logs
597
  q = Queue()
598
-
599
  def run_train():
600
  try:
601
  env = os.environ.copy()
602
  env["PYTHONPATH"] = REPO_DIR + os.pathsep + env.get("PYTHONPATH", "")
603
- proc = subprocess.Popen(cmd, cwd=REPO_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, text=True, env=env)
604
- for line in proc.stdout:
605
- q.put(line.rstrip())
 
606
  proc.wait()
607
  q.put(f"__EXITCODE__:{proc.returncode}")
608
  except Exception as e:
@@ -610,21 +554,19 @@ def training_handler(dataset_path, model_choice_key, run_name, epochs, batch, im
610
 
611
  Thread(target=run_train, daemon=True).start()
612
 
613
- log_lines = []
614
- last_epoch = 0
615
- total_epochs = int(epochs)
616
  while True:
617
  line = q.get()
618
  if line.startswith("__EXITCODE__"):
619
- code = int(line.split(":", 1)[1])
620
- if code != 0:
621
- raise gr.Error(f"Training process exited with code {code}. Check logs above.")
622
  break
623
  if line.startswith("__ERROR__"):
624
  raise gr.Error(f"Training failed: {line.split(':',1)[1]}")
625
 
626
- log_lines.append(line)
627
- # try to parse "Epoch X/Y" style hints for progress
 
628
  m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line)
629
  if m:
630
  try:
@@ -632,194 +574,126 @@ def training_handler(dataset_path, model_choice_key, run_name, epochs, batch, im
632
  total_epochs = max(total_epochs, int(m.group(2)))
633
  except Exception:
634
  pass
 
635
 
636
- frac = min(max(last_epoch / max(1, total_epochs), 0.0), 1.0)
637
- progress(frac, desc=f"Epoch {last_epoch}/{total_epochs}")
638
-
639
- # Light-weight plots (we won't have metrics dicts; just show empty placeholders so UI doesn't break)
640
- fig_loss = plt.figure()
641
- ax_loss = fig_loss.add_subplot(111)
642
- ax_loss.set_title("Loss (see logs)")
643
- fig_map = plt.figure()
644
- ax_map = fig_map.add_subplot(111)
645
- ax_map.set_title("mAP (see logs)")
646
-
647
- yield "\n".join(log_lines[-30:]), fig_loss, fig_map, None
648
 
649
- # Look for the best checkpoint
650
- ckpt = find_best_checkpoint(out_dir)
651
  if not ckpt or not os.path.exists(ckpt):
652
- # try give user any artifact
653
- alt = find_best_checkpoint("runs")
654
- if not alt or not os.path.exists(alt):
655
- raise gr.Error("Training finished, but checkpoint file was not found. See logs for details.")
656
- ckpt = alt
657
-
658
  yield "Training complete!", None, None, gr.File.update(value=ckpt, visible=True)
659
 
660
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
661
- if not model_file:
662
- raise gr.Error("No trained model file available to upload. Train a model first.")
663
-
664
  from huggingface_hub import HfApi, HfFolder
665
-
666
- hf_status = "Skipped Hugging Face (credentials not provided)."
667
  if hf_token and hf_repo:
668
  progress(0, desc="Uploading to Hugging Face...")
669
  try:
670
- api = HfApi()
671
- HfFolder.save_token(hf_token)
672
  repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token)
673
- api.upload_file(
674
- path_or_fileobj=model_file.name,
675
- path_in_repo=os.path.basename(model_file.name),
676
- repo_id=hf_repo,
677
- token=hf_token
678
- )
679
- hf_status = f"Success! Model at: {repo_url}"
680
  except Exception as e:
681
  hf_status = f"Hugging Face Error: {e}"
682
 
683
- gh_status = "Skipped GitHub (credentials not provided)."
684
  if gh_token and gh_repo:
685
  progress(0.5, desc="Uploading to GitHub...")
686
  try:
687
- if '/' not in gh_repo:
688
- raise ValueError("GitHub repo must be in the form 'username/repo'.")
689
-
690
  username, repo_name = gh_repo.split('/')
691
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
692
  headers = {"Authorization": f"token {gh_token}"}
693
-
694
- with open(model_file.name, "rb") as f:
695
- content = base64.b64encode(f.read()).decode()
696
-
697
  get_resp = requests.get(api_url, headers=headers, timeout=30)
698
  sha = get_resp.json().get('sha') if get_resp.ok else None
699
-
700
  data = {"message": "Upload trained model from Rolo app", "content": content}
701
- if sha:
702
- data["sha"] = sha
703
-
704
  put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
705
-
706
- if put_resp.ok:
707
- gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
708
- else:
709
- msg = put_resp.json().get('message', 'Unknown')
710
- gh_status = f"GitHub Error: {msg}"
711
  except Exception as e:
712
  gh_status = f"GitHub Error: {e}"
 
713
 
714
- progress(1)
715
- return hf_status, gh_status
716
-
717
- # ------------------------------
718
- # Gradio UI
719
- # ------------------------------
720
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
721
- gr.Markdown("# Rolo: RT-DETRv2 Training (Supervisely ecosystem only)")
722
 
723
  dataset_info_state = gr.State([])
724
  final_dataset_path_state = gr.State(None)
725
 
726
  with gr.Tabs():
727
  with gr.TabItem("1. Prepare Datasets"):
728
- gr.Markdown("### Load Roboflow Datasets\nProvide your Roboflow API key and upload a `.txt` file containing one Roboflow dataset URL or `workspace/project[/vN]` per line.")
729
  with gr.Row():
730
- rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY env)", type="password", scale=2)
731
- rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
732
  load_btn = gr.Button("Load Datasets", variant="primary")
733
  dataset_status = gr.Textbox(label="Status", interactive=False)
734
 
735
  with gr.TabItem("2. Manage & Merge"):
736
- gr.Markdown("### Configure Classes and Finalize Dataset\nRename classes to merge them, set image limits, or remove them. Click **Update Counts** to preview, then **Finalize** to create the dataset.")
737
  with gr.Row():
738
- class_df = gr.DataFrame(
739
- headers=["Original Name", "Rename To", "Max Images", "Remove"],
740
- datatype=["str", "str", "number", "bool"],
741
- label="Class Configuration", interactive=True, scale=3
742
- )
743
  with gr.Column(scale=1):
744
- class_count_summary_df = gr.DataFrame(
745
- label="Merged Class Counts Preview",
746
- headers=["Final Class Name", "Est. Total Images"],
747
- interactive=False
748
- )
749
  update_counts_btn = gr.Button("Update Counts")
750
  finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
751
  finalize_status = gr.Textbox(label="Status", interactive=False)
752
 
753
  with gr.TabItem("3. Configure & Train"):
754
- gr.Markdown("### Set Hyperparameters and Train the RT-DETRv2 Model")
755
  with gr.Row():
756
  with gr.Column(scale=1):
757
- model_file_dd = gr.Dropdown(
758
- label="Model (only RT-DETRv2 from Supervisely)",
759
- choices=[k for k, _ in MODEL_CHOICES],
760
- value=DEFAULT_MODEL_KEY
761
- )
762
- model_hints = gr.Markdown(
763
- "Choices: " +
764
- ", ".join([f"`{k}` ({label})" for k, label in MODEL_CHOICES])
765
- )
766
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
767
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
768
  batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
769
  imgsz_num = gr.Number(label="Image Size", value=640)
770
  lr_num = gr.Number(label="Learning Rate", value=0.001)
771
- opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="Adam", label="Optimizer")
772
  train_btn = gr.Button("Start Training", variant="primary")
773
  with gr.Column(scale=2):
774
  train_status = gr.Textbox(label="Live Logs (tail)", interactive=False, lines=12)
775
  loss_plot = gr.Plot(label="Loss")
776
  map_plot = gr.Plot(label="mAP")
777
- final_model_file = gr.File(label="Download Trained Model", interactive=False, visible=False)
778
 
779
  with gr.TabItem("4. Upload Model"):
780
- gr.Markdown("### Upload Your Trained Model")
781
  with gr.Row():
782
  with gr.Column():
783
- gr.Markdown("#### Hugging Face")
784
- hf_token = gr.Textbox(label="Hugging Face API Token", type="password")
785
- hf_repo = gr.Textbox(label="Hugging Face Repo ID", placeholder="e.g., username/my-rtdetrv2-model")
786
  with gr.Column():
787
- gr.Markdown("#### GitHub")
788
- gh_token = gr.Textbox(label="GitHub Personal Access Token", type="password")
789
- gh_repo = gr.Textbox(label="GitHub Repo", placeholder="e.g., username/my-rtdetrv2-repo")
790
- upload_btn = gr.Button("Upload Model", variant="primary")
791
  with gr.Row():
792
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
793
  gh_status = gr.Textbox(label="GitHub Status", interactive=False)
794
 
795
- # Wire UI handlers
796
- load_btn.click(
797
- fn=load_datasets_handler,
798
- inputs=[rf_api_key, rf_url_file],
799
- outputs=[dataset_status, dataset_info_state, class_df]
800
- )
801
- update_counts_btn.click(
802
- fn=update_class_counts_handler,
803
- inputs=[class_df, dataset_info_state],
804
- outputs=[class_count_summary_df]
805
- )
806
- finalize_btn.click(
807
- fn=finalize_handler,
808
- inputs=[dataset_info_state, class_df],
809
- outputs=[finalize_status, final_dataset_path_state]
810
- )
811
- train_btn.click(
812
- fn=training_handler,
813
- inputs=[final_dataset_path_state, model_file_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
814
- outputs=[train_status, loss_plot, map_plot, final_model_file]
815
- )
816
- upload_btn.click(
817
- fn=upload_handler,
818
- inputs=[final_model_file, hf_token, hf_repo, gh_token, gh_repo],
819
- outputs=[hf_status, gh_status]
820
- )
821
 
822
  if __name__ == "__main__":
823
- # Silence Ultralytics warnings if present in the env (we don't use Ultralytics at all)
824
- os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
825
  app.launch(debug=True)
 
1
+ # app.py — Rolo: RT-DETRv2-only (Supervisely) trainer with auto COCO conversion & config
2
+ import os, sys, subprocess, shutil, stat, yaml, gradio as gr, re, random, logging, requests, json, base64, time
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from urllib.parse import urlparse
4
+ from glob import glob
 
 
 
 
 
 
 
5
  from threading import Thread
6
  from queue import Queue
 
 
 
7
 
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ from roboflow import Roboflow
11
+ from PIL import Image
12
+ import torch
 
 
13
 
14
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
15
 
16
+ REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
17
+ REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
18
+ PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
19
  COMMON_REQUIREMENTS = [
20
+ "gradio>=4.36.1", "roboflow>=1.1.28", "pandas>=2.0.0", "matplotlib>=3.7.0",
21
+ "pyyaml>=6.0.1", "Pillow>=10.0.0", "requests>=2.31.0", "huggingface_hub>=0.22.0",
 
 
 
 
 
 
22
  ]
23
 
24
+ # === bootstrap (clone + pip) ===================================================
25
  def pip_install(args):
26
  logging.info(f"pip install {' '.join(args)}")
27
  subprocess.check_call([sys.executable, "-m", "pip", "install"] + args)
 
32
  logging.info(f"Cloning RT-DETRv2 repo to {REPO_DIR} ...")
33
  subprocess.check_call(["git", "clone", "--depth", "1", REPO_URL, REPO_DIR])
34
  else:
 
35
  try:
36
  subprocess.check_call(["git", "-C", REPO_DIR, "pull", "--ff-only"])
37
  except Exception:
38
+ logging.warning("git pull failed; continuing with current checkout")
39
 
 
40
  pip_install(COMMON_REQUIREMENTS)
 
 
41
  req_file = os.path.join(PY_IMPL_DIR, "requirements.txt")
42
  if os.path.exists(req_file):
43
  pip_install(["-r", req_file])
 
 
44
 
 
45
  try:
46
  ensure_repo_and_requirements()
47
+ except Exception:
48
+ logging.exception("Bootstrap failed, UI will still load so you can see errors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # === model choices (restricted to Supervisely RT-DETRv2) ======================
51
+ MODEL_CHOICES = [("rtdetrv2_s", "Small (default)"), ("rtdetrv2_l", "Large"), ("rtdetrv2_x", "X-Large")]
52
+ DEFAULT_MODEL_KEY = "rtdetrv2_s"
53
 
54
+ # === utilities ================================================================
55
  def handle_remove_readonly(func, path, exc_info):
56
+ try: os.chmod(path, stat.S_IWRITE)
57
+ except Exception: pass
 
 
58
  func(path)
59
 
60
+ _ROBO_URL_RX = re.compile(r"""
61
+ ^(?:
62
+ (?:https?://)?(?:universe|app|www)?\.?roboflow\.com/
63
+ (?P<ws>[A-Za-z0-9\-_]+)/(?P<proj>[A-Za-z0-9\-_]+)/?(?:(?:dataset/[^/]+/)?(?:v?(?P<ver>\d+))?)?
64
+ |
65
+ (?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))?
66
+ )$
67
+ """, re.VERBOSE | re.IGNORECASE)
 
 
 
 
68
 
69
  def parse_roboflow_url(s: str):
70
  s = s.strip()
 
74
  proj = m.group('proj') or m.group('proj2')
75
  ver = m.group('ver') or m.group('ver2')
76
  return ws, proj, (int(ver) if ver else None)
 
77
  parsed = urlparse(s)
78
  parts = [p for p in parsed.path.strip('/').split('/') if p]
79
  if len(parts) >= 2:
80
  version = None
81
  if len(parts) >= 3:
82
+ v = parts[2]
83
+ if v.lower().startswith('v') and v[1:].isdigit(): version = int(v[1:])
84
+ elif v.isdigit(): version = int(v)
 
 
85
  return parts[0], parts[1], version
 
86
  if '/' in s and 'roboflow' not in s:
87
  p = s.split('/')
88
  if len(p) >= 2:
89
  version = None
90
  if len(p) >= 3:
91
  v = p[2]
92
+ if v.lower().startswith('v') and v[1:].isdigit(): version = int(v[1:])
93
+ elif v.isdigit(): version = int(v)
 
 
94
  return p[0], p[1], version
 
95
  return None, None, None
96
 
97
  def get_latest_version(api_key, workspace, project):
 
108
  names = data_yaml.get('names', None)
109
  if isinstance(names, dict):
110
  def _k(x):
111
+ try: return int(x)
112
+ except Exception: return str(x)
113
+ keys = sorted(names.keys(), key=_k)
114
+ names_list = [names[k] for k in keys]
 
 
115
  elif isinstance(names, list):
116
  names_list = names
117
  else:
118
+ nc = int(data_yaml.get('nc', 0) or 0)
 
 
 
 
119
  names_list = [f"class_{i}" for i in range(nc)]
120
  return [str(x) for x in names_list]
121
 
122
  def download_dataset(api_key, workspace, project, version):
 
123
  try:
124
  rf = Roboflow(api_key=api_key)
125
  proj = rf.workspace(workspace).project(project)
126
  ver = proj.version(int(version))
127
+ dataset = ver.download("yolov8") # labels in YOLO format (we'll convert to COCO)
 
128
  data_yaml_path = os.path.join(dataset.location, 'data.yaml')
129
+ with open(data_yaml_path, 'r') as f: data_yaml = yaml.safe_load(f)
 
 
130
  class_names = _extract_class_names(data_yaml)
 
 
 
 
 
 
 
131
  splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))]
132
  return dataset.location, class_names, splits, f"{project}-v{version}"
133
  except Exception as e:
 
139
  base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
140
  return os.path.join(split_dir, 'labels', base)
141
 
142
+ # === YOLOv8 -> COCO converter =================================================
143
+ def yolo_to_coco(split_dir_images, split_dir_labels, class_names, out_json):
144
+ """
145
+ Convert YOLO txt labels to a COCO annotations json.
146
+ """
147
+ images, annotations = [], []
148
+ categories = [{"id": i, "name": n} for i, n in enumerate(class_names)]
149
+ ann_id = 1
150
+ img_id = 1
151
+
152
+ # Simple image size read (PIL); in Spaces this is fine.
153
+ for fname in sorted(os.listdir(split_dir_images)):
154
+ if not fname.lower().endswith((".jpg",".jpeg",".png")): continue
155
+ img_path = os.path.join(split_dir_images, fname)
156
+ try:
157
+ with Image.open(img_path) as im:
158
+ w, h = im.size
159
+ except Exception:
160
+ # skip unreadable images
161
+ continue
162
+ images.append({"id": img_id, "file_name": fname, "width": w, "height": h})
163
+
164
+ label_file = os.path.join(split_dir_labels, os.path.splitext(fname)[0] + ".txt")
165
+ if os.path.exists(label_file):
166
+ with open(label_file, "r") as f:
167
+ for line in f:
168
+ parts = line.strip().split()
169
+ if len(parts) < 5: continue
170
+ cls = int(float(parts[0]))
171
+ cx, cy, bw, bh = map(float, parts[1:5])
172
+ # convert normalized (cx,cy,bw,bh) to x,y,w,h in pixels
173
+ x = (cx - bw/2.0) * w
174
+ y = (cy - bh/2.0) * h
175
+ ww = bw * w
176
+ hh = bh * h
177
+ annotations.append({
178
+ "id": ann_id,
179
+ "image_id": img_id,
180
+ "category_id": cls,
181
+ "bbox": [max(0.0,x), max(0.0,y), max(1.0,ww), max(1.0,hh)],
182
+ "area": max(1.0, ww*hh),
183
+ "iscrowd": 0,
184
+ "segmentation": []
185
+ })
186
+ ann_id += 1
187
+ img_id += 1
188
+
189
+ coco = {"images": images, "annotations": annotations, "categories": categories}
190
+ os.makedirs(os.path.dirname(out_json), exist_ok=True)
191
+ with open(out_json, "w") as f: json.dump(coco, f)
192
+
193
+ def make_coco_annotations(merged_dir, class_names):
194
+ """
195
+ Build COCO jsons under merged_dir/annotations:
196
+ instances_train.json, instances_val.json, instances_test.json
197
+ """
198
+ ann_dir = os.path.join(merged_dir, "annotations")
199
+ os.makedirs(ann_dir, exist_ok=True)
200
+ mapping = {"train": "instances_train.json", "valid": "instances_val.json", "test": "instances_test.json"}
201
+ for split, outname in mapping.items():
202
+ img_dir = os.path.join(merged_dir, split, "images")
203
+ lbl_dir = os.path.join(merged_dir, split, "labels")
204
+ out_json = os.path.join(ann_dir, outname)
205
+ if os.path.exists(img_dir) and os.listdir(img_dir):
206
+ yolo_to_coco(img_dir, lbl_dir, class_names, out_json)
207
+ return ann_dir
208
+
209
+ # === dataset merging ==========================================================
210
  def gather_class_counts(dataset_info, class_mapping):
211
+ if not dataset_info: return {}
 
212
  final_names = set(v for v in class_mapping.values() if v is not None)
213
  counts = {name: 0 for name in final_names}
 
214
  for loc, names, splits, _ in dataset_info:
215
  id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
216
  for split in splits:
217
  labels_dir = os.path.join(loc, split, 'labels')
218
+ if not os.path.exists(labels_dir): continue
 
219
  for label_file in os.listdir(labels_dir):
220
+ if not label_file.endswith('.txt'): continue
 
221
  found = set()
222
  with open(os.path.join(labels_dir, label_file), 'r') as f:
223
  for line in f:
224
  parts = line.strip().split()
225
+ if not parts: continue
 
226
  try:
227
  cls_id = int(parts[0])
228
  mapped = id_to_name.get(cls_id, None)
229
+ if mapped: found.add(mapped)
 
230
  except Exception:
231
  continue
232
+ for m in found: counts[m] += 1
 
233
  return counts
234
 
235
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
 
249
  for loc, _, splits, _ in dataset_info:
250
  for split in splits:
251
  img_dir = os.path.join(loc, split, 'images')
252
+ if not os.path.exists(img_dir): continue
 
253
  for img_file in os.listdir(img_dir):
254
  if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
255
  all_images.append((os.path.join(img_dir, img_file), split, loc))
256
  random.shuffle(all_images)
257
 
258
  progress(0.2, desc="Selecting images based on limits...")
259
+ selected_images, current_counts = [], {cls: 0 for cls in active_classes}
 
260
  loc_to_names = {info[0]: info[1] for info in dataset_info}
261
 
 
262
  for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
263
  lbl_path = label_path_for(img_path)
264
+ if not os.path.exists(lbl_path): continue
 
 
265
  source_names = loc_to_names.get(source_loc, [])
266
  image_classes = set()
267
  with open(lbl_path, 'r') as f:
268
  for line in f:
269
  parts = line.strip().split()
270
+ if not parts: continue
 
271
  try:
272
  cls_id = int(parts[0])
273
  orig = source_names[cls_id]
274
  mapped = class_mapping.get(orig, orig)
275
+ if mapped in active_classes: image_classes.add(mapped)
 
276
  except Exception:
277
  continue
278
+ if not image_classes: continue
279
+ if any(current_counts[c] >= class_limits[c] for c in image_classes): continue
 
 
 
 
 
280
  selected_images.append((img_path, split))
281
+ for c in image_classes: current_counts[c] += 1
 
282
 
283
  progress(0.6, desc=f"Copying {len(selected_images)} files...")
284
  for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
 
289
 
290
  source_loc = None
291
  for info in dataset_info:
292
+ if img_path.startswith(info[0]): source_loc = info[0]; break
 
 
293
  source_names = loc_to_names.get(source_loc, [])
294
 
295
  with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out:
296
  for line in f_in:
297
  parts = line.strip().split()
298
+ if not parts: continue
 
299
  try:
300
  old_id = int(parts[0])
301
  original_name = source_names[old_id]
 
306
  except Exception:
307
  continue
308
 
309
+ progress(0.9, desc="Writing data.yaml + COCO annotations...")
310
  with open(os.path.join(merged_dir, 'data.yaml'), 'w') as f:
311
  yaml.dump({
312
  'path': os.path.abspath(merged_dir),
 
317
  'names': active_classes
318
  }, f)
319
 
320
+ # also create COCO jsons for RT-DETRv2 training
321
+ ann_dir = make_coco_annotations(merged_dir, active_classes)
322
+ progress(0.98, desc="Finalizing...")
323
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
324
 
325
+ # === entrypoint + config detection/generation =================================
326
+ def find_training_script(repo_root):
 
 
 
327
  """
328
+ Recursively search for a tools/train.py (or train.py) suitable for RT-DETRv2.
 
 
 
329
  """
330
+ candidates = []
331
+ for pat in ["**/tools/train.py", "**/train.py"]:
332
+ candidates.extend(glob(os.path.join(repo_root, pat), recursive=True))
333
+ # Prefer ones inside rtdetrv2_pytorch
334
+ candidates.sort(key=lambda p: (0 if "rtdetrv2_pytorch" in p else 1, len(p)))
335
+ return candidates[0] if candidates else None
336
+
337
+ def find_model_config_template(model_key):
 
 
 
 
 
338
  """
339
+ Find a base config YAML in the repo that matches the chosen model key.
340
+ We look under any configs directory for a yaml containing 'rtdetrv2' and the model key.
 
341
  """
342
+ yamls = glob(os.path.join(REPO_DIR, "**", "*.yml"), recursive=True) + \
343
+ glob(os.path.join(REPO_DIR, "**", "*.yaml"), recursive=True)
344
+ # prioritize files with both rtdetrv2 and the exact key in the name
345
+ def score(p):
346
+ n = os.path.basename(p).lower()
347
+ s = 0
348
+ if "rtdetrv2" in n: s += 2
349
+ if model_key in n: s += 3
350
+ if "coco" in n: s += 1
351
+ return -s, len(p)
352
+ yamls.sort(key=score)
353
+ return yamls[0] if yamls else None
354
+
355
+ def write_custom_config(base_cfg_path, merged_dir, class_count, model_key, run_name, epochs, batch, imgsz, lr, optimizer):
356
+ """
357
+ Generate a small override config that points to our COCO jsons and sets key hyperparams.
358
+ This YAML gets merged by the repo's config system if it supports '_base_' includes;
359
+ otherwise, it still provides reasonable keys many RT-DETRv2 forks accept.
360
+ """
361
+ ann_dir = os.path.join(merged_dir, "annotations")
362
+ cfg_out_dir = os.path.join("generated_configs")
363
+ os.makedirs(cfg_out_dir, exist_ok=True)
364
+ out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
365
+
366
+ # Try a broadly compatible structure (kept simple on purpose)
367
+ override = {
368
+ "_base_": os.path.relpath(base_cfg_path, start=cfg_out_dir) if base_cfg_path else None,
369
+ "model": {"name": model_key, "num_classes": int(class_count)},
370
+ "input_size": int(imgsz),
371
+ "max_epoch": int(epochs),
372
+ "solver": {
373
+ "base_lr": float(lr),
374
+ "optimizer": str(optimizer).lower(), # "adam", "adamw", "sgd"
375
+ "batch_size": int(batch),
376
+ },
377
+ "dataset": {
378
+ "train": {
379
+ "name": "coco",
380
+ "ann_file": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
381
+ "img_prefix": os.path.abspath(os.path.join(merged_dir, "train", "images")),
382
+ },
383
+ "val": {
384
+ "name": "coco",
385
+ "ann_file": os.path.abspath(os.path.join(ann_dir, "instances_val.json")),
386
+ "img_prefix": os.path.abspath(os.path.join(merged_dir, "valid", "images")),
387
+ },
388
+ "test": {
389
+ "name": "coco",
390
+ "ann_file": os.path.abspath(os.path.join(ann_dir, "instances_test.json")),
391
+ "img_prefix": os.path.abspath(os.path.join(merged_dir, "test", "images")),
392
+ },
393
+ },
394
+ "output_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
395
+ # some forks use these dataloader keys:
396
+ "train_dataloader": {"batch_size": int(batch)},
397
+ "val_dataloader": {"batch_size": int(batch)},
398
+ }
399
+ # drop None values cleanly
400
+ if override["_base_"] is None:
401
+ del override["_base_"]
402
+
403
+ with open(out_path, "w") as f: yaml.safe_dump(override, f, sort_keys=False)
404
+ return out_path
405
 
406
  def find_best_checkpoint(out_dir):
407
+ pats = [
 
408
  os.path.join(out_dir, "**", "best*.pt"),
409
  os.path.join(out_dir, "**", "best*.pth"),
410
  os.path.join(out_dir, "**", "model_best*.pt"),
411
  os.path.join(out_dir, "**", "model_best*.pth"),
412
  ]
413
+ for p in pats:
414
+ f = sorted(glob(p, recursive=True))
415
+ if f: return f[0]
 
 
416
  any_ckpt = sorted(glob(os.path.join(out_dir, "**", "*.pt"), recursive=True) +
417
  glob(os.path.join(out_dir, "**", "*.pth"), recursive=True))
418
  return any_ckpt[-1] if any_ckpt else None
419
 
420
+ # === Gradio handlers ==========================================================
 
 
 
421
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
422
  api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
423
+ if not api_key: raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
424
+ if not url_file: raise gr.Error("Upload a .txt with Roboflow URLs or 'workspace/project[/vN]' lines.")
 
 
425
 
426
  with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
427
  urls = [line.strip() for line in f if line.strip()]
 
436
  if ver is None:
437
  ver = get_latest_version(api_key, ws, proj)
438
  if ver is None:
439
+ failures.append((raw, f"No latest version for {ws}/{proj}"))
440
  continue
 
441
  loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver))
442
+ if loc: dataset_info.append((loc, names, splits, name_str))
443
+ else: failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
 
 
444
 
445
  if not dataset_info:
446
+ msg = "No datasets loaded.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
447
  raise gr.Error(msg)
448
 
 
449
  all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
450
  class_map = {name: name for name in all_names}
451
+ counts = gather_class_counts(dataset_info, class_map)
452
+ df = pd.DataFrame([[n, n, counts.get(n, 0), False] for n in all_names],
 
453
  columns=["Original Name", "Rename To", "Max Images", "Remove"])
454
+ status = "Datasets loaded successfully."
455
+ if failures: status += f" ({len(dataset_info)} OK, {len(failures)} failed; see logs)."
456
+ return status, dataset_info, df
 
 
 
457
 
458
  def update_class_counts_handler(class_df, dataset_info):
459
+ if class_df is None or not dataset_info: return None
 
 
460
  class_df = pd.DataFrame(class_df)
461
+ mapping = {row["Original Name"]: (None if bool(row["Remove"]) else row["Rename To"])
462
+ for _, row in class_df.iterrows()}
 
 
 
463
  final_names = sorted(set(v for v in mapping.values() if v))
464
  counts = {k: 0 for k in final_names}
 
465
  for loc, names, splits, _ in dataset_info:
466
  id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
467
  for split in splits:
468
  labels_dir = os.path.join(loc, split, 'labels')
469
+ if not os.path.exists(labels_dir): continue
 
470
  for label_file in os.listdir(labels_dir):
471
+ if not label_file.endswith('.txt'): continue
 
472
  found = set()
473
  with open(os.path.join(labels_dir, label_file), 'r') as f:
474
  for line in f:
475
  parts = line.strip().split()
476
+ if not parts: continue
 
477
  try:
478
  cls_id = int(parts[0])
479
  mapped = id_to_final.get(cls_id, None)
480
+ if mapped: found.add(mapped)
 
481
  except Exception:
482
  continue
483
+ for m in found: counts[m] += 1
 
 
484
  return pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
485
 
486
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
487
+ if not dataset_info: raise gr.Error("Load datasets first in Tab 1.")
488
+ if class_df is None: raise gr.Error("Class data is missing.")
 
 
 
489
  class_df = pd.DataFrame(class_df)
490
  class_mapping, class_limits = {}, {}
491
  for _, row in class_df.iterrows():
492
  orig = row["Original Name"]
493
+ if bool(row["Remove"]): continue
 
494
  final_name = row["Rename To"]
495
  class_mapping[orig] = final_name
496
  class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
 
497
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
498
  return status, path
499
 
500
+ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
501
+ if not dataset_path: raise gr.Error("Finalize a dataset in Tab 2 before training.")
502
+
503
+ # 1) find training script (nested-safe)
504
+ train_script = find_training_script(REPO_DIR)
505
+ if not train_script:
506
+ raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).")
507
+
508
+ # 2) pick a model config template from repo (best effort)
509
+ base_cfg = find_model_config_template(model_key)
510
+
511
+ # 3) read class names from our merged data.yaml to set num_classes + produce COCO JSONs
512
+ data_yaml = os.path.join(dataset_path, "data.yaml")
513
+ with open(data_yaml, "r") as f: dy = yaml.safe_load(f)
514
+ class_names = [str(x) for x in dy.get("names", [])]
515
+ ann_dir = make_coco_annotations(dataset_path, class_names)
516
+
517
+ # 4) write a small override config that points to our data and injects hyper-params
518
+ cfg_path = write_custom_config(
519
+ base_cfg_path=base_cfg,
520
+ merged_dir=dataset_path,
521
+ class_count=len(class_names),
522
+ model_key=model_key,
523
  run_name=run_name,
524
  epochs=epochs,
525
  batch=batch,
526
  imgsz=imgsz,
527
  lr=lr,
528
+ optimizer=opt,
529
  )
530
+
531
+ out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
532
+ os.makedirs(out_dir, exist_ok=True)
533
+
534
+ # 5) build & run the command (single-GPU by default, no manual CLI edits)
535
+ cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)]
536
+ # many forks accept optional flags; pass safe ones if present
537
+ if "--use-amp" in open(train_script).read(): # cheap check
538
+ cmd += ["--use-amp"]
539
  logging.info(f"Training command: {' '.join(cmd)}")
540
 
 
541
  q = Queue()
 
542
  def run_train():
543
  try:
544
  env = os.environ.copy()
545
  env["PYTHONPATH"] = REPO_DIR + os.pathsep + env.get("PYTHONPATH", "")
546
+ proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
547
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
548
+ bufsize=1, text=True, env=env)
549
+ for line in proc.stdout: q.put(line.rstrip())
550
  proc.wait()
551
  q.put(f"__EXITCODE__:{proc.returncode}")
552
  except Exception as e:
 
554
 
555
  Thread(target=run_train, daemon=True).start()
556
 
557
+ log_tail, last_epoch, total_epochs = [], 0, int(epochs)
 
 
558
  while True:
559
  line = q.get()
560
  if line.startswith("__EXITCODE__"):
561
+ code = int(line.split(":",1)[1])
562
+ if code != 0: raise gr.Error(f"Training exited with code {code}. See logs above.")
 
563
  break
564
  if line.startswith("__ERROR__"):
565
  raise gr.Error(f"Training failed: {line.split(':',1)[1]}")
566
 
567
+ log_tail.append(line)
568
+ log_tail = log_tail[-30:]
569
+
570
  m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line)
571
  if m:
572
  try:
 
574
  total_epochs = max(total_epochs, int(m.group(2)))
575
  except Exception:
576
  pass
577
+ progress(min(max(last_epoch / max(1,total_epochs),0.0),1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
578
 
579
+ fig1 = plt.figure(); plt.title("Loss (see logs)")
580
+ fig2 = plt.figure(); plt.title("mAP (see logs)")
581
+ yield "\n".join(log_tail), fig1, fig2, None
 
 
 
 
 
 
 
 
 
582
 
583
+ ckpt = find_best_checkpoint(out_dir) or find_best_checkpoint("runs")
 
584
  if not ckpt or not os.path.exists(ckpt):
585
+ raise gr.Error("Training finished, but checkpoint file not found. Check logs/output directory.")
 
 
 
 
 
586
  yield "Training complete!", None, None, gr.File.update(value=ckpt, visible=True)
587
 
588
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
589
+ if not model_file: raise gr.Error("No trained model file to upload.")
 
 
590
  from huggingface_hub import HfApi, HfFolder
591
+ hf_status = "Skipped Hugging Face."
 
592
  if hf_token and hf_repo:
593
  progress(0, desc="Uploading to Hugging Face...")
594
  try:
595
+ api = HfApi(); HfFolder.save_token(hf_token)
 
596
  repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token)
597
+ api.upload_file(model_file.name, os.path.basename(model_file.name), repo_id=hf_repo, token=hf_token)
598
+ hf_status = f"Success! {repo_url}"
 
 
 
 
 
599
  except Exception as e:
600
  hf_status = f"Hugging Face Error: {e}"
601
 
602
+ gh_status = "Skipped GitHub."
603
  if gh_token and gh_repo:
604
  progress(0.5, desc="Uploading to GitHub...")
605
  try:
606
+ if '/' not in gh_repo: raise ValueError("GitHub repo must be 'username/repo'.")
 
 
607
  username, repo_name = gh_repo.split('/')
608
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
609
  headers = {"Authorization": f"token {gh_token}"}
610
+ with open(model_file.name, "rb") as f: content = base64.b64encode(f.read()).decode()
 
 
 
611
  get_resp = requests.get(api_url, headers=headers, timeout=30)
612
  sha = get_resp.json().get('sha') if get_resp.ok else None
 
613
  data = {"message": "Upload trained model from Rolo app", "content": content}
614
+ if sha: data["sha"] = sha
 
 
615
  put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
616
+ if put_resp.ok: gh_status = f"Success! {put_resp.json()['content']['html_url']}"
617
+ else: gh_status = f"GitHub Error: {put_resp.json().get('message','Unknown')}"
 
 
 
 
618
  except Exception as e:
619
  gh_status = f"GitHub Error: {e}"
620
+ progress(1); return hf_status, gh_status
621
 
622
+ # === UI =======================================================================
 
 
 
 
 
623
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
624
+ gr.Markdown("# Rolo RT-DETRv2 Trainer (Supervisely repo only)")
625
 
626
  dataset_info_state = gr.State([])
627
  final_dataset_path_state = gr.State(None)
628
 
629
  with gr.Tabs():
630
  with gr.TabItem("1. Prepare Datasets"):
631
+ gr.Markdown("Upload a `.txt` with Roboflow URLs or `workspace/project[/vN]` per line. We’ll pull and merge them.")
632
  with gr.Row():
633
+ rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY)", type="password", scale=2)
634
+ rf_url_file = gr.File(label="Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
635
  load_btn = gr.Button("Load Datasets", variant="primary")
636
  dataset_status = gr.Textbox(label="Status", interactive=False)
637
 
638
  with gr.TabItem("2. Manage & Merge"):
639
+ gr.Markdown("Rename/merge/remove classes and set per-class image caps. Then finalize.")
640
  with gr.Row():
641
+ class_df = gr.DataFrame(headers=["Original Name","Rename To","Max Images","Remove"],
642
+ datatype=["str","str","number","bool"], label="Class Config", interactive=True, scale=3)
 
 
 
643
  with gr.Column(scale=1):
644
+ class_count_summary_df = gr.DataFrame(label="Merged Class Counts Preview",
645
+ headers=["Final Class Name","Est. Total Images"], interactive=False)
 
 
 
646
  update_counts_btn = gr.Button("Update Counts")
647
  finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
648
  finalize_status = gr.Textbox(label="Status", interactive=False)
649
 
650
  with gr.TabItem("3. Configure & Train"):
651
+ gr.Markdown("Pick RT-DETRv2 model, set hyper-params, press Start.")
652
  with gr.Row():
653
  with gr.Column(scale=1):
654
+ model_dd = gr.Dropdown(choices=[k for k,_ in MODEL_CHOICES], value=DEFAULT_MODEL_KEY,
655
+ label="Model (RT-DETRv2)")
 
 
 
 
 
 
 
656
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
657
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
658
  batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
659
  imgsz_num = gr.Number(label="Image Size", value=640)
660
  lr_num = gr.Number(label="Learning Rate", value=0.001)
661
+ opt_dd = gr.Dropdown(["Adam","AdamW","SGD"], value="Adam", label="Optimizer")
662
  train_btn = gr.Button("Start Training", variant="primary")
663
  with gr.Column(scale=2):
664
  train_status = gr.Textbox(label="Live Logs (tail)", interactive=False, lines=12)
665
  loss_plot = gr.Plot(label="Loss")
666
  map_plot = gr.Plot(label="mAP")
667
+ final_model_file = gr.File(label="Download Trained Checkpoint", interactive=False, visible=False)
668
 
669
  with gr.TabItem("4. Upload Model"):
670
+ gr.Markdown("Optionally push your checkpoint to Hugging Face / GitHub.")
671
  with gr.Row():
672
  with gr.Column():
673
+ gr.Markdown("**Hugging Face**")
674
+ hf_token = gr.Textbox(label="HF Token", type="password")
675
+ hf_repo = gr.Textbox(label="HF Repo (user/repo)")
676
  with gr.Column():
677
+ gr.Markdown("**GitHub**")
678
+ gh_token = gr.Textbox(label="GitHub PAT", type="password")
679
+ gh_repo = gr.Textbox(label="GitHub Repo (user/repo)")
680
+ upload_btn = gr.Button("Upload", variant="primary")
681
  with gr.Row():
682
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
683
  gh_status = gr.Textbox(label="GitHub Status", interactive=False)
684
 
685
+ load_btn.click(load_datasets_handler, [rf_api_key, rf_url_file],
686
+ [dataset_status, dataset_info_state, class_df])
687
+ update_counts_btn.click(update_class_counts_handler, [class_df, dataset_info_state],
688
+ [class_count_summary_df])
689
+ finalize_btn.click(finalize_handler, [dataset_info_state, class_df],
690
+ [finalize_status, final_dataset_path_state])
691
+ train_btn.click(training_handler,
692
+ [final_dataset_path_state, model_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
693
+ [train_status, loss_plot, map_plot, final_model_file])
694
+ upload_btn.click(upload_handler, [final_model_file, hf_token, hf_repo, gh_token, gh_repo],
695
+ [hf_status, gh_status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
 
697
  if __name__ == "__main__":
698
+ os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") # silence stray warnings from other libs
 
699
  app.launch(debug=True)