Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Tuple | |
| import numpy as np | |
| import torch | |
| import torchaudio.functional as F | |
| def remove_duplicates_and_blank(hyp: List[int], | |
| blank_id: int = 0) -> List[int]: | |
| new_hyp: List[int] = [] | |
| cur = 0 | |
| while cur < len(hyp): | |
| if hyp[cur] != blank_id: | |
| new_hyp.append(hyp[cur]) | |
| prev = cur | |
| while cur < len(hyp) and hyp[cur] == hyp[prev]: | |
| cur += 1 | |
| return new_hyp | |
| def replace_duplicates_with_blank(hyp: List[int], | |
| blank_id: int = 0) -> List[int]: | |
| new_hyp: List[int] = [] | |
| cur = 0 | |
| while cur < len(hyp): | |
| new_hyp.append(hyp[cur]) | |
| prev = cur | |
| cur += 1 | |
| while cur < len( | |
| hyp) and hyp[cur] == hyp[prev] and hyp[cur] != blank_id: | |
| new_hyp.append(blank_id) | |
| cur += 1 | |
| return new_hyp | |
| def gen_ctc_peak_time(hyp: List[int], blank_id: int = 0) -> List[int]: | |
| times = [] | |
| cur = 0 | |
| while cur < len(hyp): | |
| if hyp[cur] != blank_id: | |
| times.append(cur) | |
| prev = cur | |
| while cur < len(hyp) and hyp[cur] == hyp[prev]: | |
| cur += 1 | |
| return times | |
| def gen_timestamps_from_peak( | |
| peaks: List[int], | |
| max_duration: float, | |
| frame_rate: float = 0.04, | |
| max_token_duration: float = 1.0, | |
| ) -> List[Tuple[float, float]]: | |
| """ | |
| Args: | |
| peaks: ctc peaks time stamp | |
| max_duration: max_duration of the sentence | |
| frame_rate: frame rate of every time stamp, in seconds | |
| max_token_duration: max duration of the token, in seconds | |
| Returns: | |
| list(start, end) of each token | |
| """ | |
| times = [] | |
| half_max = max_token_duration / 2 | |
| for i in range(len(peaks)): | |
| if i == 0: | |
| start = max(0, peaks[0] * frame_rate - half_max) | |
| else: | |
| start = max((peaks[i - 1] + peaks[i]) / 2 * frame_rate, | |
| peaks[i] * frame_rate - half_max) | |
| if i == len(peaks) - 1: | |
| end = min(max_duration, peaks[-1] * frame_rate + half_max) | |
| else: | |
| end = min((peaks[i] + peaks[i + 1]) / 2 * frame_rate, | |
| peaks[i] * frame_rate + half_max) | |
| times.append((start, end)) | |
| return times | |
| def insert_blank(label, blank_id=0): | |
| """Insert blank token between every two label token.""" | |
| label = np.expand_dims(label, 1) | |
| blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id | |
| label = np.concatenate([blanks, label], axis=1) | |
| label = label.reshape(-1) | |
| label = np.append(label, label[0]) | |
| return label | |
| def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list: | |
| """ctc forced alignment. | |
| Args: | |
| torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) | |
| torch.Tensor y: id sequence tensor 1d tensor (L) | |
| int blank_id: blank symbol index | |
| Returns: | |
| torch.Tensor: alignment result | |
| """ | |
| ctc_probs = ctc_probs[None].cpu() | |
| y = y[None].cpu() | |
| alignments, _ = F.forced_align(ctc_probs, y, blank=blank_id) | |
| return alignments[0] | |
| def get_blank_id(configs, symbol_table): | |
| if 'ctc_conf' not in configs: | |
| configs['ctc_conf'] = {} | |
| if '<blank>' in symbol_table: | |
| if 'ctc_blank_id' in configs['ctc_conf']: | |
| assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[ | |
| '<blank>'] | |
| else: | |
| configs['ctc_conf']['ctc_blank_id'] = symbol_table['<blank>'] | |
| else: | |
| assert 'ctc_blank_id' in configs[ | |
| 'ctc_conf'], "PLZ set ctc_blank_id in yaml" | |
| return configs, configs['ctc_conf']['ctc_blank_id'] | |