|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""util for drop scheduler.""" |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode='standard', schedule='constant'): |
|
|
"""drop scheduler""" |
|
|
assert mode in ['standard', 'early', 'late'] |
|
|
if mode == 'standard': |
|
|
return np.full(epochs * niter_per_ep, drop_rate) |
|
|
|
|
|
early_iters = cutoff_epoch * niter_per_ep |
|
|
late_iters = (epochs - cutoff_epoch) * niter_per_ep |
|
|
|
|
|
if mode == 'early': |
|
|
assert schedule in ['constant', 'linear'] |
|
|
if schedule == 'constant': |
|
|
early_schedule = np.full(early_iters, drop_rate) |
|
|
elif schedule == 'linear': |
|
|
early_schedule = np.linspace(drop_rate, 0, early_iters) |
|
|
final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0))) |
|
|
elif mode == 'late': |
|
|
assert schedule in ['constant'] |
|
|
early_schedule = np.full(early_iters, 0) |
|
|
final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate))) |
|
|
|
|
|
assert len(final_schedule) == epochs * niter_per_ep |
|
|
return final_schedule |
|
|
|