Sarah Solito commited on
Commit
2b7ff2f
·
1 Parent(s): 524de77

Release v1.0 including v2_fast improvements and dynamic compute_type selection

Browse files
Files changed (3) hide show
  1. app.py +2 -1
  2. settings.py +4 -2
  3. whisper_cs_dev.py +100 -113
app.py CHANGED
@@ -22,7 +22,8 @@ with gr.Blocks() as demo:
22
  gr.Markdown(description_string)
23
  with gr.Row():
24
  with gr.Column(scale=1):
25
- model_version = gr.Dropdown(label="Model Version", choices=["v2_fast", "v2.0"], value="v2_fast")
 
26
  input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")
27
 
28
  with gr.Column(scale=1):
 
22
  gr.Markdown(description_string)
23
  with gr.Row():
24
  with gr.Column(scale=1):
25
+ model_version = gr.Dropdown(label="Model Version", choices=["v2_fast", "v1.0"], value="v2_fast")
26
+
27
  input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")
28
 
29
  with gr.Column(scale=1):
settings.py CHANGED
@@ -1,6 +1,8 @@
1
  DEBUG_MODE = True
2
- MODEL_PATH_V2 = "langtech-veu/whisper-timestamped-cs"
3
  MODEL_PATH_V2_FAST = "langtech-veu/faster-whisper-timestamped-cs"
4
  LEFT_CHANNEL_TEMP_PATH = "temp_mono_speaker2.wav"
5
  RIGHT_CHANNEL_TEMP_PATH = "temp_mono_speaker1.wav"
6
- RESAMPLING_FREQ = 16000
 
 
 
1
  DEBUG_MODE = True
2
+ MODEL_PATH_V1 = "projecte-aina/whisper-large-v3-tiny-caesar"
3
  MODEL_PATH_V2_FAST = "langtech-veu/faster-whisper-timestamped-cs"
4
  LEFT_CHANNEL_TEMP_PATH = "temp_mono_speaker2.wav"
5
  RIGHT_CHANNEL_TEMP_PATH = "temp_mono_speaker1.wav"
6
+ RESAMPLING_FREQ = 16000
7
+ BATCH_SIZE = 1
8
+ TASK = "transcribe"
whisper_cs_dev.py CHANGED
@@ -1,5 +1,5 @@
1
  from faster_whisper import WhisperModel
2
- import whisper_timestamped as whisper_ts
3
  from pydub import AudioSegment
4
  import os
5
  import torchaudio
@@ -10,8 +10,9 @@ import sys
10
  from pathlib import Path
11
  import glob
12
  import ctypes
 
13
 
14
- from settings import DEBUG_MODE, MODEL_PATH_V2_FAST, MODEL_PATH_V2, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, RESAMPLING_FREQ
15
 
16
  def load_cudnn():
17
 
@@ -54,44 +55,45 @@ def load_cudnn():
54
 
55
  def get_settings():
56
 
57
- if DEBUG_MODE: print(f"Entering get_settings function...")
 
58
 
59
  is_cuda_available = torch.cuda.is_available()
60
  if is_cuda_available:
61
  device = "cuda"
62
- compute_type = "float16"
 
63
  else:
64
  device = "cpu"
65
- compute_type = "int8"
66
- if DEBUG_MODE: print(f"is_cuda_available: {is_cuda_available}")
67
- if DEBUG_MODE: print(f"device: {device}")
68
- if DEBUG_MODE: print(f"compute_type: {compute_type}")
69
 
 
70
  if DEBUG_MODE: print(f"Exited get_settings function.")
71
 
72
  return device, compute_type
73
 
74
 
 
75
  def load_model(use_v2_fast, device, compute_type):
76
 
77
- if DEBUG_MODE: print(f"Entering load_model function...")
78
-
79
- if DEBUG_MODE: print(f"use_v2_fast: {use_v2_fast}")
80
 
81
  if use_v2_fast:
82
- if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2_FAST} using {device} with {compute_type}...")
83
  model = WhisperModel(
84
  MODEL_PATH_V2_FAST,
85
  device = device,
86
  compute_type = compute_type,
87
  )
88
  else:
89
- if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2} using {device} with {compute_type}...")
90
- # TODO add compute_type to load model
91
- model = whisper_ts.load_model(
92
- MODEL_PATH_V2,
93
- device = device,
94
- )
 
95
 
96
  if DEBUG_MODE: print(f"Exiting load_model function...")
97
 
@@ -109,21 +111,36 @@ def split_input_stereo_channels(audio_path):
109
  elif ext == ".mp3":
110
  audio = AudioSegment.from_file(audio_path, format="mp3")
111
  else:
112
- raise ValueError(f"Unsupported file format for: {audio_path}")
113
 
114
  channels = audio.split_to_mono()
115
 
116
  if len(channels) != 2:
