File size: 2,938 Bytes
066effd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
"""
Early stopping callback for RF-DETR training
"""
from logging import getLogger
logger = getLogger(__name__)
class EarlyStoppingCallback:
"""
Early stopping callback that monitors mAP and stops training if no improvement
over a threshold is observed for a specified number of epochs.
Args:
patience (int): Number of epochs with no improvement to wait before stopping
min_delta (float): Minimum change in mAP to qualify as improvement
use_ema (bool): Whether to use EMA model metrics for early stopping
verbose (bool): Whether to print early stopping messages
"""
def __init__(self, model, patience=5, min_delta=0.001, use_ema=False, verbose=True):
self.patience = patience
self.min_delta = min_delta
self.use_ema = use_ema
self.verbose = verbose
self.best_map = 0.0
self.counter = 0
self.model = model
def update(self, log_stats):
"""Update early stopping state based on epoch validation metrics"""
regular_map = None
ema_map = None
if 'test_coco_eval_bbox' in log_stats:
regular_map = log_stats['test_coco_eval_bbox'][0]
if 'ema_test_coco_eval_bbox' in log_stats:
ema_map = log_stats['ema_test_coco_eval_bbox'][0]
current_map = None
if regular_map is not None and ema_map is not None:
if self.use_ema:
current_map = ema_map
metric_source = "EMA"
else:
current_map = max(regular_map, ema_map)
metric_source = "max(regular, EMA)"
elif ema_map is not None:
current_map = ema_map
metric_source = "EMA"
elif regular_map is not None:
current_map = regular_map
metric_source = "regular"
else:
if self.verbose:
raise ValueError("No valid mAP metric found!")
return
if self.verbose:
print(f"Early stopping: Current mAP ({metric_source}): {current_map:.4f}, Best: {self.best_map:.4f}, Diff: {current_map - self.best_map:.4f}, Min delta: {self.min_delta}")
if current_map > self.best_map + self.min_delta:
self.best_map = current_map
self.counter = 0
logger.info(f"Early stopping: mAP improved to {current_map:.4f} using {metric_source} metric")
else:
self.counter += 1
if self.verbose:
print(f"Early stopping: No improvement in mAP for {self.counter} epochs (best: {self.best_map:.4f}, current: {current_map:.4f})")
if self.counter >= self.patience:
print(f"Early stopping triggered: No improvement above {self.min_delta} threshold for {self.patience} epochs")
if self.model:
self.model.request_early_stop() |