hash-map commited on
Commit
4ffa9fc
·
verified ·
1 Parent(s): 429e785

Upload 6 files

Browse files
Files changed (6) hide show
  1. infer.py +153 -0
  2. last_checkpoint.pt +3 -0
  3. model.py +83 -0
  4. phoneme_to_id.json +79 -0
  5. requirements.txt +11 -0
  6. utils.py +133 -0
infer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ import numpy as np
5
+ import os
6
+ from datetime import datetime
7
+ from model import Image2Phoneme
8
+ from utils import ctc_post_process, audio_to_mel, mel_to_image, text_to_phonemes
9
+ import soundfile as sf
10
+ import shutil
11
+ import pronouncing
12
+ import time
13
+
14
+ # Configuration
15
+ DEVICE = torch.device("cpu")
16
+ PHMAP = "phoneme_to_id.json"
17
+ AUDIO_DIR = "audio_inputs"
18
+
19
+ # Ensure audio directory exists
20
+ os.makedirs(AUDIO_DIR, exist_ok=True)
21
+
22
+ # Load phoneme vocabulary
23
+ try:
24
+ vocab = json.load(open(PHMAP, "r"))
25
+ id_to_ph = {v: k for k, v in vocab.items()}
26
+ except FileNotFoundError:
27
+ raise FileNotFoundError(f"Phoneme mapping file not found at {PHMAP}")
28
+
29
+ # Build model
30
+ vocab_size = max(vocab.values()) + 1
31
+ model = Image2Phoneme(vocab_size=vocab_size).to(DEVICE)
32
+ try:
33
+ ckpt = torch.load("last_checkpoint.pt", map_location=DEVICE, weights_only=True)
34
+ model.load_state_dict(ckpt["model_state_dict"])
35
+ model.eval()
36
+ except FileNotFoundError:
37
+ raise FileNotFoundError(f"Checkpoint file not found at last_checkpoint.pt")
38
+
39
+ def process_audio(audio_input):
40
+ """Process audio to predict phonemes and display mel spectrogram."""
41
+ try:
42
+ print(f"Received audio_input before processing: {audio_input}")
43
+ # Generate unique filename based on timestamp
44
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
45
+ audio_path = os.path.join(AUDIO_DIR, f"input_{timestamp}.wav")
46
+
47
+ # Handle audio input
48
+ if audio_input is None:
49
+ print("Audio input is None after stopping recording")
50
+ return {"error": "No audio input provided"}, None, None, None
51
+
52
+ if isinstance(audio_input, str):
53
+ # File upload: Copy the uploaded file to audio_inputs/
54
+ print(f"Processing uploaded file: {audio_input}")
55
+ if not os.path.exists(audio_input):
56
+ return {"error": f"Uploaded file not found: {audio_input}"}, None, None, None
57
+ if audio_input.endswith(".mp3"):
58
+ print("Converting .mp3 to .wav")
59
+ from pydub import AudioSegment
60
+ audio = AudioSegment.from_mp3(audio_input)
61
+ audio_path = audio_path.replace(".wav", "_converted.wav")
62
+ audio.export(audio_path, format="wav")
63
+ print(f"Converted file saved to: {audio_path}")
64
+ else:
65
+ shutil.copy(audio_input, audio_path)
66
+ print(f"Copied file to: {audio_path}")
67
+ else:
68
+ # Microphone input: (sample_rate, audio_data)
69
+ print("Processing microphone input")
70
+ sample_rate, audio_data = audio_input
71
+ print(f"Sample rate: {sample_rate}, Audio data shape: {audio_data.shape if hasattr(audio_data, 'shape') else 'None'}")
72
+ if audio_data is None or len(audio_data) == 0:
73
+ print("Microphone audio data is empty or invalid")
74
+ return {"error": "Microphone input data is empty or invalid"}, None, None, None
75
+ # Add a small delay to ensure audio data is fully captured
76
+ time.sleep(1)
77
+ sf.write(audio_path, audio_data, sample_rate)
78
+ print(f"Saved microphone audio to: {audio_path}")
79
+ # Verify the file exists
80
+ if not os.path.exists(audio_path):
81
+ print(f"Failed to save audio file at: {audio_path}")
82
+ return {"error": "Failed to save recorded audio file"}, None, None, None
83
+
84
+ # Process audio to mel spectrogram
85
+ mel_path = audio_to_mel(audio_path)
86
+ print(f"Generated mel spectrogram: {mel_path}")
87
+ if not os.path.exists(mel_path):
88
+ return {"error": f"Mel spectrogram file not found: {mel_path}"}, None, None, None
89
+
90
+ mel_image_path = mel_to_image(mel_path)
91
+ print(f"Generated mel spectrogram image: {mel_image_path}")
92
+ if not os.path.exists(mel_image_path):
93
+ return {"error": f"Mel spectrogram image not found: {mel_image_path}"}, None, None, None
94
+
95
+ # Load mel spectrogram
96
+ mel = np.load(mel_path) # shape (n_mels, T)
97
+ print(f"Loaded mel spectrogram shape: {mel.shape}")
98
+ mel_tensor = torch.tensor(mel).unsqueeze(0).to(DEVICE) # add batch dim
99
+ mel_lens = torch.tensor([mel.shape[1]]).to(DEVICE)
100
+
101
+ # Predict phonemes
102
+ with torch.no_grad():
103
+ ph_pred = model(mel_tensor) # shape (B, seq_len, vocab_size)
104
+ ph_ids = ph_pred.argmax(-1)[0].cpu().numpy() # pick first batch
105
+ print(f"Predicted phoneme IDs: {ph_ids}")
106
+
107
+ # Convert IDs to phonemes
108
+ ph_seq = [id_to_ph[i] for i in ph_ids if i > 0]
109
+ print(f"Raw phonemes: {ph_seq}")
110
+
111
+ # Post-process phonemes
112
+ post_processed = ctc_post_process(ph_seq)
113
+ print(f"Post-processed phonemes: {post_processed}")
114
+
115
+ # Return results
116
+ return {
117
+ "audio_path": audio_path,
118
+ "phonemes": " ".join(ph_seq),
119
+ "post_processed_phonemes": " ".join(post_processed)
120
+ }, mel_image_path, " ".join(ph_seq), " ".join(post_processed)
121
+ except Exception as e:
122
+ print(f"Error in process_audio: {str(e)}")
123
+ return {"error": f"Processing failed: {str(e)}"}, None, None, None
124
+
125
+ # Gradio interface
126
+ with gr.Blocks() as iface:
127
+ gr.Markdown("# Speech to Phonemes Converter")
128
+ gr.Markdown("Record or upload audio to predict phonemes and display mel spectrogram. Paste input text if available.")
129
+
130
+ audio_input = gr.Audio(sources=[ "upload"], type="filepath", label="Upload Audio (.wav or .mp3)", interactive=True)
131
+ text_input = gr.Textbox(label="Enter Text", placeholder="Type a sentence to convert to phonemes")
132
+ process_button = gr.Button("Process")
133
+
134
+ audio_output = gr.JSON(label="Audio Processing Results (Audio Path, Phonemes, Post-Processed Phonemes)")
135
+ mel_image = gr.Image(label="Mel Spectrogram", type="filepath")
136
+ raw_phonemes = gr.Textbox(label="Raw Phonemes")
137
+ post_processed_phonemes = gr.Textbox(label="Post-Processed Phonemes")
138
+ text_output = gr.JSON(label="Text-to-Phoneme Results")
139
+
140
+ def process(audio_input, text_input):
141
+ print(f"Processing inputs - Audio: {audio_input}, Text: {text_input}")
142
+ audio_result, mel_image_path, raw_ph, post_ph = process_audio(audio_input) if audio_input else ({}, None, None, None)
143
+ text_result = text_to_phonemes(text_input) if text_input else {}
144
+ return audio_result, mel_image_path, raw_ph, post_ph, text_result
145
+
146
+ process_button.click(
147
+ fn=process,
148
+ inputs=[audio_input, text_input],
149
+ outputs=[audio_output, mel_image, raw_phonemes, post_processed_phonemes, text_output]
150
+ )
151
+
152
+ if __name__ == "__main__":
153
+ iface.launch(debug=True)
last_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94001abf3c674de3828f3aaf00ffea7964c8a85ee1942988463d1244cc33e978
3
+ size 13410944
model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_img2ph.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class CNNEncoder(nn.Module):
8
+ def __init__(self, in_channels=1, hidden_dim=256, dropout=0.2):
9
+ super().__init__()
10
+ # Convolutions mostly reduce frequency dimension, not time
11
+ self.conv = nn.Sequential(
12
+ nn.Conv2d(in_channels, 64, kernel_size=3, stride=(2,1), padding=1),
13
+ nn.BatchNorm2d(64),
14
+ nn.ReLU(),
15
+ nn.Dropout(dropout),
16
+
17
+ nn.Conv2d(64, 128, kernel_size=3, stride=(2,1), padding=1),
18
+ nn.BatchNorm2d(128),
19
+ nn.ReLU(),
20
+ nn.Dropout(dropout),
21
+
22
+ nn.Conv2d(128, 256, kernel_size=3, stride=(2,1), padding=1),
23
+ nn.BatchNorm2d(256),
24
+ nn.ReLU(),
25
+ nn.Dropout(dropout),
26
+
27
+ nn.Conv2d(256, hidden_dim, kernel_size=3, stride=(2,1), padding=1),
28
+ nn.BatchNorm2d(hidden_dim),
29
+ nn.ReLU(),
30
+ nn.Dropout(dropout),
31
+ )
32
+
33
+ def forward(self, x):
34
+ # x: (B, n_mels, T)
35
+ x = x.unsqueeze(1) # (B,1,n_mels,T)
36
+ feat = self.conv(x) # (B,C,H’,T)
37
+ B, C, H, T = feat.size()
38
+ # collapse frequency into features, keep time intact
39
+ feat = feat.permute(0, 3, 1, 2).contiguous() # (B,T,C,H)
40
+ feat = feat.view(B, T, C*H) # (B,T,features)
41
+ return feat
42
+
43
+
44
+ class PhonemeDecoder(nn.Module):
45
+ def __init__(self, vocab_size, enc_dim=128*5, rnn_hidden=128, num_layers=2, dropout=0.3):
46
+ super().__init__()
47
+ self.rnn = nn.GRU(
48
+ enc_dim, rnn_hidden,
49
+ num_layers=num_layers,
50
+ batch_first=True,
51
+ dropout=dropout,
52
+ bidirectional=False # Changed to unidirectional
53
+ )
54
+ self.proj = nn.Linear(rnn_hidden, 256) # Single projection layer
55
+ self.norm = nn.LayerNorm(256) # Added LayerNorm
56
+ self.fc_out = nn.Linear(256, vocab_size)
57
+ self.dropout = nn.Dropout(dropout)
58
+
59
+ def forward(self, enc_out):
60
+ rnn_out, _ = self.rnn(enc_out) # (B,T,rnn_hidden)
61
+ dense_out = self.proj(rnn_out) # (B,T,256)
62
+ dense_out = self.norm(dense_out) # Normalize
63
+ dense_out = F.relu(dense_out) # Activation
64
+ dense_out = self.dropout(dense_out) # Dropout after activation
65
+ logits = self.fc_out(dense_out) # (B,T,vocab_size)
66
+ return logits
67
+
68
+
69
+
70
+
71
+ class Image2Phoneme(nn.Module):
72
+ def __init__(self, vocab_size, in_channels=1, enc_hidden=128, rnn_hidden=128):
73
+ super().__init__()
74
+ self.encoder = CNNEncoder(in_channels=in_channels, hidden_dim=enc_hidden)
75
+ # enc_dim = enc_hidden * H’, after convs H’≈5 (if input mel=80, stride=(2,1) 4 times → 80/16=5)
76
+ enc_dim = enc_hidden * 5
77
+ self.decoder = PhonemeDecoder(vocab_size, enc_dim=enc_dim, rnn_hidden=rnn_hidden)
78
+
79
+ def forward(self, mels):
80
+ # mels: (B,n_mels,T)
81
+ enc_out = self.encoder(mels) # (B,T,enc_dim)
82
+ logits = self.decoder(enc_out) # (B,T,vocab_size)
83
+ return logits
phoneme_to_id.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "!": 1,
3
+ "'": 2,
4
+ ",": 3,
5
+ "-": 4,
6
+ ".": 5,
7
+ "..": 6,
8
+ "?": 7,
9
+ "AA0": 8,
10
+ "AA1": 9,
11
+ "AA2": 10,
12
+ "AE0": 11,
13
+ "AE1": 12,
14
+ "AE2": 13,
15
+ "AH0": 14,
16
+ "AH1": 15,
17
+ "AH2": 16,
18
+ "AO0": 17,
19
+ "AO1": 18,
20
+ "AO2": 19,
21
+ "AW0": 20,
22
+ "AW1": 21,
23
+ "AW2": 22,
24
+ "AY0": 23,
25
+ "AY1": 24,
26
+ "AY2": 25,
27
+ "B": 26,
28
+ "CH": 27,
29
+ "D": 28,
30
+ "DH": 29,
31
+ "EH0": 30,
32
+ "EH1": 31,
33
+ "EH2": 32,
34
+ "ER0": 33,
35
+ "ER1": 34,
36
+ "ER2": 35,
37
+ "EY0": 36,
38
+ "EY1": 37,
39
+ "EY2": 38,
40
+ "F": 39,
41
+ "G": 40,
42
+ "HH": 41,
43
+ "IH0": 42,
44
+ "IH1": 43,
45
+ "IH2": 44,
46
+ "IY0": 45,
47
+ "IY1": 46,
48
+ "IY2": 47,
49
+ "JH": 48,
50
+ "K": 49,
51
+ "L": 50,
52
+ "M": 51,
53
+ "N": 52,
54
+ "NG": 53,
55
+ "OW0": 54,
56
+ "OW1": 55,
57
+ "OW2": 56,
58
+ "OY0": 57,
59
+ "OY1": 58,
60
+ "OY2": 59,
61
+ "P": 60,
62
+ "R": 61,
63
+ "S": 62,
64
+ "SH": 63,
65
+ "T": 64,
66
+ "TH": 65,
67
+ "UH0": 66,
68
+ "UH1": 67,
69
+ "UH2": 68,
70
+ "UW0": 69,
71
+ "UW1": 70,
72
+ "UW2": 71,
73
+ "V": 72,
74
+ "W": 73,
75
+ "Y": 74,
76
+ "Z": 75,
77
+ "ZH": 76,
78
+ "<PAD>": 0
79
+ }
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ torchaudio
3
+ librosa
4
+ numpy
5
+ pandas
6
+ g2p-en
7
+ soundfile
8
+ tqdm
9
+ nltk
10
+ pronouncing
11
+ gradio
utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ import numpy as np
3
+ import librosa
4
+ from matplotlib import pyplot as plt
5
+ SR = 22050
6
+ HOP_LENGTH = 256
7
+ def mel_to_audio(mel_db, sr=22050, n_fft=1024, hop_length=256, win_length=1024, n_iter=60):
8
+ # mel_db: (n_mels, T) in dB (like saved from preprocess)
9
+ S = librosa.db_to_power(mel_db)
10
+ # invert mel to linear spectrogram
11
+ mel_basis = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=S.shape[0])
12
+ # Approximate inverse using pseudo inverse
13
+ inv_mel = np.maximum(1e-10, np.linalg.pinv(mel_basis).dot(S))
14
+ # Griffin-Lim
15
+ audio = librosa.griffinlim(inv_mel, n_iter=n_iter, hop_length=hop_length, win_length=win_length)
16
+ return audio
17
+
18
+ from g2p_en import G2p
19
+ g2p = G2p()
20
+
21
+ def text_to_phonemes(text):
22
+ ph = g2p(text)
23
+ # Remove spaces/punct tokens produced by g2p_en
24
+ ph = [p for p in ph if p.isalpha()]
25
+ return " ".join(ph)
26
+
27
+ import librosa
28
+ import numpy as np
29
+ import os
30
+
31
+ def audio_to_mel(audio_path, save_dir="mels", sr=22050, n_fft=1024, hop_length=256, win_length=1024, n_mels=80):
32
+ # Load audio
33
+ y, _ = librosa.load(audio_path, sr=sr)
34
+
35
+ # Compute STFT magnitude
36
+ S = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
37
+
38
+ # Convert to mel spectrogram
39
+ mel_basis = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels)
40
+ mel = np.dot(mel_basis, S)
41
+
42
+ # Convert to dB
43
+ mel_db = librosa.power_to_db(mel)
44
+
45
+ # Make sure save directory exists
46
+ os.makedirs(save_dir, exist_ok=True)
47
+
48
+ # Save mel as .npy file
49
+ base_name = os.path.splitext(os.path.basename(audio_path))[0]
50
+ mel_path = os.path.join(save_dir, base_name + "_mel.npy")
51
+ np.save(mel_path, mel_db)
52
+
53
+ return mel_path
54
+
55
+
56
+ def ctc_post_process(phonemes):
57
+ """
58
+ Collapse repeats + remove blanks ('-') in CTC output.
59
+ phonemes: list of predicted phoneme tokens
60
+ """
61
+ new_seq = []
62
+ prev = None
63
+ for p in phonemes:
64
+ if p == "-" or p == prev:
65
+ continue
66
+ new_seq.append(p)
67
+ prev = p
68
+ return new_seq
69
+
70
+
71
+ import numpy as np
72
+ import matplotlib.pyplot as plt
73
+ import librosa.display
74
+
75
+ def mel_to_image(mel_path, sr=22050, hop_length=256, save_fig=True):
76
+ # Load mel spectrogram from .npy
77
+ mel_db = np.load(mel_path)
78
+
79
+ # Create figure
80
+ plt.figure(figsize=(14, 6))
81
+
82
+ # Plot mel spectrogram
83
+ librosa.display.specshow(mel_db, sr=sr, hop_length=hop_length, x_axis='time', y_axis='mel', cmap='magma')
84
+ plt.title("Mel Spectrogram (dB)")
85
+ plt.colorbar(format="%+2.0f dB")
86
+
87
+ save_path = mel_path.replace('.npy', '_mel.png')
88
+ plt.savefig(save_path)
89
+ print(f"Saved mel spectrogram image at: {save_path}")
90
+
91
+ """plt.show()"""
92
+ return save_path
93
+ # load reverse lexicon: phoneme_seq -> [words]
94
+ import nltk
95
+ from collections import defaultdict
96
+
97
+ nltk.download('cmudict')
98
+ arpabet = nltk.corpus.cmudict.dict()
99
+
100
+ # Build reverse lexicon
101
+ reverse_lex = defaultdict(list)
102
+ for word, pron_list in arpabet.items():
103
+ for pron in pron_list:
104
+ reverse_lex[tuple(pron)].append(word)
105
+
106
+ def split_on_boundaries(phoneme_stream, boundary_token="<w>"):
107
+ """Split on a special token representing word boundaries."""
108
+ words = []
109
+ current = []
110
+ for phon in phoneme_stream:
111
+ if phon == boundary_token:
112
+ if current:
113
+ words.append(current)
114
+ current = []
115
+ else:
116
+ current.append(phon)
117
+ if current:
118
+ words.append(current)
119
+ return words
120
+
121
+ def p2g_fallback(phoneme_word):
122
+ # Placeholder for fallback pronunciation-to-spelling
123
+ return "".join(phoneme_word).lower()
124
+
125
+ def phonemes_to_text(phoneme_stream):
126
+ words = []
127
+ for phoneme_word in split_on_boundaries(phoneme_stream):
128
+ candidates = reverse_lex.get(tuple(phoneme_word), [])
129
+ if candidates:
130
+ words.append(candidates[0])
131
+ else:
132
+ words.append(p2g_fallback(phoneme_word))
133
+ return " ".join(words)