wuhp commited on
Commit
c54a7a8
·
verified ·
1 Parent(s): a68cd78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +440 -208
app.py CHANGED
@@ -3,7 +3,7 @@ import shutil
3
  import stat
4
  import yaml
5
  import gradio as gr
6
- from ultralytics import YOLO # Note: This is the training engine for RT-DETR in this library.
7
  from roboflow import Roboflow
8
  import re
9
  from urllib.parse import urlparse
@@ -40,22 +40,87 @@ RTDETR_MODELS = {
40
  }
41
  DEFAULT_MODEL = "rtdetr-l.pt"
42
 
43
- # --- Helper & Core Logic Functions ---
 
 
 
44
 
45
  def handle_remove_readonly(func, path, exc_info):
46
  """Error handler for shutil.rmtree."""
47
- os.chmod(path, stat.S_IWRITE)
 
 
 
48
  func(path)
49
 
50
- def parse_roboflow_url(url):
51
- """Parses Roboflow URL to get workspace, project, and version."""
52
- parsed = urlparse(url.strip())
53
- path_parts = parsed.path.strip('/').split('/')
54
- if len(path_parts) >= 3 and 'roboflow.com' in parsed.netloc:
55
- # Format: /workspace/project-id/version
56
- return path_parts[1], path_parts[2], path_parts[3] if len(path_parts) > 3 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return None, None, None
58
 
 
59
  def get_latest_version(api_key, workspace, project):
60
  """Gets the latest version number of a Roboflow project."""
61
  try:
@@ -67,246 +132,348 @@ def get_latest_version(api_key, workspace, project):
67
  logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
68
  return None
69
 
 
70
  def download_dataset(api_key, workspace, project, version):
71
- """Downloads a single dataset from Roboflow."""
72
  try:
73
  rf = Roboflow(api_key=api_key)
74
  proj = rf.workspace(workspace).project(project)
75
- # RT-DETR trains perfectly with the yolov8 format.
76
- dataset = proj.version(int(version)).download("yolov8")
77
-
78
- with open(os.path.join(dataset.location, 'data.yaml'), 'r') as f:
 
79
  data_yaml = yaml.safe_load(f)
80
-
81
  class_names = data_yaml.get('names', [])
82
- splits = [s for s in ['train', 'valid', 'test'] if os.path.exists(os.path.join(dataset.location, s))]
83
-
 
84
  return dataset.location, class_names, splits, f"{project}-v{version}"
85
  except Exception as e:
86
  logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}")
87
  return None, [], [], None
88
 
 
 
 
 
 
 
 
 
89
  def gather_class_counts(dataset_info, class_mapping):
90
- """Calculates image counts for each final class based on the mapping."""
91
- unified_names = set(class_mapping.values())
92
- counts = {name: 0 for name in unified_names}
93
- if not dataset_info: return counts
 
 
 
 
 
94
 
95
  for loc, names, splits, _ in dataset_info:
 
 
 
 
 
96
  for split in splits:
97
  labels_dir = os.path.join(loc, split, 'labels')
98
- if not os.path.exists(labels_dir): continue
 
99
  for label_file in os.listdir(labels_dir):
100
- found_in_file = set()
 
 
101
  with open(os.path.join(labels_dir, label_file), 'r') as f:
102
  for line in f:
 
 
 
103
  try:
104
- class_id = int(line.split()[0])
105
- original_name = names[class_id]
106
- mapped_name = class_mapping.get(original_name, original_name)
107
- if mapped_name in unified_names:
108
- found_in_file.add(mapped_name)
109
- except (ValueError, IndexError):
110
  continue
111
- for cls in found_in_file:
112
- counts[cls] += 1
 
113
  return counts
114
 
 
115
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
116
- """The core function to merge datasets based on user rules."""
117
  merged_dir = 'rolo_merged_dataset'
118
  if os.path.exists(merged_dir):
119
  shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
120
-
121
  progress(0, desc="Creating directories...")
122
  for split in ['train', 'valid', 'test']:
123
  os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True)
124
  os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True)
125
 
126
- active_classes = sorted([cls for cls, limit in class_limits.items() if limit > 0])
 
 
127
  final_class_map = {name: i for i, name in enumerate(active_classes)}
128
 
 
129
  all_images = []
130
  for loc, _, splits, _ in dataset_info:
131
  for split in splits:
132
  img_dir = os.path.join(loc, split, 'images')
133
- if not os.path.exists(img_dir): continue
 
134
  for img_file in os.listdir(img_dir):
135
  if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
136
  all_images.append((os.path.join(img_dir, img_file), split, loc))
137
  random.shuffle(all_images)
138
-
139
  progress(0.2, desc="Selecting images based on limits...")
140
- selected_images = set()
141
  current_counts = {cls: 0 for cls in active_classes}
142
 
143
- for img_path, original_split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
144
- lbl_path = img_path.replace('/images/', '/labels/').rsplit('.', 1)[0] + '.txt'
145
- if not os.path.exists(lbl_path): continue
146
-
 
 
 
 
 
147
  image_classes = set()
148
  with open(lbl_path, 'r') as f:
149
  for line in f:
 
 
 
150
  try:
151
- source_names = next(info[1] for info in dataset_info if info[0] == source_loc)
152
- original_name = source_names[int(line.split()[0])]
153
- mapped_name = class_mapping.get(original_name, original_name)
154
- if mapped_name in active_classes:
155
- image_classes.add(mapped_name)
156
- except (ValueError, IndexError, StopIteration): continue
157
-
158
- can_add = True
159
- for cls in image_classes:
160
- if current_counts[cls] >= class_limits[cls]:
161
- can_add = False
162
- break
163
-
164
- if can_add and image_classes:
165
- selected_images.add((img_path, original_split))
166
- for cls in image_classes:
167
- current_counts[cls] += 1
168
-
 
