| from typing import List | |
| from typing import Sequence | |
| from typing import Tuple | |
| from typing import Union | |
| from typeguard import check_argument_types | |
| from typeguard import check_return_type | |
| from espnet2.samplers.abs_sampler import AbsSampler | |
| from espnet2.samplers.folded_batch_sampler import FoldedBatchSampler | |
| from espnet2.samplers.length_batch_sampler import LengthBatchSampler | |
| from espnet2.samplers.num_elements_batch_sampler import NumElementsBatchSampler | |
| from espnet2.samplers.sorted_batch_sampler import SortedBatchSampler | |
| from espnet2.samplers.unsorted_batch_sampler import UnsortedBatchSampler | |
| BATCH_TYPES = dict( | |
| unsorted="UnsortedBatchSampler has nothing in paticular feature and " | |
| "just creates mini-batches which has constant batch_size. " | |
| "This sampler doesn't require any length " | |
| "information for each feature. " | |
| "'key_file' is just a text file which describes each sample name." | |
| "\n\n" | |
| " utterance_id_a\n" | |
| " utterance_id_b\n" | |
| " utterance_id_c\n" | |
| "\n" | |
| "The fist column is referred, so 'shape file' can be used, too.\n\n" | |
| " utterance_id_a 100,80\n" | |
| " utterance_id_b 400,80\n" | |
| " utterance_id_c 512,80\n", | |
| sorted="SortedBatchSampler sorts samples by the length of the first input " | |
| " in order to make each sample in a mini-batch has close length. " | |
| "This sampler requires a text file which describes the length for each sample " | |
| "\n\n" | |
| " utterance_id_a 1000\n" | |
| " utterance_id_b 1453\n" | |
| " utterance_id_c 1241\n" | |
| "\n" | |
| "The first element of feature dimensions is referred, " | |
| "so 'shape_file' can be also used.\n\n" | |
| " utterance_id_a 1000,80\n" | |
| " utterance_id_b 1453,80\n" | |
| " utterance_id_c 1241,80\n", | |
| folded="FoldedBatchSampler supports variable batch_size. " | |
| "The batch_size is decided by\n" | |
| " batch_size = base_batch_size // (L // fold_length)\n" | |
| "L is referred to the largest length of samples in the mini-batch. " | |
| "This samples requires length information as same as SortedBatchSampler\n", | |
| length="LengthBatchSampler supports variable batch_size. " | |
| "This sampler makes mini-batches which have same number of 'bins' as possible " | |
| "counting by the total lengths of each feature in the mini-batch. " | |
| "This sampler requires a text file which describes the length for each sample. " | |
| "\n\n" | |
| " utterance_id_a 1000\n" | |
| " utterance_id_b 1453\n" | |
| " utterance_id_c 1241\n" | |
| "\n" | |
| "The first element of feature dimensions is referred, " | |
| "so 'shape_file' can be also used.\n\n" | |
| " utterance_id_a 1000,80\n" | |
| " utterance_id_b 1453,80\n" | |
| " utterance_id_c 1241,80\n", | |
| numel="NumElementsBatchSampler supports variable batch_size. " | |
| "Just like LengthBatchSampler, this sampler makes mini-batches" | |
| " which have same number of 'bins' as possible " | |
| "counting by the total number of elements of each feature " | |
| "instead of the length. " | |
| "Thus this sampler requires the full information of the dimension of the features. " | |
| "\n\n" | |
| " utterance_id_a 1000,80\n" | |
| " utterance_id_b 1453,80\n" | |
| " utterance_id_c 1241,80\n", | |
| ) | |
| def build_batch_sampler( | |
| type: str, | |
| batch_size: int, | |
| batch_bins: int, | |
| shape_files: Union[Tuple[str, ...], List[str]], | |
| sort_in_batch: str = "descending", | |
| sort_batch: str = "ascending", | |
| drop_last: bool = False, | |
| min_batch_size: int = 1, | |
| fold_lengths: Sequence[int] = (), | |
| padding: bool = True, | |
| utt2category_file: str = None, | |
| ) -> AbsSampler: | |
| """Helper function to instantiate BatchSampler. | |
| Args: | |
| type: mini-batch type. "unsorted", "sorted", "folded", "numel", or, "length" | |
| batch_size: The mini-batch size. Used for "unsorted", "sorted", "folded" mode | |
| batch_bins: Used for "numel" model | |
| shape_files: Text files describing the length and dimension | |
| of each features. e.g. uttA 1330,80 | |
| sort_in_batch: | |
| sort_batch: | |
| drop_last: | |
| min_batch_size: Used for "numel" or "folded" mode | |
| fold_lengths: Used for "folded" mode | |
| padding: Whether sequences are input as a padded tensor or not. | |
| used for "numel" mode | |
| """ | |
| assert check_argument_types() | |
| if len(shape_files) == 0: | |
| raise ValueError("No shape file are given") | |
| if type == "unsorted": | |
| retval = UnsortedBatchSampler( | |
| batch_size=batch_size, key_file=shape_files[0], drop_last=drop_last | |
| ) | |
| elif type == "sorted": | |
| retval = SortedBatchSampler( | |
| batch_size=batch_size, | |
| shape_file=shape_files[0], | |
| sort_in_batch=sort_in_batch, | |
| sort_batch=sort_batch, | |
| drop_last=drop_last, | |
| ) | |
| elif type == "folded": | |
| if len(fold_lengths) != len(shape_files): | |
| raise ValueError( | |
| f"The number of fold_lengths must be equal to " | |
| f"the number of shape_files: " | |
| f"{len(fold_lengths)} != {len(shape_files)}" | |
| ) | |
| retval = FoldedBatchSampler( | |
| batch_size=batch_size, | |
| shape_files=shape_files, | |
| fold_lengths=fold_lengths, | |
| sort_in_batch=sort_in_batch, | |
| sort_batch=sort_batch, | |
| drop_last=drop_last, | |
| min_batch_size=min_batch_size, | |
| utt2category_file=utt2category_file, | |
| ) | |
| elif type == "numel": | |
| retval = NumElementsBatchSampler( | |
| batch_bins=batch_bins, | |
| shape_files=shape_files, | |
| sort_in_batch=sort_in_batch, | |
| sort_batch=sort_batch, | |
| drop_last=drop_last, | |
| padding=padding, | |
| min_batch_size=min_batch_size, | |
| ) | |
| elif type == "length": | |
| retval = LengthBatchSampler( | |
| batch_bins=batch_bins, | |
| shape_files=shape_files, | |
| sort_in_batch=sort_in_batch, | |
| sort_batch=sort_batch, | |
| drop_last=drop_last, | |
| padding=padding, | |
| min_batch_size=min_batch_size, | |
| ) | |
| else: | |
| raise ValueError(f"Not supported: {type}") | |
| assert check_return_type(retval) | |
| return retval | |