|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
from rf100vl import get_rf100vl_projects |
|
|
import roboflow |
|
|
from rfdetr import RFDETRBase |
|
|
import torch |
|
|
import os |
|
|
|
|
|
def download_dataset(rf_project: roboflow.Project, dataset_version: int): |
|
|
versions = rf_project.versions() |
|
|
if dataset_version is not None: |
|
|
versions = [v for v in versions if v.version == str(dataset_version)] |
|
|
if len(versions) == 0: |
|
|
raise ValueError(f"Dataset version {dataset_version} not found") |
|
|
version = versions[0] |
|
|
else: |
|
|
version = max(versions, key=lambda v: v.id) |
|
|
location = os.path.join("datasets/", rf_project.name + "_v" + version.version) |
|
|
if not os.path.exists(location): |
|
|
location = version.download( |
|
|
model_format="coco", location=location, overwrite=False |
|
|
).location |
|
|
|
|
|
return location |
|
|
|
|
|
|
|
|
def train_from_rf_project(rf_project: roboflow.Project, dataset_version: int): |
|
|
location = download_dataset(rf_project, dataset_version) |
|
|
print(location) |
|
|
rf_detr = RFDETRBase() |
|
|
device_supports_cuda = torch.cuda.is_available() |
|
|
rf_detr.train( |
|
|
dataset_dir=location, |
|
|
epochs=1, |
|
|
device="cuda" if device_supports_cuda else "cpu", |
|
|
) |
|
|
|
|
|
|
|
|
def train_from_coco_dir(coco_dir: str): |
|
|
rf_detr = RFDETRBase() |
|
|
rf_detr.train( |
|
|
dataset_dir=coco_dir, |
|
|
epochs=1, |
|
|
device="cuda" if device_supports_cuda else "cpu", |
|
|
) |
|
|
|
|
|
|
|
|
def trainer(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--coco_dir", type=str, required=False) |
|
|
parser.add_argument("--api_key", type=str, required=False) |
|
|
parser.add_argument("--workspace", type=str, required=False, default=None) |
|
|
parser.add_argument("--project_name", type=str, required=False, default=None) |
|
|
parser.add_argument("--dataset_version", type=int, required=False, default=None) |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.coco_dir is not None: |
|
|
train_from_coco_dir(args.coco_dir) |
|
|
return |
|
|
|
|
|
if (args.workspace is None and args.project_name is not None) or ( |
|
|
args.workspace is not None and args.project_name is None |
|
|
): |
|
|
raise ValueError( |
|
|
"Either both workspace and project_name must be provided or none of them" |
|
|
) |
|
|
|
|
|
if args.workspace is not None: |
|
|
rf = roboflow.Roboflow(api_key=args.api_key) |
|
|
project = rf.workspace(args.workspace).project(args.project_name) |
|
|
else: |
|
|
projects = get_rf100vl_projects(api_key=args.api_key) |
|
|
project = projects[0].rf_project |
|
|
|
|
|
train_from_rf_project(project, args.dataset_version) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
trainer() |
|
|
|