169
  progress(0.6, desc=f"Copying {len(selected_images)} files...")
170
  for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
171
- lbl_path = img_path.replace('/images/', '/labels/').rsplit('.', 1)[0] + '.txt'
172
- shutil.copy(img_path, os.path.join(merged_dir, split, 'images'))
173
-
174
- with open(lbl_path, 'r') as f_in, open(os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path)), 'w') as f_out:
 
 
 
 
 
 
 
 
 
 
175
  for line in f_in:
176
- parts = line.split()
 
 
177
  try:
178
- source_loc = next(info[0] for info in dataset_info if img_path.startswith(info[0]))
179
- source_names = next(info[1] for info in dataset_info if info[0] == source_loc)
180
- original_name = source_names[int(parts[0])]
181
  mapped_name = class_mapping.get(original_name, original_name)
182
  if mapped_name in final_class_map:
183
  new_id = final_class_map[mapped_name]
184
  f_out.write(f"{new_id} {' '.join(parts[1:])}\n")
185
- except (ValueError, IndexError, StopIteration): continue
 
186
 
187
  progress(0.95, desc="Creating data.yaml...")
188
  with open(os.path.join(merged_dir, 'data.yaml'), 'w') as f:
189
  yaml.dump({
190
- 'path': os.path.abspath(merged_dir), 'train': 'train/images', 'val': 'valid/images', 'test': 'test/images',
191
- 'nc': len(active_classes), 'names': active_classes
 
 
 
 
192
  }, f)
193
-
194
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
195
 
196
 
197
- # --- Gradio UI Event Handlers ---
 
 
198
 
199
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
200
  """Handles the 'Load Datasets' button click."""
201
- if not api_key: raise gr.Error("Roboflow API Key is required.")
202
- if not url_file: raise gr.Error("Please upload a .txt file with Roboflow URLs.")
 
 
 
203
 
204
- with open(url_file.name, 'r') as f:
205
  urls = [line.strip() for line in f if line.strip()]
206
-
207
  dataset_info = []
208
- for i, url in enumerate(urls):
209
- progress((i+1)/len(urls), desc=f"Processing URL {i+1}/{len(urls)}")
210
- workspace, project, version = parse_roboflow_url(url)
211
- if not all([workspace, project]):
212
- logging.warning(f"Invalid URL skipped: {url}")
 
 
213
  continue
214
- if not version:
215
- version = get_latest_version(api_key, workspace, project)
216
- if not version:
217
- logging.warning(f"Could not find version for {project}. Skipping.")
 
218
  continue
219
-
220
- loc, names, splits, name_str = download_dataset(api_key, workspace, project, str(version))
221
  if loc:
222
  dataset_info.append((loc, names, splits, name_str))
223
-
224
- if not dataset_info: raise gr.Error("No datasets were loaded successfully.")
 
 
 
 
 
225
 
226
  all_names = sorted(list(set(n for _, names, _, _ in dataset_info for n in names)))
227
  class_map = {name: name for name in all_names}
 
 
228
  initial_counts = gather_class_counts(dataset_info, class_map)
229
  df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
