wuhp commited on
Commit
2a0a7c9
·
verified ·
1 Parent(s): ae4cf01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -279
app.py CHANGED
@@ -1,4 +1,12 @@
 
 
 
 
 
 
1
  import os
 
 
2
  import shutil
3
  import stat
4
  import yaml
@@ -11,33 +19,88 @@ import logging
11
  import requests
12
  import json
13
  from PIL import Image
 
14
  import pandas as pd
15
- import matplotlib
16
- matplotlib.use("Agg") # headless (HF Spaces)
17
  import matplotlib.pyplot as plt
18
  from threading import Thread
19
  from queue import Queue
20
- from huggingface_hub import HfApi, HfFolder
21
- import base64
22
- import subprocess
23
- import sys
24
  import time
25
- import glob
26
 
27
  # --- Logging ---
28
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
29
 
30
- # --- RT-DETRv2 backend defaults (Supervisely ecosystem) ---
31
- RTDETRV2_REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
32
- DEFAULT_REPO_DIR = os.path.join("third_party", "rtdetrv2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- RTDETRV2_MODELS = [
35
- "rtdetrv2-l-640", # labels only; match your config via the command template
36
- "rtdetrv2-x-640"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ]
38
- DEFAULT_MODEL = RTDETRV2_MODELS[0]
 
 
 
 
39
 
40
- # --- Utilities ---
41
  def handle_remove_readonly(func, path, exc_info):
42
  try:
43
  os.chmod(path, stat.S_IWRITE)
@@ -51,15 +114,11 @@ _ROBO_URL_RX = re.compile(
51
  (?:https?://)?(?:universe|app|www)?\.?roboflow\.com/
52
  (?P<ws>[A-Za-z0-9\-_]+)/
53
  (?P<proj>[A-Za-z0-9\-_]+)/?
54
- (?:
55
- (?:dataset/[^/]+/)?
56
- (?:v?(?P<ver>\d+))?
57
- )?
58
  |
59
  (?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))?
60
  )$
61
- """,
62
- re.VERBOSE | re.IGNORECASE
63
  )
64
 
65
  def parse_roboflow_url(s: str):
@@ -111,20 +170,25 @@ def _extract_class_names(data_yaml):
111
  names = data_yaml.get('names', None)
112
  if isinstance(names, dict):
113
  def _k(x):
114
- try: return int(x)
115
- except Exception: return str(x)
116
- ordered = sorted(names.keys(), key=_k)
117
- names_list = [names[k] for k in ordered]
 
 
118
  elif isinstance(names, list):
119
  names_list = names
120
  else:
121
  nc = data_yaml.get('nc', 0)
122
- try: nc = int(nc)
123
- except Exception: nc = 0
 
 
124
  names_list = [f"class_{i}" for i in range(nc)]
125
  return [str(x) for x in names_list]
126
 
127
  def download_dataset(api_key, workspace, project, version):
 
128
  try:
129
  rf = Roboflow(api_key=api_key)
130
  proj = rf.workspace(workspace).project(project)
@@ -143,16 +207,14 @@ def download_dataset(api_key, workspace, project, version):
143
  if len(class_names) != nc:
144
  logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.")
145
 
146
- splits = [s for s in ['train', 'valid', 'test']
147
- if os.path.exists(os.path.join(dataset.location, s))]
148
-
149
  return dataset.location, class_names, splits, f"{project}-v{version}"
150
  except Exception as e:
151
  logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}")
152
  return None, [], [], None
153
 
154
  def label_path_for(img_path: str) -> str:
155
- split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
156
  base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
157
  return os.path.join(split_dir, 'labels', base)
158
 
@@ -161,6 +223,7 @@ def gather_class_counts(dataset_info, class_mapping):
161
  return {}
162
  final_names = set(v for v in class_mapping.values() if v is not None)
163
  counts = {name: 0 for name in final_names}
 
164
  for loc, names, splits, _ in dataset_info:
165
  id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
166
  for split in splits:
@@ -179,7 +242,7 @@ def gather_class_counts(dataset_info, class_mapping):
179
  try:
180
  cls_id = int(parts[0])
181
  mapped = id_to_name.get(cls_id, None)
182
- if mapped in final_names:
183
  found.add(mapped)
184
  except Exception:
185
  continue
@@ -197,7 +260,7 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
197
  os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True)
198
  os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True)
199
 
200
- active_classes = sorted(set([cls for cls, limit in class_limits.items() if limit > 0]))
201
  final_class_map = {name: i for i, name in enumerate(active_classes)}
202
 
203
  all_images = []
@@ -216,6 +279,7 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
216
  current_counts = {cls: 0 for cls in active_classes}
217
  loc_to_names = {info[0]: info[1] for info in dataset_info}
218
 
 
219
  for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
220
  lbl_path = label_path_for(img_path)
221
  if not os.path.exists(lbl_path):
@@ -239,6 +303,7 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
239
 
240
  if not image_classes:
241
  continue
 
242
  if any(current_counts[c] >= class_limits[c] for c in image_classes):
243
  continue
244
 
@@ -288,128 +353,116 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
288
 
289
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
290
 
291
- # --- Repo + deps helpers (auto-install for HF Spaces) ---
292
-
293
- def run_pip_install(args, desc="pip install"):
294
- logging.info(f"{desc}: {args}")
295
- cmd = [sys.executable, "-m", "pip", "install"] + args
296
- proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
297
- logging.info(proc.stdout)
298
- if proc.returncode != 0:
299
- raise RuntimeError(f"{desc} failed with code {proc.returncode}")
300
 
301
- def ensure_repo(repo_dir: str, repo_url: str = RTDETRV2_REPO_URL):
302
- if os.path.isdir(repo_dir) and os.path.isdir(os.path.join(repo_dir, ".git")):
303
- return
304
- os.makedirs(os.path.dirname(repo_dir), exist_ok=True)
305
- logging.info(f"Cloning RT-DETRv2 repo into {repo_dir} ...")
306
- subprocess.run(["git", "clone", "--depth", "1", repo_url, repo_dir], check=True)
307
-
308
- def ensure_python_deps(repo_dir: str):
309
  """