117
- raise ValueError(f"Audio {audio_path} has {len(channels)} channels (instead of 2).")
118
 
119
  channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right
120
  channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left
121
 
122
  if DEBUG_MODE: print(f"Exited split_input_stereo_channels function.")
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- def format_audio(audio_path):
126
 
 
127
  if DEBUG_MODE: print(f"Entering format_audio function...")
128
 
129
  input_audio, sample_rate = torchaudio.load(audio_path)
@@ -135,81 +152,47 @@ def format_audio(audio_path):
135
  input_audio = resampler(input_audio)
136
  input_audio = input_audio.squeeze()
137
 
138
- if DEBUG_MODE: print(f"Exited format_audio function.")
139
-
140
- return input_audio, RESAMPLING_FREQ
141
 
 
142
 
143
- def process_waveforms():
 
 
 
144
 
145
- if DEBUG_MODE: print(f"Entering process_waveforms function...")
146
 
147
- left_waveform, _ = format_audio(LEFT_CHANNEL_TEMP_PATH)
148
- right_waveform, _ = format_audio(RIGHT_CHANNEL_TEMP_PATH)
149
 
150
- # TODO should this be equal to compute_type?
151
- left_waveform = left_waveform.numpy().astype("float16")
152
- right_waveform = right_waveform.numpy().astype("float16")
153
 
154
- if DEBUG_MODE: print(f"Exited process_waveforms function.")
 
 
 
155
 
 
156
  return left_waveform, right_waveform
157
 
158
 
159
- def transcribe_audio_no_fast_model(model, audio_path):
 
160
 
161
- if DEBUG_MODE: print(f"Entering transcribe_audio_no_fast_model function...")
162
-
163
- result = whisper_ts.transcribe(
164
- model,
165
- audio_path,
166
- beam_size=5,
167
- best_of=5,
168
- temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
169
- vad=False,
170
- detect_disfluencies=True,
171
- )
172
-
173
- words = []
174
- for segment in result.get('segments', []):
175
- for word in segment.get('words', []):
176
- word_text = word.get('word', '').strip()
177
- if word_text.startswith(' '):
178
- word_text = word_text[1:]
179
-
180
- words.append({
181
- 'word': word_text,
182
- 'start': word.get('start', 0),
183
- 'end': word.get('end', 0),
184
- 'confidence': word.get('confidence', 0)
185
- })
186
-
187
- return {
188
- 'audio_path': audio_path,
189
- 'text': result['text'].strip(),
190
- 'segments': result.get('segments', []),
191
- 'words': words,
192
- 'duration': result.get('duration', 0),
193
- 'success': True
194
- }
195
 
196
- if DEBUG_MODE: print(f"Exited transcribe_audio_no_fast_model function.")
197
 
 
198
 
199
- def transcribe_channels(left_waveform, right_waveform, model, use_v2_fast):
 
200
 
201
  if DEBUG_MODE: print(f"Entering transcribe_channels function...")
202
 
203
- if DEBUG_MODE: print(f"Preparing to transcribe...")
 
204
 
205
- if use_v2_fast:
206
- left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
207
- right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
208
- left_result = list(left_result)
209
- right_result = list(right_result)
210
- else:
211
- left_result = transcribe_audio_no_fast_model(model, left_waveform)
212
- right_result = transcribe_audio_no_fast_model(model, right_waveform)
213
 
214
  if DEBUG_MODE: print(f"Exited transcribe_channels function.")
215
 
@@ -270,37 +253,27 @@ def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
270
  return merged_transcription.strip()
271
 
272
 
273
- def get_segments(result, speaker_label, use_v2_fast):
274
 
275
  if DEBUG_MODE: print(f"Entering get_segments function...")
276
 
