Spaces:
Sleeping
Sleeping
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| import logging | |
| import json | |
| import random | |
| import re | |
| import tarfile | |
| from subprocess import PIPE, Popen | |
| from urllib.parse import urlparse | |
| import torch | |
| import torchaudio | |
| import torchaudio.compliance.kaldi as kaldi | |
| from torch.nn.utils.rnn import pad_sequence | |
| AUDIO_FORMAT_SETS = set(["flac", "mp3", "m4a", "ogg", "opus", "wav", "wma"]) | |
| def url_opener(data): | |
| """Give url or local file, return file descriptor | |
| Inplace operation. | |
| Args: | |
| data(Iterable[str]): url or local file list | |
| Returns: | |
| Iterable[{src, stream}] | |
| """ | |
| for sample in data: | |
| assert "src" in sample | |
| # TODO(Binbin Zhang): support HTTP | |
| url = sample["src"] | |
| try: | |
| pr = urlparse(url) | |
| # local file | |
| if pr.scheme == "" or pr.scheme == "file": | |
| stream = open(url, "rb") | |
| # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP | |
| else: | |
| cmd = f"wget -q -O - {url}" | |
| process = Popen(cmd, shell=True, stdout=PIPE) | |
| sample.update(process=process) | |
| stream = process.stdout | |
| sample.update(stream=stream) | |
| yield sample | |
| except Exception as ex: | |
| logging.warning("Failed to open {}".format(url)) | |
| def tar_file_and_group(data): | |
| """Expand a stream of open tar files into a stream of tar file contents. | |
| And groups the file with same prefix | |
| Args: | |
| data: Iterable[{src, stream}] | |
| Returns: | |
| Iterable[{key, wav, txt, sample_rate}] | |
| """ | |
| for sample in data: | |
| assert "stream" in sample | |
| stream = tarfile.open(fileobj=sample["stream"], mode="r|*") | |
| prev_prefix = None | |
| example = {} | |
| valid = True | |
| for tarinfo in stream: | |
| name = tarinfo.name | |
| pos = name.rfind(".") | |
| assert pos > 0 | |
| prefix, postfix = name[:pos], name[pos + 1 :] | |
| if prev_prefix is not None and prefix != prev_prefix: | |
| example["key"] = prev_prefix | |
| if valid: | |
| yield example | |
| example = {} | |
| valid = True | |
| with stream.extractfile(tarinfo) as file_obj: | |
| try: | |
| if postfix == "txt": | |
| example["txt"] = file_obj.read().decode("utf8").strip() | |
| elif postfix in AUDIO_FORMAT_SETS: | |
| waveform, sample_rate = torchaudio.load(file_obj) | |
| example["wav"] = waveform | |
| example["sample_rate"] = sample_rate | |
| else: | |
| example[postfix] = file_obj.read() | |
| except Exception as ex: | |
| valid = False | |
| logging.warning("error to parse {}".format(name)) | |
| prev_prefix = prefix | |
| if prev_prefix is not None: | |
| example["key"] = prev_prefix | |
| yield example | |
| stream.close() | |
| if "process" in sample: | |
| sample["process"].communicate() | |
| sample["stream"].close() | |
| def parse_raw(data): | |
| """Parse key/wav/txt from json line | |
| Args: | |
| data: Iterable[str], str is a json line has key/wav/txt | |
| Returns: | |
| Iterable[{key, wav, txt, sample_rate}] | |
| """ | |
| for sample in data: | |
| assert "src" in sample | |
| json_line = sample["src"] | |
| obj = json.loads(json_line) | |
| assert "key" in obj | |
| assert "wav" in obj | |
| assert "txt" in obj | |
| key = obj["key"] | |
| wav_file = obj["wav"] | |
| txt = obj["txt"] | |
| try: | |
| if "start" in obj: | |
| assert "end" in obj | |
| sample_rate = torchaudio.backend.sox_io_backend.info( | |
| wav_file | |
| ).sample_rate | |
| start_frame = int(obj["start"] * sample_rate) | |
| end_frame = int(obj["end"] * sample_rate) | |
| waveform, _ = torchaudio.backend.sox_io_backend.load( | |
| filepath=wav_file, | |
| num_frames=end_frame - start_frame, | |
| frame_offset=start_frame, | |
| ) | |
| else: | |
| waveform, sample_rate = torchaudio.load(wav_file) | |
| example = dict(key=key, txt=txt, wav=waveform, sample_rate=sample_rate) | |
| yield example | |
| except Exception as ex: | |
| logging.warning("Failed to read {}".format(wav_file)) | |
| def filter( | |
| data, | |
| max_length=10240, | |
| min_length=10, | |
| token_max_length=200, | |
| token_min_length=1, | |
| min_output_input_ratio=0.0005, | |
| max_output_input_ratio=1, | |
| ): | |
| """Filter sample according to feature and label length | |
| Inplace operation. | |
| Args:: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| max_length: drop utterance which is greater than max_length(10ms) | |
| min_length: drop utterance which is less than min_length(10ms) | |
| token_max_length: drop utterance which is greater than | |
| token_max_length, especially when use char unit for | |
| english modeling | |
| token_min_length: drop utterance which is | |
| less than token_max_length | |
| min_output_input_ratio: minimal ration of | |
| token_length / feats_length(10ms) | |
| max_output_input_ratio: maximum ration of | |
| token_length / feats_length(10ms) | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| assert "sample_rate" in sample | |
| assert "wav" in sample | |
| assert "label" in sample | |
| # sample['wav'] is torch.Tensor, we have 100 frames every second | |
| num_frames = sample["wav"].size(1) / sample["sample_rate"] * 100 | |
| if num_frames < min_length: | |
| continue | |
| if num_frames > max_length: | |
| continue | |
| if len(sample["label"]) < token_min_length: | |
| continue | |
| if len(sample["label"]) > token_max_length: | |
| continue | |
| if num_frames != 0: | |
| if len(sample["label"]) / num_frames < min_output_input_ratio: | |
| continue | |
| if len(sample["label"]) / num_frames > max_output_input_ratio: | |
| continue | |
| yield sample | |
| def resample(data, resample_rate=16000): | |
| """Resample data. | |
| Inplace operation. | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| resample_rate: target resample rate | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| print("resample...") | |
| for sample in data: | |
| assert "sample_rate" in sample | |
| assert "wav" in sample | |
| sample_rate = sample["sample_rate"] | |
| print("sample_rate: ", sample_rate) | |
| print("resample_rate: ", resample_rate) | |
| waveform = sample["wav"] | |
| if sample_rate != resample_rate: | |
| sample["sample_rate"] = resample_rate | |
| sample["wav"] = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, new_freq=resample_rate | |
| )(waveform) | |
| yield sample | |
| def speed_perturb(data, speeds=None): | |
| """Apply speed perturb to the data. | |
| Inplace operation. | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| speeds(List[float]): optional speed | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| if speeds is None: | |
| speeds = [0.9, 1.0, 1.1] | |
| for sample in data: | |
| assert "sample_rate" in sample | |
| assert "wav" in sample | |
| sample_rate = sample["sample_rate"] | |
| waveform = sample["wav"] | |
| speed = random.choice(speeds) | |
| if speed != 1.0: | |
| wav, _ = torchaudio.sox_effects.apply_effects_tensor( | |
| waveform, | |
| sample_rate, | |
| [["speed", str(speed)], ["rate", str(sample_rate)]], | |
| ) | |
| sample["wav"] = wav | |
| yield sample | |
| def compute_fbank(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0): | |
| """Extract fbank | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| assert "sample_rate" in sample | |
| assert "wav" in sample | |
| assert "key" in sample | |
| assert "label" in sample | |
| sample_rate = sample["sample_rate"] | |
| waveform = sample["wav"] | |
| waveform = waveform * (1 << 15) | |
| # Only keep key, feat, label | |
| mat = kaldi.fbank( | |
| waveform, | |
| num_mel_bins=num_mel_bins, | |
| frame_length=frame_length, | |
| frame_shift=frame_shift, | |
| dither=dither, | |
| energy_floor=0.0, | |
| sample_frequency=sample_rate, | |
| ) | |
| yield dict(key=sample["key"], label=sample["label"], feat=mat) | |
| def compute_mfcc( | |
| data, | |
| num_mel_bins=23, | |
| frame_length=25, | |
| frame_shift=10, | |
| dither=0.0, | |
| num_ceps=40, | |
| high_freq=0.0, | |
| low_freq=20.0, | |
| ): | |
| """Extract mfcc | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| assert "sample_rate" in sample | |
| assert "wav" in sample | |
| assert "key" in sample | |
| assert "label" in sample | |
| sample_rate = sample["sample_rate"] | |
| waveform = sample["wav"] | |
| waveform = waveform * (1 << 15) | |
| # Only keep key, feat, label | |
| mat = kaldi.mfcc( | |
| waveform, | |
| num_mel_bins=num_mel_bins, | |
| frame_length=frame_length, | |
| frame_shift=frame_shift, | |
| dither=dither, | |
| num_ceps=num_ceps, | |
| high_freq=high_freq, | |
| low_freq=low_freq, | |
| sample_frequency=sample_rate, | |
| ) | |
| yield dict(key=sample["key"], label=sample["label"], feat=mat) | |
| def __tokenize_by_bpe_model(sp, txt): | |
| tokens = [] | |
| # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: | |
| # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |
| pattern = re.compile(r"([\u4e00-\u9fff])") | |
| # Example: | |
| # txt = "你好 ITS'S OKAY 的" | |
| # chars = ["你", "好", " ITS'S OKAY ", "的"] | |
| chars = pattern.split(txt.upper()) | |
| mix_chars = [w for w in chars if len(w.strip()) > 0] | |
| for ch_or_w in mix_chars: | |
| # ch_or_w is a single CJK charater(i.e., "你"), do nothing. | |
| if pattern.fullmatch(ch_or_w) is not None: | |
| tokens.append(ch_or_w) | |
| # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), | |
| # encode ch_or_w using bpe_model. | |
| else: | |
| for p in sp.encode_as_pieces(ch_or_w): | |
| tokens.append(p) | |
| return tokens | |
| def tokenize( | |
| data, symbol_table, bpe_model=None, non_lang_syms=None, split_with_space=False | |
| ): | |
| """Decode text to chars or BPE | |
| Inplace operation | |
| Args: | |
| data: Iterable[{key, wav, txt, sample_rate}] | |
| Returns: | |
| Iterable[{key, wav, txt, tokens, label, sample_rate}] | |
| """ | |
| if non_lang_syms is not None: | |
| non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") | |
| else: | |
| non_lang_syms = {} | |
| non_lang_syms_pattern = None | |
| if bpe_model is not None: | |
| import sentencepiece as spm | |
| sp = spm.SentencePieceProcessor() | |
| sp.load(bpe_model) | |
| else: | |
| sp = None | |
| for sample in data: | |
| assert "txt" in sample | |
| txt = sample["txt"].strip() | |
| if non_lang_syms_pattern is not None: | |
| parts = non_lang_syms_pattern.split(txt.upper()) | |
| parts = [w for w in parts if len(w.strip()) > 0] | |
| else: | |
| parts = [txt] | |
| label = [] | |
| tokens = [] | |
| for part in parts: | |
| if part in non_lang_syms: | |
| tokens.append(part) | |
| else: | |
| if bpe_model is not None: | |
| tokens.extend(__tokenize_by_bpe_model(sp, part)) | |
| else: | |
| if split_with_space: | |
| part = part.split(" ") | |
| for ch in part: | |
| if ch == " ": | |
| ch = "▁" | |
| tokens.append(ch) | |
| for ch in tokens: | |
| if ch in symbol_table: | |
| label.append(symbol_table[ch]) | |
| elif "<unk>" in symbol_table: | |
| label.append(symbol_table["<unk>"]) | |
| sample["tokens"] = tokens | |
| sample["label"] = label | |
| yield sample | |
| def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): | |
| """Do spec augmentation | |
| Inplace operation | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| num_t_mask: number of time mask to apply | |
| num_f_mask: number of freq mask to apply | |
| max_t: max width of time mask | |
| max_f: max width of freq mask | |
| max_w: max width of time warp | |
| Returns | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| assert "feat" in sample | |
| x = sample["feat"] | |
| assert isinstance(x, torch.Tensor) | |
| y = x.clone().detach() | |
| max_frames = y.size(0) | |
| max_freq = y.size(1) | |
| # time mask | |
| for i in range(num_t_mask): | |
| start = random.randint(0, max_frames - 1) | |
| length = random.randint(1, max_t) | |
| end = min(max_frames, start + length) | |
| y[start:end, :] = 0 | |
| # freq mask | |
| for i in range(num_f_mask): | |
| start = random.randint(0, max_freq - 1) | |
| length = random.randint(1, max_f) | |
| end = min(max_freq, start + length) | |
| y[:, start:end] = 0 | |
| sample["feat"] = y | |
| yield sample | |
| def spec_sub(data, max_t=20, num_t_sub=3): | |
| """Do spec substitute | |
| Inplace operation | |
| ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642] | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| max_t: max width of time substitute | |
| num_t_sub: number of time substitute to apply | |
| Returns | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| assert "feat" in sample | |
| x = sample["feat"] | |
| assert isinstance(x, torch.Tensor) | |
| y = x.clone().detach() | |
| max_frames = y.size(0) | |
| for i in range(num_t_sub): | |
| start = random.randint(0, max_frames - 1) | |
| length = random.randint(1, max_t) | |
| end = min(max_frames, start + length) | |
| # only substitute the earlier time chosen randomly for current time | |
| pos = random.randint(0, start) | |
| y[start:end, :] = x[start - pos : end - pos, :] | |
| sample["feat"] = y | |
| yield sample | |
| def spec_trim(data, max_t=20): | |
| """Trim tailing frames. Inplace operation. | |
| ref: TrimTail [https://arxiv.org/abs/2211.00522] | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| max_t: max width of length trimming | |
| Returns | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| assert "feat" in sample | |
| x = sample["feat"] | |
| assert isinstance(x, torch.Tensor) | |
| max_frames = x.size(0) | |
| length = random.randint(1, max_t) | |
| if length < max_frames / 2: | |
| y = x.clone().detach()[: max_frames - length] | |
| sample["feat"] = y | |
| yield sample | |
| def shuffle(data, shuffle_size=10000): | |
| """Local shuffle the data | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| shuffle_size: buffer size for shuffle | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= shuffle_size: | |
| random.shuffle(buf) | |
| for x in buf: | |
| yield x | |
| buf = [] | |
| # The sample left over | |
| random.shuffle(buf) | |
| for x in buf: | |
| yield x | |
| def sort(data, sort_size=500): | |
| """Sort the data by feature length. | |
| Sort is used after shuffle and before batch, so we can group | |
| utts with similar lengths into a batch, and `sort_size` should | |
| be less than `shuffle_size` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| sort_size: buffer size for sort | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= sort_size: | |
| buf.sort(key=lambda x: x["feat"].size(0)) | |
| for x in buf: | |
| yield x | |
| buf = [] | |
| # The sample left over | |
| buf.sort(key=lambda x: x["feat"].size(0)) | |
| for x in buf: | |
| yield x | |
| def static_batch(data, batch_size=16): | |
| """Static batch the data by `batch_size` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| batch_size: batch size | |
| Returns: | |
| Iterable[List[{key, feat, label}]] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= batch_size: | |
| yield buf | |
| buf = [] | |
| if len(buf) > 0: | |
| yield buf | |
| def dynamic_batch(data, max_frames_in_batch=12000): | |
| """Dynamic batch the data until the total frames in batch | |
| reach `max_frames_in_batch` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| max_frames_in_batch: max_frames in one batch | |
| Returns: | |
| Iterable[List[{key, feat, label}]] | |
| """ | |
| buf = [] | |
| longest_frames = 0 | |
| for sample in data: | |
| assert "feat" in sample | |
| assert isinstance(sample["feat"], torch.Tensor) | |
| new_sample_frames = sample["feat"].size(0) | |
| longest_frames = max(longest_frames, new_sample_frames) | |
| frames_after_padding = longest_frames * (len(buf) + 1) | |
| if frames_after_padding > max_frames_in_batch: | |
| yield buf | |
| buf = [sample] | |
| longest_frames = new_sample_frames | |
| else: | |
| buf.append(sample) | |
| if len(buf) > 0: | |
| yield buf | |
| def batch(data, batch_type="static", batch_size=16, max_frames_in_batch=12000): | |
| """Wrapper for static/dynamic batch""" | |
| if batch_type == "static": | |
| return static_batch(data, batch_size) | |
| elif batch_type == "dynamic": | |
| return dynamic_batch(data, max_frames_in_batch) | |
| else: | |
| logging.fatal("Unsupported batch type {}".format(batch_type)) | |
| def padding(data): | |
| """Padding the data into training data | |
| Args: | |
| data: Iterable[List[{key, feat, label}]] | |
| Returns: | |
| Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] | |
| """ | |
| for sample in data: | |
| assert isinstance(sample, list) | |
| feats_length = torch.tensor( | |
| [x["feat"].size(0) for x in sample], dtype=torch.int32 | |
| ) | |
| order = torch.argsort(feats_length, descending=True) | |
| feats_lengths = torch.tensor( | |
| [sample[i]["feat"].size(0) for i in order], dtype=torch.int32 | |
| ) | |
| sorted_feats = [sample[i]["feat"] for i in order] | |
| sorted_keys = [sample[i]["key"] for i in order] | |
| sorted_labels = [ | |
| torch.tensor(sample[i]["label"], dtype=torch.int64) for i in order | |
| ] | |
| label_lengths = torch.tensor( | |
| [x.size(0) for x in sorted_labels], dtype=torch.int32 | |
| ) | |
| padded_feats = pad_sequence(sorted_feats, batch_first=True, padding_value=0) | |
| padding_labels = pad_sequence(sorted_labels, batch_first=True, padding_value=-1) | |
| yield (sorted_keys, padded_feats, padding_labels, feats_lengths, label_lengths) | |