310
- Auto-install dependencies (idempotent).
311
- - Tries to install pinned basics that are often needed.
312
- - If repo has requirements*.txt, install them.
313
- - Creates a .deps_installed marker to skip on next run.
314
  """
315
- marker = os.path.join(repo_dir, ".deps_installed")
316
- if os.path.exists(marker):
317
- logging.info("Dependencies already installed; skipping.")
318
- return
319
-
320
- # 1) Common essentials for vision training environments on HF Spaces
321
- basics = [
322
- "numpy<2", # safer with many libs
323
- "pillow",
324
- "tqdm",
325
- "pyyaml",
326
- "matplotlib",
327
- "pandas",
328
- "scipy",
329
- "opencv-python-headless",
330
- "packaging",
331
- "requests",
332
- "pycocotools-windows; platform_system=='Windows'",
333
- "pycocotools; platform_system!='Windows'",
334
- ]
335
- try:
336
- run_pip_install(basics, desc="Installing common basics")
337
- except Exception as e:
338
- logging.warning(f"Basic installs had issues: {e}")
339
-
340
- # 2) Repo requirements
341
- req_files = []
342
- for name in ["requirements.txt", "requirements-dev.txt", "requirements.in"]:
343
- p = os.path.join(repo_dir, name)
344
- if os.path.isfile(p):
345
- req_files.append(p)
346
-
347
- for rf in req_files:
348
- try:
349
- run_pip_install(["-r", rf], desc=f"Installing repo requirements from {rf}")
350
- except Exception as e:
351
- logging.warning(f"Installing {rf} failed: {e}")
352
-
353
- # 3) Optional: torch if not present (CPU-only by default on Spaces)
354
- try:
355
- import torch # noqa: F401
356
- except Exception:
357
- # Try a CPU-friendly torch; change version/cuda wheels if needed
358
- try:
359
- run_pip_install(["torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cpu"], desc="Installing PyTorch (CPU)")
360
- except Exception as e:
361
- logging.warning(f"PyTorch installation failed/skipped: {e}")
362
-
363
- # Mark done
364
- with open(marker, "w") as f:
365
- f.write("ok\n")
366
-
367
- def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, imgsz: int,
368
- lr: float, optimizer: str, run_name: str, output_dir: str) -> str:
369
- return template.format(
370
- data_yaml=data_yaml,
371
- epochs=int(epochs),
372
- batch=int(batch),
373
- imgsz=int(imgsz),
374
- lr=float(lr),
375
- optimizer=str(optimizer),
376
- run_name=str(run_name),
377
- output_dir=output_dir
378
- )
379
-
380
- _METRIC_PATTERNS = [
381
- (re.compile(r"mAP@0\.5[:/]?0\.95[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50_95"),
382
- (re.compile(r"mAP50[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50"),
383
- (re.compile(r"\bval[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "val_loss"),
384
- (re.compile(r"\btrain[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "train_loss"),
385
- (re.compile(r"\bepoch[^0-9]*([0-9]+)"), "epoch"),
386
- ]
387
-
388
- def parse_metrics_from_line(line: str):
389
- result = {}
390
- for pat, key in _METRIC_PATTERNS:
391
- m = pat.search(line)
392
- if m:
393
- val = m.group(1)
394
- try:
395
- result[key] = int(val) if key == "epoch" else float(val)
396
- except Exception:
397
- pass
398
- return result
399
-
400
- def guess_final_weights(output_dir: str):
401
  patterns = [
402
- os.path.join(output_dir, "**", "best.*"),
403
- os.path.join(output_dir, "**", "best_model.*"),
404
- os.path.join(output_dir, "**", "checkpoint_best.*"),
 
405
  ]
406
  for p in patterns:
407
- hits = glob.glob(p, recursive=True)
408
- if hits:
409
- return hits[0]
410
- return None
 
 
 
 
 
 
 
411
 
412
- # --- Gradio handlers ---
413
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
414
  api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
415
  if not api_key:
@@ -420,8 +473,7 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
420
  with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
421
  urls = [line.strip() for line in f if line.strip()]
422
 
423
- dataset_info = []
424
- failures = []
425
  for i, raw in enumerate(urls):
426
  progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}")
427
  ws, proj, ver = parse_roboflow_url(raw)
@@ -444,32 +496,33 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
444
  msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
445
  raise gr.Error(msg)
446
 
 
447
  all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
448
  class_map = {name: name for name in all_names}
 
449
  initial_counts = gather_class_counts(dataset_info, class_map)
450
- df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
 
451
  status_text = "Datasets loaded successfully."
452
  if failures:
453
  status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
454
 
455
- return status_text, dataset_info, gr.update(
456
- value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
457
- )
458
 
459
  def update_class_counts_handler(class_df, dataset_info):
460
  if class_df is None or not dataset_info:
461
  return None
 
462
  class_df = pd.DataFrame(class_df)
463
  mapping = {}
464
  for _, row in class_df.iterrows():
465
  orig = row["Original Name"]
466
- if bool(row["Remove"]):
467
- mapping[orig] = None
468
- else:
469
- mapping[orig] = row["Rename To"]
470
 
471
  final_names = sorted(set(v for v in mapping.values() if v))
472
  counts = {k: 0 for k in final_names}
 
473
  for loc, names, splits, _ in dataset_info:
474
  id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
475
  for split in splits:
@@ -495,8 +548,7 @@ def update_class_counts_handler(class_df, dataset_info):
495
  for m in found:
496
  counts[m] += 1
497
 
498
- summary_df = pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
499
- return summary_df
500
 
501
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
502
  if not dataset_info:
@@ -505,8 +557,7 @@ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
505
  raise gr.Error("Class data is missing.")
506
 
507
  class_df = pd.DataFrame(class_df)
508
- class_mapping = {}
509
- class_limits = {}
510
  for _, row in class_df.iterrows():
511
  orig = row["Original Name"]
512
  if bool(row["Remove"]):
@@ -518,89 +569,100 @@ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
518
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
519
  return status, path
520
 
521
- def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, epochs, batch, imgsz, lr, opt,
522
- cmd_template, progress=gr.Progress()):
523
  if not dataset_path:
524
  raise gr.Error("Finalize a dataset in Tab 2 before training.")
525
 
526
- # Clone + deps (idempotent)
527
- try:
528
- ensure_repo(repo_dir)
529
- ensure_python_deps(repo_dir)
530
- except subprocess.CalledProcessError as e:
531
- raise gr.Error(f"Failed to clone repo: {e}")
532
- except Exception as e:
533
- raise gr.Error(f"Dependency setup failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
534
 
535
- # Output dir
536
- output_dir = os.path.join("runs", "train", str(run_name))
537
- os.makedirs(output_dir, exist_ok=True)
538
 
539
- data_yaml = os.path.join(dataset_path, "data.yaml")
540
- if not os.path.isfile(data_yaml):
541
- raise gr.Error(f"'data.yaml' was not found in: {dataset_path}")
542
-
543
- # Build command from template
544
- cmd = make_train_command(
545
- template=cmd_template,
546
- data_yaml=data_yaml,
547
- epochs=int(epochs),
548
- batch=int(batch),
549
- imgsz=int(imgsz),
550
- lr=float(lr),
551
- optimizer=str(opt),
552
- run_name=str(run_name),
553
- output_dir=output_dir
554
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
 
556
- logging.info(f"Running training command in {repo_dir}: {cmd}")
557
- proc = subprocess.Popen(
558
- cmd, cwd=repo_dir, shell=True,
559
- stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
560
- bufsize=1, universal_newlines=True, env={**os.environ}
561
- )
562
 
563
- history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']}
564
- for line in iter(proc.stdout.readline, ''):
565
- line = line.rstrip()
566
- progress(0.0, desc=line[-120:])
567
- metrics = parse_metrics_from_line(line)
568
- if metrics:
569
- for k, v in metrics.items():
570
- history[k].append(v)
571
-
572
- # plot loss
573
- fig_loss = plt.figure()
574
- ax_loss = fig_loss.add_subplot(111)
575
- ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
576
- ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
577
- ax_loss.legend(); ax_loss.set_title("Loss")
578
-
579
- # plot mAP
580
- fig_map = plt.figure()
581
- ax_map = fig_map.add_subplot(111)
582
- ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
583
- ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
584
- ax_map.legend(); ax_map.set_title("mAP")
585
-
586
- yield line[-200:], fig_loss, fig_map, None
587
-
588
- proc.stdout.close()
589
- ret = proc.wait()
590
- if ret != 0:
591
- raise gr.Error(f"Training process exited with code {ret}. Check console/logs for details.")
592
-
593
- final_ckpt = guess_final_weights(output_dir)
594
- if final_ckpt and os.path.isfile(final_ckpt):
595
- yield "Training complete!", None, None, gr.File.update(value=final_ckpt, visible=True)
596
- else:
597
- yield ("Training finished. Could not auto-detect a 'best' checkpoint; "
598
- "please check the output directory."), None, None, gr.update(visible=False)
599
 
600
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
601
  if not model_file:
602
  raise gr.Error("No trained model file available to upload. Train a model first.")
603
 
 
 
604
  hf_status = "Skipped Hugging Face (credentials not provided)."
605
  if hf_token and hf_repo:
606
  progress(0, desc="Uploading to Hugging Face...")
@@ -624,6 +686,7 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
624
  try:
625
  if '/' not in gh_repo:
626
  raise ValueError("GitHub repo must be in the form 'username/repo'.")
 
627
  username, repo_name = gh_repo.split('/')
628
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
629
  headers = {"Authorization": f"token {gh_token}"}
@@ -635,9 +698,11 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
635
  sha = get_resp.json().get('sha') if get_resp.ok else None
636
 
637
  data = {"message": "Upload trained model from Rolo app", "content": content}
638
- if sha: data["sha"] = sha
 
639
 
640
  put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
 
641
  if put_resp.ok:
642
  gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
643
  else:
@@ -649,24 +714,26 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
649
  progress(1)
650
  return hf_status, gh_status
651
 
652
- # --- Gradio UI ---
 
 
653
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
654
- gr.Markdown("# Rolo: RT-DETRv2 Training Dashboard (Auto-setup for Hugging Face)")
655
 
656
  dataset_info_state = gr.State([])
657
  final_dataset_path_state = gr.State(None)
658
 
659
  with gr.Tabs():
660
  with gr.TabItem("1. Prepare Datasets"):
661
- gr.Markdown("Upload a `.txt` with Roboflow URLs or `workspace/project[/vN]` lines.")
662
  with gr.Row():
663
- rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY)", type="password", scale=2)
664
  rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
665
  load_btn = gr.Button("Load Datasets", variant="primary")
666
  dataset_status = gr.Textbox(label="Status", interactive=False)
667
 
668
  with gr.TabItem("2. Manage & Merge"):
669
- gr.Markdown("Rename classes, set image limits, or remove them. Preview, then finalize.")
670
  with gr.Row():
671
  class_df = gr.DataFrame(
672
  headers=["Original Name", "Rename To", "Max Images", "Remove"],
@@ -684,43 +751,33 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
684
  finalize_status = gr.Textbox(label="Status", interactive=False)
685
 
686
  with gr.TabItem("3. Configure & Train"):
687
- gr.Markdown("Set hyperparameters and the training command template.")
688
  with gr.Row():
689
  with gr.Column(scale=1):
690
- model_choice_dd = gr.Dropdown(
691
- label="Model Choice (label only; use your config in the template)",
692
- choices=RTDETRV2_MODELS, value=DEFAULT_MODEL
 
 
 
 
 
693
  )
694
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
695
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
696
  batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
697
  imgsz_num = gr.Number(label="Image Size", value=640)
698
  lr_num = gr.Number(label="Learning Rate", value=0.001)
699
- opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="AdamW", label="Optimizer")
700
- repo_dir_tb = gr.Textbox(label="RT-DETRv2 repo directory", value=DEFAULT_REPO_DIR)
701
- cmd_template_tb = gr.Textbox(
702
- label="Train command template",
703
- value=(
704
- "python tools/train.py "
705
- "--data {data_yaml} "
706
- "--epochs {epochs} "
707
- "--batch {batch} "
708
- "--imgsz {imgsz} "
709
- "--lr {lr} "
710
- "--optimizer {optimizer} "
711
- "--output {output_dir}"
712
- ),
713
- lines=4
714
- )
715
  train_btn = gr.Button("Start Training", variant="primary")
716
  with gr.Column(scale=2):
717
- train_status = gr.Textbox(label="Live Status / Logs", interactive=False)
718
- loss_plot = gr.Plot(label="Loss Curves")
719
- map_plot = gr.Plot(label="mAP Curves")
720
- final_model_file = gr.File(label="Download Trained Model (best.*)", interactive=False, visible=False)
721
 
722
  with gr.TabItem("4. Upload Model"):
723
- gr.Markdown("Upload your best checkpoint to Hugging Face or GitHub.")
724
  with gr.Row():
725
  with gr.Column():
726
  gr.Markdown("#### Hugging Face")
@@ -735,6 +792,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
735
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
736
  gh_status = gr.Textbox(label="GitHub Status", interactive=False)
737
 
 
738
  load_btn.click(
739
  fn=load_datasets_handler,
740
  inputs=[rf_api_key, rf_url_file],
@@ -751,19 +809,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
751
  outputs=[finalize_status, final_dataset_path_state]
752
  )
753
  train_btn.click(
754
- fn=training_handler_rtdetrv2,
755
- inputs=[
756
- final_dataset_path_state, # dataset_path
757
- repo_dir_tb, # repo_dir (auto clone + pip install)
758
- model_choice_dd, # model_choice (label only)
759
- run_name_tb,
760
- epochs_sl,
761
- batch_sl,
762
- imgsz_num,
763
- lr_num,
764
- opt_dd,
765
- cmd_template_tb
766
- ],
767
  outputs=[train_status, loss_plot, map_plot, final_model_file]
768
  )
769
  upload_btn.click(
@@ -773,6 +820,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
773
  )
774
 
775
  if __name__ == "__main__":
776
- # Hugging Face Spaces: set server name/port via env if needed.
777
- # Example: app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), debug=True)
778
  app.launch(debug=True)
 
1
+ # app.py
2
+ # Rolo: RT-DETRv2-only Training Dashboard (Supervisely ecosystem)
3
+ # - No Ultralytics import or usage
4
+ # - Auto-installs deps in HF Spaces
5
+ # - Only supports models that ship with https://github.com/supervisely-ecosystem/RT-DETRv2
6
+
7
  import os
8
+ import sys
9
+ import subprocess
10
  import shutil
11
  import stat
12
  import yaml
 
19
  import requests
20
  import json
21
  from PIL import Image
22
+ import torch
23
  import pandas as pd
 
 
24
  import matplotlib.pyplot as plt
25
  from threading import Thread
26
  from queue import Queue
27
+ from glob import glob
 
 
 
28
  import time
29
+ import base64
30
 
31
  # --- Logging ---
32
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
33
 
34
+ REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2" # :contentReference[oaicite:1]{index=1}
35
+ REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
36
+ PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # contains the pytorch impl (models, training)
37
+ WEIGHTS_DIR = os.path.join(PY_IMPL_DIR, "weights")
38
+
39
+ # ------------------------------
40
+ # Environment bootstrap (HF Spaces)
41
+ # ------------------------------
42
+
43
+ COMMON_REQUIREMENTS = [
44
+ "gradio>=4.36.1",
45
+ "roboflow>=1.1.28",
46
+ "pandas>=2.0.0",
47
+ "matplotlib>=3.7.0",
48
+ "pyyaml>=6.0.1",
49
+ "Pillow>=10.0.0",
50
+ "requests>=2.31.0",
51
+ "huggingface_hub>=0.22.0",
52
+ ]
53
+
54
+ def pip_install(args):
55
+ logging.info(f"pip install {' '.join(args)}")
56
+ subprocess.check_call([sys.executable, "-m", "pip", "install"] + args)
57
+
58
+ def ensure_repo_and_requirements():
59
+ os.makedirs(os.path.dirname(REPO_DIR), exist_ok=True)
60
+ if not os.path.exists(REPO_DIR):
61
+ logging.info(f"Cloning RT-DETRv2 repo to {REPO_DIR} ...")
62
+ subprocess.check_call(["git", "clone", "--depth", "1", REPO_URL, REPO_DIR])
63
+ else:
64
+ logging.info("RT-DETRv2 repo already present, pulling latest...")
65
+ try:
66
+ subprocess.check_call(["git", "-C", REPO_DIR, "pull", "--ff-only"])
67
+ except Exception:
68
+ logging.warning("Could not pull latest; continuing with current checkout.")
69
 
70
+ # Install common libs
71
+ pip_install(COMMON_REQUIREMENTS)
72
+
73
+ # Install rtdetrv2_pytorch requirements if present
74
+ req_file = os.path.join(PY_IMPL_DIR, "requirements.txt")
75
+ if os.path.exists(req_file):
76
+ pip_install(["-r", req_file])
77
+ else:
78
+ logging.info("No rtdetrv2_pytorch/requirements.txt found; relying on common reqs.")
79
+
80
+ # Do the bootstrap once at import time (HF Spaces-friendly).
81
+ try:
82
+ ensure_repo_and_requirements()
83
+ except Exception as e:
84
+ logging.exception("Bootstrap failed")
85
+ # Still allow UI to load so user can see the error
86
+ pass
87
+
88
+ # ------------------------------
89
+ # Model options (strictly from RT-DETRv2 repo)
90
+ # ------------------------------
91
+ # We expose only the canonical small/large/xlarge variants that ship with the repo.
92
+ # If the repo adds/removes variants, you can read from weights dir dynamically.
93
+ MODEL_CHOICES = [
94
+ ("rtdetrv2_s", "Small (default)"),
95
+ ("rtdetrv2_l", "Large"),
96
+ ("rtdetrv2_x", "X-Large")
97
  ]
98
+ DEFAULT_MODEL_KEY = "rtdetrv2_s" # Small as default
99
+
100
+ # ------------------------------
101
+ # Utilities
102
+ # ------------------------------
103
 
 
104
  def handle_remove_readonly(func, path, exc_info):
105
  try:
106
  os.chmod(path, stat.S_IWRITE)
 
114
  (?:https?://)?(?:universe|app|www)?\.?roboflow\.com/
115
  (?P<ws>[A-Za-z0-9\-_]+)/
116
  (?P<proj>[A-Za-z0-9\-_]+)/?
117
+ (?:(?:dataset/[^/]+/)?(?:v?(?P<ver>\d+))?)?
 
 
 
118
  |
119
  (?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))?
120
  )$
121
+ """, re.VERBOSE | re.IGNORECASE
 
