test-detr / app.py
wuhp's picture
Update app.py
0257e16 verified
raw
history blame
30.2 kB
import os
import shutil
import stat
import yaml
import gradio as gr
from ultralytics import YOLO # Ultralytics RT-DETR runner
from roboflow import Roboflow
import re
from urllib.parse import urlparse
import random
import logging
import requests
import json
from PIL import Image
import torch
import pandas as pd
import matplotlib.pyplot as plt
from threading import Thread
from queue import Queue
from huggingface_hub import HfApi, HfFolder
import base64
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Hardcode RT-DETR model configurations. All YOLO options are removed.
RTDETR_MODELS = {
"detection": [
{
"filename": "rtdetr-l.pt",
"url": "https://github.com/ultralytics/assets/releases/download/v8.0.0/rtdetr-l.pt",
"description": "RT-DETR Large model (Default)"
},
{
"filename": "rtdetr-x.pt",
"url": "https://github.com/ultralytics/assets/releases/download/v8.0.0/rtdetr-x.pt",
"description": "RT-DETR Extra-Large model."
}
]
}
DEFAULT_MODEL = "rtdetr-l.pt"
# ------------------------------
# Utilities
# ------------------------------
def handle_remove_readonly(func, path, exc_info):
"""Error handler for shutil.rmtree."""
try:
os.chmod(path, stat.S_IWRITE)
except Exception:
pass
func(path)
_ROBO_URL_RX = re.compile(
r"""
^(?:
(?:https?://)?(?:universe|app|www)?\.?roboflow\.com/ # Any roboflow host
(?P<ws>[A-Za-z0-9\-_]+)/ # workspace
(?P<proj>[A-Za-z0-9\-_]+)/? # project
(?:
(?:dataset/[^/]+/)? # optional 'dataset/<fmt>/'
(?:v?(?P<ver>\d+))? # optional version 'vN' or 'N'
)?
|
(?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))? # raw ws/proj[/vN]
)$
""",
re.VERBOSE | re.IGNORECASE
)
def parse_roboflow_url(s: str):
"""
Accepts:
- https://universe.roboflow.com/<workspace>/<project>[/vN | /N]
- https://app.roboflow.com/<workspace>/<project>[/vN | /N]
- https://roboflow.com/<workspace>/<project>[/vN | /N]
- raw: <workspace>/<project>[/vN | /N]
Returns: (workspace, project, version_or_None)
"""
s = s.strip()
# Fast path: try regex
m = _ROBO_URL_RX.match(s)
if m:
ws = m.group('ws') or m.group('ws2')
proj = m.group('proj') or m.group('proj2')
ver = m.group('ver') or m.group('ver2')
return ws, proj, (int(ver) if ver else None)
# Fallback: parse like URL and split path
parsed = urlparse(s)
parts = [p for p in parsed.path.strip('/').split('/') if p]
if len(parts) >= 2:
# Try to pull raw version from the 3rd part if it exists
version = None
if len(parts) >= 3:
# Accept 'vN' or 'N'
vpart = parts[2]
if vpart.lower().startswith('v') and vpart[1:].isdigit():
version = int(vpart[1:])
elif vpart.isdigit():
version = int(vpart)
return parts[0], parts[1], version
# Fallback raw "ws/proj" without slashes in URL
if '/' in s and 'roboflow' not in s:
p = s.split('/')
if len(p) >= 2:
# Accept trailing version if present
version = None
if len(p) >= 3:
v = p[2]
if v.lower().startswith('v') and v[1:].isdigit():
version = int(v[1:])
elif v.isdigit():
version = int(v)
return p[0], p[1], version
return None, None, None
def get_latest_version(api_key, workspace, project):
"""Gets the latest version number of a Roboflow project."""
try:
rf = Roboflow(api_key=api_key)
proj = rf.workspace(workspace).project(project)
versions = sorted([int(v.version) for v in proj.versions()], reverse=True)
return versions[0] if versions else None
except Exception as e:
logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
return None
# --- NEW: normalize class names from data.yaml ---
def _extract_class_names(data_yaml):
"""
Return a list[str] of class names in index order.
Handles:
- list (possibly containing non-str types)
- dict with numeric keys (e.g., {0: 'cat', 1: 'dog'})
- fallback to ['class_0', ..., f'class_{nc-1}'] if names missing
"""
names = data_yaml.get('names', None)
if isinstance(names, dict):
def _k(x):
try:
return int(x)
except Exception:
return str(x)
ordered_keys = sorted(names.keys(), key=_k)
names_list = [names[k] for k in ordered_keys]
elif isinstance(names, list):
names_list = names
else:
nc = data_yaml.get('nc', 0)
try:
nc = int(nc)
except Exception:
nc = 0
names_list = [f"class_{i}" for i in range(nc)]
return [str(x) for x in names_list]
def download_dataset(api_key, workspace, project, version):
"""Downloads a single dataset from Roboflow (yolov8 format works fine for RT-DETR)."""
try:
rf = Roboflow(api_key=api_key)
proj = rf.workspace(workspace).project(project)
ver = proj.version(int(version))
dataset = ver.download("yolov8")
data_yaml_path = os.path.join(dataset.location, 'data.yaml')
with open(data_yaml_path, 'r') as f:
data_yaml = yaml.safe_load(f)
# --- UPDATED: use normalized names and optional sanity log ---
class_names = _extract_class_names(data_yaml)
try:
nc = int(data_yaml.get('nc', len(class_names)))
except Exception:
nc = len(class_names)
if len(class_names) != nc:
logging.warning(f"[{project}-v{version}] names length ({len(class_names)}) != nc ({nc}); using normalized names.")
splits = [s for s in ['train', 'valid', 'test']
if os.path.exists(os.path.join(dataset.location, s))]
return dataset.location, class_names, splits, f"{project}-v{version}"
except Exception as e:
logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}")
return None, [], [], None
def label_path_for(img_path: str) -> str:
"""Convert .../split/images/file.jpg -> .../split/labels/file.txt in a safe way."""
split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
return os.path.join(split_dir, 'labels', base)
def gather_class_counts(dataset_info, class_mapping):
"""
Count, per final class, how many images contain at least one instance of that class
(counted once per image). class_mapping maps original_name -> final_name.
"""
if not dataset_info:
return {}
final_names = set(class_mapping.values())
counts = {name: 0 for name in final_names}
for loc, names, splits, _ in dataset_info:
# Map from original idx -> mapped name (or None if removed later)
id_to_name = {}
for idx, n in enumerate(names):
id_to_name[idx] = class_mapping.get(n, None)
for split in splits:
labels_dir = os.path.join(loc, split, 'labels')
if not os.path.exists(labels_dir):
continue
for label_file in os.listdir(labels_dir):
if not label_file.endswith('.txt'):
continue
found = set()
with open(os.path.join(labels_dir, label_file), 'r') as f:
for line in f:
parts = line.strip().split()
if not parts:
continue
try:
cls_id = int(parts[0])
mapped = id_to_name.get(cls_id, None)
if mapped in final_names:
found.add(mapped)
except Exception:
continue
for m in found:
counts[m] += 1
return counts
def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
"""Core function to merge datasets based on user rules."""
merged_dir = 'rolo_merged_dataset'
if os.path.exists(merged_dir):
shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
progress(0, desc="Creating directories...")
for split in ['train', 'valid', 'test']:
os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True)
os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True)
# Only classes with positive limits are active
active_classes = [cls for cls, limit in class_limits.items() if limit > 0]
active_classes = sorted(set(active_classes))
final_class_map = {name: i for i, name in enumerate(active_classes)}
# Collect all candidate images
all_images = []
for loc, _, splits, _ in dataset_info:
for split in splits:
img_dir = os.path.join(loc, split, 'images')
if not os.path.exists(img_dir):
continue
for img_file in os.listdir(img_dir):
if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
all_images.append((os.path.join(img_dir, img_file), split, loc))
random.shuffle(all_images)
progress(0.2, desc="Selecting images based on limits...")
selected_images = []
current_counts = {cls: 0 for cls in active_classes}
# Build a quick lookup: source_loc -> names list
loc_to_names = {info[0]: info[1] for info in dataset_info}
for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
lbl_path = label_path_for(img_path)
if not os.path.exists(lbl_path):
continue
source_names = loc_to_names.get(source_loc, [])
image_classes = set()
with open(lbl_path, 'r') as f:
for line in f:
parts = line.strip().split()
if not parts:
continue
try:
cls_id = int(parts[0])
orig = source_names[cls_id]
mapped = class_mapping.get(orig, orig)
if mapped in active_classes:
image_classes.add(mapped)
except Exception:
continue
if not image_classes:
continue
# Check limits
if any(current_counts[c] >= class_limits[c] for c in image_classes):
continue
selected_images.append((img_path, split))
for c in image_classes:
current_counts[c] += 1
progress(0.6, desc=f"Copying {len(selected_images)} files...")
for img_path, split in progress.tqdm(selected_images, desc="Finalizing files"):
lbl_path = label_path_for(img_path)
out_img = os.path.join(merged_dir, split, 'images', os.path.basename(img_path))
out_lbl = os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path))
shutil.copy(img_path, out_img)
# Determine source names by matching the parent dataset root
source_loc = None
for info in dataset_info:
if img_path.startswith(info[0]):
source_loc = info[0]
break
source_names = loc_to_names.get(source_loc, [])
with open(lbl_path, 'r') as f_in, open(out_lbl, 'w') as f_out:
for line in f_in:
parts = line.strip().split()
if not parts:
continue
try:
old_id = int(parts[0])
original_name = source_names[old_id]
mapped_name = class_mapping.get(original_name, original_name)
if mapped_name in final_class_map:
new_id = final_class_map[mapped_name]
f_out.write(f"{new_id} {' '.join(parts[1:])}\n")
except Exception:
continue
progress(0.95, desc="Creating data.yaml...")
with open(os.path.join(merged_dir, 'data.yaml'), 'w') as f:
yaml.dump({
'path': os.path.abspath(merged_dir),
'train': 'train/images',
'val': 'valid/images',
'test': 'test/images',
'nc': len(active_classes),
'names': active_classes
}, f)
return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
# ------------------------------
# Gradio UI Event Handlers
# ------------------------------
def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
"""Handles the 'Load Datasets' button click."""
api_key = api_key or os.getenv("ROBOFLOW_API_KEY", "")
if not api_key:
raise gr.Error("Roboflow API Key is required (or set ROBOFLOW_API_KEY).")
if not url_file:
raise gr.Error("Please upload a .txt file with Roboflow URLs or lines like 'workspace/project[/vN]'.")
with open(url_file.name, 'r', encoding='utf-8', errors='ignore') as f:
urls = [line.strip() for line in f if line.strip()]
dataset_info = []
failures = []
for i, raw in enumerate(urls):
progress((i + 1) / max(1, len(urls)), desc=f"Parsing {i+1}/{len(urls)}")
ws, proj, ver = parse_roboflow_url(raw)
if not (ws and proj):
failures.append((raw, "ParseError: could not resolve workspace/project"))
continue
if ver is None:
ver = get_latest_version(api_key, ws, proj)
if ver is None:
failures.append((raw, f"Could not resolve latest version for {ws}/{proj}"))
continue
loc, names, splits, name_str = download_dataset(api_key, ws, proj, int(ver))
if loc:
dataset_info.append((loc, names, splits, name_str))
else:
failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
if not dataset_info:
# Show a compact failure report to the UI
msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
raise gr.Error(msg)
# --- UPDATED: ensure all names are strings before sorting
all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
class_map = {name: name for name in all_names}
# Initial preview uses "keep all" mapping
initial_counts = gather_class_counts(dataset_info, class_map)
df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
status_text = "Datasets loaded successfully."
if failures:
status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
return status_text, dataset_info, gr.DataFrame.update(
value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
)
def update_class_counts_handler(class_df, dataset_info):
"""
Provides live feedback on class counts as the user edits the DataFrame.
We compute a mapping of original -> final (or None if removed), then count images
for each final name.
"""
if class_df is None or not dataset_info:
return None
# Build mapping original_name -> final_name or None if removed
class_df = pd.DataFrame(class_df)
mapping = {}
for _, row in class_df.iterrows():
orig = row["Original Name"]
if bool(row["Remove"]):
mapping[orig] = None
else:
mapping[orig] = row["Rename To"]
# Build final set
final_names = sorted(set(v for v in mapping.values() if v))
counts = {k: 0 for k in final_names}
for loc, names, splits, _ in dataset_info:
id_to_final = {}
for idx, n in enumerate(names):
id_to_final[idx] = mapping.get(n, None)
for split in splits:
labels_dir = os.path.join(loc, split, 'labels')
if not os.path.exists(labels_dir):
continue
for label_file in os.listdir(labels_dir):
if not label_file.endswith('.txt'):
continue
found = set()
with open(os.path.join(labels_dir, label_file), 'r') as f:
for line in f:
parts = line.strip().split()
if not parts:
continue
try:
cls_id = int(parts[0])
mapped = id_to_final.get(cls_id, None)
if mapped:
found.add(mapped)
except Exception:
continue
for m in found:
counts[m] += 1
summary_df = pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
return summary_df
def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
"""Handles the 'Finalize' button click."""
if not dataset_info:
raise gr.Error("Load datasets first in Tab 1.")
if class_df is None:
raise gr.Error("Class data is missing.")
# Mapping and limits
class_df = pd.DataFrame(class_df)
class_mapping = {}
class_limits = {}
for _, row in class_df.iterrows():
orig = row["Original Name"]
if bool(row["Remove"]):
continue
final_name = row["Rename To"]
class_mapping[orig] = final_name
# Sum limits for final_name over any merged originals
class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
return status, path
def training_handler(dataset_path, model_filename, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
"""Handles the training process with live feedback."""
if not dataset_path:
raise gr.Error("Finalize a dataset in Tab 2 before training.")
# Ultralytics expects device string, e.g. '0' or 'cpu'
device_str = "0" if torch.cuda.is_available() else "cpu"
metrics_queue = Queue()
def on_epoch_end(trainer):
# Be defensive about metric keys
m = trainer.metrics or {}
metrics_queue.put({
'epoch': (trainer.epoch or 0) + 1,
'train_loss': m.get('train/loss') or m.get('loss'),
'val_loss': m.get('val/loss'),
'mAP50': m.get('metrics/mAP50(B)') or m.get('metrics/mAP50'),
'mAP50_95': m.get('metrics/mAP50-95(B)') or m.get('metrics/mAP50-95')
})
def train_thread_func():
try:
model_url = next(m['url'] for m in RTDETR_MODELS['detection'] if m['filename'] == model_filename)
weights_path = os.path.join('pretrained_models', model_filename)
if not os.path.exists(weights_path):
os.makedirs('pretrained_models', exist_ok=True)
r = requests.get(model_url, stream=True, timeout=60)
r.raise_for_status()
with open(weights_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
model = YOLO(weights_path)
model.add_callback("on_train_epoch_end", on_epoch_end)
model.train(
data=os.path.join(dataset_path, 'data.yaml'),
epochs=int(epochs),
batch=int(batch),
imgsz=int(imgsz),
lr0=float(lr),
optimizer=str(opt),
project='runs/train',
name=str(run_name),
exist_ok=True,
device=device_str
)
metrics_queue.put("done")
except Exception as e:
logging.exception("Training thread error")
metrics_queue.put(f"error: {e}")
Thread(target=train_thread_func, daemon=True).start()
history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']}
while True:
item = metrics_queue.get()
if isinstance(item, str):
if item == "done":
break
if item.startswith("error"):
raise gr.Error(f"Training failed: {item}")
# Append metrics
for key in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']:
val = item.get(key, None)
if val is not None:
history[key].append(val)
current_epoch = history['epoch'][-1] if history['epoch'] else 0
total_epochs = int(epochs)
frac = min(max(current_epoch / max(1, total_epochs), 0.0), 1.0)
progress(frac, desc=f"Epoch {current_epoch}/{total_epochs}")
# Plot Loss
fig_loss = plt.figure()
ax_loss = fig_loss.add_subplot(111)
ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
ax_loss.legend()
ax_loss.set_title("Loss")
# Plot mAP
fig_map = plt.figure()
ax_map = fig_map.add_subplot(111)
ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
ax_map.legend()
ax_map.set_title("mAP")
yield f"Epoch {current_epoch}/{total_epochs} complete.", fig_loss, fig_map, None
final_path = os.path.join('runs', 'train', str(run_name), 'weights', 'best.pt')
if not os.path.exists(final_path):
raise gr.Error("Training finished, but 'best.pt' was not found.")
yield "Training complete!", None, None, gr.File.update(value=final_path, visible=True)
def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
"""Handles model upload to Hugging Face and GitHub."""
if not model_file:
raise gr.Error("No trained model file available to upload. Train a model first.")
hf_status = "Skipped Hugging Face (credentials not provided)."
if hf_token and hf_repo:
progress(0, desc="Uploading to Hugging Face...")
try:
api = HfApi()
HfFolder.save_token(hf_token)
repo_url = api.create_repo(repo_id=hf_repo, exist_ok=True, token=hf_token)
api.upload_file(
path_or_fileobj=model_file.name,
path_in_repo=os.path.basename(model_file.name),
repo_id=hf_repo,
token=hf_token
)
hf_status = f"Success! Model at: {repo_url}"
except Exception as e:
hf_status = f"Hugging Face Error: {e}"
gh_status = "Skipped GitHub (credentials not provided)."
if gh_token and gh_repo:
progress(0.5, desc="Uploading to GitHub...")
try:
if '/' not in gh_repo:
raise ValueError("GitHub repo must be in the form 'username/repo'.")
username, repo_name = gh_repo.split('/')
api_url = f"https://api.github.com/repos/{username}/{repo_name}/contents/{os.path.basename(model_file.name)}"
headers = {"Authorization": f"token {gh_token}"}
with open(model_file.name, "rb") as f:
content = base64.b64encode(f.read()).decode()
get_resp = requests.get(api_url, headers=headers, timeout=30)
sha = get_resp.json().get('sha') if get_resp.ok else None
data = {"message": "Upload trained model from Rolo app", "content": content}
if sha:
data["sha"] = sha
put_resp = requests.put(api_url, headers=headers, json=data, timeout=60)
if put_resp.ok:
gh_status = f"Success! Model at: {put_resp.json()['content']['html_url']}"
else:
msg = put_resp.json().get('message', 'Unknown')
gh_status = f"GitHub Error: {msg}"
except Exception as e:
gh_status = f"GitHub Error: {e}"
progress(1)
return hf_status, gh_status
# ------------------------------
# Gradio UI
# ------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
gr.Markdown("# Rolo: A Dedicated RT-DETR Training Dashboard")
# State variables
dataset_info_state = gr.State([])
final_dataset_path_state = gr.State(None)
with gr.Tabs():
with gr.TabItem("1. Prepare Datasets"):
gr.Markdown("### Load Roboflow Datasets\nProvide your Roboflow API key and upload a `.txt` file containing one Roboflow dataset URL or `workspace/project[/vN]` per line.")
with gr.Row():
rf_api_key = gr.Textbox(label="Roboflow API Key (or set ROBOFLOW_API_KEY env)", type="password", scale=2)
rf_url_file = gr.File(label="Upload Roboflow URLs (.txt)", file_types=[".txt"], scale=1)
load_btn = gr.Button("Load Datasets", variant="primary")
dataset_status = gr.Textbox(label="Status", interactive=False)
with gr.TabItem("2. Manage & Merge"):
gr.Markdown("### Configure Classes and Finalize Dataset\nRename classes to merge them, set image limits, or remove them. Click **Update Counts** to preview, then **Finalize** to create the dataset.")
with gr.Row():
class_df = gr.DataFrame(
headers=["Original Name", "Rename To", "Max Images", "Remove"],
datatype=["str", "str", "number", "bool"],
label="Class Configuration", interactive=True, scale=3
)
with gr.Column(scale=1):
class_count_summary_df = gr.DataFrame(
label="Merged Class Counts Preview",
headers=["Final Class Name", "Est. Total Images"],
interactive=False
)
update_counts_btn = gr.Button("Update Counts")
finalize_btn = gr.Button("Finalize Merged Dataset", variant="primary")
finalize_status = gr.Textbox(label="Status", interactive=False)
with gr.TabItem("3. Configure & Train"):
gr.Markdown("### Set Hyperparameters and Train the RT-DETR Model")
with gr.Row():
with gr.Column(scale=1):
model_file_dd = gr.Dropdown(
label="Select Pre-Trained RT-DETR Model",
choices=[m["filename"] for m in RTDETR_MODELS["detection"]],
value=DEFAULT_MODEL
)
run_name_tb = gr.Textbox(label="Run Name", value="rtdetr_run_1")
epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
batch_sl = gr.Slider(1, 32, 8, step=1, label="Batch Size")
imgsz_num = gr.Number(label="Image Size", value=640)
lr_num = gr.Number(label="Learning Rate", value=0.001)
opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="Adam", label="Optimizer")
train_btn = gr.Button("Start Training", variant="primary")
with gr.Column(scale=2):
train_status = gr.Textbox(label="Live Status", interactive=False)
loss_plot = gr.Plot(label="Loss Curves")
map_plot = gr.Plot(label="mAP Curves")
final_model_file = gr.File(label="Download Trained Model (best.pt)", interactive=False, visible=False)
with gr.TabItem("4. Upload Model"):
gr.Markdown("### Upload Your Trained Model\nAfter training, you can upload the `best.pt` file to Hugging Face and/or GitHub.")
with gr.Row():
with gr.Column():
gr.Markdown("#### Hugging Face")
hf_token = gr.Textbox(label="Hugging Face API Token", type="password")
hf_repo = gr.Textbox(label="Hugging Face Repo ID", placeholder="e.g., username/my-rtdetr-model")
with gr.Column():
gr.Markdown("#### GitHub")
gh_token = gr.Textbox(label="GitHub Personal Access Token", type="password")
gh_repo = gr.Textbox(label="GitHub Repo", placeholder="e.g., username/my-rtdetr-repo")
upload_btn = gr.Button("Upload Model", variant="primary")
with gr.Row():
hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
gh_status = gr.Textbox(label="GitHub Status", interactive=False)
# Wire UI handlers
load_btn.click(
fn=load_datasets_handler,
inputs=[rf_api_key, rf_url_file],
outputs=[dataset_status, dataset_info_state, class_df]
)
update_counts_btn.click(
fn=update_class_counts_handler,
inputs=[class_df, dataset_info_state],
outputs=[class_count_summary_df]
)
finalize_btn.click(
fn=finalize_handler,
inputs=[dataset_info_state, class_df],
outputs=[finalize_status, final_dataset_path_state]
)
train_btn.click(
fn=training_handler,
inputs=[final_dataset_path_state, model_file_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
outputs=[train_status, loss_plot, map_plot, final_model_file]
)
upload_btn.click(
fn=upload_handler,
inputs=[final_model_file, hf_token, hf_repo, gh_token, gh_repo],
outputs=[hf_status, gh_status]
)
if __name__ == "__main__":
# Tip: silence Ultralytics settings warning by setting env var:
# export YOLO_CONFIG_DIR=/tmp/Ultralytics
app.launch(debug=True)