| import logging | |
| from typing import Callable | |
| from typing import Collection | |
| from typing import Iterator | |
| import numpy as np | |
| from typeguard import check_argument_types | |
| from espnet2.iterators.abs_iter_factory import AbsIterFactory | |
| class MultipleIterFactory(AbsIterFactory): | |
| def __init__( | |
| self, | |
| build_funcs: Collection[Callable[[], AbsIterFactory]], | |
| seed: int = 0, | |
| shuffle: bool = False, | |
| ): | |
| assert check_argument_types() | |
| self.build_funcs = list(build_funcs) | |
| self.seed = seed | |
| self.shuffle = shuffle | |
| def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator: | |
| if shuffle is None: | |
| shuffle = self.shuffle | |
| build_funcs = list(self.build_funcs) | |
| if shuffle: | |
| np.random.RandomState(epoch + self.seed).shuffle(build_funcs) | |
| for i, build_func in enumerate(build_funcs): | |
| logging.info(f"Building {i}th iter-factory...") | |
| iter_factory = build_func() | |
| assert isinstance(iter_factory, AbsIterFactory), type(iter_factory) | |
| yield from iter_factory.build_iter(epoch, shuffle) | |