wuhp commited on
Commit
0257e16
·
verified ·
1 Parent(s): c54a7a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -5
app.py CHANGED
@@ -133,6 +133,38 @@ def get_latest_version(api_key, workspace, project):
133
  return None
134
 
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def download_dataset(api_key, workspace, project, version):
137
  """Downloads a single dataset from Roboflow (yolov8 format works fine for RT-DETR)."""
138
  try:
@@ -145,7 +177,15 @@ def download_dataset(api_key, workspace, project, version):
145
  with open(data_yaml_path, 'r') as f:
146
  data_yaml = yaml.safe_load(f)
147
 
148
- class_names = data_yaml.get('names', [])
 
 
 
 
 
 
 
 
149
  splits = [s for s in ['train', 'valid', 'test']
150
  if os.path.exists(os.path.join(dataset.location, s))]
151
 
@@ -358,7 +398,8 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
358
  msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
359
  raise gr.Error(msg)
360
 
361
- all_names = sorted(list(set(n for _, names, _, _ in dataset_info for n in names)))
 
362
  class_map = {name: name for name in all_names}
363
 
364
  # Initial preview uses "keep all" mapping
@@ -448,9 +489,6 @@ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
448
  # Sum limits for final_name over any merged originals
449
  class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
450
 
451
- # Any original not present in mapping will map to itself (keep behavior)
452
- # BUT we do not want to include classes with 0 limit in the final dataset
453
- # finalize_merged_dataset uses the limits dict to decide active classes.
454
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
455
  return status, path
456
 
 
133
  return None
134
 
135
 
136
+ # --- NEW: normalize class names from data.yaml ---
137
+ def _extract_class_names(data_yaml):
138
+ """
139
+ Return a list[str] of class names in index order.
140
+ Handles:
141
+ - list (possibly containing non-str types)
142
+ - dict with numeric keys (e.g., {0: 'cat', 1: 'dog'})
143
+ - fallback to ['class_0', ..., f'class_{nc-1}'] if names missing
144
+ """
145
+ names = data_yaml.get('names', None)
146
+
147
+ if isinstance(names, dict):
148
+ def _k(x):
149
+ try:
150
+ return int(x)
151
+ except Exception:
152
+ return str(x)
153
+ ordered_keys = sorted(names.keys(), key=_k)
154
+ names_list = [names[k] for k in ordered_keys]
155
+ elif isinstance(names, list):
156
+ names_list = names
157
+ else:
158
+ nc = data_yaml.get('nc', 0)
159
+ try:
160
+ nc = int(nc)
161
+ except Exception:
162
+ nc = 0
163
+ names_list = [f"class_{i}" for i in range(nc)]
164
+
165
+ return [str(x) for x in names_list]
166
+
167
+
168
  def download_dataset(api_key, workspace, project, version):
169
  """Downloads a single dataset from Roboflow (yolov8 format works fine for RT-DETR)."""
170
  try:
 
177
  with open(data_yaml_path, 'r') as f:
178
  data_yaml = yaml.safe_load(f)
179
 
180
+ # --- UPDATED: use normalized names and optional sanity log ---
181
+ class_names = _extract_class_names(data_yaml)
182
+ try:
183
+ nc = int(data_yaml.get('nc', len(class_names)))
184
+ except Exception:
185
+ nc = len(class_names)
186
+ if len(class_names) != nc:
187
+ logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.")
188
+
189
  splits = [s for s in ['train', 'valid', 'test']
190
  if os.path.exists(os.path.join(dataset.location, s))]
191
 
 
398
  msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
399
  raise gr.Error(msg)
400
 
401
+ # --- UPDATED: ensure all names are strings before sorting
402
+ all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
403
  class_map = {name: name for name in all_names}
404
 
405
  # Initial preview uses "keep all" mapping
 
489
  # Sum limits for final_name over any merged originals
490
  class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
491
 
 
 
 
492
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
493
  return status, path
494