122
  )
123
 
124
  def parse_roboflow_url(s: str):
 
170
  names = data_yaml.get('names', None)
171
  if isinstance(names, dict):
172
  def _k(x):
173
+ try:
174
+ return int(x)
175
+ except Exception:
176
+ return str(x)
177
+ ordered_keys = sorted(names.keys(), key=_k)
178
+ names_list = [names[k] for k in ordered_keys]
179
  elif isinstance(names, list):
180
  names_list = names
181
  else:
182
  nc = data_yaml.get('nc', 0)
183
+ try:
184
+ nc = int(nc)
185
+ except Exception:
186
+ nc = 0
187
  names_list = [f"class_{i}" for i in range(nc)]
188
  return [str(x) for x in names_list]
189
 
190
  def download_dataset(api_key, workspace, project, version):
191
+ """Download a Roboflow dataset in YOLOv8 format (labels are compatible with our merger)."""
192
  try:
193
  rf = Roboflow(api_key=api_key)
194
  proj = rf.workspace(workspace).project(project)
 
207
  if len(class_names) != nc:
208
  logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.")
209
 
210
+ splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))]
 
 
211
  return dataset.location, class_names, splits, f"{project}-v{version}"
212
  except Exception as e:
213
  logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}")
214
  return None, [], [], None
