|
|
import torch
|
|
|
import numpy as np
|
|
|
import yaml
|
|
|
import os
|
|
|
from torch.utils.data import Sampler
|
|
|
|
|
|
|
|
|
def load_yaml_with_includes(yaml_file):
|
|
|
def loader_with_include(loader, node):
|
|
|
|
|
|
include_path = os.path.join(os.path.dirname(yaml_file), loader.construct_scalar(node))
|
|
|
with open(include_path, 'r') as f:
|
|
|
return yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
|
|
yaml.add_constructor('!include', loader_with_include, Loader=yaml.FullLoader)
|
|
|
|
|
|
with open(yaml_file, 'r') as f:
|
|
|
return yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
|
|
|
|
|
def customized_lr_scheduler(optimizer, warmup_steps=10000, decay_steps=1e6, end_factor=1e-4):
|
|
|
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
|
|
warmup_scheduler = LinearLR(optimizer,
|
|
|
start_factor=min(1 / warmup_steps, 1),
|
|
|
end_factor=1.0, total_iters=warmup_steps)
|
|
|
|
|
|
decay_scheduler = LinearLR(optimizer,
|
|
|
start_factor=1.0,
|
|
|
end_factor=end_factor,
|
|
|
total_iters=decay_steps)
|
|
|
|
|
|
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, decay_scheduler],
|
|
|
milestones=[warmup_steps])
|
|
|
return scheduler
|
|
|
|
|
|
|
|
|
def get_lr_scheduler(optimizer, name, **kwargs):
|
|
|
if name == 'customized':
|
|
|
return customized_lr_scheduler(optimizer, **kwargs)
|
|
|
elif name == 'cosine':
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
|
return CosineAnnealingLR(optimizer, **kwargs)
|
|
|
else:
|
|
|
raise NotImplementedError(name)
|
|
|
|
|
|
|
|
|
class ConcatDatasetBatchSampler(Sampler):
|
|
|
def __init__(self, samplers, batch_sizes, epoch=0):
|
|
|
self.batch_sizes = batch_sizes
|
|
|
self.samplers = samplers
|
|
|
self.offsets = [0] + np.cumsum([len(x) for x in self.samplers]).tolist()[:-1]
|
|
|
|
|
|
self.epoch = epoch
|
|
|
self.set_epoch(self.epoch)
|
|
|
|
|
|
def _iter_one_dataset(self, c_batch_size, c_sampler, c_offset):
|
|
|
batch = []
|
|
|
for idx in c_sampler:
|
|
|
batch.append(c_offset + idx)
|
|
|
if len(batch) == c_batch_size:
|
|
|
yield batch
|
|
|
|
|
|
def set_epoch(self, epoch):
|
|
|
if hasattr(self.samplers[0], "epoch"):
|
|
|
for s in self.samplers:
|
|
|
s.set_epoch(epoch)
|
|
|
|
|
|
def __iter__(self):
|
|
|
iterators = [iter(i) for i in self.samplers]
|
|
|
tot_batch = []
|
|
|
for b_num in range(len(self)):
|
|
|
for samp_idx in range(len(self.samplers)):
|
|
|
c_batch = []
|
|
|
while len(c_batch) < self.batch_sizes[samp_idx]:
|
|
|
c_batch.append(self.offsets[samp_idx] + next(iterators[samp_idx]))
|
|
|
tot_batch.extend(c_batch)
|
|
|
yield tot_batch
|
|
|
tot_batch = []
|
|
|
|
|
|
def __len__(self):
|
|
|
min_len = float("inf")
|
|
|
for idx, sampler in enumerate(self.samplers):
|
|
|
c_len = (len(sampler)) // self.batch_sizes[idx]
|
|
|
min_len = min(c_len, min_len)
|
|
|
return min_len |