CU-1 / rfdetr /util /early_stopping.py
Matis Despujols
Upload 97 files
066effd verified
"""
Early stopping callback for RF-DETR training
"""
from logging import getLogger
logger = getLogger(__name__)
class EarlyStoppingCallback:
"""
Early stopping callback that monitors mAP and stops training if no improvement
over a threshold is observed for a specified number of epochs.
Args:
patience (int): Number of epochs with no improvement to wait before stopping
min_delta (float): Minimum change in mAP to qualify as improvement
use_ema (bool): Whether to use EMA model metrics for early stopping
verbose (bool): Whether to print early stopping messages
"""
def __init__(self, model, patience=5, min_delta=0.001, use_ema=False, verbose=True):
self.patience = patience
self.min_delta = min_delta
self.use_ema = use_ema
self.verbose = verbose
self.best_map = 0.0
self.counter = 0
self.model = model
def update(self, log_stats):
"""Update early stopping state based on epoch validation metrics"""
regular_map = None
ema_map = None
if 'test_coco_eval_bbox' in log_stats:
regular_map = log_stats['test_coco_eval_bbox'][0]
if 'ema_test_coco_eval_bbox' in log_stats:
ema_map = log_stats['ema_test_coco_eval_bbox'][0]
current_map = None
if regular_map is not None and ema_map is not None:
if self.use_ema:
current_map = ema_map
metric_source = "EMA"
else:
current_map = max(regular_map, ema_map)
metric_source = "max(regular, EMA)"
elif ema_map is not None:
current_map = ema_map
metric_source = "EMA"
elif regular_map is not None:
current_map = regular_map
metric_source = "regular"
else:
if self.verbose:
raise ValueError("No valid mAP metric found!")
return
if self.verbose:
print(f"Early stopping: Current mAP ({metric_source}): {current_map:.4f}, Best: {self.best_map:.4f}, Diff: {current_map - self.best_map:.4f}, Min delta: {self.min_delta}")
if current_map > self.best_map + self.min_delta:
self.best_map = current_map
self.counter = 0
logger.info(f"Early stopping: mAP improved to {current_map:.4f} using {metric_source} metric")
else:
self.counter += 1
if self.verbose:
print(f"Early stopping: No improvement in mAP for {self.counter} epochs (best: {self.best_map:.4f}, current: {current_map:.4f})")
if self.counter >= self.patience:
print(f"Early stopping triggered: No improvement above {self.min_delta} threshold for {self.patience} epochs")
if self.model:
self.model.request_early_stop()