215
 
216
  def label_path_for(img_path: str) -> str:
217
+ split_dir = os.path.dirname(os.path.dirname(img_path))
218
  base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
219
  return os.path.join(split_dir, 'labels', base)
220
 
 
223
  return {}
224
  final_names = set(v for v in class_mapping.values() if v is not None)
225
  counts = {name: 0 for name in final_names}
226
+
227
  for loc, names, splits, _ in dataset_info:
228
  id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
229
  for split in splits:
 
242
  try:
243
  cls_id = int(parts[0])
244
  mapped = id_to_name.get(cls_id, None)
245
+ if mapped:
246
  found.add(mapped)
247
  except Exception:
248
  continue
 
260
  os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True)
261
  os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True)
262
 
263
+ active_classes = sorted({cls for cls, limit in class_limits.items() if limit > 0})
264
  final_class_map = {name: i for i, name in enumerate(active_classes)}
265
 
266
  all_images = []
 
279
  current_counts = {cls: 0 for cls in active_classes}
280
  loc_to_names = {info[0]: info[1] for info in dataset_info}
281
 
282
+ # progress.tqdm is available on Gradio Progress objects
283
  for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
284
  lbl_path = label_path_for(img_path)
285
  if not os.path.exists(lbl_path):
 
303
 
304
  if not image_classes:
305
  continue
306
+
307
  if any(current_counts[c] >= class_limits[c] for c in image_classes):
308
  continue
309
 
 
353
 
354
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
355
 
356
+ # ------------------------------
357
+ # Training integration (RT-DETRv2 repo)
358
+ # ------------------------------
 
 
 
 
 
 
359
 
360
+ def detect_training_entrypoint():
 
 
 
 
 
 
 
361
  """
362
+ We try a couple of common patterns inside the Supervisely repo:
363
+ 1) rtdetrv2_pytorch/train.py
364
+ 2) tools/train.py
365
+ Returns (python_file, style) where style hints how to build args.
366
  """
367
+ cand1 = os.path.join(PY_IMPL_DIR, "train.py")
368
+ cand2 = os.path.join(REPO_DIR, "tools", "train.py")
369
+ if os.path.exists(cand1):
370
+ return cand1, "pytorch_train"
371
+ if os.path.exists(cand2):
372
+ return cand2, "tools_train"
373
+ # Fallback: just try main.py if present
374
+ cand3 = os.path.join(REPO_DIR, "src", "main.py")
375
+ if os.path.exists(cand3):
376
+ return cand3, "app_main"
377
+ return None, None
378
+
379
+ def build_command(entrypoint, style, dataset_path, model_key, run_name, epochs, batch, imgsz, lr, optimizer):
380
+ """
381
+ Build a best-guess command for the detected style.
382
+ Users never have to edit CLI; we do it for them.
383
+ We keep args conservative and standard (data, epochs, batch, img size).
384
+ """
385
+ data_yaml = os.path.join(dataset_path, "data.yaml")
386
+ out_dir = os.path.join("runs", "train", str(run_name))
387
+ os.makedirs(out_dir, exist_ok=True)
388
+
389
+ # Some repos expect weight/model name; we pass model_key (e.g., rtdetrv2_s) and let their script resolve it.
390
+ # Learning rate / optimizer flags may differ; include only when style suggests they're supported.
391
+ if style == "pytorch_train":
392
+ # Hypothetical common args for a train.py inside rtdetrv2_pytorch
393
+ cmd = [
394
+ sys.executable, entrypoint,
395
+ "--data", data_yaml,
396
+ "--model", model_key,
397
+ "--epochs", str(int(epochs)),
398
+ "--batch", str(int(batch)),
399
+ "--imgsz", str(int(imgsz)),
400
+ "--project", os.path.abspath(out_dir)
401
+ ]
402
+ if lr is not None:
403
+ cmd += ["--lr", str(float(lr))]
404
+ if optimizer:
405
+ cmd += ["--optimizer", str(optimizer)]
406
+ return cmd, out_dir
407
+
408
+ if style == "tools_train":
409
+ # Alternate style (tools/train.py). We keep flags generic.
410
+ cmd = [
411
+ sys.executable, entrypoint,
412
+ "--data", data_yaml,
413
+ "--model", model_key,
414
+ "--epochs", str(int(epochs)),
415
+ "--batch-size", str(int(batch)),
416
+ "--imgsz", str(int(imgsz)),
417
+ "--project", os.path.abspath(out_dir),
418
+ "--name", "exp"
419
+ ]
420
+ if lr is not None:
421
+ cmd += ["--lr0", str(float(lr))]
422
+ if optimizer:
423
+ cmd += ["--optimizer", str(optimizer)]
424
+ return cmd, out_dir
425
+
426
+ if style == "app_main":
427
+ # If app_main exists, it may require an options file; we still try a generic mapping.
428
+ cmd = [
429
+ sys.executable, entrypoint,
430
+ "--data", data_yaml,
431
+ "--model", model_key,
432
+ "--epochs", str(int(epochs)),
433
+ "--batch", str(int(batch)),
434
+ "--imgsz", str(int(imgsz)),
435
+ "--output", os.path.abspath(out_dir)
436
+ ]
437
+ if lr is not None:
438
+ cmd += ["--lr", str(float(lr))]
439
+ if optimizer:
440
+ cmd += ["--optimizer", str(optimizer)]
441
+ return cmd, out_dir
442
+
443
+ raise gr.Error("Could not locate a training script inside RT-DETRv2 repo. Please check the repo layout.")
444
+
445
+ def find_best_checkpoint(out_dir):
446
+ # Look for common patterns
 
 
 
 
 
 
447
  patterns = [
448
+ os.path.join(out_dir, "**", "best*.pt"),
449
+ os.path.join(out_dir, "**", "best*.pth"),
450
+ os.path.join(out_dir, "**", "model_best*.pt"),
451
+ os.path.join(out_dir, "**", "model_best*.pth"),
452
  ]