230
-
231
- return "Datasets loaded successfully. Proceed to the next tab to manage classes.", dataset_info, gr.DataFrame.update(value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"]))
 
 
 
 
 
 
232
 
233
  def update_class_counts_handler(class_df, dataset_info):
234
- """Provides live feedback on class counts as the user edits the DataFrame."""
235
- if class_df is None or not dataset_info: return None
236
- class_mapping = dict(zip(class_df["Original Name"], class_df["Rename To"]))
237
- updated_counts = gather_class_counts(dataset_info, class_mapping)
238
-
239
- merged_summary = {}
 
 
 
 
 
240
  for _, row in class_df.iterrows():
241
- if not row["Remove"]:
242
- rename_to = row["Rename To"]
243
- # This logic needs to be careful: sum counts of all original classes that map to the same `rename_to`
244
- # Let's recalculate based on mapping
245
- merged_summary[rename_to] = 0 # reset
246
-
247
- for original_name, rename_to in class_mapping.items():
248
- if rename_to in merged_summary:
249
- # find count for original name in its original mapped state
250
- original_count = gather_class_counts(dataset_info, {k:k for k in class_mapping.keys()}).get(original_name,0)
251
- is_removed = class_df.loc[class_df['Original Name'] == original_name, 'Remove'].iloc[0]
252
- if not is_removed:
253
- merged_summary[rename_to] += original_count
254
-
255
- final_summary = {}
256
- # Recalculate from scratch for simplicity and accuracy
257
- class_map_for_summary = dict(zip(class_df["Original Name"], class_df["Rename To"]))
258
- all_final_names = set(class_df[~class_df['Remove']]['Rename To'])
259
-
260
- final_counts = {name: 0 for name in all_final_names}
261
 
262
  for loc, names, splits, _ in dataset_info:
 
 
 
 
263
  for split in splits:
264
  labels_dir = os.path.join(loc, split, 'labels')
265
- if not os.path.exists(labels_dir): continue
 
266
  for label_file in os.listdir(labels_dir):
267
- found_in_file = set()
 
 
268
  with open(os.path.join(labels_dir, label_file), 'r') as f:
269
  for line in f:
 
 
 
270
  try:
271
- class_id = int(line.split()[0])
272
- original_name = names[class_id]
273
- is_removed = class_df.loc[class_df['Original Name'] == original_name, 'Remove'].iloc[0]
274
- if not is_removed:
275
- mapped_name = class_map_for_summary.get(original_name)
276
- if mapped_name:
277
- found_in_file.add(mapped_name)
278
- except (ValueError, IndexError, KeyError): continue
279
- for cls in found_in_file:
280
- final_counts[cls] += 1
281
-
282
- summary_df = pd.DataFrame(list(final_counts.items()), columns=["Final Class Name", "Est. Total Images"])
283
  return summary_df
284
 
 
285
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
286
  """Handles the 'Finalize' button click."""
287
- if not dataset_info: raise gr.Error("Load datasets first in Tab 1.")
288
- if class_df is None: raise gr.Error("Class data is missing.")
289
-
290
- class_mapping = dict(zip(class_df["Original Name"], class_df["Rename To"]))
 
 
 
 
291
  class_limits = {}
292
  for _, row in class_df.iterrows():
293
- if not row["Remove"]:
294
- rename_to = row["Rename To"]
295
- class_limits[rename_to] = class_limits.get(rename_to, 0) + int(row["Max Images"])
296
-
 
 
 
 
 
 
 
297
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
298
  return status, path
299
 
 
300
  def training_handler(dataset_path, model_filename, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
301
- """Handles the training process with real-time feedback."""
302
- if not dataset_path: raise gr.Error("Finalize a dataset in Tab 2 before training.")
 
 
 
 
303
 
304
  metrics_queue = Queue()
 
305
  def on_epoch_end(trainer):
 
 
306
  metrics_queue.put({
307
- 'epoch': trainer.epoch + 1, 'train_loss': trainer.metrics.get('train/loss'),
308
- 'val_loss': trainer.metrics.get('val/loss'), 'mAP50': trainer.metrics.get('metrics/mAP50(B)'),
309
- 'mAP50_95': trainer.metrics.get('metrics/mAP50-95(B)')
 
 
310
  })
311
 
312
  def train_thread_func():
@@ -315,22 +482,30 @@ def training_handler(dataset_path, model_filename, run_name, epochs, batch, imgs
315
  weights_path = os.path.join('pretrained_models', model_filename)
316
  if not os.path.exists(weights_path):
317
  os.makedirs('pretrained_models', exist_ok=True)
318
- r = requests.get(model_url, stream=True)
319
  r.raise_for_status()
320
  with open(weights_path, 'wb') as f:
321
  for chunk in r.iter_content(chunk_size=8192):
322
  f.write(chunk)
323
-
324
  model = YOLO(weights_path)
325
  model.add_callback("on_train_epoch_end", on_epoch_end)
 
326
  model.train(
327
- data=os.path.join(dataset_path, 'data.yaml'), epochs=epochs, batch=batch, imgsz=imgsz,
328
- lr0=lr, optimizer=opt, project='runs/train', name=run_name, exist_ok=True,
329
- device=0 if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
330
  )
331
  metrics_queue.put("done")
332
  except Exception as e:
333
- logging.error(f"Training thread error: {e}")
334
  metrics_queue.put(f"error: {e}")
335
 
336
  Thread(target=train_thread_func, daemon=True).start()
@@ -339,38 +514,52 @@ def training_handler(dataset_path, model_filename, run_name, epochs, batch, imgs
339
  while True:
340
  item = metrics_queue.get()
341
  if isinstance(item, str):
342
- if item == "done": break
343
- if item.startswith("error"): raise gr.Error(f"Training failed: {item}")
344
-
345
- for key, val in item.items():
346
- if val is not None: history[key].append(val)
347
-
348
- current_epoch = history['epoch'][-1]
349
- progress(current_epoch / epochs, desc=f"Epoch {current_epoch}/{epochs}")
350
-
351
- # Gradio Plotting does not require clearing figures
352
- fig_loss = plt.figure(); ax_loss = fig_loss.add_subplot(111)
 
 
 
 
 
 
 
 
353
  ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
354
  ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
355
- ax_loss.legend(); ax_loss.set_title("Loss")
356
-
357
- fig_map = plt.figure(); ax_map = fig_map.add_subplot(111)
 
 
 
358
  ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
359
  ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
360
- ax_map.legend(); ax_map.set_title("mAP")
361
-
362
- yield f"Epoch {current_epoch}/{epochs} complete.", fig_loss, fig_map, None
 
363
 
364
- final_path = os.path.join('runs/train', run_name, 'weights/best.pt')
365
  if not os.path.exists(final_path):
366
  raise gr.Error("Training finished, but 'best.pt' was not found.")
367
-
368
  yield "Training complete!", None, None, gr.File.update(value=final_path, visible=True)
369
 
 
370
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
371
- """Handles the model upload to Hugging Face and GitHub."""
372
- if not model_file: raise gr.Error("No trained model file available to upload. Train a model first.")
373
-
 
374
  hf_status = "Skipped Hugging Face (credentials not provided)."
375
  if hf_token and hf_repo:
376
  progress(0, desc="Uploading to Hugging Face...")
@@ -379,54 +568,71 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
379
  HfFolder.save_token(hf_token)
380
  repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token)
381
  api.upload_file(
382
- path_or_fileobj=model_file.name, path_in_repo=os.path.basename(model_file.name),
383
- repo_id=hf_repo, token=hf_token
 
 
384
  )
385
  hf_status = f"Success! Model at: {repo_url}"
386
- except Exception as e: hf_status = f"Hugging Face Error: {e}"
 
387
 
388
  gh_status = "Skipped GitHub (credentials not provided)."
389
  if gh_token and gh_repo:
390
  progress(0.5, desc="Uploading to GitHub...")
391
  try:
 
 
 
392
  username, repo_name = gh_repo.split('/')
393
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
394
  headers = {"Authorization": f"token {gh_token}"}
395
-
396
- with open(model_file.name, "rb") as f: content = base64.b64encode(f.read()).decode()
397
-
398
- get_resp = requests.get(api_url, headers=headers)
 
399
  sha = get_resp.json().get('sha') if get_resp.ok else None
400
-
401
- data = {"message": "Upload trained model from Rolo app", "content": content, "sha": sha}
402
- put_resp = requests.put(api_url, headers=headers, json={k: v for k, v in data.items() if v is not None})
403
-
404
- if put_resp.ok: gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
405
- else: gh_status = f"GitHub Error: {put_resp.json().get('message', 'Unknown')}"
406
- except Exception as e: gh_status = f"GitHub Error: {e}"
407
-
 
 
 
 
 
 
 
408
  progress(1)
409
  return hf_status, gh_status
410
 
411
- # --- Gradio UI Layout ---
 
 
 
412
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
413
  gr.Markdown("# Rolo: A Dedicated RT-DETR Training Dashboard")
414
-
415
  # State variables
416
  dataset_info_state = gr.State([])
417
  final_dataset_path_state = gr.State(None)
418
 
419
  with gr.Tabs():
420
  with gr.TabItem("1. Prepare Datasets"):
421
- gr.Markdown("### Load Roboflow Datasets\nProvide your Roboflow API key and upload a `.txt` file containing one Roboflow dataset URL per line.")
422
  with gr.Row():
423
- rf_api_key = gr.Textbox(label="Roboflow API Key", type="password", scale=2)
424
  rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
425
  load_btn = gr.Button("Load Datasets", variant="primary")
426
  dataset_status = gr.Textbox(label="Status", interactive=False)
427
-
428
  with gr.TabItem("2. Manage & Merge"):
429
- gr.Markdown("### Configure Classes and Finalize Dataset\nRename classes to merge them, set image limits, or remove them. Click **Update Counts** to see a preview of your changes, then click **Finalize** to create the dataset.")
430
  with gr.Row():
431
  class_df = gr.DataFrame(
432
  headers=["Original Name", "Rename To", "Max Images", "Remove"],
@@ -434,9 +640,13 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
434
  label="Class Configuration", interactive=True, scale=3
435
  )
436
  with gr.Column(scale=1):
437
- class_count_summary_df = gr.DataFrame(label="Merged Class Counts Preview", headers=["Final Class Name", "Est. Total Images"], interactive=False)
 
 
 
 
438
  update_counts_btn = gr.Button("Update Counts")
439
-
440
  finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
441
  finalize_status = gr.Textbox(label="Status", interactive=False)
442
 
@@ -444,13 +654,15 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
444
  gr.Markdown("### Set Hyperparameters and Train the RT-DETR Model")
445
  with gr.Row():
446
  with gr.Column(scale=1):
447
- model_file_dd = gr.Dropdown(label="Select Pre-Trained RT-DETR Model",
448
- choices=[m["filename"] for m in RTDETR_MODELS["detection"]], value=DEFAULT_MODEL)
 
 
 
449
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetr_run_1")
450
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
451
  batch_sl = gr.Slider(1, 32, 8, step=1, label="Batch Size")
452
  imgsz_num = gr.Number(label="Image Size", value=640)
453
- # <<< FIXED: Removed the 'format' argument which is not supported.
454
  lr_num = gr.Number(label="Learning Rate", value=0.001)
455
  opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="Adam", label="Optimizer")
456
  train_btn = gr.Button("Start Training", variant="primary")
@@ -476,14 +688,34 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
476
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
477
  gh_status = gr.Textbox(label="GitHub Status", interactive=False)
478
 
479
- # --- Wire UI handlers ---
480
- load_btn.click(fn=load_datasets_handler, inputs=[rf_api_key, rf_url_file], outputs=[dataset_status, dataset_info_state, class_df])
481
- update_counts_btn.click(fn=update_class_counts_handler, inputs=[class_df, dataset_info_state], outputs=[class_count_summary_df])
482
- finalize_btn.click(fn=finalize_handler, inputs=[dataset_info_state, class_df], outputs=[finalize_status, final_dataset_path_state])
483
- train_btn.click(fn=training_handler,
484
- inputs=[final_dataset_path_state, model_file_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
485
- outputs=[train_status, loss_plot, map_plot, final_model_file])
486
- upload_btn.click(fn=upload_handler, inputs=[final_model_file, hf_token, hf_repo, gh_token, gh_repo], outputs=[hf_status, gh_status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
  if __name__ == "__main__":
489
- app.launch(debug=True)
 
 
 
3
  import stat
4
  import yaml
5
  import gradio as gr
6
+ from ultralytics import YOLO # Ultralytics RT-DETR runner
7
  from roboflow import Roboflow
8
  import re
9
  from urllib.parse import urlparse
 
40
  }
41
  DEFAULT_MODEL = "rtdetr-l.pt"
42
 
43
+
44
+ # ------------------------------
45
+ # Utilities
46
+ # ------------------------------
47
 
48
  def handle_remove_readonly(func, path, exc_info):
49
  """Error handler for shutil.rmtree."""
50
+ try:
51
+ os.chmod(path, stat.S_IWRITE)
52
+ except Exception:
53
+ pass
54
  func(path)
55
 
56
+
57
+ _ROBO_URL_RX = re.compile(
58
+ r"""
59
+ ^(?:
60
+ (?:https?://)?(?:universe|app|www)?\.?roboflow\.com/ # Any roboflow host
61
+ (?P<ws>[A-Za-z0-9\-_]+)/ # workspace
62
+ (?P<proj>[A-Za-z0-9\-_]+)/? # project
63
+ (?:
64
+ (?:dataset/[^/]+/)? # optional 'dataset/<fmt>/'
65
+ (?:v?(?P<ver>\d+))? # optional version 'vN' or 'N'
66
+ )?
67
+ |
68
+ (?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))? # raw ws/proj[/vN]
69
+ )$
70
+ """,
71
+ re.VERBOSE | re.IGNORECASE
72
+ )
73
+
74
+ def parse_roboflow_url(s: str):
75
+ """
76
+ Accepts:
77
+ - https://universe.roboflow.com/<workspace>/<project>[/vN | /N]
78
+ - https://app.roboflow.com/<workspace>/<project>[/vN | /N]
79
+ - https://roboflow.com/<workspace>/<project>[/vN | /N]
80
+ - raw: <workspace>/<project>[/vN | /N]
81
+ Returns: (workspace, project, version_or_None)
82
+ """
83
+ s = s.strip()
84
+ # Fast path: try regex
85
+ m = _ROBO_URL_RX.match(s)
86
+ if m:
87
+ ws = m.group('ws') or m.group('ws2')
88
+ proj = m.group('proj') or m.group('proj2')
89
+ ver = m.group('ver') or m.group('ver2')
90
+ return ws, proj, (int(ver) if ver else None)
91
+
92
+ # Fallback: parse like URL and split path
93
+ parsed = urlparse(s)
94
+ parts = [p for p in parsed.path.strip('/').split('/') if p]
95
+ if len(parts) >= 2:
96
+ # Try to pull raw version from the 3rd part if it exists
97
+ version = None
98
+ if len(parts) >= 3:
99
+ # Accept 'vN' or 'N'
100
+ vpart = parts[2]
101
+ if vpart.lower().startswith('v') and vpart[1:].isdigit():
102
+ version = int(vpart[1:])
103
+ elif vpart.isdigit():
104
+ version = int(vpart)
105
+ return parts[0], parts[1], version
106
+
107
+ # Fallback raw "ws/proj" without slashes in URL
108
+ if '/' in s and 'roboflow' not in s:
109
+ p = s.split('/')
110
+ if len(p) >= 2:
111
+ # Accept trailing version if present
112
+ version = None
113
+ if len(p) >= 3:
114
+ v = p[2]
115
+ if v.lower().startswith('v') and v[1:].isdigit():
116
+ version = int(v[1:])
117
+ elif v.isdigit():
118
+ version = int(v)
119
+ return p[0], p[1], version
120
+
121
  return None, None, None
122
 
123
+
124
  def get_latest_version(api_key, workspace, project):
125
  """Gets the latest version number of a Roboflow project."""
126
  try:
 
132
  logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
133
  return None
134
 
135
+
136
  def download_dataset(api_key, workspace, project, version):
137
+ """Downloads a single dataset from Roboflow (yolov8 format works fine for RT-DETR)."""
138
  try:
139
  rf = Roboflow(api_key=api_key)
140
  proj = rf.workspace(workspace).project(project)
141
+ ver = proj.version(int(version))
142
+ dataset = ver.download("yolov8")
143
+
144
+ data_yaml_path = os.path.join(dataset.location, 'data.yaml')
145
+ with open(data_yaml_path, 'r') as f:
146
  data_yaml = yaml.safe_load(f)
147
+
148
  class_names = data_yaml.get('names', [])
149
+ splits = [s for s in ['train', 'valid', 'test']
150
+ if os.path.exists(os.path.join(dataset.location, s))]
151
+
152
  return dataset.location, class_names, splits, f"{project}-v{version}"
153
  except Exception as e:
154
  logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}")
155
  return None, [], [], None
156
 
157
+
158
+ def label_path_for(img_path: str) -> str:
159
+ """Convert .../split/images/file.jpg -> .../split/labels/file.txt in a safe way."""
160
+ split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
161
+ base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
162
+ return os.path.join(split_dir, 'labels', base)
163
+
164
+
165
  def gather_class_counts(dataset_info, class_mapping):
166
+ """
167
+ Count, per final class, how many images contain at least one instance of that class
168
+ (counted once per image). class_mapping maps original_name -> final_name.
169
+ """
170
+ if not dataset_info:
171
+ return {}
172
+
173
+ final_names = set(class_mapping.values())
174
+ counts = {name: 0 for name in final_names}
175
 
176
  for loc, names, splits, _ in dataset_info:
177
+ # Map from original idx -> mapped name (or None if removed later)
178
+ id_to_name = {}
179
+ for idx, n in enumerate(names):
180
+ id_to_name[idx] = class_mapping.get(n, None)
181
+
182
  for split in splits:
183
  labels_dir = os.path.join(loc, split, 'labels')
184
+ if not os.path.exists(labels_dir):
185
+ continue
186
  for label_file in os.listdir(labels_dir):
187
+ if not label_file.endswith('.txt'):
188
+ continue
189
+ found = set()
190
  with open(os.path.join(labels_dir, label_file), 'r') as f:
191
  for line in f:
192
+ parts = line.strip().split()
193
+ if not parts:
194
+ continue
195
  try:
196
+ cls_id = int(parts[0])
197
+ mapped = id_to_name.get(cls_id, None)
198
+ if mapped in final_names:
199
+ found.add(mapped)
200
+ except Exception:
 
201
  continue
202
+ for m in found:
203
+ counts[m] += 1
204
+
205
  return counts
206
 
207
+
208
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
209
+ """Core function to merge datasets based on user rules."""
210
  merged_dir = 'rolo_merged_dataset'
211
  if os.path.exists(merged_dir):
212
  shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
213
+
214
  progress(0, desc="Creating directories...")
215
  for split in ['train', 'valid', 'test']:
216
  os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True)
217
  os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True)
218
 
219
+ # Only classes with positive limits are active
220
+ active_classes = [cls for cls, limit in class_limits.items() if limit > 0]
221
+ active_classes = sorted(set(active_classes))
222
  final_class_map = {name: i for i, name in enumerate(active_classes)}
223
 
224
+ # Collect all candidate images
225
  all_images = []
226
  for loc, _, splits, _ in dataset_info:
227
  for split in splits:
228
  img_dir = os.path.join(loc, split, 'images')
229
+ if not os.path.exists(img_dir):
230
+ continue
231
  for img_file in os.listdir(img_dir):
232
  if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
233
  all_images.append((os.path.join(img_dir, img_file), split, loc))
234
  random.shuffle(all_images)
235
+
236
  progress(0.2, desc="Selecting images based on limits...")
237
+ selected_images = []
238
  current_counts = {cls: 0 for cls in active_classes}
239
 
240
+ # Build a quick lookup: source_loc -> names list
241
+ loc_to_names = {info[0]: info[1] for info in dataset_info}
242
+
243
+ for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
244
+ lbl_path = label_path_for(img_path)
245
+ if not os.path.exists(lbl_path):
246
+ continue
247
+
248
+ source_names = loc_to_names.get(source_loc, [])
249
  image_classes = set()
250
  with open(lbl_path, 'r') as f:
251
  for line in f:
252
+ parts = line.strip().split()
253
+ if not parts:
254
+ continue
255
  try:
256
+ cls_id = int(parts[0])
257
+ orig = source_names[cls_id]
258
+ mapped = class_mapping.get(orig, orig)
259
+ if mapped in active_classes:
260
+ image_classes.add(mapped)
261
+ except Exception:
262
+ continue
263
+
264
+ if not image_classes:
265
+ continue
266
+
267
+ # Check limits
268
+ if any(current_counts[c] >= class_limits[c] for c in image_classes):
269
+ continue
270
+
271
+ selected_images.append((img_path, split))
272
+ for c in image_classes:
273
+ current_counts[c] += 1
274
+
275
  progress(0.6, desc=f"Copying {len(selected_images)} files...")
276
  for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
277
+ lbl_path = label_path_for(img_path)
278
+ out_img = os.path.join(merged_dir, split, 'images', os.path.basename(img_path))
279
+ out_lbl = os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path))
280
+ shutil.copy(img_path, out_img)
281
+
282
+ # Determine source names by matching the parent dataset root
283
+ source_loc = None
284
+ for info in dataset_info:
285
+ if img_path.startswith(info[0]):
286
+ source_loc = info[0]
287
+ break
288
+ source_names = loc_to_names.get(source_loc, [])
289
+
290
+ with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out:
291
  for line in f_in:
292
+ parts = line.strip().split()
293
+ if not parts:
294
+ continue
295
  try:
296
+ old_id = int(parts[0])
297
+ original_name = source_names[old_id]
 
298
  mapped_name = class_mapping.get(original_name, original_name)
299
  if mapped_name in final_class_map:
300
  new_id = final_class_map[mapped_name]
301
  f_out.write(f"{new_id} {' '.join(parts[1:])}\n")
302
+ except Exception:
303
+ continue
304
 
305
  progress(0.95, desc="Creating data.yaml...")
306
  with open(os.path.join(merged_dir, 'data.yaml'), 'w') as f:
307
  yaml.dump({
308
+ 'path': os.path.abspath(merged_dir),
309
+ 'train': 'train/images',
310
+ 'val': 'valid/images',
311
+ 'test': 'test/images',
312
+ 'nc': len(active_classes),
313
+ 'names': active_classes
314
  }, f)
315
+
316
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
317
 
318
 
319
+ # ------------------------------
320
+ # Gradio UI Event Handlers
321
+ # ------------------------------
322
 
323
  def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
324
  """Handles the 'Load Datasets' button click."""
325
+ api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
326
+ if not api_key:
327
+ raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
328
+ if not url_file:
329
+ raise gr.Error("Please upload a .txt file with Roboflow URLs or lines like 'workspace/project[/vN]'.")
330
 
