Sarah Solito commited on
Commit
a7c6058
·
1 Parent(s): 2b7ff2f

Update: code cleaning

Browse files
Files changed (1) hide show
  1. whisper_cs_dev.py +1 -51
whisper_cs_dev.py CHANGED
@@ -55,9 +55,6 @@ def load_cudnn():
55
 
56
  def get_settings():
57
 
58
- if DEBUG_MODE:
59
- print(f"Entering get_settings function...")
60
-
61
  is_cuda_available = torch.cuda.is_available()
62
  if is_cuda_available:
63
  device = "cuda"
@@ -68,7 +65,6 @@ def get_settings():
68
  compute_type = "default"
69
 
70
  if DEBUG_MODE: print(f"[SETTINGS] Device: {device}")
71
- if DEBUG_MODE: print(f"Exited get_settings function.")
72
 
73
  return device, compute_type
74
 
@@ -77,7 +73,6 @@ def get_settings():
77
  def load_model(use_v2_fast, device, compute_type):
78
 
79
  if DEBUG_MODE:
80
- print(f"Entering load_model function...")
81
  print(f"[MODEL LOADING] use_v2_fast: {use_v2_fast}")
82
 
83
  if use_v2_fast:
@@ -94,16 +89,12 @@ def load_model(use_v2_fast, device, compute_type):
94
  device=device,
95
  token=os.getenv("HF_TOKEN")
96
  )
97
-
98
- if DEBUG_MODE: print(f"Exiting load_model function...")
99
-
100
  return model
101
 
102
 
103
  def split_input_stereo_channels(audio_path):
104
 
105
- if DEBUG_MODE: print(f"Entering split_input_stereo_channels function...")
106
-
107
  ext = os.path.splitext(audio_path)[1].lower()
108
 
109
  if ext == ".wav":
@@ -121,10 +112,8 @@ def split_input_stereo_channels(audio_path):
121
  channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right
122
  channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left
123
 
124
- if DEBUG_MODE: print(f"Exited split_input_stereo_channels function.")
125
 
126
  def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype:
127
- if DEBUG_MODE: print(f"Entering compute_type_to_audio_dtype function.")
128
 
129
  compute_type = compute_type.lower()
130
 
@@ -136,12 +125,10 @@ def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype:
136
  else:
137
  audio_np_dtype = np.float32
138
 
139
- if DEBUG_MODE: print(f"Exited compute_type_to_audio_dtype function.")
140
  return audio_np_dtype
141
 
142
 
143
  def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray:
144
- if DEBUG_MODE: print(f"Entering format_audio function...")
145
 
146
  input_audio, sample_rate = torchaudio.load(audio_path)
147
 
@@ -158,44 +145,31 @@ def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray:
158
 
159
  if DEBUG_MODE:
160
  print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}")
161
- print(f"Exited format_audio function.")
162
  return input_audio
163
 
164
 
165
 
166
  def process_waveforms(device: str, compute_type: str):
167
 
168
- if DEBUG_MODE: print(f"Entering process_waveforms function...")
169
-
170
  left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device)
171
  right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device)
172
 
173
- if DEBUG_MODE: print(f"Exited process_waveforms function.")
174
  return left_waveform, right_waveform
175
 
176
 
177
  def transcribe_pipeline(audio, model):
178
- if DEBUG_MODE: print(f"Entering transcribe_pipeline function.")
179
-
180
  text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"]
181
-
182
- if DEBUG_MODE: print(f"Exited transcribe_pipeline function.")
183
-
184
  return text
185
 
186
 
187
  def transcribe_channels(left_waveform, right_waveform, model):
188
 
189
- if DEBUG_MODE: print(f"Entering transcribe_channels function...")
190
-
191
  left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
192
  right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
193
 
194
  left_result = list(left_result)
195
  right_result = list(right_result)
196
 
197
- if DEBUG_MODE: print(f"Exited transcribe_channels function.")
198
-
199
  return left_result, right_result
200
 
201
 
@@ -255,23 +229,17 @@ def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
255
 
256
  def get_segments(result, speaker_label):
257
 
258
- if DEBUG_MODE: print(f"Entering get_segments function...")
259
-
260
  segments = result