453
  for p in patterns:
454
+ files = sorted(glob(p, recursive=True))
455
+ if files:
456
+ return files[0]
457
+ # Fall back to latest .pt/.pth
458
+ any_ckpt = sorted(glob(os.path.join(out_dir, "**", "*.pt"), recursive=True) +
459
+ glob(os.path.join(out_dir, "**", "*.pth"), recursive=True))
460
+ return any_ckpt[-1] if any_ckpt else None
461
+
462
+ # ------------------------------
463
+ # Gradio Handlers
464
+ # ------------------------------
465
 
 
466
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
467
  api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
468
  if not api_key:
 
473
  with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
474
  urls = [line.strip() for line in f if line.strip()]
475
 
476
+ dataset_info, failures = [], []
 
477
  for i, raw in enumerate(urls):
478
  progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}")
479
  ws, proj, ver = parse_roboflow_url(raw)
 
496
  msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
497
  raise gr.Error(msg)
498
 
499
+ # Make sure names are strings before sorting to avoid mixed-type comparison
500
  all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
501
  class_map = {name: name for name in all_names}
502
+
503
  initial_counts = gather_class_counts(dataset_info, class_map)
504
+ df = pd.DataFrame([[name, name, initial_counts.get(name, 0), False] for name in all_names],
505
+ columns=["Original Name", "Rename To", "Max Images", "Remove"])
506
  status_text = "Datasets loaded successfully."
