WJ88's picture
Update app.py
a21ea5d verified
from __future__ import annotations
import os
import copy
import uuid
import logging
from typing import List, Optional, Tuple, Dict
# Reduce progress/log spam before heavy imports
os.environ.setdefault("TQDM_DISABLE", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import numpy as np
import torch
import torchaudio
import soundfile as sf
import gradio as gr
# NeMo
from nemo.collections.asr.models import ASRModel
from omegaconf import OmegaConf
from nemo.utils import logging as nemo_logging
# ----------------------------
# Config
# ----------------------------
MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
TARGET_SR = 16_000
BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "16")) # Increased for quality
OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0")) # Increased for better context
FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0")) # Increased for better finalization
# ----------------------------
# Logging (unified)
# ----------------------------
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
logger = logging.getLogger("parakeet_app")
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
_handler = logging.StreamHandler()
_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
logger.handlers = [_handler]
logger.propagate = False
# Quiet NeMo logs
nemo_logging.setLevel(logging.ERROR)
logging.getLogger("nemo").setLevel(logging.ERROR)
logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
torch.set_grad_enabled(False)
# ----------------------------
# Audio utils
# ----------------------------
def to_mono_np(x: np.ndarray) -> np.ndarray:
if x.ndim == 2:
x = x.mean(axis=1)
return x.astype(np.float32, copy=False)
class ResamplerCache:
def __init__(self):
self._cache: Dict[int, torchaudio.transforms.Resample] = {}
def resample(self, wav: np.ndarray, src_sr: int) -> np.ndarray:
if src_sr == TARGET_SR:
return wav
if src_sr not in self._cache:
logger.debug(f"create_resampler src_sr={src_sr} -> {TARGET_SR}")
self._cache[src_sr] = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=TARGET_SR)
t = torch.from_numpy(wav)
if t.ndim == 1:
t = t.unsqueeze(0)
y = self._cache[src_sr](t)
return y.squeeze(0).numpy()
RESAMPLER = ResamplerCache()
def load_mono16k(path: str) -> np.ndarray:
"""Load any audio file, convert to mono float32 at 16 kHz."""
try:
wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C)
wav = wav.mean(axis=1).astype(np.float32, copy=False)
return RESAMPLER.resample(wav, sr)
except Exception:
wav_t, sr = torchaudio.load(path) # (C,T)
if wav_t.dtype != torch.float32:
wav_t = wav_t.float()
wav = wav_t.mean(dim=0).numpy()
return RESAMPLER.resample(wav, int(sr))
# ----------------------------
# Model manager (MALSD batched beam everywhere, loop_labels=True)
# ----------------------------
class ParakeetManager:
def __init__(self, device: str = "cpu"):
self.device = torch.device(device)
logger.info(f"loading_model name={MODEL_NAME} device={self.device}")
self.model: ASRModel = ASRModel.from_pretrained(model_name=MODEL_NAME)
self.model.to(self.device)
self.model.eval()
for p in self.model.parameters():
p.requires_grad = False
# Base decoding cfg differs by class
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
else:
self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
self._set_malsd_beam()
logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
def _set_malsd_beam(self):
cfg = copy.deepcopy(self._base_decoding)
cfg.strategy = "malsd_batch"
cfg.beam = OmegaConf.create({
"beam_size": BEAM_SIZE,
"return_best_hypothesis": True,
"score_norm": True,
"allow_cuda_graphs": False, # CPU-only
"max_symbols_per_step": 10, # Added for stability in MALSD
})
OmegaConf.set_struct(cfg, False)
cfg["loop_labels"] = True
cfg["fused_batch_size"] = -1 # Added for CPU compatibility
cfg["compute_timestamps"] = False # Added to match legacy, avoid overhead
if hasattr(cfg, "greedy"):
cfg.greedy.use_cuda_graph_decoder = False
self.model.change_decoding_strategy(cfg)
logger.info("decoding_set strategy=malsd_batch loop_labels=True")
def _transcribe(self, items: List, *, partial=None):
with torch.inference_mode():
return self.model.transcribe(
items,
batch_size=1 if len(items) == 1 else OFFLINE_BATCH,
num_workers=0,
return_hypotheses=True,
partial_hypothesis=partial,
)
# Offline batch
def transcribe_files(self, paths: List[str]):
n = 0 if not paths else len(paths)
logger.info(f"files_run start count={n} batch={OFFLINE_BATCH}")
if not paths:
return []
arrays = [load_mono16k(p) for p in paths]
out = self._transcribe(arrays, partial=None)
results = []
for p, o in zip(paths, out):
h = o[0] if isinstance(o, list) and o else o
text = h if isinstance(h, str) else getattr(h, "text", "")
results.append({"path": p, "text": text})
logger.info("files_run ok")
return results
# Streaming step (rolling hypothesis)
def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
h = out[0][0] if isinstance(out[0], list) else out[0]
return h # Hypothesis
# ----------------------------
# Streaming session (no overlap, rolling hypothesis)
# ----------------------------
class StreamingSession:
def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
self.mgr = manager
self.chunk_s = chunk_s
self.flush_pad_s = flush_pad_s
self.hyp = None
self.pending = np.zeros(0, dtype=np.float32)
self.text = ""
logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
def add_audio(self, audio: np.ndarray, src_sr: int):
mono = to_mono_np(audio)
res = RESAMPLER.resample(mono, src_sr)
self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
self._drain()
def _drain(self):
C = int(self.chunk_s * TARGET_SR)
while self.pending.size >= C:
chunk = self.pending[:C]
self.pending = self.pending[C:]
try:
self.hyp = self.mgr.stream_step(chunk, self.hyp)
self.text = getattr(self.hyp, "text", self.text)
except Exception:
logger.exception("mic_step failed")
break
def flush(self) -> str:
if self.pending.size:
pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
final = np.concatenate([self.pending, pad])
try:
self.hyp = self.mgr.stream_step(final, self.hyp)
self.text = getattr(self.hyp, "text", self.text)
except Exception:
logger.exception("mic_flush failed")
self.pending = np.zeros(0, dtype=np.float32)
return self.text
# ----------------------------
# Simple session registry (avoid deepcopy in gr.State)
# ----------------------------
SESS: Dict[str, StreamingSession] = {}
def _new_session_id() -> str:
return uuid.uuid4().hex
# ----------------------------
# Gradio callbacks
# ----------------------------
MANAGER = ParakeetManager(device="cpu")
def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
if x is None:
return np.zeros(0, dtype=np.float32), TARGET_SR
if isinstance(x, tuple) and len(x) == 2:
sr = int(x[0]); arr = np.array(x[1], dtype=np.float32); return arr, sr
if isinstance(x, dict) and "data" in x and "sampling_rate" in x:
arr = np.array(x["data"], dtype=np.float32); sr = int(x["sampling_rate"]); return arr, sr
if isinstance(x, np.ndarray):
return x.astype(np.float32, copy=False), TARGET_SR
logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
def mic_step(audio_chunk, sess_id: Optional[str]):
if not sess_id or sess_id not in SESS:
sess_id = _new_session_id()
SESS[sess_id] = StreamingSession(MANAGER, CHUNK_S, FLUSH_PAD_S)
sess = SESS[sess_id]
try:
wav, sr = _parse_gr_audio(audio_chunk)
except Exception:
logger.exception("mic_parse failed")
return sess_id, sess.text
if wav.size:
sess.add_audio(wav, sr)
return sess_id, sess.text
def mic_flush(sess_id: Optional[str]):
if not sess_id or sess_id not in SESS:
return None, ""
text = SESS[sess_id].flush()
logger.info("mic_flush ok")
return None, text
def files_run(files):
n = 0 if not files else len(files)
logger.info(f"files_ui start count={n}")
if not files:
return []
paths: List[str] = []
for f in files:
if isinstance(f, str):
paths.append(f)
elif hasattr(f, "name"):
paths.append(f.name)
try:
results = MANAGER.transcribe_files(paths)
except Exception:
logger.exception("files_run failed"); raise
table = [[os.path.basename(r["path"]), r["text"]] for r in results]
logger.info("files_ui ok")
return table
# ----------------------------
# UI
# ----------------------------
with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo:
with gr.Tab("Mic"):
mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak")
text_out = gr.Textbox(label="Transcript", lines=8)
flush_btn = gr.Button("Flush")
state_id = gr.State() # only a string id
mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out])
flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
with gr.Tab("Files"):
files = gr.File(file_count="multiple", type="filepath", label="Upload audio files")
run_btn = gr.Button("Run")
results_table = gr.Dataframe(headers=["file", "text"], label="Results",
row_count=(0, "dynamic"), col_count=(2, "fixed"))
run_btn.click(files_run, inputs=[files], outputs=[results_table])
demo.queue().launch(ssr_mode=False)