261
  final_segments = [
262
  (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
263
  for seg in segments if seg.text
264
  ]
265
 
266
- if DEBUG_MODE: print(f"EXited get_segments function.")
267
-
268
  return final_segments
269
 
270
 
271
  def post_process_transcripts(left_result, right_result):
272
 
273
- if DEBUG_MODE: print(f"Entering post_process_transcripts function...")
274
-
275
  left_segs = get_segments(left_result, "Speaker 1")
276
  right_segs = get_segments(right_result, "Speaker 2")
277
 
@@ -285,29 +253,20 @@ def post_process_transcripts(left_result, right_result):
285
  clean_output += f"[{speaker}]: {text}\n"
286
  clean_output = clean_output.strip()
287
 
288
- if DEBUG_MODE: print(f"Exited post_process_transcripts function.")
289
-
290
  return clean_output
291
 
292
 
293
  def cleanup_temp_files(*file_paths):
294
-
295
- if DEBUG_MODE: print(f"Entered cleanup_temp_files function...")
296
 
297
  for path in file_paths:
298
  if path and os.path.exists(path):
299
  if DEBUG_MODE: print(f"Removing path: {path}")
300
  os.remove(path)
301
 
302
- if DEBUG_MODE: print(f"Exited cleanup_temp_files function.")
303
-
304
 
305
 
306
 
307
  def generate(audio_path, use_v2_fast):
308
- if DEBUG_MODE: print(f"Entering generate function...")
309
-
310
- start = time.time()
311
 
312
  load_cudnn()
313
  device, requested_compute_type = get_settings()
@@ -333,13 +292,4 @@ def generate(audio_path, use_v2_fast):
333
  merged_results = transcribe_pipeline(audio, model)
334
  output = post_process_transcription(merged_results)
335
 
336
- end = time.time()
337
-
338
- audio_duration = torchaudio.info(audio_path).num_frames / torchaudio.info(audio_path).sample_rate
339
- rtf = (end - start) / audio_duration
340
-
341
- if DEBUG_MODE: print(f"[LATENCY]: {end - start}")
342
- if DEBUG_MODE: print(f"[RTF]: {rtf:.2f}")
343
- if DEBUG_MODE: print(f"Exited generate function.")
344
-
345
  return output
 
55
 
56
  def get_settings():
57
 
 
 
 
58
  is_cuda_available = torch.cuda.is_available()
59
  if is_cuda_available:
60
  device = "cuda"
 
65
  compute_type = "default"
66
 
67
  if DEBUG_MODE: print(f"[SETTINGS] Device: {device}")
 
68
 
69
  return device, compute_type
70
 
 
73
  def load_model(use_v2_fast, device, compute_type):
74
 
75
  if DEBUG_MODE:
 
76
  print(f"[MODEL LOADING] use_v2_fast: {use_v2_fast}")
77
 
78
  if use_v2_fast:
 
89
  device=device,
90
  token=os.getenv("HF_TOKEN")
91
  )
92
+
 
 
93
  return model
94
 
95
 
96
  def split_input_stereo_channels(audio_path):
97
 
 
 
98
  ext = os.path.splitext(audio_path)[1].lower()
99
 
100
  if ext == ".wav":
 
112
  channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right
113
  channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left
114
 
 
115
 
116
  def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype:
 
117
 
118
  compute_type = compute_type.lower()
119
 
 
125
  else:
126
  audio_np_dtype = np.float32
127
 
 
128
  return audio_np_dtype
129
 
130
 
131
  def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray:
 
132
 
133
  input_audio, sample_rate = torchaudio.load(audio_path)
134
 
 
145
 
146
  if DEBUG_MODE:
147
  print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}")
 
148
  return input_audio
149
 
150
 
151
 
152
  def process_waveforms(device: str, compute_type: str):
153
 
 
 
154
  left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device)
155
  right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device)
156
 
 
157
  return left_waveform, right_waveform
158
 
159
 
160
  def transcribe_pipeline(audio, model):
 
 
161
  text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"]
 
 
 
162
  return text
163
 
164
 
165
  def transcribe_channels(left_waveform, right_waveform, model):
166
 
 
 
167
  left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
168
  right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
169
 
170
  left_result = list(left_result)
171
  right_result = list(right_result)
172
 
 
 
173
  return left_result, right_result
174
 
175
 
 
229
 
230
  def get_segments(result, speaker_label):
231
 
 
 
232
  segments = result
233
  final_segments = [
234
  (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
235
  for seg in segments if seg.text
236
  ]
237
 
 
 
238
  return final_segments
239
 
240
 
241
  def post_process_transcripts(left_result, right_result):
242
 
 
 
243
  left_segs = get_segments(left_result, "Speaker 1")
244
  right_segs = get_segments(right_result, "Speaker 2")
245
 
 
253
  clean_output += f"[{speaker}]: {text}\n"
254
  clean_output = clean_output.strip()
255
 
 
 
256
  return clean_output
257
 
258
 
259
  def cleanup_temp_files(*file_paths):
 
 
260
 
261
  for path in file_paths:
262
  if path and os.path.exists(path):
263
  if DEBUG_MODE: print(f"Removing path: {path}")
264
  os.remove(path)
265
 
 
 
266
 
267
 
268
 
269
  def generate(audio_path, use_v2_fast):
 
 
 
270
 
271
  load_cudnn()
272
  device, requested_compute_type = get_settings()
 
292
  merged_results = transcribe_pipeline(audio, model)
293
  output = post_process_transcription(merged_results)
294
 
 
 
 
 
 
 
 
 
 
295
  return output