File size: 3,069 Bytes
066effd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
# Copyright (c) 2024 Baidu. All Rights Reserved.
# ------------------------------------------------------------------------
import argparse
from rf100vl import get_rf100vl_projects
import roboflow
from rfdetr import RFDETRBase
import torch
import os
def download_dataset(rf_project: roboflow.Project, dataset_version: int):
versions = rf_project.versions()
if dataset_version is not None:
versions = [v for v in versions if v.version == str(dataset_version)]
if len(versions) == 0:
raise ValueError(f"Dataset version {dataset_version} not found")
version = versions[0]
else:
version = max(versions, key=lambda v: v.id)
location = os.path.join("datasets/", rf_project.name + "_v" + version.version)
if not os.path.exists(location):
location = version.download(
model_format="coco", location=location, overwrite=False
).location
return location
def train_from_rf_project(rf_project: roboflow.Project, dataset_version: int):
location = download_dataset(rf_project, dataset_version)
print(location)
rf_detr = RFDETRBase()
device_supports_cuda = torch.cuda.is_available()
rf_detr.train(
dataset_dir=location,
epochs=1,
device="cuda" if device_supports_cuda else "cpu",
)
def train_from_coco_dir(coco_dir: str):
rf_detr = RFDETRBase()
rf_detr.train(
dataset_dir=coco_dir,
epochs=1,
device="cuda" if device_supports_cuda else "cpu",
)
def trainer():
parser = argparse.ArgumentParser()
parser.add_argument("--coco_dir", type=str, required=False)
parser.add_argument("--api_key", type=str, required=False)
parser.add_argument("--workspace", type=str, required=False, default=None)
parser.add_argument("--project_name", type=str, required=False, default=None)
parser.add_argument("--dataset_version", type=int, required=False, default=None)
args = parser.parse_args()
if args.coco_dir is not None:
train_from_coco_dir(args.coco_dir)
return
if (args.workspace is None and args.project_name is not None) or (
args.workspace is not None and args.project_name is None
):
raise ValueError(
"Either both workspace and project_name must be provided or none of them"
)
if args.workspace is not None:
rf = roboflow.Roboflow(api_key=args.api_key)
project = rf.workspace(args.workspace).project(args.project_name)
else:
projects = get_rf100vl_projects(api_key=args.api_key)
project = projects[0].rf_project
train_from_rf_project(project, args.dataset_version)
if __name__ == "__main__":
trainer()
|