Spaces:
Runtime error
Runtime error
Vaibhav Srivastav
commited on
Commit
Β·
379fa33
1
Parent(s):
3b8d409
adding greedy decoding
Browse files
app.py
CHANGED
|
@@ -22,8 +22,7 @@ def load_and_fix_data(input_file):
|
|
| 22 |
if sample_rate !=16000:
|
| 23 |
speech = librosa.resample(speech, sample_rate,16000)
|
| 24 |
return speech
|
| 25 |
-
|
| 26 |
-
|
| 27 |
def fix_transcription_casing(input_sentence):
|
| 28 |
sentences = nltk.sent_tokenize(input_sentence)
|
| 29 |
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
|
|
@@ -41,10 +40,27 @@ def predict_and_decode(input_file):
|
|
| 41 |
transcribed_text = fix_transcription_casing(pred.lower())
|
| 42 |
|
| 43 |
return transcribed_text
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
gr.Interface(predict_and_decode,
|
| 46 |
inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record/ Drop audio"),
|
| 47 |
-
outputs = gr.outputs.Textbox(label="
|
| 48 |
title="ASR using Wav2Vec 2.0 & pyctcdecode",
|
| 49 |
description = "Extending HF ASR models with pyctcdecode decoder",
|
| 50 |
layout = "horizontal",
|
|
|
|
| 22 |
if sample_rate !=16000:
|
| 23 |
speech = librosa.resample(speech, sample_rate,16000)
|
| 24 |
return speech
|
| 25 |
+
|
|
|
|
| 26 |
def fix_transcription_casing(input_sentence):
|
| 27 |
sentences = nltk.sent_tokenize(input_sentence)
|
| 28 |
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
|
|
|
|
| 40 |
transcribed_text = fix_transcription_casing(pred.lower())
|
| 41 |
|
| 42 |
return transcribed_text
|
| 43 |
+
|
| 44 |
+
def predict_and_greedy_decode(input_file):
|
| 45 |
+
speech = load_and_fix_data(input_file)
|
| 46 |
+
|
| 47 |
+
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
|
| 48 |
+
logits = model(input_values).logits
|
| 49 |
+
|
| 50 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 51 |
+
pred = processor.batch_decode(predicted_ids)
|
| 52 |
+
|
| 53 |
+
transcribed_text = fix_transcription_casing(pred.lower())
|
| 54 |
+
|
| 55 |
+
return transcribed_text
|
| 56 |
+
|
| 57 |
+
def return_all_predictions(input_file):
|
| 58 |
+
return predict_and_decode(input_file), predict_and_greedy_decode(input_file)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
gr.Interface(predict_and_decode,
|
| 62 |
inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record/ Drop audio"),
|
| 63 |
+
outputs = [gr.outputs.Textbox(label="Beam CTC Decoding"), gr.outputs.Textbox(label="Greedy Decoding")],
|
| 64 |
title="ASR using Wav2Vec 2.0 & pyctcdecode",
|
| 65 |
description = "Extending HF ASR models with pyctcdecode decoder",
|
| 66 |
layout = "horizontal",
|