File size: 3,131 Bytes
1619dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from faster_whisper import WhisperModel
from transformers import pipeline
import os
import time
from settings import MODEL_PATH_V2_FAST, MODEL_PATH_V1, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, BATCH_SIZE, TASK
from audio_utils import debug_print, get_settings, split_input_stereo_channels, format_audio, process_waveforms, post_process_transcripts, post_process_transcription, post_merge_consecutive_segments_from_text, cleanup_temp_files

hf_token = os.getenv("HF_TOKEN")

ASR_MODEL_V2 = None
ASR_MODEL_V1 = None

def get_asr_model_v2(DEVICE, COMPUTE_TYPE):
    global ASR_MODEL_V2
    if ASR_MODEL_V2 is None:
        debug_print("[MODEL LOADING] Loading ASR v2_fast model...")
        ASR_MODEL_V2 = WhisperModel(
            MODEL_PATH_V2_FAST,
            device=DEVICE,
            compute_type=COMPUTE_TYPE
        )
        debug_print("[MODEL LOADING]v2_fast model loaded")
    return ASR_MODEL_V2

def get_asr_model_v1(DEVICE):
    global ASR_MODEL_V1
    if ASR_MODEL_V1 is None:
        debug_print("[MODEL LOADING]Loading ASR v1 pipeline model...")
        ASR_MODEL_V1 = pipeline(
            task="automatic-speech-recognition",
            model=MODEL_PATH_V1,
            chunk_length_s=30,
            device=0 if DEVICE == "cuda" else -1,
            token=hf_token
        )
        debug_print("[MODEL LOADING]ASR v1 model loaded")
    return ASR_MODEL_V1

def transcribe_asr(audio, model):
    text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"]
    return text

def transcribe_faster_asr(left_waveform, right_waveform, model):
    left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
    right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
    return list(left_result), list(right_result)

def generate_fase_1(audio_path, model_version, civil_channel):
    DEVICE, COMPUTE_TYPE = get_settings()

    debug_print(f"[Fase1] Starting inference with model version: {model_version}")

    if model_version == "v2_fast":
        asr_model = get_asr_model_v2(DEVICE, COMPUTE_TYPE)
        actual_compute_type = asr_model.model.compute_type
        debug_print(f"[SETTINGS] Device: {DEVICE}, Compute type: {actual_compute_type}")

        split_input_stereo_channels(audio_path)
        left_waveform, right_waveform = process_waveforms(DEVICE, actual_compute_type)

        debug_print(f"[SETTINGS] Civil channel: {civil_channel}")
        left_result, right_result = transcribe_faster_asr(left_waveform, right_waveform, asr_model)

        text, _ = post_process_transcripts(left_result, right_result, civil_channel)
        cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH)
    else:
        actual_compute_type = "float32"  # HF pipeline safe default
        debug_print(f"[SETTINGS] Device: {DEVICE}, Compute type: {actual_compute_type}")

        asr_model = get_asr_model_v1(DEVICE)
        audio = format_audio(audio_path, actual_compute_type, DEVICE)
        result = transcribe_asr(audio, asr_model)
        text = post_process_transcription(result)

    return text