Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,858 Bytes
1619dcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import torch
import torchaudio
import numpy as np
import re
from pydub import AudioSegment
from settings import DEBUG_MODE, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, RESAMPLING_FREQ
import soundfile as sf
# ------------------ DEBUG UTILITIES ------------------
def debug_print(*args, **kwargs):
if DEBUG_MODE:
print(*args, **kwargs)
# ------------------ Device Settings ------------------
def get_settings():
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "default"
if DEBUG_MODE: print(f"[SETTINGS] Device: {device}")
return device, compute_type
# ------------------ Audio Utilities ------------------
def split_input_stereo_channels(audio_path):
ext = os.path.splitext(audio_path)[1].lower()
if ext == ".wav":
audio = AudioSegment.from_wav(audio_path)
elif ext == ".mp3":
audio = AudioSegment.from_file(audio_path, format="mp3")
else:
raise ValueError(f"[FORMAT AUDIO] Unsupported file format for: {audio_path}")
channels = audio.split_to_mono()
if len(channels) != 2:
raise ValueError(f"[FORMAT AUDIO] Audio {audio_path} has {len(channels)} channels (instead of 2).")
channels[0].export(LEFT_CHANNEL_TEMP_PATH, format="wav")
channels[1].export(RIGHT_CHANNEL_TEMP_PATH, format="wav")
def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype:
compute_type = compute_type.lower()
if device.startswith("cuda"):
if "float16" in compute_type or "int8" in compute_type:
audio_np_dtype = np.float16
else:
audio_np_dtype = np.float32
else:
audio_np_dtype = np.float32
return audio_np_dtype
def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray:
input_audio, sample_rate = torchaudio.load(audio_path)
if input_audio.shape[0] == 2:
input_audio = torch.mean(input_audio, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=RESAMPLING_FREQ)
input_audio = resampler(input_audio)
input_audio = input_audio.squeeze()
np_dtype = compute_type_to_audio_dtype(compute_type, device)
input_audio = input_audio.numpy().astype(np_dtype)
if DEBUG_MODE:
print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}")
return input_audio
def process_waveforms(device: str, compute_type: str):
left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device)
right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device)
return left_waveform, right_waveform
# ------------------ Post-processing ------------------
def get_segments(result, speaker_label):
segments = result
final_segments = [
(seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
for seg in segments if seg.text
]
return final_segments
def post_process_transcripts(left_result, right_result, civil_channel):
if civil_channel == "Left":
civil_segs = get_segments(left_result, "Civil")
operador_segs = get_segments(right_result, "Operador")
else:
civil_segs = get_segments(right_result, "Civil")
operador_segs = get_segments(left_result, "Operador")
merged_transcript = sorted(
operador_segs + civil_segs,
key=lambda x: float(x[0]) if x[0] is not None else float("inf")
)
clean_output_asr = ""
clean_output_meteo = ""
for start, end, speaker, text in merged_transcript:
clean_output_asr += f"[{speaker}]: {text}\n"
clean_output_meteo += f"{text}"
clean_output_asr = clean_output_asr.strip()
clean_output_meteo = clean_output_meteo.strip()
return clean_output_asr, clean_output_meteo
def post_process_transcription(transcription, max_repeats=2):
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
cleaned_tokens = []
repetition_count = 0
previous_token = None
for token in tokens:
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
if reduced_token == previous_token:
repetition_count += 1
if repetition_count <= max_repeats:
cleaned_tokens.append(reduced_token)
else:
repetition_count = 1
cleaned_tokens.append(reduced_token)
previous_token = reduced_token
cleaned_transcription = " ".join(cleaned_tokens)
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
return cleaned_transcription
# TODO not used right now, decide to use it or not
def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text)
merged_transcription = ''
current_speaker = None
current_segment = []
for i in range(1, len(segments) - 1, 2):
speaker_tag = segments[i]
text = segments[i + 1].strip()
speaker = re.search(r'\d{2}', speaker_tag).group()
if speaker == current_speaker:
current_segment.append(text)
else:
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
current_speaker = speaker
current_segment = [text]
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
return merged_transcription.strip()
def cleanup_temp_files(*file_paths):
for path in file_paths:
if path and os.path.exists(path):
os.remove(path)
def sec_to_hhmmss(seconds):
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
return f"{h:02d}:{m:02d}:{s:02d}"
|