File size: 1,359 Bytes
066effd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# ------------------------------------------------------------------------
# LW-DETR
# Copyright (c) 2024 Baidu. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
"""util for drop scheduler."""
import numpy as np


def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode='standard', schedule='constant'):
    """drop scheduler"""
    assert mode in ['standard', 'early', 'late']
    if mode == 'standard':
        return np.full(epochs * niter_per_ep, drop_rate)
    
    early_iters = cutoff_epoch * niter_per_ep
    late_iters = (epochs - cutoff_epoch) * niter_per_ep
    
    if mode == 'early':
        assert schedule in ['constant', 'linear']
        if schedule == 'constant':
            early_schedule = np.full(early_iters, drop_rate)
        elif schedule == 'linear':
            early_schedule = np.linspace(drop_rate, 0, early_iters)
        final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0)))
    elif mode == 'late':
        assert schedule in ['constant']
        early_schedule = np.full(early_iters, 0)
        final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate)))
    
    assert len(final_schedule) == epochs * niter_per_ep
    return final_schedule