277
- if use_v2_fast:
278
- segments = result
279
- final_segments = [
280
- (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
281
- for seg in segments if seg.text
282
- ]
283
- else:
284
- segments = result.get("segments", [])
285
- if not segments:
286
- final_segments = []
287
- final_segments = [
288
- (seg.get("start", 0.0), seg.get("end", 0.0), speaker_label,
289
- post_process_transcription(seg.get("text", "").strip()))
290
- for seg in segments if seg.get("text")
291
- ]
292
 
293
  if DEBUG_MODE: print(f"EXited get_segments function.")
294
 
295
  return final_segments
296
 
297
 
298
- def post_process_transcripts(left_result, right_result, use_v2_fast):
299
 
300
  if DEBUG_MODE: print(f"Entering post_process_transcripts function...")
301
 
302
- left_segs = get_segments(left_result, "Speaker 1", use_v2_fast)
303
- right_segs = get_segments(right_result, "Speaker 2", use_v2_fast)
304
 
305
  merged_transcript = sorted(
306
  left_segs + right_segs,
@@ -320,8 +293,6 @@ def post_process_transcripts(left_result, right_result, use_v2_fast):
320
  def cleanup_temp_files(*file_paths):
321
 
322
  if DEBUG_MODE: print(f"Entered cleanup_temp_files function...")
323
-
324
- if DEBUG_MODE: print(f"File paths to remove: {file_paths}")
325
 
326
  for path in file_paths:
327
  if path and os.path.exists(path):
@@ -331,28 +302,44 @@ def cleanup_temp_files(*file_paths):
331
  if DEBUG_MODE: print(f"Exited cleanup_temp_files function.")
332
 
333
 
334
- def generate(audio_path, use_v2_fast):
335
 
 
 
336
  if DEBUG_MODE: print(f"Entering generate function...")
337
 
338
  start = time.time()
339
 
340
  load_cudnn()
341
- device, compute_type = get_settings()
342
- model = load_model(use_v2_fast, device, compute_type)
343
- split_input_stereo_channels(audio_path)
344
- left_waveform, right_waveform = process_waveforms()
345
- left_result, right_result = transcribe_channels(left_waveform, right_waveform, model, use_v2_fast)
346
- output = post_process_transcripts(left_result, right_result, use_v2_fast)
347
- cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH)
348
-
349
- end = time.time()
350
- elapsed_secs = end - start
351
 
352
- if DEBUG_MODE: print(f"elapsed_secs: {elapsed_secs}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
- if DEBUG_MODE: print(f"Exited generate function.")
355
 
356
- return output
 
357
 
 
 
 
358
 
 
 
1
  from faster_whisper import WhisperModel
2
+ from transformers import pipeline
3
  from pydub import AudioSegment
4
  import os
5
  import torchaudio
 
10
  from pathlib import Path
11
  import glob
12
  import ctypes
13
+ import numpy as np
14
 
15
+ from settings import DEBUG_MODE, MODEL_PATH_V2_FAST, MODEL_PATH_V1, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, RESAMPLING_FREQ, BATCH_SIZE, TASK
16
 
17
  def load_cudnn():
18
 
 
55
 
56
  def get_settings():
57
 
58
+ if DEBUG_MODE:
59
+ print(f"Entering get_settings function...")
60
 
61
  is_cuda_available = torch.cuda.is_available()
62
  if is_cuda_available:
63
  device = "cuda"
64
+ compute_type = "default"
65
+
66
  else:
67
  device = "cpu"
68
+ compute_type = "default"
 
 
 
69
 
70
+ if DEBUG_MODE: print(f"[SETTINGS] Device: {device}")
71
  if DEBUG_MODE: print(f"Exited get_settings function.")
72
 
73
  return device, compute_type
74
 
75
 
76
+
77
  def load_model(use_v2_fast, device, compute_type):
78
 
79
+ if DEBUG_MODE:
80
+ print(f"Entering load_model function...")
81
+ print(f"[MODEL LOADING] use_v2_fast: {use_v2_fast}")
82
 
83
  if use_v2_fast:
 
84
  model = WhisperModel(
85
  MODEL_PATH_V2_FAST,
86
  device = device,
87
  compute_type = compute_type,
88
  )
89
  else:
90
+ model = pipeline(
91
+ task="automatic-speech-recognition",
92
+ model=MODEL_PATH_V1,
93
+ chunk_length_s=30,
94
+ device=device,
95
+ token=os.getenv("HF_TOKEN")
96
+ )
97
 
98
  if DEBUG_MODE: print(f"Exiting load_model function...")
99
 
 
111
  elif ext == ".mp3":
112
  audio = AudioSegment.from_file(audio_path, format="mp3")
113
  else:
114
+ raise ValueError(f"[FORMAT AUDIO] Unsupported file format for: {audio_path}")
115
 
116
  channels = audio.split_to_mono()
117
 
118
  if len(channels) != 2:
119
+ raise ValueError(f"[FORMAT AUDIO] Audio {audio_path} has {len(channels)} channels (instead of 2).")
120
 
121
  channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right
122
  channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left
123
 
124
  if DEBUG_MODE: print(f"Exited split_input_stereo_channels function.")
125
 
126
+ def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype:
127
+ if DEBUG_MODE: print(f"Entering compute_type_to_audio_dtype function.")
128
+
129
+ compute_type = compute_type.lower()
130
+
131
+ if device.startswith("cuda"):
132
+ if "float16" in compute_type or "int8" in compute_type:
133
+ audio_np_dtype = np.float16
134
+ else:
135
+ audio_np_dtype = np.float32
136
+ else:
137
+ audio_np_dtype = np.float32
138
+
139
+ if DEBUG_MODE: print(f"Exited compute_type_to_audio_dtype function.")
140
+ return audio_np_dtype
141
 
 
142
 
143
+ def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray:
144
  if DEBUG_MODE: print(f"Entering format_audio function...")
145
 
146
  input_audio, sample_rate = torchaudio.load(audio_path)
 
152
  input_audio = resampler(input_audio)
153
  input_audio = input_audio.squeeze()
154
 
155
+ np_dtype = compute_type_to_audio_dtype(compute_type, device)
 
 
156
 
157
+ input_audio = input_audio.numpy().astype(np_dtype)
158
 
159
+ if DEBUG_MODE:
160
+ print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}")
161
+ print(f"Exited format_audio function.")
162
+ return input_audio
163
 
 
164
 
 
 
165
 
166
+ def process_waveforms(device: str, compute_type: str):
 
 
167
 
168
+ if DEBUG_MODE: print(f"Entering process_waveforms function...")
169
+
170
+ left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device)
171
+ right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device)
172
 