331
+ with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
332
  urls = [line.strip() for line in f if line.strip()]
333
+
334
  dataset_info = []
335
+ failures = []
336
+
337
+ for i, raw in enumerate(urls):
338
+ progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}")
339
+ ws, proj, ver = parse_roboflow_url(raw)
340
+ if not (ws and proj):
341
+ failures.append((raw, "ParseError: could not resolve workspace/project"))
342
  continue
343
+
344
+ if ver is None:
345
+ ver = get_latest_version(api_key, ws, proj)
346
+ if ver is None:
347
+ failures.append((raw, f"Could not resolve latest version for {ws}/{proj}"))
348
  continue
349
+
350
+ loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver))
351
  if loc:
352
  dataset_info.append((loc, names, splits, name_str))
353
+ else:
354
+ failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
355
+
356
+ if not dataset_info:
357
+ # Show a compact failure report to the UI
358
+ msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
359
+ raise gr.Error(msg)
360
 
361
  all_names = sorted(list(set(n for _, names, _, _ in dataset_info for n in names)))
362
  class_map = {name: name for name in all_names}
363
+
364
+ # Initial preview uses "keep all" mapping
365
  initial_counts = gather_class_counts(dataset_info, class_map)
366
  df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
367
+ status_text = "Datasets loaded successfully."
368
+ if failures:
369
+ status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
370
+
371
+ return status_text, dataset_info, gr.DataFrame.update(
372
+ value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
373
+ )
374
+
375
 
