hash-map commited on
Commit
2015ff1
·
verified ·
1 Parent(s): 4ffa9fc

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ import numpy as np
5
+ import os
6
+ from datetime import datetime
7
+ from model import Image2Phoneme
8
+ from utils import ctc_post_process, audio_to_mel, mel_to_image, text_to_phonemes
9
+ import soundfile as sf
10
+ import shutil
11
+ import pronouncing
12
+ import time
13
+
14
+ # Configuration
15
+ DEVICE = torch.device("cpu")
16
+ PHMAP = "phoneme_to_id.json"
17
+ AUDIO_DIR = "audio_inputs"
18
+
19
+ # Ensure audio directory exists
20
+ os.makedirs(AUDIO_DIR, exist_ok=True)
21
+
22
+ # Load phoneme vocabulary
23
+ try:
24
+ vocab = json.load(open(PHMAP, "r"))
25
+ id_to_ph = {v: k for k, v in vocab.items()}
26
+ except FileNotFoundError:
27
+ raise FileNotFoundError(f"Phoneme mapping file not found at {PHMAP}")
28
+
29
+ # Build model
30
+ vocab_size = max(vocab.values()) + 1
31
+ model = Image2Phoneme(vocab_size=vocab_size).to(DEVICE)
32
+ try:
33
+ ckpt = torch.load("last_checkpoint.pt", map_location=DEVICE, weights_only=True)
34
+ model.load_state_dict(ckpt["model_state_dict"])
35
+ model.eval()
36
+ except FileNotFoundError:
37
+ raise FileNotFoundError(f"Checkpoint file not found at last_checkpoint.pt")
38
+
39
+ def process_audio(audio_input):
40
+ """Process audio to predict phonemes and display mel spectrogram."""
41
+ try:
42
+ print(f"Received audio_input before processing: {audio_input}")
43
+ # Generate unique filename based on timestamp
44
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
45
+ audio_path = os.path.join(AUDIO_DIR, f"input_{timestamp}.wav")
46
+
47
+ # Handle audio input
48
+ if audio_input is None:
49
+ print("Audio input is None after stopping recording")
50
+ return {"error": "No audio input provided"}, None, None, None
51
+
52
+ if isinstance(audio_input, str):
53
+ # File upload: Copy the uploaded file to audio_inputs/
54
+ print(f"Processing uploaded file: {audio_input}")
55
+ if not os.path.exists(audio_input):
56
+ return {"error": f"Uploaded file not found: {audio_input}"}, None, None, None
57
+ if audio_input.endswith(".mp3"):
58
+ print("Converting .mp3 to .wav")
59
+ from pydub import AudioSegment
60
+ audio = AudioSegment.from_mp3(audio_input)
61
+ audio_path = audio_path.replace(".wav", "_converted.wav")
62
+ audio.export(audio_path, format="wav")
63
+ print(f"Converted file saved to: {audio_path}")
64
+ else:
65
+ shutil.copy(audio_input, audio_path)
66
+ print(f"Copied file to: {audio_path}")
67
+ else:
68
+ # Microphone input: (sample_rate, audio_data)
69
+ print("Processing microphone input")
70
+ sample_rate, audio_data = audio_input
71
+ print(f"Sample rate: {sample_rate}, Audio data shape: {audio_data.shape if hasattr(audio_data, 'shape') else 'None'}")
72
+ if audio_data is None or len(audio_data) == 0:
73
+ print("Microphone audio data is empty or invalid")
74
+ return {"error": "Microphone input data is empty or invalid"}, None, None, None
75
+ # Add a small delay to ensure audio data is fully captured
76
+ time.sleep(1)
77
+ sf.write(audio_path, audio_data, sample_rate)
78
+ print(f"Saved microphone audio to: {audio_path}")
79
+ # Verify the file exists
80
+ if not os.path.exists(audio_path):
81
+ print(f"Failed to save audio file at: {audio_path}")
82
+ return {"error": "Failed to save recorded audio file"}, None, None, None
83
+
84
+ # Process audio to mel spectrogram
85
+ mel_path = audio_to_mel(audio_path)
86
+ print(f"Generated mel spectrogram: {mel_path}")
87
+ if not os.path.exists(mel_path):
88
+ return {"error": f"Mel spectrogram file not found: {mel_path}"}, None, None, None
89
+
90
+ mel_image_path = mel_to_image(mel_path)
91
+ print(f"Generated mel spectrogram image: {mel_image_path}")
92
+ if not os.path.exists(mel_image_path):
93
+ return {"error": f"Mel spectrogram image not found: {mel_image_path}"}, None, None, None
94
+
95
+ # Load mel spectrogram
96
+ mel = np.load(mel_path) # shape (n_mels, T)
97
+ print(f"Loaded mel spectrogram shape: {mel.shape}")
98
+ mel_tensor = torch.tensor(mel).unsqueeze(0).to(DEVICE) # add batch dim
99
+ mel_lens = torch.tensor([mel.shape[1]]).to(DEVICE)
100
+
101
+ # Predict phonemes
102
+ with torch.no_grad():
103
+ ph_pred = model(mel_tensor) # shape (B, seq_len, vocab_size)
104
+ ph_ids = ph_pred.argmax(-1)[0].cpu().numpy() # pick first batch
105
+ print(f"Predicted phoneme IDs: {ph_ids}")
106
+
107
+ # Convert IDs to phonemes
108
+ ph_seq = [id_to_ph[i] for i in ph_ids if i > 0]
109
+ print(f"Raw phonemes: {ph_seq}")
110
+
111
+ # Post-process phonemes
112
+ post_processed = ctc_post_process(ph_seq)
113
+ print(f"Post-processed phonemes: {post_processed}")
114
+
115
+ # Return results
116
+ return {
117
+ "audio_path": audio_path,
118
+ "phonemes": " ".join(ph_seq),
119
+ "post_processed_phonemes": " ".join(post_processed)
120
+ }, mel_image_path, " ".join(ph_seq), " ".join(post_processed)
121
+ except Exception as e:
122
+ print(f"Error in process_audio: {str(e)}")
123
+ return {"error": f"Processing failed: {str(e)}"}, None, None, None
124
+
125
+ # Gradio interface
126
+ with gr.Blocks() as iface:
127
+ gr.Markdown("# Speech to Phonemes Converter")
128
+ gr.Markdown("Record or upload audio to predict phonemes and display mel spectrogram. Paste input text if available.")
129
+
130
+ audio_input = gr.Audio(sources=[ "upload"], type="filepath", label="Upload Audio (.wav or .mp3)", interactive=True)
131
+ text_input = gr.Textbox(label="Enter Text", placeholder="Type a sentence to convert to phonemes")
132
+ process_button = gr.Button("Process")
133
+
134
+ audio_output = gr.JSON(label="Audio Processing Results (Audio Path, Phonemes, Post-Processed Phonemes)")
135
+ mel_image = gr.Image(label="Mel Spectrogram", type="filepath")
136
+ raw_phonemes = gr.Textbox(label="Raw Phonemes")
137
+ post_processed_phonemes = gr.Textbox(label="Post-Processed Phonemes")
138
+ text_output = gr.JSON(label="Text-to-Phoneme Results")
139
+
140
+ def process(audio_input, text_input):
141
+ print(f"Processing inputs - Audio: {audio_input}, Text: {text_input}")
142
+ audio_result, mel_image_path, raw_ph, post_ph = process_audio(audio_input) if audio_input else ({}, None, None, None)
143
+ text_result = text_to_phonemes(text_input) if text_input else {}
144
+ return audio_result, mel_image_path, raw_ph, post_ph, text_result
145
+
146
+ process_button.click(
147
+ fn=process,
148
+ inputs=[audio_input, text_input],
149
+ outputs=[audio_output, mel_image, raw_phonemes, post_processed_phonemes, text_output]
150
+ )
151
+
152
+ if __name__ == "__main__":
153
+ iface.launch(debug=True)