File size: 12,868 Bytes
066effd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 |
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
# Copyright (c) 2024 Baidu. All Rights Reserved.
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable
import random
import torch
import torch.nn.functional as F
import rfdetr.util.misc as utils
from rfdetr.datasets.coco_eval import CocoEvaluator
from rfdetr.datasets.coco import compute_multi_scale_scales
try:
from torch.amp import autocast, GradScaler
DEPRECATED_AMP = False
except ImportError:
from torch.cuda.amp import autocast, GradScaler
DEPRECATED_AMP = True
from typing import DefaultDict, List, Callable
from rfdetr.util.misc import NestedTensor
import numpy as np
def get_autocast_args(args):
if DEPRECATED_AMP:
return {'enabled': args.amp, 'dtype': torch.bfloat16}
else:
return {'device_type': 'cuda', 'enabled': args.amp, 'dtype': torch.bfloat16}
def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
batch_size: int,
max_norm: float = 0,
ema_m: torch.nn.Module = None,
schedules: dict = {},
num_training_steps_per_epoch=None,
vit_encoder_num_layers=None,
args=None,
callbacks: DefaultDict[str, List[Callable]] = None,
):
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
metric_logger.add_meter(
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
)
header = "Epoch: [{}]".format(epoch)
print_freq = 10
start_steps = epoch * num_training_steps_per_epoch
print("Grad accum steps: ", args.grad_accum_steps)
print("Total batch size: ", batch_size * utils.get_world_size())
# Add gradient scaler for AMP
if DEPRECATED_AMP:
scaler = GradScaler(enabled=args.amp)
else:
scaler = GradScaler('cuda', enabled=args.amp)
optimizer.zero_grad()
assert batch_size % args.grad_accum_steps == 0
sub_batch_size = batch_size // args.grad_accum_steps
print("LENGTH OF DATA LOADER:", len(data_loader))
for data_iter_step, (samples, targets) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)
):
it = start_steps + data_iter_step
callback_dict = {
"step": it,
"model": model,
"epoch": epoch,
}
for callback in callbacks["on_train_batch_start"]:
callback(callback_dict)
if "dp" in schedules:
if args.distributed:
model.module.update_drop_path(
schedules["dp"][it], vit_encoder_num_layers
)
else:
model.update_drop_path(schedules["dp"][it], vit_encoder_num_layers)
if "do" in schedules:
if args.distributed:
model.module.update_dropout(schedules["do"][it])
else:
model.update_dropout(schedules["do"][it])
if args.multi_scale and not args.do_random_resize_via_padding:
scales = compute_multi_scale_scales(args.resolution, args.expanded_scales, args.patch_size, args.num_windows)
random.seed(it)
scale = random.choice(scales)
with torch.inference_mode():
samples.tensors = F.interpolate(samples.tensors, size=scale, mode='bilinear', align_corners=False)
samples.mask = F.interpolate(samples.mask.unsqueeze(1).float(), size=scale, mode='nearest').squeeze(1).bool()
for i in range(args.grad_accum_steps):
start_idx = i * sub_batch_size
final_idx = start_idx + sub_batch_size
new_samples_tensors = samples.tensors[start_idx:final_idx]
new_samples = NestedTensor(new_samples_tensors, samples.mask[start_idx:final_idx])
new_samples = new_samples.to(device)
new_targets = [{k: v.to(device) for k, v in t.items()} for t in targets[start_idx:final_idx]]
with autocast(**get_autocast_args(args)):
outputs = model(new_samples, new_targets)
loss_dict = criterion(outputs, new_targets)
weight_dict = criterion.weight_dict
losses = sum(
(1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k]
for k in loss_dict.keys()
if k in weight_dict
)
scaler.scale(losses).backward()
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_unscaled = {
f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
}
loss_dict_reduced_scaled = {
k: v * weight_dict[k]
for k, v in loss_dict_reduced.items()
if k in weight_dict
}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item()
if not math.isfinite(loss_value):
print(loss_dict_reduced)
raise ValueError("Loss is {}, stopping training".format(loss_value))
if max_norm > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
if ema_m is not None:
if epoch >= 0:
ema_m.update(model)
metric_logger.update(
loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
)
metric_logger.update(class_error=loss_dict_reduced["class_error"])
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def coco_extended_metrics(coco_eval):
"""
Safe version: ignores the –1 sentinel entries so precision/F1 never explode.
"""
iou_thrs, rec_thrs = coco_eval.params.iouThrs, coco_eval.params.recThrs
iou50_idx, area_idx, maxdet_idx = (
int(np.argwhere(np.isclose(iou_thrs, 0.50))), 0, 2)
P = coco_eval.eval["precision"]
S = coco_eval.eval["scores"]
prec_raw = P[iou50_idx, :, :, area_idx, maxdet_idx]
prec = prec_raw.copy().astype(float)
prec[prec < 0] = np.nan
f1_cls = 2 * prec * rec_thrs[:, None] / (prec + rec_thrs[:, None])
f1_macro = np.nanmean(f1_cls, axis=1)
best_j = int(f1_macro.argmax())
macro_precision = float(np.nanmean(prec[best_j]))
macro_recall = float(rec_thrs[best_j])
macro_f1 = float(f1_macro[best_j])
score_vec = S[iou50_idx, best_j, :, area_idx, maxdet_idx].astype(float)
score_vec[prec_raw[best_j] < 0] = np.nan
score_thr = float(np.nanmean(score_vec))
map_50_95, map_50 = float(coco_eval.stats[0]), float(coco_eval.stats[1])
per_class = []
cat_ids = coco_eval.params.catIds
cat_id_to_name = {c["id"]: c["name"] for c in coco_eval.cocoGt.loadCats(cat_ids)}
for k, cid in enumerate(cat_ids):
p_slice = P[:, :, k, area_idx, maxdet_idx]
valid = p_slice > -1
ap_50_95 = float(p_slice[valid].mean()) if valid.any() else float("nan")
ap_50 = float(p_slice[iou50_idx][p_slice[iou50_idx] > -1].mean()) if (p_slice[iou50_idx] > -1).any() else float("nan")
pc = float(prec[best_j, k]) if prec_raw[best_j, k] > -1 else float("nan")
rc = macro_recall
#Doing to this to filter out dataset class
if np.isnan(ap_50_95) or np.isnan(ap_50) or np.isnan(pc) or np.isnan(rc):
continue
per_class.append({
"class" : cat_id_to_name[int(cid)],
"map@50:95" : ap_50_95,
"map@50" : ap_50,
"precision" : pc,
"recall" : rc,
})
per_class.append({
"class" : "all",
"map@50:95" : map_50_95,
"map@50" : map_50,
"precision" : macro_precision,
"recall" : macro_recall,
})
return {
"class_map": per_class,
"map" : map_50,
"precision": macro_precision,
"recall" : macro_recall
}
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args=None):
model.eval()
if args.fp16_eval:
model.half()
criterion.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter(
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
)
header = "Test:"
iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(base_ds, iou_types)
for samples, targets in metric_logger.log_every(data_loader, 10, header):
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
if args.fp16_eval:
samples.tensors = samples.tensors.half()
# Add autocast for evaluation
with autocast(**get_autocast_args(args)):
outputs = model(samples)
if args.fp16_eval:
for key in outputs.keys():
if key == "enc_outputs":
for sub_key in outputs[key].keys():
outputs[key][sub_key] = outputs[key][sub_key].float()
elif key == "aux_outputs":
for idx in range(len(outputs[key])):
for sub_key in outputs[key][idx].keys():
outputs[key][idx][sub_key] = outputs[key][idx][
sub_key
].float()
else:
outputs[key] = outputs[key].float()
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {
k: v * weight_dict[k]
for k, v in loss_dict_reduced.items()
if k in weight_dict
}
loss_dict_reduced_unscaled = {
f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
}
metric_logger.update(
loss=sum(loss_dict_reduced_scaled.values()),
**loss_dict_reduced_scaled,
**loss_dict_reduced_unscaled,
)
metric_logger.update(class_error=loss_dict_reduced["class_error"])
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors["bbox"](outputs, orig_target_sizes)
res = {
target["image_id"].item(): output
for target, output in zip(targets, results)
}
if coco_evaluator is not None:
coco_evaluator.update(res)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None:
results_json = coco_extended_metrics(coco_evaluator.coco_eval["bbox"])
stats["results_json"] = results_json
if "bbox" in postprocessors.keys():
stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
if "segm" in postprocessors.keys():
stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
return stats, coco_evaluator |