376
  def update_class_counts_handler(class_df, dataset_info):
377
+ """
378
+ Provides live feedback on class counts as the user edits the DataFrame.
379
+ We compute a mapping of original -> final (or None if removed), then count images
380
+ for each final name.
381
+ """
382
+ if class_df is None or not dataset_info:
383
+ return None
384
+
385
+ # Build mapping original_name -> final_name or None if removed
386
+ class_df = pd.DataFrame(class_df)
387
+ mapping = {}
388
  for _, row in class_df.iterrows():
389
+ orig = row["Original Name"]
390
+ if bool(row["Remove"]):
391
+ mapping[orig] = None
392
+ else:
393
+ mapping[orig] = row["Rename To"]
394
+
395
+ # Build final set
396
+ final_names = sorted(set(v for v in mapping.values() if v))
397
+ counts = {k: 0 for k in final_names}
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  for loc, names, splits, _ in dataset_info:
400
+ id_to_final = {}
401
+ for idx, n in enumerate(names):
402
+ id_to_final[idx] = mapping.get(n, None)
403
+
404
  for split in splits:
405
  labels_dir = os.path.join(loc, split, 'labels')
406
+ if not os.path.exists(labels_dir):
407
+ continue
408
  for label_file in os.listdir(labels_dir):
409
+ if not label_file.endswith('.txt'):
410
+ continue
411
+ found = set()
412
  with open(os.path.join(labels_dir, label_file), 'r') as f:
