Update app.py
Browse files
app.py
CHANGED
|
@@ -133,6 +133,38 @@ def get_latest_version(api_key, workspace, project):
|
|
| 133 |
return None
|
| 134 |
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
def download_dataset(api_key, workspace, project, version):
|
| 137 |
"""Downloads a single dataset from Roboflow (yolov8 format works fine for RT-DETR)."""
|
| 138 |
try:
|
|
@@ -145,7 +177,15 @@ def download_dataset(api_key, workspace, project, version):
|
|
| 145 |
with open(data_yaml_path, 'r') as f:
|
| 146 |
data_yaml = yaml.safe_load(f)
|
| 147 |
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
splits = [s for s in ['train', 'valid', 'test']
|
| 150 |
if os.path.exists(os.path.join(dataset.location, s))]
|
| 151 |
|
|
@@ -358,7 +398,8 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
|
|
| 358 |
msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
|
| 359 |
raise gr.Error(msg)
|
| 360 |
|
| 361 |
-
|
|
|
|
| 362 |
class_map = {name: name for name in all_names}
|
| 363 |
|
| 364 |
# Initial preview uses "keep all" mapping
|
|
@@ -448,9 +489,6 @@ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
|
|
| 448 |
# Sum limits for final_name over any merged originals
|
| 449 |
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
|
| 450 |
|
| 451 |
-
# Any original not present in mapping will map to itself (keep behavior)
|
| 452 |
-
# BUT we do not want to include classes with 0 limit in the final dataset
|
| 453 |
-
# finalize_merged_dataset uses the limits dict to decide active classes.
|
| 454 |
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
|
| 455 |
return status, path
|
| 456 |
|
|
|
|
| 133 |
return None
|
| 134 |
|
| 135 |
|
| 136 |
+
# --- NEW: normalize class names from data.yaml ---
|
| 137 |
+
def _extract_class_names(data_yaml):
|
| 138 |
+
"""
|
| 139 |
+
Return a list[str] of class names in index order.
|
| 140 |
+
Handles:
|
| 141 |
+
- list (possibly containing non-str types)
|
| 142 |
+
- dict with numeric keys (e.g., {0: 'cat', 1: 'dog'})
|
| 143 |
+
- fallback to ['class_0', ..., f'class_{nc-1}'] if names missing
|
| 144 |
+
"""
|
| 145 |
+
names = data_yaml.get('names', None)
|
| 146 |
+
|
| 147 |
+
if isinstance(names, dict):
|
| 148 |
+
def _k(x):
|
| 149 |
+
try:
|
| 150 |
+
return int(x)
|
| 151 |
+
except Exception:
|
| 152 |
+
return str(x)
|
| 153 |
+
ordered_keys = sorted(names.keys(), key=_k)
|
| 154 |
+
names_list = [names[k] for k in ordered_keys]
|
| 155 |
+
elif isinstance(names, list):
|
| 156 |
+
names_list = names
|
| 157 |
+
else:
|
| 158 |
+
nc = data_yaml.get('nc', 0)
|
| 159 |
+
try:
|
| 160 |
+
nc = int(nc)
|
| 161 |
+
except Exception:
|
| 162 |
+
nc = 0
|
| 163 |
+
names_list = [f"class_{i}" for i in range(nc)]
|
| 164 |
+
|
| 165 |
+
return [str(x) for x in names_list]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
def download_dataset(api_key, workspace, project, version):
|
| 169 |
"""Downloads a single dataset from Roboflow (yolov8 format works fine for RT-DETR)."""
|
| 170 |
try:
|
|
|
|
| 177 |
with open(data_yaml_path, 'r') as f:
|
| 178 |
data_yaml = yaml.safe_load(f)
|
| 179 |
|
| 180 |
+
# --- UPDATED: use normalized names and optional sanity log ---
|
| 181 |
+
class_names = _extract_class_names(data_yaml)
|
| 182 |
+
try:
|
| 183 |
+
nc = int(data_yaml.get('nc', len(class_names)))
|
| 184 |
+
except Exception:
|
| 185 |
+
nc = len(class_names)
|
| 186 |
+
if len(class_names) != nc:
|
| 187 |
+
logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.")
|
| 188 |
+
|
| 189 |
splits = [s for s in ['train', 'valid', 'test']
|
| 190 |
if os.path.exists(os.path.join(dataset.location, s))]
|
| 191 |
|
|
|
|
| 398 |
msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
|
| 399 |
raise gr.Error(msg)
|
| 400 |
|
| 401 |
+
# --- UPDATED: ensure all names are strings before sorting
|
| 402 |
+
all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
|
| 403 |
class_map = {name: name for name in all_names}
|
| 404 |
|
| 405 |
# Initial preview uses "keep all" mapping
|
|
|
|
| 489 |
# Sum limits for final_name over any merged originals
|
| 490 |
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
|
| 491 |
|
|
|
|
|
|
|
|
|
|
| 492 |
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
|
| 493 |
return status, path
|
| 494 |
|