from __future__ import annotations import os import copy import uuid import logging from typing import List, Optional, Tuple, Dict # Reduce progress/log spam before heavy imports os.environ.setdefault("TQDM_DISABLE", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import numpy as np import torch import torchaudio import soundfile as sf import gradio as gr # NeMo from nemo.collections.asr.models import ASRModel from omegaconf import OmegaConf from nemo.utils import logging as nemo_logging # ---------------------------- # Config # ---------------------------- MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3") TARGET_SR = 16_000 BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "16")) # Increased for quality OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8")) CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0")) # Increased for better context FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0")) # Increased for better finalization # ---------------------------- # Logging (unified) # ---------------------------- LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper() logger = logging.getLogger("parakeet_app") logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) _handler = logging.StreamHandler() _handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) logger.handlers = [_handler] logger.propagate = False # Quiet NeMo logs nemo_logging.setLevel(logging.ERROR) logging.getLogger("nemo").setLevel(logging.ERROR) logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR) torch.set_grad_enabled(False) # ---------------------------- # Audio utils # ---------------------------- def to_mono_np(x: np.ndarray) -> np.ndarray: if x.ndim == 2: x = x.mean(axis=1) return x.astype(np.float32, copy=False) class ResamplerCache: def __init__(self): self._cache: Dict[int, torchaudio.transforms.Resample] = {} def resample(self, wav: np.ndarray, src_sr: int) -> np.ndarray: if src_sr == TARGET_SR: return wav if src_sr not in self._cache: logger.debug(f"create_resampler src_sr={src_sr} -> {TARGET_SR}") self._cache[src_sr] = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=TARGET_SR) t = torch.from_numpy(wav) if t.ndim == 1: t = t.unsqueeze(0) y = self._cache[src_sr](t) return y.squeeze(0).numpy() RESAMPLER = ResamplerCache() def load_mono16k(path: str) -> np.ndarray: """Load any audio file, convert to mono float32 at 16 kHz.""" try: wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C) wav = wav.mean(axis=1).astype(np.float32, copy=False) return RESAMPLER.resample(wav, sr) except Exception: wav_t, sr = torchaudio.load(path) # (C,T) if wav_t.dtype != torch.float32: wav_t = wav_t.float() wav = wav_t.mean(dim=0).numpy() return RESAMPLER.resample(wav, int(sr)) # ---------------------------- # Model manager (MALSD batched beam everywhere, loop_labels=True) # ---------------------------- class ParakeetManager: def __init__(self, device: str = "cpu"): self.device = torch.device(device) logger.info(f"loading_model name={MODEL_NAME} device={self.device}") self.model: ASRModel = ASRModel.from_pretrained(model_name=MODEL_NAME) self.model.to(self.device) self.model.eval() for p in self.model.parameters(): p.requires_grad = False # Base decoding cfg differs by class if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"): self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg) else: self._base_decoding = copy.deepcopy(self.model.cfg.decoding) self._set_malsd_beam() logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}") def _set_malsd_beam(self): cfg = copy.deepcopy(self._base_decoding) cfg.strategy = "malsd_batch" cfg.beam = OmegaConf.create({ "beam_size": BEAM_SIZE, "return_best_hypothesis": True, "score_norm": True, "allow_cuda_graphs": False, # CPU-only "max_symbols_per_step": 10, # Added for stability in MALSD }) OmegaConf.set_struct(cfg, False) cfg["loop_labels"] = True cfg["fused_batch_size"] = -1 # Added for CPU compatibility cfg["compute_timestamps"] = False # Added to match legacy, avoid overhead if hasattr(cfg, "greedy"): cfg.greedy.use_cuda_graph_decoder = False self.model.change_decoding_strategy(cfg) logger.info("decoding_set strategy=malsd_batch loop_labels=True") def _transcribe(self, items: List, *, partial=None): with torch.inference_mode(): return self.model.transcribe( items, batch_size=1 if len(items) == 1 else OFFLINE_BATCH, num_workers=0, return_hypotheses=True, partial_hypothesis=partial, ) # Offline batch def transcribe_files(self, paths: List[str]): n = 0 if not paths else len(paths) logger.info(f"files_run start count={n} batch={OFFLINE_BATCH}") if not paths: return [] arrays = [load_mono16k(p) for p in paths] out = self._transcribe(arrays, partial=None) results = [] for p, o in zip(paths, out): h = o[0] if isinstance(o, list) and o else o text = h if isinstance(h, str) else getattr(h, "text", "") results.append({"path": p, "text": text}) logger.info("files_run ok") return results # Streaming step (rolling hypothesis) def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object: out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None) h = out[0][0] if isinstance(out[0], list) else out[0] return h # Hypothesis # ---------------------------- # Streaming session (no overlap, rolling hypothesis) # ---------------------------- class StreamingSession: def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float): self.mgr = manager self.chunk_s = chunk_s self.flush_pad_s = flush_pad_s self.hyp = None self.pending = np.zeros(0, dtype=np.float32) self.text = "" logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s") def add_audio(self, audio: np.ndarray, src_sr: int): mono = to_mono_np(audio) res = RESAMPLER.resample(mono, src_sr) self.pending = np.concatenate([self.pending, res]) if self.pending.size else res self._drain() def _drain(self): C = int(self.chunk_s * TARGET_SR) while self.pending.size >= C: chunk = self.pending[:C] self.pending = self.pending[C:] try: self.hyp = self.mgr.stream_step(chunk, self.hyp) self.text = getattr(self.hyp, "text", self.text) except Exception: logger.exception("mic_step failed") break def flush(self) -> str: if self.pending.size: pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32) final = np.concatenate([self.pending, pad]) try: self.hyp = self.mgr.stream_step(final, self.hyp) self.text = getattr(self.hyp, "text", self.text) except Exception: logger.exception("mic_flush failed") self.pending = np.zeros(0, dtype=np.float32) return self.text # ---------------------------- # Simple session registry (avoid deepcopy in gr.State) # ---------------------------- SESS: Dict[str, StreamingSession] = {} def _new_session_id() -> str: return uuid.uuid4().hex # ---------------------------- # Gradio callbacks # ---------------------------- MANAGER = ParakeetManager(device="cpu") def _parse_gr_audio(x) -> Tuple[np.ndarray, int]: if x is None: return np.zeros(0, dtype=np.float32), TARGET_SR if isinstance(x, tuple) and len(x) == 2: sr = int(x[0]); arr = np.array(x[1], dtype=np.float32); return arr, sr if isinstance(x, dict) and "data" in x and "sampling_rate" in x: arr = np.array(x["data"], dtype=np.float32); sr = int(x["sampling_rate"]); return arr, sr if isinstance(x, np.ndarray): return x.astype(np.float32, copy=False), TARGET_SR logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload") def mic_step(audio_chunk, sess_id: Optional[str]): if not sess_id or sess_id not in SESS: sess_id = _new_session_id() SESS[sess_id] = StreamingSession(MANAGER, CHUNK_S, FLUSH_PAD_S) sess = SESS[sess_id] try: wav, sr = _parse_gr_audio(audio_chunk) except Exception: logger.exception("mic_parse failed") return sess_id, sess.text if wav.size: sess.add_audio(wav, sr) return sess_id, sess.text def mic_flush(sess_id: Optional[str]): if not sess_id or sess_id not in SESS: return None, "" text = SESS[sess_id].flush() logger.info("mic_flush ok") return None, text def files_run(files): n = 0 if not files else len(files) logger.info(f"files_ui start count={n}") if not files: return [] paths: List[str] = [] for f in files: if isinstance(f, str): paths.append(f) elif hasattr(f, "name"): paths.append(f.name) try: results = MANAGER.transcribe_files(paths) except Exception: logger.exception("files_run failed"); raise table = [[os.path.basename(r["path"]), r["text"]] for r in results] logger.info("files_ui ok") return table # ---------------------------- # UI # ---------------------------- with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo: with gr.Tab("Mic"): mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak") text_out = gr.Textbox(label="Transcript", lines=8) flush_btn = gr.Button("Flush") state_id = gr.State() # only a string id mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out]) flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out]) with gr.Tab("Files"): files = gr.File(file_count="multiple", type="filepath", label="Upload audio files") run_btn = gr.Button("Run") results_table = gr.Dataframe(headers=["file", "text"], label="Results", row_count=(0, "dynamic"), col_count=(2, "fixed")) run_btn.click(files_run, inputs=[files], outputs=[results_table]) demo.queue().launch(ssr_mode=False)