413
  for line in f:
414
+ parts = line.strip().split()
415
+ if not parts:
416
+ continue
417
  try:
418
+ cls_id = int(parts[0])
419
+ mapped = id_to_final.get(cls_id, None)
420
+ if mapped:
421
+ found.add(mapped)
422
+ except Exception:
423
+ continue
424
+ for m in found:
425
+ counts[m] += 1
426
+
427
+ summary_df = pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
 
 
428
  return summary_df
429
 
430
+
431
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
432
  """Handles the 'Finalize' button click."""
433
+ if not dataset_info:
434
+ raise gr.Error("Load datasets first in Tab 1.")
435
+ if class_df is None:
436
+ raise gr.Error("Class data is missing.")
437
+
438
+ # Mapping and limits
439
+ class_df = pd.DataFrame(class_df)
440
+ class_mapping = {}
441
  class_limits = {}
442
  for _, row in class_df.iterrows():
443
+ orig = row["Original Name"]
444
+ if bool(row["Remove"]):
445
+ continue
446
+ final_name = row["Rename To"]
447
+ class_mapping[orig] = final_name
448
+ # Sum limits for final_name over any merged originals
449
+ class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
450
+
451
+ # Any original not present in mapping will map to itself (keep behavior)
452
+ # BUT we do not want to include classes with 0 limit in the final dataset
453
+ # finalize_merged_dataset uses the limits dict to decide active classes.
454
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
455
  return status, path
456
 
457
+
458
  def training_handler(dataset_path, model_filename, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
459
+ """Handles the training process with live feedback."""
460
+ if not dataset_path:
461
+ raise gr.Error("Finalize a dataset in Tab 2 before training.")
462
+
463
+ # Ultralytics expects device string, e.g. '0' or 'cpu'
464
+ device_str = "0" if torch.cuda.is_available() else "cpu"
465
 
466
  metrics_queue = Queue()
467
+
468
  def on_epoch_end(trainer):
469
+ # Be defensive about metric keys
470
+ m = trainer.metrics or {}
471
  metrics_queue.put({
472
+ 'epoch': (trainer.epoch or 0) + 1,
473
+ 'train_loss': m.get('train/loss') or m.get('loss'),
474
+ 'val_loss': m.get('val/loss'),
475
+ 'mAP50': m.get('metrics/mAP50(B)') or m.get('metrics/mAP50'),
476
+ 'mAP50_95': m.get('metrics/mAP50-95(B)') or m.get('metrics/mAP50-95')
477
  })
478
 
479
  def train_thread_func():
 
482
  weights_path = os.path.join('pretrained_models', model_filename)
483
  if not os.path.exists(weights_path):
484
  os.makedirs('pretrained_models', exist_ok=True)
485
+ r = requests.get(model_url, stream=True, timeout=60)
486
  r.raise_for_status()
487
  with open(weights_path, 'wb') as f:
488
  for chunk in r.iter_content(chunk_size=8192):
489
  f.write(chunk)
490
+
491
  model = YOLO(weights_path)
492
  model.add_callback("on_train_epoch_end", on_epoch_end)
493
+
494
  model.train(
495
+ data=os.path.join(dataset_path, 'data.yaml'),
496
+ epochs=int(epochs),
497
+ batch=int(batch),
498
+ imgsz=int(imgsz),
499
+ lr0=float(lr),
500
+ optimizer=str(opt),
501
+ project='runs/train',
502
+ name=str(run_name),
503
+ exist_ok=True,
504
+ device=device_str
505
  )
506
  metrics_queue.put("done")
507
  except Exception as e:
508
+ logging.exception("Training thread error")
509
  metrics_queue.put(f"error: {e}")
510
 
511
  Thread(target=train_thread_func, daemon=True).start()
 
514
  while True:
515
  item = metrics_queue.get()
516
  if isinstance(item, str):
517
+ if item == "done":
518
+ break
519
+ if item.startswith("error"):
520
+ raise gr.Error(f"Training failed: {item}")
521
+
522
+ # Append metrics
523
+ for key in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']:
524
+ val = item.get(key, None)
525
+ if val is not None:
526
+ history[key].append(val)
527
+
528
+ current_epoch = history['epoch'][-1] if history['epoch'] else 0
529
+ total_epochs = int(epochs)
530
+ frac = min(max(current_epoch / max(1, total_epochs), 0.0), 1.0)
531
+ progress(frac, desc=f"Epoch {current_epoch}/{total_epochs}")
532
+
533
+ # Plot Loss
534
+ fig_loss = plt.figure()
535
+ ax_loss = fig_loss.add_subplot(111)
536
  ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
537
  ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
538
+ ax_loss.legend()
539
+ ax_loss.set_title("Loss")
540
+
541
+ # Plot mAP
542
+ fig_map = plt.figure()
543
+ ax_map = fig_map.add_subplot(111)
544
  ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
545
  ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
546
+ ax_map.legend()
547
+ ax_map.set_title("mAP")
548
+
549
+ yield f"Epoch {current_epoch}/{total_epochs} complete.", fig_loss, fig_map, None
550
 
551
+ final_path = os.path.join('runs', 'train', str(run_name), 'weights', 'best.pt')
552
  if not os.path.exists(final_path):
553
  raise gr.Error("Training finished, but 'best.pt' was not found.")
554
+
555
  yield "Training complete!", None, None, gr.File.update(value=final_path, visible=True)
556
 
557
+
558
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
559
+ """Handles model upload to Hugging Face and GitHub."""
560
+ if not model_file:
561
+ raise gr.Error("No trained model file available to upload. Train a model first.")
562
+
563
  hf_status = "Skipped Hugging Face (credentials not provided)."
564
  if hf_token and hf_repo:
565
  progress(0, desc="Uploading to Hugging Face...")
 
568
  HfFolder.save_token(hf_token)
569
  repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token)
