|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Train and eval functions used in main.py |
|
|
""" |
|
|
import math |
|
|
import sys |
|
|
from typing import Iterable |
|
|
import random |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import rfdetr.util.misc as utils |
|
|
from rfdetr.datasets.coco_eval import CocoEvaluator |
|
|
from rfdetr.datasets.coco import compute_multi_scale_scales |
|
|
|
|
|
try: |
|
|
from torch.amp import autocast, GradScaler |
|
|
DEPRECATED_AMP = False |
|
|
except ImportError: |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
DEPRECATED_AMP = True |
|
|
from typing import DefaultDict, List, Callable |
|
|
from rfdetr.util.misc import NestedTensor |
|
|
import numpy as np |
|
|
|
|
|
def get_autocast_args(args): |
|
|
if DEPRECATED_AMP: |
|
|
return {'enabled': args.amp, 'dtype': torch.bfloat16} |
|
|
else: |
|
|
return {'device_type': 'cuda', 'enabled': args.amp, 'dtype': torch.bfloat16} |
|
|
|
|
|
|
|
|
def train_one_epoch( |
|
|
model: torch.nn.Module, |
|
|
criterion: torch.nn.Module, |
|
|
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, |
|
|
data_loader: Iterable, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
device: torch.device, |
|
|
epoch: int, |
|
|
batch_size: int, |
|
|
max_norm: float = 0, |
|
|
ema_m: torch.nn.Module = None, |
|
|
schedules: dict = {}, |
|
|
num_training_steps_per_epoch=None, |
|
|
vit_encoder_num_layers=None, |
|
|
args=None, |
|
|
callbacks: DefaultDict[str, List[Callable]] = None, |
|
|
): |
|
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
|
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) |
|
|
metric_logger.add_meter( |
|
|
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") |
|
|
) |
|
|
header = "Epoch: [{}]".format(epoch) |
|
|
print_freq = 10 |
|
|
start_steps = epoch * num_training_steps_per_epoch |
|
|
|
|
|
print("Grad accum steps: ", args.grad_accum_steps) |
|
|
print("Total batch size: ", batch_size * utils.get_world_size()) |
|
|
|
|
|
|
|
|
if DEPRECATED_AMP: |
|
|
scaler = GradScaler(enabled=args.amp) |
|
|
else: |
|
|
scaler = GradScaler('cuda', enabled=args.amp) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
assert batch_size % args.grad_accum_steps == 0 |
|
|
sub_batch_size = batch_size // args.grad_accum_steps |
|
|
print("LENGTH OF DATA LOADER:", len(data_loader)) |
|
|
for data_iter_step, (samples, targets) in enumerate( |
|
|
metric_logger.log_every(data_loader, print_freq, header) |
|
|
): |
|
|
it = start_steps + data_iter_step |
|
|
callback_dict = { |
|
|
"step": it, |
|
|
"model": model, |
|
|
"epoch": epoch, |
|
|
} |
|
|
for callback in callbacks["on_train_batch_start"]: |
|
|
callback(callback_dict) |
|
|
if "dp" in schedules: |
|
|
if args.distributed: |
|
|
model.module.update_drop_path( |
|
|
schedules["dp"][it], vit_encoder_num_layers |
|
|
) |
|
|
else: |
|
|
model.update_drop_path(schedules["dp"][it], vit_encoder_num_layers) |
|
|
if "do" in schedules: |
|
|
if args.distributed: |
|
|
model.module.update_dropout(schedules["do"][it]) |
|
|
else: |
|
|
model.update_dropout(schedules["do"][it]) |
|
|
|
|
|
if args.multi_scale and not args.do_random_resize_via_padding: |
|
|
scales = compute_multi_scale_scales(args.resolution, args.expanded_scales, args.patch_size, args.num_windows) |
|
|
random.seed(it) |
|
|
scale = random.choice(scales) |
|
|
with torch.inference_mode(): |
|
|
samples.tensors = F.interpolate(samples.tensors, size=scale, mode='bilinear', align_corners=False) |
|
|
samples.mask = F.interpolate(samples.mask.unsqueeze(1).float(), size=scale, mode='nearest').squeeze(1).bool() |
|
|
|
|
|
for i in range(args.grad_accum_steps): |
|
|
start_idx = i * sub_batch_size |
|
|
final_idx = start_idx + sub_batch_size |
|
|
new_samples_tensors = samples.tensors[start_idx:final_idx] |
|
|
new_samples = NestedTensor(new_samples_tensors, samples.mask[start_idx:final_idx]) |
|
|
new_samples = new_samples.to(device) |
|
|
new_targets = [{k: v.to(device) for k, v in t.items()} for t in targets[start_idx:final_idx]] |
|
|
|
|
|
with autocast(**get_autocast_args(args)): |
|
|
outputs = model(new_samples, new_targets) |
|
|
loss_dict = criterion(outputs, new_targets) |
|
|
weight_dict = criterion.weight_dict |
|
|
losses = sum( |
|
|
(1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k] |
|
|
for k in loss_dict.keys() |
|
|
if k in weight_dict |
|
|
) |
|
|
|
|
|
|
|
|
scaler.scale(losses).backward() |
|
|
|
|
|
|
|
|
loss_dict_reduced = utils.reduce_dict(loss_dict) |
|
|
loss_dict_reduced_unscaled = { |
|
|
f"{k}_unscaled": v for k, v in loss_dict_reduced.items() |
|
|
} |
|
|
loss_dict_reduced_scaled = { |
|
|
k: v * weight_dict[k] |
|
|
for k, v in loss_dict_reduced.items() |
|
|
if k in weight_dict |
|
|
} |
|
|
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) |
|
|
|
|
|
loss_value = losses_reduced_scaled.item() |
|
|
|
|
|
if not math.isfinite(loss_value): |
|
|
print(loss_dict_reduced) |
|
|
raise ValueError("Loss is {}, stopping training".format(loss_value)) |
|
|
|
|
|
if max_norm > 0: |
|
|
scaler.unscale_(optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
|
|
|
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
lr_scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
if ema_m is not None: |
|
|
if epoch >= 0: |
|
|
ema_m.update(model) |
|
|
metric_logger.update( |
|
|
loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled |
|
|
) |
|
|
metric_logger.update(class_error=loss_dict_reduced["class_error"]) |
|
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
|
print("Averaged stats:", metric_logger) |
|
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
|
|
|
def coco_extended_metrics(coco_eval): |
|
|
""" |
|
|
Safe version: ignores the –1 sentinel entries so precision/F1 never explode. |
|
|
""" |
|
|
|
|
|
iou_thrs, rec_thrs = coco_eval.params.iouThrs, coco_eval.params.recThrs |
|
|
iou50_idx, area_idx, maxdet_idx = ( |
|
|
int(np.argwhere(np.isclose(iou_thrs, 0.50))), 0, 2) |
|
|
|
|
|
P = coco_eval.eval["precision"] |
|
|
S = coco_eval.eval["scores"] |
|
|
|
|
|
prec_raw = P[iou50_idx, :, :, area_idx, maxdet_idx] |
|
|
|
|
|
prec = prec_raw.copy().astype(float) |
|
|
prec[prec < 0] = np.nan |
|
|
|
|
|
f1_cls = 2 * prec * rec_thrs[:, None] / (prec + rec_thrs[:, None]) |
|
|
f1_macro = np.nanmean(f1_cls, axis=1) |
|
|
|
|
|
best_j = int(f1_macro.argmax()) |
|
|
|
|
|
macro_precision = float(np.nanmean(prec[best_j])) |
|
|
macro_recall = float(rec_thrs[best_j]) |
|
|
macro_f1 = float(f1_macro[best_j]) |
|
|
|
|
|
score_vec = S[iou50_idx, best_j, :, area_idx, maxdet_idx].astype(float) |
|
|
score_vec[prec_raw[best_j] < 0] = np.nan |
|
|
score_thr = float(np.nanmean(score_vec)) |
|
|
|
|
|
map_50_95, map_50 = float(coco_eval.stats[0]), float(coco_eval.stats[1]) |
|
|
|
|
|
per_class = [] |
|
|
cat_ids = coco_eval.params.catIds |
|
|
cat_id_to_name = {c["id"]: c["name"] for c in coco_eval.cocoGt.loadCats(cat_ids)} |
|
|
for k, cid in enumerate(cat_ids): |
|
|
p_slice = P[:, :, k, area_idx, maxdet_idx] |
|
|
valid = p_slice > -1 |
|
|
ap_50_95 = float(p_slice[valid].mean()) if valid.any() else float("nan") |
|
|
ap_50 = float(p_slice[iou50_idx][p_slice[iou50_idx] > -1].mean()) if (p_slice[iou50_idx] > -1).any() else float("nan") |
|
|
|
|
|
pc = float(prec[best_j, k]) if prec_raw[best_j, k] > -1 else float("nan") |
|
|
rc = macro_recall |
|
|
|
|
|
|
|
|
if np.isnan(ap_50_95) or np.isnan(ap_50) or np.isnan(pc) or np.isnan(rc): |
|
|
continue |
|
|
|
|
|
per_class.append({ |
|
|
"class" : cat_id_to_name[int(cid)], |
|
|
"map@50:95" : ap_50_95, |
|
|
"map@50" : ap_50, |
|
|
"precision" : pc, |
|
|
"recall" : rc, |
|
|
}) |
|
|
|
|
|
per_class.append({ |
|
|
"class" : "all", |
|
|
"map@50:95" : map_50_95, |
|
|
"map@50" : map_50, |
|
|
"precision" : macro_precision, |
|
|
"recall" : macro_recall, |
|
|
}) |
|
|
|
|
|
return { |
|
|
"class_map": per_class, |
|
|
"map" : map_50, |
|
|
"precision": macro_precision, |
|
|
"recall" : macro_recall |
|
|
} |
|
|
|
|
|
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args=None): |
|
|
model.eval() |
|
|
if args.fp16_eval: |
|
|
model.half() |
|
|
criterion.eval() |
|
|
|
|
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
|
metric_logger.add_meter( |
|
|
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") |
|
|
) |
|
|
header = "Test:" |
|
|
|
|
|
iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys()) |
|
|
coco_evaluator = CocoEvaluator(base_ds, iou_types) |
|
|
|
|
|
for samples, targets in metric_logger.log_every(data_loader, 10, header): |
|
|
samples = samples.to(device) |
|
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] |
|
|
|
|
|
if args.fp16_eval: |
|
|
samples.tensors = samples.tensors.half() |
|
|
|
|
|
|
|
|
with autocast(**get_autocast_args(args)): |
|
|
outputs = model(samples) |
|
|
|
|
|
if args.fp16_eval: |
|
|
for key in outputs.keys(): |
|
|
if key == "enc_outputs": |
|
|
for sub_key in outputs[key].keys(): |
|
|
outputs[key][sub_key] = outputs[key][sub_key].float() |
|
|
elif key == "aux_outputs": |
|
|
for idx in range(len(outputs[key])): |
|
|
for sub_key in outputs[key][idx].keys(): |
|
|
outputs[key][idx][sub_key] = outputs[key][idx][ |
|
|
sub_key |
|
|
].float() |
|
|
else: |
|
|
outputs[key] = outputs[key].float() |
|
|
|
|
|
loss_dict = criterion(outputs, targets) |
|
|
weight_dict = criterion.weight_dict |
|
|
|
|
|
|
|
|
loss_dict_reduced = utils.reduce_dict(loss_dict) |
|
|
loss_dict_reduced_scaled = { |
|
|
k: v * weight_dict[k] |
|
|
for k, v in loss_dict_reduced.items() |
|
|
if k in weight_dict |
|
|
} |
|
|
loss_dict_reduced_unscaled = { |
|
|
f"{k}_unscaled": v for k, v in loss_dict_reduced.items() |
|
|
} |
|
|
metric_logger.update( |
|
|
loss=sum(loss_dict_reduced_scaled.values()), |
|
|
**loss_dict_reduced_scaled, |
|
|
**loss_dict_reduced_unscaled, |
|
|
) |
|
|
metric_logger.update(class_error=loss_dict_reduced["class_error"]) |
|
|
|
|
|
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) |
|
|
results = postprocessors["bbox"](outputs, orig_target_sizes) |
|
|
res = { |
|
|
target["image_id"].item(): output |
|
|
for target, output in zip(targets, results) |
|
|
} |
|
|
if coco_evaluator is not None: |
|
|
coco_evaluator.update(res) |
|
|
|
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
|
print("Averaged stats:", metric_logger) |
|
|
if coco_evaluator is not None: |
|
|
coco_evaluator.synchronize_between_processes() |
|
|
|
|
|
|
|
|
if coco_evaluator is not None: |
|
|
coco_evaluator.accumulate() |
|
|
coco_evaluator.summarize() |
|
|
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
if coco_evaluator is not None: |
|
|
results_json = coco_extended_metrics(coco_evaluator.coco_eval["bbox"]) |
|
|
stats["results_json"] = results_json |
|
|
if "bbox" in postprocessors.keys(): |
|
|
stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist() |
|
|
|
|
|
if "segm" in postprocessors.keys(): |
|
|
stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist() |
|
|
return stats, coco_evaluator |