asr-inference / audio_utils.py
AbirMessaoudi's picture
fase_1, fase_2 releases (#46)
1619dcb verified
raw
history blame
5.86 kB
import os
import torch
import torchaudio
import numpy as np
import re
from pydub import AudioSegment
from settings import DEBUG_MODE, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, RESAMPLING_FREQ
import soundfile as sf
# ------------------ DEBUG UTILITIES ------------------
def debug_print(*args, **kwargs):
if DEBUG_MODE:
print(*args, **kwargs)
# ------------------ Device Settings ------------------
def get_settings():
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "default"
if DEBUG_MODE: print(f"[SETTINGS] Device: {device}")
return device, compute_type
# ------------------ Audio Utilities ------------------
def split_input_stereo_channels(audio_path):
ext = os.path.splitext(audio_path)[1].lower()
if ext == ".wav":
audio = AudioSegment.from_wav(audio_path)
elif ext == ".mp3":
audio = AudioSegment.from_file(audio_path, format="mp3")
else:
raise ValueError(f"[FORMAT AUDIO] Unsupported file format for: {audio_path}")
channels = audio.split_to_mono()
if len(channels) != 2:
raise ValueError(f"[FORMAT AUDIO] Audio {audio_path} has {len(channels)} channels (instead of 2).")
channels[0].export(LEFT_CHANNEL_TEMP_PATH, format="wav")
channels[1].export(RIGHT_CHANNEL_TEMP_PATH, format="wav")
def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype:
compute_type = compute_type.lower()
if device.startswith("cuda"):
if "float16" in compute_type or "int8" in compute_type:
audio_np_dtype = np.float16
else:
audio_np_dtype = np.float32
else:
audio_np_dtype = np.float32
return audio_np_dtype
def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray:
input_audio, sample_rate = torchaudio.load(audio_path)
if input_audio.shape[0] == 2:
input_audio = torch.mean(input_audio, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=RESAMPLING_FREQ)
input_audio = resampler(input_audio)
input_audio = input_audio.squeeze()
np_dtype = compute_type_to_audio_dtype(compute_type, device)
input_audio = input_audio.numpy().astype(np_dtype)
if DEBUG_MODE:
print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}")
return input_audio
def process_waveforms(device: str, compute_type: str):
left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device)
right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device)
return left_waveform, right_waveform
# ------------------ Post-processing ------------------
def get_segments(result, speaker_label):
segments = result
final_segments = [
(seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
for seg in segments if seg.text
]
return final_segments
def post_process_transcripts(left_result, right_result, civil_channel):
if civil_channel == "Left":
civil_segs = get_segments(left_result, "Civil")
operador_segs = get_segments(right_result, "Operador")
else:
civil_segs = get_segments(right_result, "Civil")
operador_segs = get_segments(left_result, "Operador")
merged_transcript = sorted(
operador_segs + civil_segs,
key=lambda x: float(x[0]) if x[0] is not None else float("inf")
)
clean_output_asr = ""
clean_output_meteo = ""
for start, end, speaker, text in merged_transcript:
clean_output_asr += f"[{speaker}]: {text}\n"
clean_output_meteo += f"{text}"
clean_output_asr = clean_output_asr.strip()
clean_output_meteo = clean_output_meteo.strip()
return clean_output_asr, clean_output_meteo
def post_process_transcription(transcription, max_repeats=2):
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
cleaned_tokens = []
repetition_count = 0
previous_token = None
for token in tokens:
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
if reduced_token == previous_token:
repetition_count += 1
if repetition_count <= max_repeats:
cleaned_tokens.append(reduced_token)
else:
repetition_count = 1
cleaned_tokens.append(reduced_token)
previous_token = reduced_token
cleaned_transcription = " ".join(cleaned_tokens)
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
return cleaned_transcription
# TODO not used right now, decide to use it or not
def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text)
merged_transcription = ''
current_speaker = None
current_segment = []
for i in range(1, len(segments) - 1, 2):
speaker_tag = segments[i]
text = segments[i + 1].strip()
speaker = re.search(r'\d{2}', speaker_tag).group()
if speaker == current_speaker:
current_segment.append(text)
else:
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
current_speaker = speaker
current_segment = [text]
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
return merged_transcription.strip()
def cleanup_temp_files(*file_paths):
for path in file_paths:
if path and os.path.exists(path):
os.remove(path)
def sec_to_hhmmss(seconds):
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
return f"{h:02d}:{m:02d}:{s:02d}"