173
+ if DEBUG_MODE: print(f"Exited process_waveforms function.")
174
  return left_waveform, right_waveform
175
 
176
 
177
+ def transcribe_pipeline(audio, model):
178
+ if DEBUG_MODE: print(f"Entering transcribe_pipeline function.")
179
 
180
+ text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ if DEBUG_MODE: print(f"Exited transcribe_pipeline function.")
183
 
184
+ return text
185
 
186
+
187
+ def transcribe_channels(left_waveform, right_waveform, model):
188
 
189
  if DEBUG_MODE: print(f"Entering transcribe_channels function...")
190
 
191
+ left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
192
+ right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
193
 
194
+ left_result = list(left_result)
195
+ right_result = list(right_result)
 
 
 
 
 
 
196
 
197
  if DEBUG_MODE: print(f"Exited transcribe_channels function.")
198
 
 
253
  return merged_transcription.strip()
254
 
255
 
256
+ def get_segments(result, speaker_label):
257
 
258
  if DEBUG_MODE: print(f"Entering get_segments function...")
259
 
260
+ segments = result
261
+ final_segments = [
262
+ (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
263
+ for seg in segments if seg.text
264
+ ]
 
 
 
 
 
 
 
 
 
 
265
 
266
  if DEBUG_MODE: print(f"EXited get_segments function.")
267
 
268
  return final_segments
269
 
270
 
271
+ def post_process_transcripts(left_result, right_result):
272
 
273
  if DEBUG_MODE: print(f"Entering post_process_transcripts function...")
274
 
275
+ left_segs = get_segments(left_result, "Speaker 1")
276
+ right_segs = get_segments(right_result, "Speaker 2")
277
 
278
  merged_transcript = sorted(
279
  left_segs + right_segs,
 
293
  def cleanup_temp_files(*file_paths):
294
 
295
  if DEBUG_MODE: print(f"Entered cleanup_temp_files function...")
 
 
296
 
297
  for path in file_paths:
298
  if path and os.path.exists(path):
 
302
  if DEBUG_MODE: print(f"Exited cleanup_temp_files function.")
303
 
304
 
 
305
 
306
+
307
+ def generate(audio_path, use_v2_fast):
308
  if DEBUG_MODE: print(f"Entering generate function...")
309
 
310
  start = time.time()
311
 
312
  load_cudnn()
313
+ device, requested_compute_type = get_settings()
314
+ model = load_model(use_v2_fast, device, requested_compute_type)
 
 
 
 
 
 
 
 
315
 
316
+ if use_v2_fast:
317
+ actual_compute_type = model.model.compute_type
318
+ else:
319
+ actual_compute_type = "float32" #HF pipeline safe default
320
+
321
+ if DEBUG_MODE:
322
+ print(f"[SETTINGS] Requested compute_type: {requested_compute_type}")
323
+ print(f"[SETTINGS] Actual compute_type: {actual_compute_type}")
324
+
325
+ if use_v2_fast:
326
+ split_input_stereo_channels(audio_path)
327
+ left_waveform, right_waveform = process_waveforms(device, actual_compute_type)
328
+ left_result, right_result = transcribe_channels(left_waveform, right_waveform, model)
329
+ output = post_process_transcripts(left_result, right_result)
330
+ cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH)
331
+ else:
332
+ audio = format_audio(audio_path, actual_compute_type, device)
333
+ merged_results = transcribe_pipeline(audio, model)
334
+ output = post_process_transcription(merged_results)
335
 
336
+ end = time.time()
337
 
338
+ audio_duration = torchaudio.info(audio_path).num_frames / torchaudio.info(audio_path).sample_rate
339
+ rtf = (end - start) / audio_duration
340
 
341
+ if DEBUG_MODE: print(f"[LATENCY]: {end - start}")
342
+ if DEBUG_MODE: print(f"[RTF]: {rtf:.2f}")
343
+ if DEBUG_MODE: print(f"Exited generate function.")
344
 
345
+ return output