CU-1 / rfdetr /util /drop_scheduler.py
Matis Despujols
Upload 97 files
066effd verified
# ------------------------------------------------------------------------
# LW-DETR
# Copyright (c) 2024 Baidu. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
"""util for drop scheduler."""
import numpy as np
def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode='standard', schedule='constant'):
"""drop scheduler"""
assert mode in ['standard', 'early', 'late']
if mode == 'standard':
return np.full(epochs * niter_per_ep, drop_rate)
early_iters = cutoff_epoch * niter_per_ep
late_iters = (epochs - cutoff_epoch) * niter_per_ep
if mode == 'early':
assert schedule in ['constant', 'linear']
if schedule == 'constant':
early_schedule = np.full(early_iters, drop_rate)
elif schedule == 'linear':
early_schedule = np.linspace(drop_rate, 0, early_iters)
final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0)))
elif mode == 'late':
assert schedule in ['constant']
early_schedule = np.full(early_iters, 0)
final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate)))
assert len(final_schedule) == epochs * niter_per_ep
return final_schedule