507
  if failures:
508
  status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
509
 
510
+ # Return the DataFrame value directly (works across Gradio versions)
511
+ return status_text, dataset_info, df
 
512
 
513
  def update_class_counts_handler(class_df, dataset_info):
514
  if class_df is None or not dataset_info:
515
  return None
516
+
517
  class_df = pd.DataFrame(class_df)
518
  mapping = {}
519
  for _, row in class_df.iterrows():
520
  orig = row["Original Name"]
521
+ mapping[orig] = None if bool(row["Remove"]) else row["Rename To"]
 
 
 
522
 
523
  final_names = sorted(set(v for v in mapping.values() if v))
524
  counts = {k: 0 for k in final_names}
525
+
526
  for loc, names, splits, _ in dataset_info:
527
  id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
528
  for split in splits:
 
548
  for m in found:
549
  counts[m] += 1
550
 
551
+ return pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
 
552
 
553
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
554
  if not dataset_info:
 
557
  raise gr.Error("Class data is missing.")
558
 
559
  class_df = pd.DataFrame(class_df)
560
+ class_mapping, class_limits = {}, {}
 
561
  for _, row in class_df.iterrows():
562
  orig = row["Original Name"]
563
  if bool(row["Remove"]):
 
569
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
570
  return status, path
571
 
572
+ def training_handler(dataset_path, model_choice_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
 
573
  if not dataset_path:
574
  raise gr.Error("Finalize a dataset in Tab 2 before training.")
575
 
576
+ # Verify repo entrypoint
577
+ entrypoint, style = detect_training_entrypoint()
578
+ if not entrypoint:
579
+ raise gr.Error("RT-DETRv2 training script not found in the repo. Please check repo contents.")
580
+
581
+ # Build and run command (users never touch CLI)
582
+ cmd, out_dir = build_command(
583
+ entrypoint=entrypoint,
584
+ style=style,
585
+ dataset_path=dataset_path,
586
+ model_key=model_choice_key,
587
+ run_name=run_name,
588
+ epochs=epochs,
589
+ batch=batch,
590
+ imgsz=imgsz,
591
+ lr=lr,
592
+ optimizer=opt
593
+ )
594
+ logging.info(f"Training command: {' '.join(cmd)}")
595
 
596
+ # Live-run in a thread and stream logs
597
+ q = Queue()
 
598
 
599
+ def run_train():
600
+ try:
601
+ env = os.environ.copy()
602
+ env["PYTHONPATH"] = REPO_DIR + os.pathsep + env.get("PYTHONPATH", "")
603
+ proc = subprocess.Popen(cmd, cwd=REPO_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, text=True, env=env)
604
+ for line in proc.stdout:
605
+ q.put(line.rstrip())
606
+ proc.wait()
607
+ q.put(f"__EXITCODE__:{proc.returncode}")
608
+ except Exception as e:
609
+ q.put(f"__ERROR__:{e}")
610
+
611
+ Thread(target=run_train, daemon=True).start()
612
+
613
+ log_lines = []
614
+ last_epoch = 0
615
+ total_epochs = int(epochs)
616
+ while True:
617
+ line = q.get()
618
+ if line.startswith("__EXITCODE__"):
619
+ code = int(line.split(":", 1)[1])
620
+ if code != 0:
621
+ raise gr.Error(f"Training process exited with code {code}. Check logs above.")
622
+ break
623
+ if line.startswith("__ERROR__"):
624
+ raise gr.Error(f"Training failed: {line.split(':',1)[1]}")
625
+
626
+ log_lines.append(line)
627
+ # try to parse "Epoch X/Y" style hints for progress
628
+ m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line)
629
+ if m:
630
+ try:
631
+ last_epoch = int(m.group(1))
632
+ total_epochs = max(total_epochs, int(m.group(2)))
633
+ except Exception:
634
+ pass
635
 
636
+ frac = min(max(last_epoch / max(1, total_epochs), 0.0), 1.0)
637
+ progress(frac, desc=f"Epoch {last_epoch}/{total_epochs}")
 
 
 
 
638
 
