pure ai refactoring
Browse files
app.py
CHANGED
|
@@ -1,118 +1,234 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import nemo.collections.asr as nemo_asr
|
| 3 |
from omegaconf import OmegaConf
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
|
| 6 |
-
from nemo.collections.asr.parts.utils.rnnt_utils import
|
| 7 |
from nemo.collections.asr.parts.utils.streaming_utils import ContextSize, StreamingBatchedAudioBuffer
|
| 8 |
|
| 9 |
-
def _div(a, b): return (a // b) * b
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
@dataclass
|
| 12 |
-
class
|
| 13 |
-
|
| 14 |
left_s: float = 10.0
|
| 15 |
chunk_s: float = 2.0
|
| 16 |
right_s: float = 2.0
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
device: str = "cpu"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
l = 0
|
| 74 |
-
r = min(ctx_samp.chunk + ctx_samp.right, a.shape[1])
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
while l < a.shape[1]:
|
| 78 |
-
clen = min(r, a.shape[1]) - l
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
l = r
|
| 93 |
-
r = min(r + ctx_samp.chunk, a.shape[1])
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
| 96 |
return outs[0].text if outs else ""
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
if __name__ == "__main__":
|
| 118 |
-
demo
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""Refactored Gradio app for streaming ASR with NVIDIA NeMo Parakeet-TDT-0.6B-v3.
|
| 3 |
+
|
| 4 |
+
Functionality preserved. Structure simplified and documented.
|
| 5 |
+
- Buffered streaming on CPU by default (configurable device).
|
| 6 |
+
- Monophonic conversion and resampling to model sample rate.
|
| 7 |
+
- Greedy batched RNNT decoding with label-looping.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torchaudio
|
| 18 |
+
import gradio as gr
|
| 19 |
+
|
| 20 |
import nemo.collections.asr as nemo_asr
|
| 21 |
from omegaconf import OmegaConf
|
|
|
|
| 22 |
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
|
| 23 |
+
from nemo.collections.asr.parts.utils.rnnt_utils import batched_hyps_to_hypotheses
|
| 24 |
from nemo.collections.asr.parts.utils.streaming_utils import ContextSize, StreamingBatchedAudioBuffer
|
| 25 |
|
|
|
|
| 26 |
|
| 27 |
+
# ----------------------------
|
| 28 |
+
# Config
|
| 29 |
+
# ----------------------------
|
| 30 |
@dataclass
|
| 31 |
+
class AppConfig:
|
| 32 |
+
model_name: str = "nvidia/parakeet-tdt-0.6b-v3"
|
| 33 |
left_s: float = 10.0
|
| 34 |
chunk_s: float = 2.0
|
| 35 |
right_s: float = 2.0
|
| 36 |
+
max_buffer_s: float = 40.0
|
| 37 |
+
batch_size: int = 1
|
| 38 |
+
device: str = "cpu" # "cuda" to force GPU if available
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ----------------------------
|
| 42 |
+
# Utility
|
| 43 |
+
# ----------------------------
|
| 44 |
+
def _floor_multiple(a: int, b: int) -> int:
|
| 45 |
+
"""Largest multiple of b not exceeding a."""
|
| 46 |
+
return (a // b) * b
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ----------------------------
|
| 50 |
+
# ASR Engine
|
| 51 |
+
# ----------------------------
|
| 52 |
+
class ParakeetStreamer:
|
| 53 |
+
"""Encapsulates model, preprocessor settings, and decoding."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, cfg: AppConfig) -> None:
|
| 56 |
+
self.cfg = cfg
|
| 57 |
+
|
| 58 |
+
# Load model
|
| 59 |
+
self.model = (
|
| 60 |
+
nemo_asr.models.EncDecRNNTModel.from_pretrained(cfg.model_name)
|
| 61 |
+
.to(cfg.device)
|
| 62 |
+
.eval()
|
| 63 |
+
)
|
| 64 |
+
for p in self.model.parameters():
|
| 65 |
+
p.requires_grad_(False)
|
| 66 |
+
|
| 67 |
+
# Decoding strategy: greedy-batch with label-looping for batched efficiency
|
| 68 |
+
dec_cfg = RNNTDecodingConfig(
|
| 69 |
+
strategy="greedy_batch", fused_batch_size=-1, compute_timestamps=False
|
| 70 |
+
)
|
| 71 |
+
dec_cfg.greedy.loop_labels = True
|
| 72 |
+
self.model.change_decoding_strategy(dec_cfg)
|
| 73 |
+
self._decoding_computer = self.model.decoding.decoding.decoding_computer
|
| 74 |
+
|
| 75 |
+
# Clone and tweak preprocessor to avoid dither and padding during inference
|
| 76 |
+
mcfg = copy.deepcopy(self.model.cfg)
|
| 77 |
+
OmegaConf.set_struct(mcfg.preprocessor, False)
|
| 78 |
+
mcfg.preprocessor.dither = 0.0
|
| 79 |
+
mcfg.preprocessor.pad_to = 0
|
| 80 |
+
OmegaConf.set_struct(mcfg.preprocessor, True)
|
| 81 |
+
|
| 82 |
+
# Derived constants
|
| 83 |
+
self.sample_rate: int = int(mcfg.preprocessor.sample_rate)
|
| 84 |
+
window_stride: float = float(mcfg.preprocessor.window_stride)
|
| 85 |
+
self.frames_per_second: float = 1.0 / window_stride
|
| 86 |
+
self.subsampling: int = int(self.model.encoder.subsampling_factor)
|
| 87 |
+
|
| 88 |
+
# Feature->audio and encoder->audio subsampling alignment
|
| 89 |
+
feat_f2a = _floor_multiple(int(self.sample_rate * window_stride), self.subsampling)
|
| 90 |
+
self.enc_f2a = feat_f2a * self.subsampling
|
| 91 |
+
|
| 92 |
+
# Context sizes
|
| 93 |
+
self.ctx_enc = ContextSize(
|
| 94 |
+
left=int(cfg.left_s * self.frames_per_second / self.subsampling),
|
| 95 |
+
chunk=int(cfg.chunk_s * self.frames_per_second / self.subsampling),
|
| 96 |
+
right=int(cfg.right_s * self.frames_per_second / self.subsampling),
|
| 97 |
+
)
|
| 98 |
+
self.ctx_samp = ContextSize(
|
| 99 |
+
left=self.ctx_enc.left * self.subsampling * feat_f2a,
|
| 100 |
+
chunk=self.ctx_enc.chunk * self.subsampling * feat_f2a,
|
| 101 |
+
right=self.ctx_enc.right * self.subsampling * feat_f2a,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.max_samples = int(cfg.max_buffer_s * self.sample_rate)
|
| 105 |
+
|
| 106 |
+
# -------- audio helpers --------
|
| 107 |
+
@staticmethod
|
| 108 |
+
def _to_mono(x: np.ndarray) -> np.ndarray:
|
| 109 |
+
"""Ensure mono float32 array."""
|
| 110 |
+
x = np.asarray(x)
|
| 111 |
+
if x.ndim == 2:
|
| 112 |
+
# Handle shape (samples, channels) or (channels, samples)
|
| 113 |
+
x = x.mean(axis=1) if x.shape[1] == 2 else x.mean(axis=-1)
|
| 114 |
+
return x.astype(np.float32, copy=False)
|
| 115 |
+
|
| 116 |
+
def _resample_if_needed(self, x: np.ndarray, in_sr: int) -> np.ndarray:
|
| 117 |
+
"""Resample to model sample rate if required."""
|
| 118 |
+
if int(in_sr) == self.sample_rate:
|
| 119 |
+
return x
|
| 120 |
+
y = torchaudio.functional.resample(
|
| 121 |
+
torch.from_numpy(x), in_sr, self.sample_rate
|
| 122 |
+
)
|
| 123 |
+
return y.numpy().astype(np.float32, copy=False)
|
| 124 |
+
|
| 125 |
+
# -------- core decoding --------
|
| 126 |
+
@torch.inference_mode()
|
| 127 |
+
def _decode_buffer(self, audio_np: np.ndarray) -> str:
|
| 128 |
+
"""Run buffered streaming decoding over the entire audio buffer."""
|
| 129 |
+
if audio_np.size == 0:
|
| 130 |
+
return ""
|
| 131 |
+
|
| 132 |
+
a = torch.from_numpy(audio_np).unsqueeze(0).to(torch.float32).to(self.cfg.device)
|
| 133 |
+
total_len = torch.tensor([a.shape[1]], dtype=torch.long, device=self.cfg.device)
|
| 134 |
+
|
| 135 |
+
cur_hyps = None
|
| 136 |
+
prev_state = None
|
| 137 |
+
|
| 138 |
l = 0
|
| 139 |
+
r = min(self.ctx_samp.chunk + self.ctx_samp.right, a.shape[1])
|
| 140 |
+
|
| 141 |
+
buf = StreamingBatchedAudioBuffer(
|
| 142 |
+
batch_size=self.cfg.batch_size,
|
| 143 |
+
context_samples=self.ctx_samp,
|
| 144 |
+
dtype=a.dtype,
|
| 145 |
+
device=self.cfg.device,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
remaining = total_len.clone()
|
| 149 |
+
|
| 150 |
while l < a.shape[1]:
|
| 151 |
+
clen = int(min(r, a.shape[1]) - l)
|
| 152 |
+
is_last = r >= a.shape[1]
|
| 153 |
+
|
| 154 |
+
is_last_b = torch.tensor([clen >= remaining[0]], dtype=torch.bool, device=self.cfg.device)
|
| 155 |
+
clen_b = torch.where(is_last_b, remaining, torch.full_like(remaining, fill_value=clen))
|
| 156 |
+
|
| 157 |
+
buf.add_audio_batch_(
|
| 158 |
+
a[:, l:r], audio_lengths=clen_b, is_last_chunk=is_last, is_last_chunk_batch=is_last_b
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
enc, _ = self.model(input_signal=buf.samples, input_signal_length=buf.context_size_batch.total())
|
| 162 |
+
enc = enc.transpose(1, 2) # [B, T, C]
|
| 163 |
+
|
| 164 |
+
enc_ctx = buf.context_size.subsample(factor=self.enc_f2a)
|
| 165 |
+
enc_ctx_b = buf.context_size_batch.subsample(factor=self.enc_f2a)
|
| 166 |
+
|
| 167 |
+
enc = enc[:, enc_ctx.left:] # drop left context before decoding
|
| 168 |
+
|
| 169 |
+
hyps, _, prev_state = self._decoding_computer(
|
| 170 |
+
x=enc, out_len=enc_ctx_b.chunk, prev_batched_state=prev_state
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if cur_hyps is None:
|
| 174 |
+
cur_hyps = hyps
|
| 175 |
+
else:
|
| 176 |
+
cur_hyps.merge_(hyps)
|
| 177 |
+
|
| 178 |
+
remaining -= clen_b
|
| 179 |
l = r
|
| 180 |
+
r = min(r + self.ctx_samp.chunk, a.shape[1])
|
| 181 |
+
|
| 182 |
+
outs = batched_hyps_to_hypotheses(cur_hyps, None, batch_size=self.cfg.batch_size) if cur_hyps is not None else []
|
| 183 |
+
for h in outs:
|
| 184 |
+
h.text = self.model.tokenizer.ids_to_text(h.y_sequence.tolist())
|
| 185 |
+
|
| 186 |
return outs[0].text if outs else ""
|
| 187 |
|
| 188 |
+
# -------- public API for Gradio --------
|
| 189 |
+
def transcribe(self, stream: Optional[np.ndarray], new_chunk: Optional[Tuple[int, np.ndarray]]):
|
| 190 |
+
"""Gradio callback. Maintains rolling buffer in `stream` and returns transcript.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
stream: rolling buffer carried in gr.State()
|
| 194 |
+
new_chunk: tuple (sample_rate, np.ndarray) provided by gr.Audio with type='numpy'
|
| 195 |
+
"""
|
| 196 |
+
if new_chunk is None:
|
| 197 |
+
return stream, ""
|
| 198 |
+
|
| 199 |
+
in_sr, data = new_chunk
|
| 200 |
+
y = self._to_mono(data)
|
| 201 |
+
y = self._resample_if_needed(y, int(in_sr))
|
| 202 |
+
|
| 203 |
+
if stream is None or len(stream) == 0:
|
| 204 |
+
a = y
|
| 205 |
+
else:
|
| 206 |
+
a = np.concatenate([stream, y])
|
| 207 |
+
|
| 208 |
+
if a.size > self.max_samples:
|
| 209 |
+
a = a[-self.max_samples:]
|
| 210 |
+
|
| 211 |
+
text = self._decode_buffer(a) if a.size else ""
|
| 212 |
+
return a, text
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# ----------------------------
|
| 216 |
+
# UI
|
| 217 |
+
# ----------------------------
|
| 218 |
+
def build_demo(cfg: Optional[AppConfig] = None) -> gr.Interface:
|
| 219 |
+
cfg = cfg or AppConfig()
|
| 220 |
+
engine = ParakeetStreamer(cfg)
|
| 221 |
+
|
| 222 |
+
return gr.Interface(
|
| 223 |
+
fn=engine.transcribe,
|
| 224 |
+
inputs=[gr.State(), gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Mic")],
|
| 225 |
+
outputs=[gr.State(), gr.Textbox(label="Transcript", lines=3)],
|
| 226 |
+
title="Parakeet-TDT-0.6B-v3 — CPU streaming",
|
| 227 |
+
description="Multilingual buffered streaming (10-2-2) in memory",
|
| 228 |
+
live=True,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
|
| 232 |
if __name__ == "__main__":
|
| 233 |
+
demo = build_demo()
|
| 234 |
+
demo.launch()
|