Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from collections import defaultdict | |
| from typing import List, Dict | |
| import torch | |
| from torch.nn.utils.rnn import pad_sequence | |
| from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens, | |
| mask_to_bias) | |
| from wenet.utils.ctc_utils import remove_duplicates_and_blank | |
| from wenet.utils.mask import (make_pad_mask, mask_finished_preds, | |
| mask_finished_scores, subsequent_mask) | |
| from wenet.utils.context_graph import ContextGraph, ContextState | |
| class DecodeResult: | |
| def __init__(self, | |
| tokens: List[int], | |
| score: float = 0.0, | |
| confidence: float = 0.0, | |
| tokens_confidence: List[float] = None, | |
| times: List[int] = None, | |
| nbest: List[List[int]] = None, | |
| nbest_scores: List[float] = None, | |
| nbest_times: List[List[int]] = None): | |
| """ | |
| Args: | |
| tokens: decode token list | |
| score: the total decode score of this result | |
| confidence: the total confidence of this result, it's in 0~1 | |
| tokens_confidence: confidence of each token | |
| times: timestamp of each token, list of (start, end) | |
| nbest: nbest result | |
| nbest_scores: score of each nbest | |
| nbest_times: | |
| """ | |
| self.tokens = tokens | |
| self.score = score | |
| self.confidence = confidence | |
| self.tokens_confidence = tokens_confidence | |
| self.times = times | |
| self.nbest = nbest | |
| self.nbest_scores = nbest_scores | |
| self.nbest_times = nbest_times | |
| class PrefixScore: | |
| """ For CTC prefix beam search """ | |
| def __init__(self, | |
| s: float = float('-inf'), | |
| ns: float = float('-inf'), | |
| v_s: float = float('-inf'), | |
| v_ns: float = float('-inf'), | |
| context_state: ContextState = None, | |
| context_score: float = 0.0): | |
| self.s = s # blank_ending_score | |
| self.ns = ns # none_blank_ending_score | |
| self.v_s = v_s # viterbi blank ending score | |
| self.v_ns = v_ns # viterbi none blank ending score | |
| self.cur_token_prob = float('-inf') # prob of current token | |
| self.times_s = [] # times of viterbi blank path | |
| self.times_ns = [] # times of viterbi none blank path | |
| self.context_state = context_state | |
| self.context_score = context_score | |
| self.has_context = False | |
| def score(self): | |
| return log_add(self.s, self.ns) | |
| def viterbi_score(self): | |
| return self.v_s if self.v_s > self.v_ns else self.v_ns | |
| def times(self): | |
| return self.times_s if self.v_s > self.v_ns else self.times_ns | |
| def total_score(self): | |
| return self.score() + self.context_score | |
| def copy_context(self, prefix_score): | |
| self.context_score = prefix_score.context_score | |
| self.context_state = prefix_score.context_state | |
| def update_context(self, context_graph, prefix_score, word_id): | |
| self.copy_context(prefix_score) | |
| (score, context_state) = context_graph.forward_one_step( | |
| prefix_score.context_state, word_id) | |
| self.context_score += score | |
| self.context_state = context_state | |
| def ctc_greedy_search(ctc_probs: torch.Tensor, | |
| ctc_lens: torch.Tensor, | |
| blank_id: int = 0) -> List[DecodeResult]: | |
| batch_size = ctc_probs.shape[0] | |
| maxlen = ctc_probs.size(1) | |
| topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) | |
| topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) | |
| mask = make_pad_mask(ctc_lens, maxlen) # (B, maxlen) | |
| topk_index = topk_index.masked_fill_(mask, blank_id) # (B, maxlen) | |
| hyps = [hyp.tolist() for hyp in topk_index] | |
| scores = topk_prob.max(1) | |
| results = [] | |
| for hyp in hyps: | |
| r = DecodeResult(remove_duplicates_and_blank(hyp, blank_id)) | |
| results.append(r) | |
| return results | |
| def ctc_prefix_beam_search( | |
| ctc_probs: torch.Tensor, | |
| ctc_lens: torch.Tensor, | |
| beam_size: int, | |
| context_graph: ContextGraph = None, | |
| blank_id: int = 0, | |
| ) -> List[DecodeResult]: | |
| """ | |
| Returns: | |
| List[List[List[int]]]: nbest result for each utterance | |
| """ | |
| batch_size = ctc_probs.shape[0] | |
| results = [] | |
| # CTC prefix beam search can not be paralleled, so search one by one | |
| for i in range(batch_size): | |
| ctc_prob = ctc_probs[i] | |
| num_t = ctc_lens[i] | |
| cur_hyps = [(tuple(), | |
| PrefixScore(s=0.0, | |
| ns=-float('inf'), | |
| v_s=0.0, | |
| v_ns=0.0, | |
| context_state=None if context_graph is None | |
| else context_graph.root, | |
| context_score=0.0))] | |
| # 2. CTC beam search step by step | |
| for t in range(0, num_t): | |
| logp = ctc_prob[t] # (vocab_size,) | |
| # key: prefix, value: PrefixScore | |
| next_hyps = defaultdict(lambda: PrefixScore()) | |
| # 2.1 First beam prune: select topk best | |
| top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) | |
| for u in top_k_index: | |
| u = u.item() | |
| prob = logp[u].item() | |
| for prefix, prefix_score in cur_hyps: | |
| last = prefix[-1] if len(prefix) > 0 else None | |
| if u == blank_id: # blank | |
| next_score = next_hyps[prefix] | |
| next_score.s = log_add(next_score.s, | |
| prefix_score.score() + prob) | |
| next_score.v_s = prefix_score.viterbi_score() + prob | |
| next_score.times_s = prefix_score.times().copy() | |
| # perfix not changed, copy the context from prefix | |
| if context_graph and not next_score.has_context: | |
| next_score.copy_context(prefix_score) | |
| next_score.has_context = True | |
| elif u == last: | |
| # Update *uu -> *u; | |
| next_score1 = next_hyps[prefix] | |
| next_score1.ns = log_add(next_score1.ns, | |
| prefix_score.ns + prob) | |
| if next_score1.v_ns < prefix_score.v_ns + prob: | |
| next_score1.v_ns = prefix_score.v_ns + prob | |
| if next_score1.cur_token_prob < prob: | |
| next_score1.cur_token_prob = prob | |
| next_score1.times_ns = prefix_score.times_ns.copy( | |
| ) | |
| next_score1.times_ns[-1] = t | |
| if context_graph and not next_score1.has_context: | |
| next_score1.copy_context(prefix_score) | |
| next_score1.has_context = True | |
| # Update *u-u -> *uu, - is for blank | |
| n_prefix = prefix + (u, ) | |
| next_score2 = next_hyps[n_prefix] | |
| next_score2.ns = log_add(next_score2.ns, | |
| prefix_score.s + prob) | |
| if next_score2.v_ns < prefix_score.v_s + prob: | |
| next_score2.v_ns = prefix_score.v_s + prob | |
| next_score2.cur_token_prob = prob | |
| next_score2.times_ns = prefix_score.times_s.copy() | |
| next_score2.times_ns.append(t) | |
| if context_graph and not next_score2.has_context: | |
| next_score2.update_context(context_graph, | |
| prefix_score, u) | |
| next_score2.has_context = True | |
| else: | |
| n_prefix = prefix + (u, ) | |
| next_score = next_hyps[n_prefix] | |
| next_score.ns = log_add(next_score.ns, | |
| prefix_score.score() + prob) | |
| if next_score.v_ns < prefix_score.viterbi_score( | |
| ) + prob: | |
| next_score.v_ns = prefix_score.viterbi_score( | |
| ) + prob | |
| next_score.cur_token_prob = prob | |
| next_score.times_ns = prefix_score.times().copy() | |
| next_score.times_ns.append(t) | |
| if context_graph and not next_score.has_context: | |
| next_score.update_context(context_graph, | |
| prefix_score, u) | |
| next_score.has_context = True | |
| # 2.2 Second beam prune | |
| next_hyps = sorted(next_hyps.items(), | |
| key=lambda x: x[1].total_score(), | |
| reverse=True) | |
| cur_hyps = next_hyps[:beam_size] | |
| # We should backoff the context score/state when the context is | |
| # not fully matched at the last time. | |
| if context_graph is not None: | |
| for i, hyp in enumerate(cur_hyps): | |
| context_score, new_context_state = context_graph.finalize( | |
| hyp[1].context_state) | |
| cur_hyps[i][1].context_score = context_score | |
| cur_hyps[i][1].context_state = new_context_state | |
| nbest = [y[0] for y in cur_hyps] | |
| nbest_scores = [y[1].total_score() for y in cur_hyps] | |
| nbest_times = [y[1].times() for y in cur_hyps] | |
| best = nbest[0] | |
| best_score = nbest_scores[0] | |
| best_time = nbest_times[0] | |
| results.append( | |
| DecodeResult(tokens=best, | |
| score=best_score, | |
| times=best_time, | |
| nbest=nbest, | |
| nbest_scores=nbest_scores, | |
| nbest_times=nbest_times)) | |
| return results | |
| def attention_beam_search( | |
| model, | |
| encoder_out: torch.Tensor, | |
| encoder_mask: torch.Tensor, | |
| beam_size: int = 10, | |
| length_penalty: float = 0.0, | |
| infos: Dict[str, List[str]] = None, | |
| ) -> List[DecodeResult]: | |
| device = encoder_out.device | |
| batch_size = encoder_out.shape[0] | |
| # Let's assume B = batch_size and N = beam_size | |
| # 1. Encoder | |
| maxlen = encoder_out.size(1) | |
| encoder_dim = encoder_out.size(2) | |
| running_size = batch_size * beam_size | |
| if getattr(model, 'special_tokens', None) is not None \ | |
| and "transcribe" in model.special_tokens: | |
| tasks, langs = infos["tasks"], infos["langs"] | |
| tasks = [t for t in tasks for _ in range(beam_size)] | |
| langs = [l for l in langs for _ in range(beam_size)] | |
| hyps = torch.ones([running_size, 0], dtype=torch.long, | |
| device=device) # (B*N, 0) | |
| hyps, _ = add_whisper_tokens(model.special_tokens, | |
| hyps, | |
| model.ignore_id, | |
| tasks=tasks, | |
| no_timestamp=True, | |
| langs=langs, | |
| use_prev=False) | |
| else: | |
| hyps = torch.ones([running_size, 1], dtype=torch.long, | |
| device=device).fill_(model.sos) # (B*N, 1) | |
| prefix_len = hyps.size(1) | |
| scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1), | |
| dtype=torch.float) | |
| scores = scores.to(device).repeat([batch_size | |
| ]).unsqueeze(1).to(device) # (B*N, 1) | |
| end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) | |
| cache = { | |
| 'self_att_cache': {}, | |
| 'cross_att_cache': {}, | |
| } | |
| if model.decoder.use_sdpa: | |
| encoder_mask = mask_to_bias(encoder_mask, encoder_out.dtype) | |
| if hasattr(model, 'decode_maxlen'): | |
| maxlen = model.decode_maxlen | |
| # 2. Decoder forward step by step | |
| for i in range(prefix_len, maxlen + 1): | |
| # Stop if all batch and all beam produce eos | |
| if end_flag.sum() == running_size: | |
| break | |
| # 2.1 Forward decoder step | |
| hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( | |
| running_size, 1, 1).to(device) # (B*N, i, i) | |
| if model.decoder.use_sdpa: | |
| hyps_mask = mask_to_bias(hyps_mask, encoder_out.dtype) | |
| # logp: (B*N, vocab) | |
| logp = model.decoder.forward_one_step(encoder_out, encoder_mask, hyps, | |
| hyps_mask, cache) | |
| # 2.2 First beam prune: select topk best prob at current time | |
| top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) | |
| top_k_logp = mask_finished_scores(top_k_logp, end_flag) | |
| top_k_index = mask_finished_preds(top_k_index, end_flag, model.eos) | |
| # 2.3 Second beam prune: select topk score with history | |
| scores = scores + top_k_logp # (B*N, N), broadcast add | |
| scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) | |
| scores, offset_k_index = scores.topk(k=beam_size) # (B, N) | |
| # Update cache to be consistent with new topk scores / hyps | |
| cache_index = (offset_k_index // beam_size).view(-1) # (B*N) | |
| base_cache_index = (torch.arange(batch_size, device=device).view( | |
| -1, 1).repeat([1, beam_size]) * beam_size).view(-1) # (B*N) | |
| cache_index = base_cache_index + cache_index | |
| cache['self_att_cache'] = { | |
| i_layer: (torch.index_select(value[0], dim=0, index=cache_index), | |
| torch.index_select(value[1], dim=0, index=cache_index)) | |
| for (i_layer, value) in cache['self_att_cache'].items() | |
| } | |
| # NOTE(Mddct): we don't need select cross att here | |
| torch.cuda.empty_cache() | |
| scores = scores.view(-1, 1) # (B*N, 1) | |
| # 2.4. Compute base index in top_k_index, | |
| # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), | |
| # then find offset_k_index in top_k_index | |
| base_k_index = torch.arange(batch_size, device=device).view( | |
| -1, 1).repeat([1, beam_size]) # (B, N) | |
| base_k_index = base_k_index * beam_size * beam_size | |
| best_k_index = base_k_index.view(-1) + offset_k_index.view(-1) # (B*N) | |
| # 2.5 Update best hyps | |
| best_k_pred = torch.index_select(top_k_index.view(-1), | |
| dim=-1, | |
| index=best_k_index) # (B*N) | |
| best_hyps_index = best_k_index // beam_size | |
| last_best_k_hyps = torch.index_select( | |
| hyps, dim=0, index=best_hyps_index) # (B*N, i) | |
| hyps = torch.cat((last_best_k_hyps, best_k_pred.view(-1, 1)), | |
| dim=1) # (B*N, i+1) | |
| # 2.6 Update end flag | |
| end_flag = torch.eq(hyps[:, -1], model.eos).view(-1, 1) | |
| # 3. Select best of best | |
| scores = scores.view(batch_size, beam_size) | |
| lengths = hyps.ne(model.eos).sum(dim=1).view(batch_size, beam_size).float() | |
| scores = scores / lengths.pow(length_penalty) | |
| best_scores, best_index = scores.max(dim=-1) | |
| best_hyps_index = best_index + torch.arange( | |
| batch_size, dtype=torch.long, device=device) * beam_size | |
| best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) | |
| best_hyps = best_hyps[:, prefix_len:] | |
| results = [] | |
| for i in range(batch_size): | |
| hyp = best_hyps[i] | |
| hyp = hyp[hyp != model.eos] | |
| results.append(DecodeResult(hyp.tolist())) | |
| return results | |
| def attention_rescoring( | |
| model, | |
| ctc_prefix_results: List[DecodeResult], | |
| encoder_outs: torch.Tensor, | |
| encoder_lens: torch.Tensor, | |
| ctc_weight: float = 0.0, | |
| reverse_weight: float = 0.0, | |
| infos: Dict[str, List[str]] = None, | |
| ) -> List[DecodeResult]: | |
| """ | |
| Args: | |
| ctc_prefix_results(List[DecodeResult]): ctc prefix beam search results | |
| """ | |
| sos, eos = model.sos_symbol(), model.eos_symbol() | |
| device = encoder_outs.device | |
| assert encoder_outs.shape[0] == len(ctc_prefix_results) | |
| batch_size = encoder_outs.shape[0] | |
| results = [] | |
| for b in range(batch_size): | |
| encoder_out = encoder_outs[b, :encoder_lens[b], :].unsqueeze(0) | |
| hyps = ctc_prefix_results[b].nbest | |
| ctc_scores = ctc_prefix_results[b].nbest_scores | |
| hyps_pad = pad_sequence([ | |
| torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps | |
| ], True, model.ignore_id) # (beam_size, max_hyps_len) | |
| hyps_lens = torch.tensor([len(hyp) for hyp in hyps], | |
| device=device, | |
| dtype=torch.long) # (beam_size,) | |
| if getattr(model, 'special_tokens', None) is not None \ | |
| and "transcribe" in model.special_tokens: | |
| prev_len = hyps_pad.size(1) | |
| hyps_pad, _ = add_whisper_tokens( | |
| model.special_tokens, | |
| hyps_pad, | |
| model.ignore_id, | |
| tasks=[infos["tasks"][b]] * len(hyps), | |
| no_timestamp=True, | |
| langs=[infos["langs"][b]] * len(hyps), | |
| use_prev=False) | |
| cur_len = hyps_pad.size(1) | |
| hyps_lens = hyps_lens + cur_len - prev_len | |
| prefix_len = 4 | |
| else: | |
| hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id) | |
| hyps_lens = hyps_lens + 1 # Add <sos> at begining | |
| prefix_len = 1 | |
| decoder_out, r_decoder_out = model.forward_attention_decoder( | |
| hyps_pad, hyps_lens, encoder_out, reverse_weight) | |
| # Only use decoder score for rescoring | |
| best_score = -float('inf') | |
| best_index = 0 | |
| confidences = [] | |
| tokens_confidences = [] | |
| for i, hyp in enumerate(hyps): | |
| score = 0.0 | |
| tc = [] # tokens confidences | |
| for j, w in enumerate(hyp): | |
| s = decoder_out[i][j + (prefix_len - 1)][w] | |
| score += s | |
| tc.append(math.exp(s)) | |
| score += decoder_out[i][len(hyp) + (prefix_len - 1)][eos] | |
| # add right to left decoder score | |
| if reverse_weight > 0 and r_decoder_out.dim() > 0: | |
| r_score = 0.0 | |
| for j, w in enumerate(hyp): | |
| s = r_decoder_out[i][len(hyp) - j - 1 + | |
| (prefix_len - 1)][w] | |
| r_score += s | |
| tc[j] = (tc[j] + math.exp(s)) / 2 | |
| r_score += r_decoder_out[i][len(hyp) + (prefix_len - 1)][eos] | |
| score = score * (1 - reverse_weight) + r_score * reverse_weight | |
| confidences.append(math.exp(score / (len(hyp) + 1))) | |
| # add ctc score | |
| score += ctc_scores[i] * ctc_weight | |
| if score > best_score: | |
| best_score = score.item() | |
| best_index = i | |
| tokens_confidences.append(tc) | |
| results.append( | |
| DecodeResult(hyps[best_index], | |
| best_score, | |
| confidence=confidences[best_index], | |
| times=ctc_prefix_results[b].nbest_times[best_index], | |
| tokens_confidence=tokens_confidences[best_index])) | |
| return results | |