570
  api.upload_file(
571
+ path_or_fileobj=model_file.name,
572
+ path_in_repo=os.path.basename(model_file.name),
573
+ repo_id=hf_repo,
574
+ token=hf_token
575
  )
576
  hf_status = f"Success! Model at: {repo_url}"
577
+ except Exception as e:
578
+ hf_status = f"Hugging Face Error: {e}"
579
 
580
  gh_status = "Skipped GitHub (credentials not provided)."
581
  if gh_token and gh_repo:
582
  progress(0.5, desc="Uploading to GitHub...")
583
  try:
584
+ if '/' not in gh_repo:
585
+ raise ValueError("GitHub repo must be in the form 'username/repo'.")
586
+
587
  username, repo_name = gh_repo.split('/')
588
  api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
589
  headers = {"Authorization": f"token {gh_token}"}
590
+
591
+ with open(model_file.name, "rb") as f:
592
+ content = base64.b64encode(f.read()).decode()
593
+
594
+ get_resp = requests.get(api_url, headers=headers, timeout=30)
595
  sha = get_resp.json().get('sha') if get_resp.ok else None
596
+
597
+ data = {"message": "Upload trained model from Rolo app", "content": content}
598
+ if sha:
599
+ data["sha"] = sha
600
+
601
+ put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
602
+
603
+ if put_resp.ok:
604
+ gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
605
+ else:
606
+ msg = put_resp.json().get('message', 'Unknown')
607
+ gh_status = f"GitHub Error: {msg}"
608
+ except Exception as e:
609
+ gh_status = f"GitHub Error: {e}"
610
+
611
  progress(1)
612
  return hf_status, gh_status
613
 
614
+
615
+ # ------------------------------
616
+ # Gradio UI
617
+ # ------------------------------
618
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
619
  gr.Markdown("# Rolo: A Dedicated RT-DETR Training Dashboard")
620
+
621
  # State variables
622
  dataset_info_state = gr.State([])
623
  final_dataset_path_state = gr.State(None)
624
 
625
  with gr.Tabs():
626
  with gr.TabItem("1. Prepare Datasets"):
627
+ gr.Markdown("### Load Roboflow Datasets\nProvide your Roboflow API key and upload a `.txt` file containing one Roboflow dataset URL or `workspace/project[/vN]` per line.")
628
  with gr.Row():
629
+ rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY env)", type="password", scale=2)
630
  rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
631
  load_btn = gr.Button("Load Datasets", variant="primary")
632
  dataset_status = gr.Textbox(label="Status", interactive=False)
633
+
634
  with gr.TabItem("2. Manage & Merge"):
635
+ gr.Markdown("### Configure Classes and Finalize Dataset\nRename classes to merge them, set image limits, or remove them. Click **Update Counts** to preview, then **Finalize** to create the dataset.")
636
  with gr.Row():
637
  class_df = gr.DataFrame(
638
  headers=["Original Name", "Rename To", "Max Images", "Remove"],
 
640
  label="Class Configuration", interactive=True, scale=3
641
  )
642
  with gr.Column(scale=1):
643
+ class_count_summary_df = gr.DataFrame(
644
+ label="Merged Class Counts Preview",
645
+ headers=["Final Class Name", "Est. Total Images"],
646
+ interactive=False
647
+ )
648
  update_counts_btn = gr.Button("Update Counts")
649
+
650
  finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
651
  finalize_status = gr.Textbox(label="Status", interactive=False)
652
 
 
654
  gr.Markdown("### Set Hyperparameters and Train the RT-DETR Model")
655
  with gr.Row():
656
  with gr.Column(scale=1):
657
+ model_file_dd = gr.Dropdown(
658
+ label="Select Pre-Trained RT-DETR Model",
659
+ choices=[m["filename"] for m in RTDETR_MODELS["detection"]],
660
+ value=DEFAULT_MODEL
661
+ )
662
  run_name_tb = gr.Textbox(label="Run Name", value="rtdetr_run_1")
663
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
664
  batch_sl = gr.Slider(1, 32, 8, step=1, label="Batch Size")
665
  imgsz_num = gr.Number(label="Image Size", value=640)
 
666
  lr_num = gr.Number(label="Learning Rate", value=0.001)
667
  opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="Adam", label="Optimizer")
668
  train_btn = gr.Button("Start Training", variant="primary")
 
688
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
689
  gh_status = gr.Textbox(label="GitHub Status", interactive=False)
690
 
691
+ # Wire UI handlers
692
+ load_btn.click(
693
+ fn=load_datasets_handler,
694
+ inputs=[rf_api_key, rf_url_file],
695
+ outputs=[dataset_status, dataset_info_state, class_df]
696
+ )
697
+ update_counts_btn.click(
698
+ fn=update_class_counts_handler,
699
+ inputs=[class_df, dataset_info_state],
700
+ outputs=[class_count_summary_df]
701
+ )
702
+ finalize_btn.click(
703
+ fn=finalize_handler,
704
+ inputs=[dataset_info_state, class_df],
705
+ outputs=[finalize_status, final_dataset_path_state]
706
+ )
707
+ train_btn.click(
708
+ fn=training_handler,
709
+ inputs=[final_dataset_path_state, model_file_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
710
+ outputs=[train_status, loss_plot, map_plot, final_model_file]
711
+ )
712
+ upload_btn.click(
713
+ fn=upload_handler,
714
+ inputs=[final_model_file, hf_token, hf_repo, gh_token, gh_repo],
715
+ outputs=[hf_status, gh_status]
716
+ )
717
 
718
  if __name__ == "__main__":
719
+ # Tip: silence Ultralytics settings warning by setting env var:
720
+ # export YOLO_CONFIG_DIR=/tmp/Ultralytics
721
+ app.launch(debug=True)