Spaces:
Running
on
Zero
Running
on
Zero
Sarah Solito
commited on
Commit
·
a7c6058
1
Parent(s):
2b7ff2f
Update: code cleaning
Browse files- whisper_cs_dev.py +1 -51
whisper_cs_dev.py
CHANGED
|
@@ -55,9 +55,6 @@ def load_cudnn():
|
|
| 55 |
|
| 56 |
def get_settings():
|
| 57 |
|
| 58 |
-
if DEBUG_MODE:
|
| 59 |
-
print(f"Entering get_settings function...")
|
| 60 |
-
|
| 61 |
is_cuda_available = torch.cuda.is_available()
|
| 62 |
if is_cuda_available:
|
| 63 |
device = "cuda"
|
|
@@ -68,7 +65,6 @@ def get_settings():
|
|
| 68 |
compute_type = "default"
|
| 69 |
|
| 70 |
if DEBUG_MODE: print(f"[SETTINGS] Device: {device}")
|
| 71 |
-
if DEBUG_MODE: print(f"Exited get_settings function.")
|
| 72 |
|
| 73 |
return device, compute_type
|
| 74 |
|
|
@@ -77,7 +73,6 @@ def get_settings():
|
|
| 77 |
def load_model(use_v2_fast, device, compute_type):
|
| 78 |
|
| 79 |
if DEBUG_MODE:
|
| 80 |
-
print(f"Entering load_model function...")
|
| 81 |
print(f"[MODEL LOADING] use_v2_fast: {use_v2_fast}")
|
| 82 |
|
| 83 |
if use_v2_fast:
|
|
@@ -94,16 +89,12 @@ def load_model(use_v2_fast, device, compute_type):
|
|
| 94 |
device=device,
|
| 95 |
token=os.getenv("HF_TOKEN")
|
| 96 |
)
|
| 97 |
-
|
| 98 |
-
if DEBUG_MODE: print(f"Exiting load_model function...")
|
| 99 |
-
|
| 100 |
return model
|
| 101 |
|
| 102 |
|
| 103 |
def split_input_stereo_channels(audio_path):
|
| 104 |
|
| 105 |
-
if DEBUG_MODE: print(f"Entering split_input_stereo_channels function...")
|
| 106 |
-
|
| 107 |
ext = os.path.splitext(audio_path)[1].lower()
|
| 108 |
|
| 109 |
if ext == ".wav":
|
|
@@ -121,10 +112,8 @@ def split_input_stereo_channels(audio_path):
|
|
| 121 |
channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right
|
| 122 |
channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left
|
| 123 |
|
| 124 |
-
if DEBUG_MODE: print(f"Exited split_input_stereo_channels function.")
|
| 125 |
|
| 126 |
def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype:
|
| 127 |
-
if DEBUG_MODE: print(f"Entering compute_type_to_audio_dtype function.")
|
| 128 |
|
| 129 |
compute_type = compute_type.lower()
|
| 130 |
|
|
@@ -136,12 +125,10 @@ def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype:
|
|
| 136 |
else:
|
| 137 |
audio_np_dtype = np.float32
|
| 138 |
|
| 139 |
-
if DEBUG_MODE: print(f"Exited compute_type_to_audio_dtype function.")
|
| 140 |
return audio_np_dtype
|
| 141 |
|
| 142 |
|
| 143 |
def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray:
|
| 144 |
-
if DEBUG_MODE: print(f"Entering format_audio function...")
|
| 145 |
|
| 146 |
input_audio, sample_rate = torchaudio.load(audio_path)
|
| 147 |
|
|
@@ -158,44 +145,31 @@ def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray:
|
|
| 158 |
|
| 159 |
if DEBUG_MODE:
|
| 160 |
print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}")
|
| 161 |
-
print(f"Exited format_audio function.")
|
| 162 |
return input_audio
|
| 163 |
|
| 164 |
|
| 165 |
|
| 166 |
def process_waveforms(device: str, compute_type: str):
|
| 167 |
|
| 168 |
-
if DEBUG_MODE: print(f"Entering process_waveforms function...")
|
| 169 |
-
|
| 170 |
left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device)
|
| 171 |
right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device)
|
| 172 |
|
| 173 |
-
if DEBUG_MODE: print(f"Exited process_waveforms function.")
|
| 174 |
return left_waveform, right_waveform
|
| 175 |
|
| 176 |
|
| 177 |
def transcribe_pipeline(audio, model):
|
| 178 |
-
if DEBUG_MODE: print(f"Entering transcribe_pipeline function.")
|
| 179 |
-
|
| 180 |
text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"]
|
| 181 |
-
|
| 182 |
-
if DEBUG_MODE: print(f"Exited transcribe_pipeline function.")
|
| 183 |
-
|
| 184 |
return text
|
| 185 |
|
| 186 |
|
| 187 |
def transcribe_channels(left_waveform, right_waveform, model):
|
| 188 |
|
| 189 |
-
if DEBUG_MODE: print(f"Entering transcribe_channels function...")
|
| 190 |
-
|
| 191 |
left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
|
| 192 |
right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
|
| 193 |
|
| 194 |
left_result = list(left_result)
|
| 195 |
right_result = list(right_result)
|
| 196 |
|
| 197 |
-
if DEBUG_MODE: print(f"Exited transcribe_channels function.")
|
| 198 |
-
|
| 199 |
return left_result, right_result
|
| 200 |
|
| 201 |
|
|
@@ -255,23 +229,17 @@ def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
|
|
| 255 |
|
| 256 |
def get_segments(result, speaker_label):
|
| 257 |
|
| 258 |
-
if DEBUG_MODE: print(f"Entering get_segments function...")
|
| 259 |
-
|
| 260 |
segments = result
|
| 261 |
final_segments = [
|
| 262 |
(seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
|
| 263 |
for seg in segments if seg.text
|
| 264 |
]
|
| 265 |
|
| 266 |
-
if DEBUG_MODE: print(f"EXited get_segments function.")
|
| 267 |
-
|
| 268 |
return final_segments
|
| 269 |
|
| 270 |
|
| 271 |
def post_process_transcripts(left_result, right_result):
|
| 272 |
|
| 273 |
-
if DEBUG_MODE: print(f"Entering post_process_transcripts function...")
|
| 274 |
-
|
| 275 |
left_segs = get_segments(left_result, "Speaker 1")
|
| 276 |
right_segs = get_segments(right_result, "Speaker 2")
|
| 277 |
|
|
@@ -285,29 +253,20 @@ def post_process_transcripts(left_result, right_result):
|
|
| 285 |
clean_output += f"[{speaker}]: {text}\n"
|
| 286 |
clean_output = clean_output.strip()
|
| 287 |
|
| 288 |
-
if DEBUG_MODE: print(f"Exited post_process_transcripts function.")
|
| 289 |
-
|
| 290 |
return clean_output
|
| 291 |
|
| 292 |
|
| 293 |
def cleanup_temp_files(*file_paths):
|
| 294 |
-
|
| 295 |
-
if DEBUG_MODE: print(f"Entered cleanup_temp_files function...")
|
| 296 |
|
| 297 |
for path in file_paths:
|
| 298 |
if path and os.path.exists(path):
|
| 299 |
if DEBUG_MODE: print(f"Removing path: {path}")
|
| 300 |
os.remove(path)
|
| 301 |
|
| 302 |
-
if DEBUG_MODE: print(f"Exited cleanup_temp_files function.")
|
| 303 |
-
|
| 304 |
|
| 305 |
|
| 306 |
|
| 307 |
def generate(audio_path, use_v2_fast):
|
| 308 |
-
if DEBUG_MODE: print(f"Entering generate function...")
|
| 309 |
-
|
| 310 |
-
start = time.time()
|
| 311 |
|
| 312 |
load_cudnn()
|
| 313 |
device, requested_compute_type = get_settings()
|
|
@@ -333,13 +292,4 @@ def generate(audio_path, use_v2_fast):
|
|
| 333 |
merged_results = transcribe_pipeline(audio, model)
|
| 334 |
output = post_process_transcription(merged_results)
|
| 335 |
|
| 336 |
-
end = time.time()
|
| 337 |
-
|
| 338 |
-
audio_duration = torchaudio.info(audio_path).num_frames / torchaudio.info(audio_path).sample_rate
|
| 339 |
-
rtf = (end - start) / audio_duration
|
| 340 |
-
|
| 341 |
-
if DEBUG_MODE: print(f"[LATENCY]: {end - start}")
|
| 342 |
-
if DEBUG_MODE: print(f"[RTF]: {rtf:.2f}")
|
| 343 |
-
if DEBUG_MODE: print(f"Exited generate function.")
|
| 344 |
-
|
| 345 |
return output
|
|
|
|
| 55 |
|
| 56 |
def get_settings():
|
| 57 |
|
|
|
|
|
|
|
|
|
|
| 58 |
is_cuda_available = torch.cuda.is_available()
|
| 59 |
if is_cuda_available:
|
| 60 |
device = "cuda"
|
|
|
|
| 65 |
compute_type = "default"
|
| 66 |
|
| 67 |
if DEBUG_MODE: print(f"[SETTINGS] Device: {device}")
|
|
|
|
| 68 |
|
| 69 |
return device, compute_type
|
| 70 |
|
|
|
|
| 73 |
def load_model(use_v2_fast, device, compute_type):
|
| 74 |
|
| 75 |
if DEBUG_MODE:
|
|
|
|
| 76 |
print(f"[MODEL LOADING] use_v2_fast: {use_v2_fast}")
|
| 77 |
|
| 78 |
if use_v2_fast:
|
|
|
|
| 89 |
device=device,
|
| 90 |
token=os.getenv("HF_TOKEN")
|
| 91 |
)
|
| 92 |
+
|
|
|
|
|
|
|
| 93 |
return model
|
| 94 |
|
| 95 |
|
| 96 |
def split_input_stereo_channels(audio_path):
|
| 97 |
|
|
|
|
|
|
|
| 98 |
ext = os.path.splitext(audio_path)[1].lower()
|
| 99 |
|
| 100 |
if ext == ".wav":
|
|
|
|
| 112 |
channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right
|
| 113 |
channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left
|
| 114 |
|
|
|
|
| 115 |
|
| 116 |
def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype:
|
|
|
|
| 117 |
|
| 118 |
compute_type = compute_type.lower()
|
| 119 |
|
|
|
|
| 125 |
else:
|
| 126 |
audio_np_dtype = np.float32
|
| 127 |
|
|
|
|
| 128 |
return audio_np_dtype
|
| 129 |
|
| 130 |
|
| 131 |
def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray:
|
|
|
|
| 132 |
|
| 133 |
input_audio, sample_rate = torchaudio.load(audio_path)
|
| 134 |
|
|
|
|
| 145 |
|
| 146 |
if DEBUG_MODE:
|
| 147 |
print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}")
|
|
|
|
| 148 |
return input_audio
|
| 149 |
|
| 150 |
|
| 151 |
|
| 152 |
def process_waveforms(device: str, compute_type: str):
|
| 153 |
|
|
|
|
|
|
|
| 154 |
left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device)
|
| 155 |
right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device)
|
| 156 |
|
|
|
|
| 157 |
return left_waveform, right_waveform
|
| 158 |
|
| 159 |
|
| 160 |
def transcribe_pipeline(audio, model):
|
|
|
|
|
|
|
| 161 |
text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"]
|
|
|
|
|
|
|
|
|
|
| 162 |
return text
|
| 163 |
|
| 164 |
|
| 165 |
def transcribe_channels(left_waveform, right_waveform, model):
|
| 166 |
|
|
|
|
|
|
|
| 167 |
left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
|
| 168 |
right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
|
| 169 |
|
| 170 |
left_result = list(left_result)
|
| 171 |
right_result = list(right_result)
|
| 172 |
|
|
|
|
|
|
|
| 173 |
return left_result, right_result
|
| 174 |
|
| 175 |
|
|
|
|
| 229 |
|
| 230 |
def get_segments(result, speaker_label):
|
| 231 |
|
|
|
|
|
|
|
| 232 |
segments = result
|
| 233 |
final_segments = [
|
| 234 |
(seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
|
| 235 |
for seg in segments if seg.text
|
| 236 |
]
|
| 237 |
|
|
|
|
|
|
|
| 238 |
return final_segments
|
| 239 |
|
| 240 |
|
| 241 |
def post_process_transcripts(left_result, right_result):
|
| 242 |
|
|
|
|
|
|
|
| 243 |
left_segs = get_segments(left_result, "Speaker 1")
|
| 244 |
right_segs = get_segments(right_result, "Speaker 2")
|
| 245 |
|
|
|
|
| 253 |
clean_output += f"[{speaker}]: {text}\n"
|
| 254 |
clean_output = clean_output.strip()
|
| 255 |
|
|
|
|
|
|
|
| 256 |
return clean_output
|
| 257 |
|
| 258 |
|
| 259 |
def cleanup_temp_files(*file_paths):
|
|
|
|
|
|
|
| 260 |
|
| 261 |
for path in file_paths:
|
| 262 |
if path and os.path.exists(path):
|
| 263 |
if DEBUG_MODE: print(f"Removing path: {path}")
|
| 264 |
os.remove(path)
|
| 265 |
|
|
|
|
|
|
|
| 266 |
|
| 267 |
|
| 268 |
|
| 269 |
def generate(audio_path, use_v2_fast):
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
load_cudnn()
|
| 272 |
device, requested_compute_type = get_settings()
|
|
|
|
| 292 |
merged_results = transcribe_pipeline(audio, model)
|
| 293 |
output = post_process_transcription(merged_results)
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
return output
|