639
+ # Light-weight plots (we won't have metrics dicts; just show empty placeholders so UI doesn't break)
640
+ fig_loss = plt.figure()
641
+ ax_loss = fig_loss.add_subplot(111)
642
+ ax_loss.set_title("Loss (see logs)")
643
+ fig_map = plt.figure()
644
+ ax_map = fig_map.add_subplot(111)
645
+ ax_map.set_title("mAP (see logs)")
646
+
647
+ yield "\n".join(log_lines[-30:]), fig_loss, fig_map, None
648
+
649
+ # Look for the best checkpoint
650
+ ckpt = find_best_checkpoint(out_dir)
651
+ if not ckpt or not os.path.exists(ckpt):
652
+ # try give user any artifact
653
+ alt = find_best_checkpoint("runs")
654
+ if not alt or not os.path.exists(alt):
655
+ raise gr.Error("Training finished, but checkpoint file was not found. See logs for details.")
656
+ ckpt = alt
657
+
658
+ yield "Training complete!", None, None, gr.File.update(value=ckpt, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
 
660
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
661
  if not model_file:
662
  raise gr.Error("No trained model file available to upload. Train a model first.")
663
 
664
+ from huggingface_hub import HfApi, HfFolder
665
+
666
  hf_status = "Skipped Hugging Face (credentials not provided)."
667
  if hf_token and hf_repo:
668
  progress(0, desc="Uploading to Hugging Face...")
 
686
  try:
687
  if '/' not in gh_repo:
688
  raise ValueError("GitHub repo must be in the form 'username/repo'.")
689
+
690
  username, repo_name = gh_repo.split('/')
691
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
692
  headers = {"Authorization": f"token {gh_token}"}
 
698
  sha = get_resp.json().get('sha') if get_resp.ok else None
699
 
700
  data = {"message": "Upload trained model from Rolo app", "content": content}
701
+ if sha:
702
+ data["sha"] = sha
703
 
704
  put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
705
+
706
  if put_resp.ok:
707
  gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
708
  else:
 
714
  progress(1)
715
  return hf_status, gh_status
716
 
717
+ # ------------------------------
718
+ # Gradio UI
719
+ # ------------------------------
720
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
721
+ gr.Markdown("# Rolo: RT-DETRv2 Training (Supervisely ecosystem only)")
722
 
723
  dataset_info_state = gr.State([])
724
  final_dataset_path_state = gr.State(None)
725
 
726
  with gr.Tabs():
727
  with gr.TabItem("1. Prepare Datasets"):
728
+ gr.Markdown("### Load Roboflow Datasets\nProvide your Roboflow API key and upload a `.txt` file containing one Roboflow dataset URL or `workspace/project[/vN]` per line.")
729
  with gr.Row():
730
+ rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY env)", type="password", scale=2)
731
  rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
732
  load_btn = gr.Button("Load Datasets", variant="primary")
733
  dataset_status = gr.Textbox(label="Status", interactive=False)
734
 
735
  with gr.TabItem("2. Manage & Merge"):
736
+ gr.Markdown("### Configure Classes and Finalize Dataset\nRename classes to merge them, set image limits, or remove them. Click **Update Counts** to preview, then **Finalize** to create the dataset.")
737
  with gr.Row():
738
  class_df = gr.DataFrame(
739
  headers=["Original Name", "Rename To", "Max Images", "Remove"],
 
751
  finalize_status = gr.Textbox(label="Status", interactive=False)
752
 
753
  with gr.TabItem("3. Configure & Train"):
754
+ gr.Markdown("### Set Hyperparameters and Train the RT-DETRv2 Model")
755
  with gr.Row():
756
  with gr.Column(scale=1):
757
+ model_file_dd = gr.Dropdown(
758
+ label="Model (only RT-DETRv2 from Supervisely)",
759
+ choices=[k for k, _ in MODEL_CHOICES],
760
+ value=DEFAULT_MODEL_KEY
761
+ )
762
+ model_hints = gr.Markdown(
763
+ "Choices: " +
764
+ ", ".join([f"`{k}` ({label})" for k, label in MODEL_CHOICES])
765
  )
766
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
767
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
768
  batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
769
  imgsz_num = gr.Number(label="Image Size", value=640)
770
  lr_num = gr.Number(label="Learning Rate", value=0.001)
771
+ opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="Adam", label="Optimizer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
  train_btn = gr.Button("Start Training", variant="primary")
773
  with gr.Column(scale=2):
774
+ train_status = gr.Textbox(label="Live Logs (tail)", interactive=False, lines=12)
775
+ loss_plot = gr.Plot(label="Loss")
776
+ map_plot = gr.Plot(label="mAP")
777
+ final_model_file = gr.File(label="Download Trained Model", interactive=False, visible=False)
778
 
779
  with gr.TabItem("4. Upload Model"):
780
+ gr.Markdown("### Upload Your Trained Model")
781
  with gr.Row():
782
  with gr.Column():
783
  gr.Markdown("#### Hugging Face")
 
792
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
793
  gh_status = gr.Textbox(label="GitHub Status", interactive=False)
794
 
795
+ # Wire UI handlers
796
  load_btn.click(
797
  fn=load_datasets_handler,
798
  inputs=[rf_api_key, rf_url_file],
 
809
  outputs=[finalize_status, final_dataset_path_state]
810
  )
811
  train_btn.click(
812
+ fn=training_handler,
813
+ inputs=[final_dataset_path_state, model_file_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
 
 
 
 
 
 
 
 
 
 
 
814
  outputs=[train_status, loss_plot, map_plot, final_model_file]
815
  )
816
  upload_btn.click(
 
820
  )
821
 
822
  if __name__ == "__main__":
823
+ # Silence Ultralytics warnings if present in the env (we don't use Ultralytics at all)
824
+ os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
825
  app.launch(debug=True)