Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Deep speech decoder.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import itertools | |
| from nltk.metrics import distance | |
| import numpy as np | |
| class DeepSpeechDecoder(object): | |
| """Greedy decoder implementation for Deep Speech model.""" | |
| def __init__(self, labels, blank_index=28): | |
| """Decoder initialization. | |
| Arguments: | |
| labels: a string specifying the speech labels for the decoder to use. | |
| blank_index: an integer specifying index for the blank character. | |
| Defaults to 28. | |
| """ | |
| # e.g. labels = "[a-z]' _" | |
| self.labels = labels | |
| self.blank_index = blank_index | |
| self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) | |
| def convert_to_string(self, sequence): | |
| """Convert a sequence of indexes into corresponding string.""" | |
| return ''.join([self.int_to_char[i] for i in sequence]) | |
| def wer(self, decode, target): | |
| """Computes the Word Error Rate (WER). | |
| WER is defined as the edit distance between the two provided sentences after | |
| tokenizing to words. | |
| Args: | |
| decode: string of the decoded output. | |
| target: a string for the ground truth label. | |
| Returns: | |
| A float number for the WER of the current decode-target pair. | |
| """ | |
| # Map each word to a new char. | |
| words = set(decode.split() + target.split()) | |
| word2char = dict(zip(words, range(len(words)))) | |
| new_decode = [chr(word2char[w]) for w in decode.split()] | |
| new_target = [chr(word2char[w]) for w in target.split()] | |
| return distance.edit_distance(''.join(new_decode), ''.join(new_target)) | |
| def cer(self, decode, target): | |
| """Computes the Character Error Rate (CER). | |
| CER is defined as the edit distance between the two given strings. | |
| Args: | |
| decode: a string of the decoded output. | |
| target: a string for the ground truth label. | |
| Returns: | |
| A float number denoting the CER for the current sentence pair. | |
| """ | |
| return distance.edit_distance(decode, target) | |
| def decode(self, logits): | |
| """Decode the best guess from logits using greedy algorithm.""" | |
| # Choose the class with maximimum probability. | |
| best = list(np.argmax(logits, axis=1)) | |
| # Merge repeated chars. | |
| merge = [k for k, _ in itertools.groupby(best)] | |
| # Remove the blank index in the decoded sequence. | |
| merge_remove_blank = [] | |
| for k in merge: | |
| if k != self.blank_index: | |
| merge_remove_blank.append(k) | |
| return self.convert_to_string(merge_remove_blank) | |