Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- infer.py +153 -0
- last_checkpoint.pt +3 -0
- model.py +83 -0
- phoneme_to_id.json +79 -0
- requirements.txt +11 -0
- utils.py +133 -0
infer.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from model import Image2Phoneme
|
| 8 |
+
from utils import ctc_post_process, audio_to_mel, mel_to_image, text_to_phonemes
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import shutil
|
| 11 |
+
import pronouncing
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
# Configuration
|
| 15 |
+
DEVICE = torch.device("cpu")
|
| 16 |
+
PHMAP = "phoneme_to_id.json"
|
| 17 |
+
AUDIO_DIR = "audio_inputs"
|
| 18 |
+
|
| 19 |
+
# Ensure audio directory exists
|
| 20 |
+
os.makedirs(AUDIO_DIR, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
# Load phoneme vocabulary
|
| 23 |
+
try:
|
| 24 |
+
vocab = json.load(open(PHMAP, "r"))
|
| 25 |
+
id_to_ph = {v: k for k, v in vocab.items()}
|
| 26 |
+
except FileNotFoundError:
|
| 27 |
+
raise FileNotFoundError(f"Phoneme mapping file not found at {PHMAP}")
|
| 28 |
+
|
| 29 |
+
# Build model
|
| 30 |
+
vocab_size = max(vocab.values()) + 1
|
| 31 |
+
model = Image2Phoneme(vocab_size=vocab_size).to(DEVICE)
|
| 32 |
+
try:
|
| 33 |
+
ckpt = torch.load("last_checkpoint.pt", map_location=DEVICE, weights_only=True)
|
| 34 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 35 |
+
model.eval()
|
| 36 |
+
except FileNotFoundError:
|
| 37 |
+
raise FileNotFoundError(f"Checkpoint file not found at last_checkpoint.pt")
|
| 38 |
+
|
| 39 |
+
def process_audio(audio_input):
|
| 40 |
+
"""Process audio to predict phonemes and display mel spectrogram."""
|
| 41 |
+
try:
|
| 42 |
+
print(f"Received audio_input before processing: {audio_input}")
|
| 43 |
+
# Generate unique filename based on timestamp
|
| 44 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 45 |
+
audio_path = os.path.join(AUDIO_DIR, f"input_{timestamp}.wav")
|
| 46 |
+
|
| 47 |
+
# Handle audio input
|
| 48 |
+
if audio_input is None:
|
| 49 |
+
print("Audio input is None after stopping recording")
|
| 50 |
+
return {"error": "No audio input provided"}, None, None, None
|
| 51 |
+
|
| 52 |
+
if isinstance(audio_input, str):
|
| 53 |
+
# File upload: Copy the uploaded file to audio_inputs/
|
| 54 |
+
print(f"Processing uploaded file: {audio_input}")
|
| 55 |
+
if not os.path.exists(audio_input):
|
| 56 |
+
return {"error": f"Uploaded file not found: {audio_input}"}, None, None, None
|
| 57 |
+
if audio_input.endswith(".mp3"):
|
| 58 |
+
print("Converting .mp3 to .wav")
|
| 59 |
+
from pydub import AudioSegment
|
| 60 |
+
audio = AudioSegment.from_mp3(audio_input)
|
| 61 |
+
audio_path = audio_path.replace(".wav", "_converted.wav")
|
| 62 |
+
audio.export(audio_path, format="wav")
|
| 63 |
+
print(f"Converted file saved to: {audio_path}")
|
| 64 |
+
else:
|
| 65 |
+
shutil.copy(audio_input, audio_path)
|
| 66 |
+
print(f"Copied file to: {audio_path}")
|
| 67 |
+
else:
|
| 68 |
+
# Microphone input: (sample_rate, audio_data)
|
| 69 |
+
print("Processing microphone input")
|
| 70 |
+
sample_rate, audio_data = audio_input
|
| 71 |
+
print(f"Sample rate: {sample_rate}, Audio data shape: {audio_data.shape if hasattr(audio_data, 'shape') else 'None'}")
|
| 72 |
+
if audio_data is None or len(audio_data) == 0:
|
| 73 |
+
print("Microphone audio data is empty or invalid")
|
| 74 |
+
return {"error": "Microphone input data is empty or invalid"}, None, None, None
|
| 75 |
+
# Add a small delay to ensure audio data is fully captured
|
| 76 |
+
time.sleep(1)
|
| 77 |
+
sf.write(audio_path, audio_data, sample_rate)
|
| 78 |
+
print(f"Saved microphone audio to: {audio_path}")
|
| 79 |
+
# Verify the file exists
|
| 80 |
+
if not os.path.exists(audio_path):
|
| 81 |
+
print(f"Failed to save audio file at: {audio_path}")
|
| 82 |
+
return {"error": "Failed to save recorded audio file"}, None, None, None
|
| 83 |
+
|
| 84 |
+
# Process audio to mel spectrogram
|
| 85 |
+
mel_path = audio_to_mel(audio_path)
|
| 86 |
+
print(f"Generated mel spectrogram: {mel_path}")
|
| 87 |
+
if not os.path.exists(mel_path):
|
| 88 |
+
return {"error": f"Mel spectrogram file not found: {mel_path}"}, None, None, None
|
| 89 |
+
|
| 90 |
+
mel_image_path = mel_to_image(mel_path)
|
| 91 |
+
print(f"Generated mel spectrogram image: {mel_image_path}")
|
| 92 |
+
if not os.path.exists(mel_image_path):
|
| 93 |
+
return {"error": f"Mel spectrogram image not found: {mel_image_path}"}, None, None, None
|
| 94 |
+
|
| 95 |
+
# Load mel spectrogram
|
| 96 |
+
mel = np.load(mel_path) # shape (n_mels, T)
|
| 97 |
+
print(f"Loaded mel spectrogram shape: {mel.shape}")
|
| 98 |
+
mel_tensor = torch.tensor(mel).unsqueeze(0).to(DEVICE) # add batch dim
|
| 99 |
+
mel_lens = torch.tensor([mel.shape[1]]).to(DEVICE)
|
| 100 |
+
|
| 101 |
+
# Predict phonemes
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
ph_pred = model(mel_tensor) # shape (B, seq_len, vocab_size)
|
| 104 |
+
ph_ids = ph_pred.argmax(-1)[0].cpu().numpy() # pick first batch
|
| 105 |
+
print(f"Predicted phoneme IDs: {ph_ids}")
|
| 106 |
+
|
| 107 |
+
# Convert IDs to phonemes
|
| 108 |
+
ph_seq = [id_to_ph[i] for i in ph_ids if i > 0]
|
| 109 |
+
print(f"Raw phonemes: {ph_seq}")
|
| 110 |
+
|
| 111 |
+
# Post-process phonemes
|
| 112 |
+
post_processed = ctc_post_process(ph_seq)
|
| 113 |
+
print(f"Post-processed phonemes: {post_processed}")
|
| 114 |
+
|
| 115 |
+
# Return results
|
| 116 |
+
return {
|
| 117 |
+
"audio_path": audio_path,
|
| 118 |
+
"phonemes": " ".join(ph_seq),
|
| 119 |
+
"post_processed_phonemes": " ".join(post_processed)
|
| 120 |
+
}, mel_image_path, " ".join(ph_seq), " ".join(post_processed)
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"Error in process_audio: {str(e)}")
|
| 123 |
+
return {"error": f"Processing failed: {str(e)}"}, None, None, None
|
| 124 |
+
|
| 125 |
+
# Gradio interface
|
| 126 |
+
with gr.Blocks() as iface:
|
| 127 |
+
gr.Markdown("# Speech to Phonemes Converter")
|
| 128 |
+
gr.Markdown("Record or upload audio to predict phonemes and display mel spectrogram. Paste input text if available.")
|
| 129 |
+
|
| 130 |
+
audio_input = gr.Audio(sources=[ "upload"], type="filepath", label="Upload Audio (.wav or .mp3)", interactive=True)
|
| 131 |
+
text_input = gr.Textbox(label="Enter Text", placeholder="Type a sentence to convert to phonemes")
|
| 132 |
+
process_button = gr.Button("Process")
|
| 133 |
+
|
| 134 |
+
audio_output = gr.JSON(label="Audio Processing Results (Audio Path, Phonemes, Post-Processed Phonemes)")
|
| 135 |
+
mel_image = gr.Image(label="Mel Spectrogram", type="filepath")
|
| 136 |
+
raw_phonemes = gr.Textbox(label="Raw Phonemes")
|
| 137 |
+
post_processed_phonemes = gr.Textbox(label="Post-Processed Phonemes")
|
| 138 |
+
text_output = gr.JSON(label="Text-to-Phoneme Results")
|
| 139 |
+
|
| 140 |
+
def process(audio_input, text_input):
|
| 141 |
+
print(f"Processing inputs - Audio: {audio_input}, Text: {text_input}")
|
| 142 |
+
audio_result, mel_image_path, raw_ph, post_ph = process_audio(audio_input) if audio_input else ({}, None, None, None)
|
| 143 |
+
text_result = text_to_phonemes(text_input) if text_input else {}
|
| 144 |
+
return audio_result, mel_image_path, raw_ph, post_ph, text_result
|
| 145 |
+
|
| 146 |
+
process_button.click(
|
| 147 |
+
fn=process,
|
| 148 |
+
inputs=[audio_input, text_input],
|
| 149 |
+
outputs=[audio_output, mel_image, raw_phonemes, post_processed_phonemes, text_output]
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
iface.launch(debug=True)
|
last_checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94001abf3c674de3828f3aaf00ffea7964c8a85ee1942988463d1244cc33e978
|
| 3 |
+
size 13410944
|
model.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model_img2ph.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CNNEncoder(nn.Module):
|
| 8 |
+
def __init__(self, in_channels=1, hidden_dim=256, dropout=0.2):
|
| 9 |
+
super().__init__()
|
| 10 |
+
# Convolutions mostly reduce frequency dimension, not time
|
| 11 |
+
self.conv = nn.Sequential(
|
| 12 |
+
nn.Conv2d(in_channels, 64, kernel_size=3, stride=(2,1), padding=1),
|
| 13 |
+
nn.BatchNorm2d(64),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.Dropout(dropout),
|
| 16 |
+
|
| 17 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=(2,1), padding=1),
|
| 18 |
+
nn.BatchNorm2d(128),
|
| 19 |
+
nn.ReLU(),
|
| 20 |
+
nn.Dropout(dropout),
|
| 21 |
+
|
| 22 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=(2,1), padding=1),
|
| 23 |
+
nn.BatchNorm2d(256),
|
| 24 |
+
nn.ReLU(),
|
| 25 |
+
nn.Dropout(dropout),
|
| 26 |
+
|
| 27 |
+
nn.Conv2d(256, hidden_dim, kernel_size=3, stride=(2,1), padding=1),
|
| 28 |
+
nn.BatchNorm2d(hidden_dim),
|
| 29 |
+
nn.ReLU(),
|
| 30 |
+
nn.Dropout(dropout),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
# x: (B, n_mels, T)
|
| 35 |
+
x = x.unsqueeze(1) # (B,1,n_mels,T)
|
| 36 |
+
feat = self.conv(x) # (B,C,H’,T)
|
| 37 |
+
B, C, H, T = feat.size()
|
| 38 |
+
# collapse frequency into features, keep time intact
|
| 39 |
+
feat = feat.permute(0, 3, 1, 2).contiguous() # (B,T,C,H)
|
| 40 |
+
feat = feat.view(B, T, C*H) # (B,T,features)
|
| 41 |
+
return feat
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class PhonemeDecoder(nn.Module):
|
| 45 |
+
def __init__(self, vocab_size, enc_dim=128*5, rnn_hidden=128, num_layers=2, dropout=0.3):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.rnn = nn.GRU(
|
| 48 |
+
enc_dim, rnn_hidden,
|
| 49 |
+
num_layers=num_layers,
|
| 50 |
+
batch_first=True,
|
| 51 |
+
dropout=dropout,
|
| 52 |
+
bidirectional=False # Changed to unidirectional
|
| 53 |
+
)
|
| 54 |
+
self.proj = nn.Linear(rnn_hidden, 256) # Single projection layer
|
| 55 |
+
self.norm = nn.LayerNorm(256) # Added LayerNorm
|
| 56 |
+
self.fc_out = nn.Linear(256, vocab_size)
|
| 57 |
+
self.dropout = nn.Dropout(dropout)
|
| 58 |
+
|
| 59 |
+
def forward(self, enc_out):
|
| 60 |
+
rnn_out, _ = self.rnn(enc_out) # (B,T,rnn_hidden)
|
| 61 |
+
dense_out = self.proj(rnn_out) # (B,T,256)
|
| 62 |
+
dense_out = self.norm(dense_out) # Normalize
|
| 63 |
+
dense_out = F.relu(dense_out) # Activation
|
| 64 |
+
dense_out = self.dropout(dense_out) # Dropout after activation
|
| 65 |
+
logits = self.fc_out(dense_out) # (B,T,vocab_size)
|
| 66 |
+
return logits
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Image2Phoneme(nn.Module):
|
| 72 |
+
def __init__(self, vocab_size, in_channels=1, enc_hidden=128, rnn_hidden=128):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.encoder = CNNEncoder(in_channels=in_channels, hidden_dim=enc_hidden)
|
| 75 |
+
# enc_dim = enc_hidden * H’, after convs H’≈5 (if input mel=80, stride=(2,1) 4 times → 80/16=5)
|
| 76 |
+
enc_dim = enc_hidden * 5
|
| 77 |
+
self.decoder = PhonemeDecoder(vocab_size, enc_dim=enc_dim, rnn_hidden=rnn_hidden)
|
| 78 |
+
|
| 79 |
+
def forward(self, mels):
|
| 80 |
+
# mels: (B,n_mels,T)
|
| 81 |
+
enc_out = self.encoder(mels) # (B,T,enc_dim)
|
| 82 |
+
logits = self.decoder(enc_out) # (B,T,vocab_size)
|
| 83 |
+
return logits
|
phoneme_to_id.json
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"!": 1,
|
| 3 |
+
"'": 2,
|
| 4 |
+
",": 3,
|
| 5 |
+
"-": 4,
|
| 6 |
+
".": 5,
|
| 7 |
+
"..": 6,
|
| 8 |
+
"?": 7,
|
| 9 |
+
"AA0": 8,
|
| 10 |
+
"AA1": 9,
|
| 11 |
+
"AA2": 10,
|
| 12 |
+
"AE0": 11,
|
| 13 |
+
"AE1": 12,
|
| 14 |
+
"AE2": 13,
|
| 15 |
+
"AH0": 14,
|
| 16 |
+
"AH1": 15,
|
| 17 |
+
"AH2": 16,
|
| 18 |
+
"AO0": 17,
|
| 19 |
+
"AO1": 18,
|
| 20 |
+
"AO2": 19,
|
| 21 |
+
"AW0": 20,
|
| 22 |
+
"AW1": 21,
|
| 23 |
+
"AW2": 22,
|
| 24 |
+
"AY0": 23,
|
| 25 |
+
"AY1": 24,
|
| 26 |
+
"AY2": 25,
|
| 27 |
+
"B": 26,
|
| 28 |
+
"CH": 27,
|
| 29 |
+
"D": 28,
|
| 30 |
+
"DH": 29,
|
| 31 |
+
"EH0": 30,
|
| 32 |
+
"EH1": 31,
|
| 33 |
+
"EH2": 32,
|
| 34 |
+
"ER0": 33,
|
| 35 |
+
"ER1": 34,
|
| 36 |
+
"ER2": 35,
|
| 37 |
+
"EY0": 36,
|
| 38 |
+
"EY1": 37,
|
| 39 |
+
"EY2": 38,
|
| 40 |
+
"F": 39,
|
| 41 |
+
"G": 40,
|
| 42 |
+
"HH": 41,
|
| 43 |
+
"IH0": 42,
|
| 44 |
+
"IH1": 43,
|
| 45 |
+
"IH2": 44,
|
| 46 |
+
"IY0": 45,
|
| 47 |
+
"IY1": 46,
|
| 48 |
+
"IY2": 47,
|
| 49 |
+
"JH": 48,
|
| 50 |
+
"K": 49,
|
| 51 |
+
"L": 50,
|
| 52 |
+
"M": 51,
|
| 53 |
+
"N": 52,
|
| 54 |
+
"NG": 53,
|
| 55 |
+
"OW0": 54,
|
| 56 |
+
"OW1": 55,
|
| 57 |
+
"OW2": 56,
|
| 58 |
+
"OY0": 57,
|
| 59 |
+
"OY1": 58,
|
| 60 |
+
"OY2": 59,
|
| 61 |
+
"P": 60,
|
| 62 |
+
"R": 61,
|
| 63 |
+
"S": 62,
|
| 64 |
+
"SH": 63,
|
| 65 |
+
"T": 64,
|
| 66 |
+
"TH": 65,
|
| 67 |
+
"UH0": 66,
|
| 68 |
+
"UH1": 67,
|
| 69 |
+
"UH2": 68,
|
| 70 |
+
"UW0": 69,
|
| 71 |
+
"UW1": 70,
|
| 72 |
+
"UW2": 71,
|
| 73 |
+
"V": 72,
|
| 74 |
+
"W": 73,
|
| 75 |
+
"Y": 74,
|
| 76 |
+
"Z": 75,
|
| 77 |
+
"ZH": 76,
|
| 78 |
+
"<PAD>": 0
|
| 79 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
torchaudio
|
| 3 |
+
librosa
|
| 4 |
+
numpy
|
| 5 |
+
pandas
|
| 6 |
+
g2p-en
|
| 7 |
+
soundfile
|
| 8 |
+
tqdm
|
| 9 |
+
nltk
|
| 10 |
+
pronouncing
|
| 11 |
+
gradio
|
utils.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
import librosa
|
| 4 |
+
from matplotlib import pyplot as plt
|
| 5 |
+
SR = 22050
|
| 6 |
+
HOP_LENGTH = 256
|
| 7 |
+
def mel_to_audio(mel_db, sr=22050, n_fft=1024, hop_length=256, win_length=1024, n_iter=60):
|
| 8 |
+
# mel_db: (n_mels, T) in dB (like saved from preprocess)
|
| 9 |
+
S = librosa.db_to_power(mel_db)
|
| 10 |
+
# invert mel to linear spectrogram
|
| 11 |
+
mel_basis = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=S.shape[0])
|
| 12 |
+
# Approximate inverse using pseudo inverse
|
| 13 |
+
inv_mel = np.maximum(1e-10, np.linalg.pinv(mel_basis).dot(S))
|
| 14 |
+
# Griffin-Lim
|
| 15 |
+
audio = librosa.griffinlim(inv_mel, n_iter=n_iter, hop_length=hop_length, win_length=win_length)
|
| 16 |
+
return audio
|
| 17 |
+
|
| 18 |
+
from g2p_en import G2p
|
| 19 |
+
g2p = G2p()
|
| 20 |
+
|
| 21 |
+
def text_to_phonemes(text):
|
| 22 |
+
ph = g2p(text)
|
| 23 |
+
# Remove spaces/punct tokens produced by g2p_en
|
| 24 |
+
ph = [p for p in ph if p.isalpha()]
|
| 25 |
+
return " ".join(ph)
|
| 26 |
+
|
| 27 |
+
import librosa
|
| 28 |
+
import numpy as np
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
def audio_to_mel(audio_path, save_dir="mels", sr=22050, n_fft=1024, hop_length=256, win_length=1024, n_mels=80):
|
| 32 |
+
# Load audio
|
| 33 |
+
y, _ = librosa.load(audio_path, sr=sr)
|
| 34 |
+
|
| 35 |
+
# Compute STFT magnitude
|
| 36 |
+
S = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
|
| 37 |
+
|
| 38 |
+
# Convert to mel spectrogram
|
| 39 |
+
mel_basis = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels)
|
| 40 |
+
mel = np.dot(mel_basis, S)
|
| 41 |
+
|
| 42 |
+
# Convert to dB
|
| 43 |
+
mel_db = librosa.power_to_db(mel)
|
| 44 |
+
|
| 45 |
+
# Make sure save directory exists
|
| 46 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
# Save mel as .npy file
|
| 49 |
+
base_name = os.path.splitext(os.path.basename(audio_path))[0]
|
| 50 |
+
mel_path = os.path.join(save_dir, base_name + "_mel.npy")
|
| 51 |
+
np.save(mel_path, mel_db)
|
| 52 |
+
|
| 53 |
+
return mel_path
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def ctc_post_process(phonemes):
|
| 57 |
+
"""
|
| 58 |
+
Collapse repeats + remove blanks ('-') in CTC output.
|
| 59 |
+
phonemes: list of predicted phoneme tokens
|
| 60 |
+
"""
|
| 61 |
+
new_seq = []
|
| 62 |
+
prev = None
|
| 63 |
+
for p in phonemes:
|
| 64 |
+
if p == "-" or p == prev:
|
| 65 |
+
continue
|
| 66 |
+
new_seq.append(p)
|
| 67 |
+
prev = p
|
| 68 |
+
return new_seq
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
import numpy as np
|
| 72 |
+
import matplotlib.pyplot as plt
|
| 73 |
+
import librosa.display
|
| 74 |
+
|
| 75 |
+
def mel_to_image(mel_path, sr=22050, hop_length=256, save_fig=True):
|
| 76 |
+
# Load mel spectrogram from .npy
|
| 77 |
+
mel_db = np.load(mel_path)
|
| 78 |
+
|
| 79 |
+
# Create figure
|
| 80 |
+
plt.figure(figsize=(14, 6))
|
| 81 |
+
|
| 82 |
+
# Plot mel spectrogram
|
| 83 |
+
librosa.display.specshow(mel_db, sr=sr, hop_length=hop_length, x_axis='time', y_axis='mel', cmap='magma')
|
| 84 |
+
plt.title("Mel Spectrogram (dB)")
|
| 85 |
+
plt.colorbar(format="%+2.0f dB")
|
| 86 |
+
|
| 87 |
+
save_path = mel_path.replace('.npy', '_mel.png')
|
| 88 |
+
plt.savefig(save_path)
|
| 89 |
+
print(f"Saved mel spectrogram image at: {save_path}")
|
| 90 |
+
|
| 91 |
+
"""plt.show()"""
|
| 92 |
+
return save_path
|
| 93 |
+
# load reverse lexicon: phoneme_seq -> [words]
|
| 94 |
+
import nltk
|
| 95 |
+
from collections import defaultdict
|
| 96 |
+
|
| 97 |
+
nltk.download('cmudict')
|
| 98 |
+
arpabet = nltk.corpus.cmudict.dict()
|
| 99 |
+
|
| 100 |
+
# Build reverse lexicon
|
| 101 |
+
reverse_lex = defaultdict(list)
|
| 102 |
+
for word, pron_list in arpabet.items():
|
| 103 |
+
for pron in pron_list:
|
| 104 |
+
reverse_lex[tuple(pron)].append(word)
|
| 105 |
+
|
| 106 |
+
def split_on_boundaries(phoneme_stream, boundary_token="<w>"):
|
| 107 |
+
"""Split on a special token representing word boundaries."""
|
| 108 |
+
words = []
|
| 109 |
+
current = []
|
| 110 |
+
for phon in phoneme_stream:
|
| 111 |
+
if phon == boundary_token:
|
| 112 |
+
if current:
|
| 113 |
+
words.append(current)
|
| 114 |
+
current = []
|
| 115 |
+
else:
|
| 116 |
+
current.append(phon)
|
| 117 |
+
if current:
|
| 118 |
+
words.append(current)
|
| 119 |
+
return words
|
| 120 |
+
|
| 121 |
+
def p2g_fallback(phoneme_word):
|
| 122 |
+
# Placeholder for fallback pronunciation-to-spelling
|
| 123 |
+
return "".join(phoneme_word).lower()
|
| 124 |
+
|
| 125 |
+
def phonemes_to_text(phoneme_stream):
|
| 126 |
+
words = []
|
| 127 |
+
for phoneme_word in split_on_boundaries(phoneme_stream):
|
| 128 |
+
candidates = reverse_lex.get(tuple(phoneme_word), [])
|
| 129 |
+
if candidates:
|
| 130 |
+
words.append(candidates[0])
|
| 131 |
+
else:
|
| 132 |
+
words.append(p2g_fallback(phoneme_word))
|
| 133 |
+
return " ".join(words)
|