WJ88 commited on
Commit
5ab166b
·
verified ·
1 Parent(s): 4318ae4

pure ai refactoring

Browse files
Files changed (1) hide show
  1. app.py +217 -101
app.py CHANGED
@@ -1,118 +1,234 @@
1
- import gradio as gr, numpy as np, torch, torchaudio, copy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import nemo.collections.asr as nemo_asr
3
  from omegaconf import OmegaConf
4
- from dataclasses import dataclass
5
  from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
6
- from nemo.collections.asr.parts.utils.rnnt_utils import BatchedHyps, batched_hyps_to_hypotheses
7
  from nemo.collections.asr.parts.utils.streaming_utils import ContextSize, StreamingBatchedAudioBuffer
8
 
9
- def _div(a, b): return (a // b) * b
10
 
 
 
 
11
  @dataclass
12
- class Cfg:
13
- name: str = "nvidia/parakeet-tdt-0.6b-v3"
14
  left_s: float = 10.0
15
  chunk_s: float = 2.0
16
  right_s: float = 2.0
17
- max_s: float = 40.0
18
- batch: int = 1
19
- device: str = "cpu"
20
-
21
- cfg = Cfg()
22
-
23
- m = nemo_asr.models.EncDecRNNTModel.from_pretrained(cfg.name).to(cfg.device).eval()
24
- for p in m.parameters(): p.requires_grad_(False)
25
- dec = RNNTDecodingConfig(strategy="greedy_batch", fused_batch_size=-1, compute_timestamps=False)
26
- dec.greedy.loop_labels = True
27
- m.change_decoding_strategy(dec)
28
- dc = m.decoding.decoding.decoding_computer
29
-
30
- mc = copy.deepcopy(m.cfg)
31
- OmegaConf.set_struct(mc.preprocessor, False)
32
- mc.preprocessor.dither = 0.0
33
- mc.preprocessor.pad_to = 0
34
- OmegaConf.set_struct(mc.preprocessor, True)
35
-
36
- sr = mc.preprocessor.sample_rate
37
- ws = mc.preprocessor.window_stride
38
- fps = 1.0 / ws
39
- sub = m.encoder.subsampling_factor
40
- feat_f2a = _div(int(sr * ws), sub)
41
- enc_f2a = feat_f2a * sub
42
-
43
- ctx_enc = ContextSize(
44
- left=int(cfg.left_s * fps / sub),
45
- chunk=int(cfg.chunk_s * fps / sub),
46
- right=int(cfg.right_s * fps / sub),
47
- )
48
- ctx_samp = ContextSize(
49
- left=ctx_enc.left * sub * feat_f2a,
50
- chunk=ctx_enc.chunk * sub * feat_f2a,
51
- right=ctx_enc.right * sub * feat_f2a,
52
- )
53
-
54
- max_samples = int(cfg.max_s * sr)
55
-
56
- def _mono(x):
57
- x = np.asarray(x)
58
- if x.ndim == 2:
59
- if x.shape[1] == 2: x = x.mean(axis=1)
60
- else: x = x.mean(axis=-1)
61
- return x.astype(np.float32)
62
-
63
- def _resample(x, in_sr):
64
- if in_sr == sr: return x
65
- return torchaudio.functional.resample(torch.from_numpy(x), in_sr, sr).numpy().astype(np.float32)
66
-
67
- def _decode(a_np):
68
- with torch.inference_mode():
69
- a = torch.from_numpy(a_np).unsqueeze(0).to(torch.float32).to(cfg.device)
70
- L = torch.tensor([a.shape[1]], dtype=torch.long, device=cfg.device)
71
- cur = None
72
- st = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  l = 0
74
- r = min(ctx_samp.chunk + ctx_samp.right, a.shape[1])
75
- buf = StreamingBatchedAudioBuffer(batch_size=cfg.batch, context_samples=ctx_samp, dtype=a.dtype, device=cfg.device)
76
- rest = L.clone()
 
 
 
 
 
 
 
 
77
  while l < a.shape[1]:
78
- clen = min(r, a.shape[1]) - l
79
- last = r >= a.shape[1]
80
- last_b = torch.tensor([clen >= rest[0]], dtype=torch.bool, device=cfg.device)
81
- clen_b = torch.where(last_b, rest, torch.full_like(rest, fill_value=clen))
82
- buf.add_audio_batch_(a[:, l:r], audio_lengths=clen_b, is_last_chunk=last, is_last_chunk_batch=last_b)
83
- enc, _ = m(input_signal=buf.samples, input_signal_length=buf.context_size_batch.total())
84
- enc = enc.transpose(1, 2)
85
- enc_ctx = buf.context_size.subsample(factor=enc_f2a)
86
- enc_ctx_b = buf.context_size_batch.subsample(factor=enc_f2a)
87
- enc = enc[:, enc_ctx.left:]
88
- hyps, _, st = dc(x=enc, out_len=enc_ctx_b.chunk, prev_batched_state=st)
89
- if cur is None: cur = hyps
90
- else: cur.merge_(hyps)
91
- rest -= clen_b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  l = r
93
- r = min(r + ctx_samp.chunk, a.shape[1])
94
- outs = batched_hyps_to_hypotheses(cur, None, batch_size=cfg.batch) if cur is not None else []
95
- for h in outs: h.text = m.tokenizer.ids_to_text(h.y_sequence.tolist())
 
 
 
96
  return outs[0].text if outs else ""
97
 
98
- def transcribe(stream, new_chunk):
99
- if new_chunk is None: return stream, ""
100
- in_sr, data = new_chunk
101
- y = _mono(data)
102
- y = _resample(y, int(in_sr))
103
- a = y if stream is None or len(stream) == 0 else np.concatenate([stream, y])
104
- if len(a) > max_samples: a = a[-max_samples:]
105
- text = _decode(a) if a.size else ""
106
- return a, text
107
-
108
- demo = gr.Interface(
109
- fn=transcribe,
110
- inputs=[gr.State(), gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Mic")],
111
- outputs=[gr.State(), gr.Textbox(label="Transcript", lines=3)],
112
- title="Parakeet-TDT-0.6B-v3 — CPU streaming",
113
- description="Multilingual buffered streaming (10-2-2) in memory",
114
- live=True,
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  if __name__ == "__main__":
118
- demo.launch()
 
 
1
+
2
+ """Refactored Gradio app for streaming ASR with NVIDIA NeMo Parakeet-TDT-0.6B-v3.
3
+
4
+ Functionality preserved. Structure simplified and documented.
5
+ - Buffered streaming on CPU by default (configurable device).
6
+ - Monophonic conversion and resampling to model sample rate.
7
+ - Greedy batched RNNT decoding with label-looping.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Optional, Tuple
13
+
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import torchaudio
18
+ import gradio as gr
19
+
20
  import nemo.collections.asr as nemo_asr
21
  from omegaconf import OmegaConf
 
22
  from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
23
+ from nemo.collections.asr.parts.utils.rnnt_utils import batched_hyps_to_hypotheses
24
  from nemo.collections.asr.parts.utils.streaming_utils import ContextSize, StreamingBatchedAudioBuffer
25
 
 
26
 
27
+ # ----------------------------
28
+ # Config
29
+ # ----------------------------
30
  @dataclass
31
+ class AppConfig:
32
+ model_name: str = "nvidia/parakeet-tdt-0.6b-v3"
33
  left_s: float = 10.0
34
  chunk_s: float = 2.0
35
  right_s: float = 2.0
36
+ max_buffer_s: float = 40.0
37
+ batch_size: int = 1
38
+ device: str = "cpu" # "cuda" to force GPU if available
39
+
40
+
41
+ # ----------------------------
42
+ # Utility
43
+ # ----------------------------
44
+ def _floor_multiple(a: int, b: int) -> int:
45
+ """Largest multiple of b not exceeding a."""
46
+ return (a // b) * b
47
+
48
+
49
+ # ----------------------------
50
+ # ASR Engine
51
+ # ----------------------------
52
+ class ParakeetStreamer:
53
+ """Encapsulates model, preprocessor settings, and decoding."""
54
+
55
+ def __init__(self, cfg: AppConfig) -> None:
56
+ self.cfg = cfg
57
+
58
+ # Load model
59
+ self.model = (
60
+ nemo_asr.models.EncDecRNNTModel.from_pretrained(cfg.model_name)
61
+ .to(cfg.device)
62
+ .eval()
63
+ )
64
+ for p in self.model.parameters():
65
+ p.requires_grad_(False)
66
+
67
+ # Decoding strategy: greedy-batch with label-looping for batched efficiency
68
+ dec_cfg = RNNTDecodingConfig(
69
+ strategy="greedy_batch", fused_batch_size=-1, compute_timestamps=False
70
+ )
71
+ dec_cfg.greedy.loop_labels = True
72
+ self.model.change_decoding_strategy(dec_cfg)
73
+ self._decoding_computer = self.model.decoding.decoding.decoding_computer
74
+
75
+ # Clone and tweak preprocessor to avoid dither and padding during inference
76
+ mcfg = copy.deepcopy(self.model.cfg)
77
+ OmegaConf.set_struct(mcfg.preprocessor, False)
78
+ mcfg.preprocessor.dither = 0.0
79
+ mcfg.preprocessor.pad_to = 0
80
+ OmegaConf.set_struct(mcfg.preprocessor, True)
81
+
82
+ # Derived constants
83
+ self.sample_rate: int = int(mcfg.preprocessor.sample_rate)
84
+ window_stride: float = float(mcfg.preprocessor.window_stride)
85
+ self.frames_per_second: float = 1.0 / window_stride
86
+ self.subsampling: int = int(self.model.encoder.subsampling_factor)
87
+
88
+ # Feature->audio and encoder->audio subsampling alignment
89
+ feat_f2a = _floor_multiple(int(self.sample_rate * window_stride), self.subsampling)
90
+ self.enc_f2a = feat_f2a * self.subsampling
91
+
92
+ # Context sizes
93
+ self.ctx_enc = ContextSize(
94
+ left=int(cfg.left_s * self.frames_per_second / self.subsampling),
95
+ chunk=int(cfg.chunk_s * self.frames_per_second / self.subsampling),
96
+ right=int(cfg.right_s * self.frames_per_second / self.subsampling),
97
+ )
98
+ self.ctx_samp = ContextSize(
99
+ left=self.ctx_enc.left * self.subsampling * feat_f2a,
100
+ chunk=self.ctx_enc.chunk * self.subsampling * feat_f2a,
101
+ right=self.ctx_enc.right * self.subsampling * feat_f2a,
102
+ )
103
+
104
+ self.max_samples = int(cfg.max_buffer_s * self.sample_rate)
105
+
106
+ # -------- audio helpers --------
107
+ @staticmethod
108
+ def _to_mono(x: np.ndarray) -> np.ndarray:
109
+ """Ensure mono float32 array."""
110
+ x = np.asarray(x)
111
+ if x.ndim == 2:
112
+ # Handle shape (samples, channels) or (channels, samples)
113
+ x = x.mean(axis=1) if x.shape[1] == 2 else x.mean(axis=-1)
114
+ return x.astype(np.float32, copy=False)
115
+
116
+ def _resample_if_needed(self, x: np.ndarray, in_sr: int) -> np.ndarray:
117
+ """Resample to model sample rate if required."""
118
+ if int(in_sr) == self.sample_rate:
119
+ return x
120
+ y = torchaudio.functional.resample(
121
+ torch.from_numpy(x), in_sr, self.sample_rate
122
+ )
123
+ return y.numpy().astype(np.float32, copy=False)
124
+
125
+ # -------- core decoding --------
126
+ @torch.inference_mode()
127
+ def _decode_buffer(self, audio_np: np.ndarray) -> str:
128
+ """Run buffered streaming decoding over the entire audio buffer."""
129
+ if audio_np.size == 0:
130
+ return ""
131
+
132
+ a = torch.from_numpy(audio_np).unsqueeze(0).to(torch.float32).to(self.cfg.device)
133
+ total_len = torch.tensor([a.shape[1]], dtype=torch.long, device=self.cfg.device)
134
+
135
+ cur_hyps = None
136
+ prev_state = None
137
+
138
  l = 0
139
+ r = min(self.ctx_samp.chunk + self.ctx_samp.right, a.shape[1])
140
+
141
+ buf = StreamingBatchedAudioBuffer(
142
+ batch_size=self.cfg.batch_size,
143
+ context_samples=self.ctx_samp,
144
+ dtype=a.dtype,
145
+ device=self.cfg.device,
146
+ )
147
+
148
+ remaining = total_len.clone()
149
+
150
  while l < a.shape[1]:
151
+ clen = int(min(r, a.shape[1]) - l)
152
+ is_last = r >= a.shape[1]
153
+
154
+ is_last_b = torch.tensor([clen >= remaining[0]], dtype=torch.bool, device=self.cfg.device)
155
+ clen_b = torch.where(is_last_b, remaining, torch.full_like(remaining, fill_value=clen))
156
+
157
+ buf.add_audio_batch_(
158
+ a[:, l:r], audio_lengths=clen_b, is_last_chunk=is_last, is_last_chunk_batch=is_last_b
159
+ )
160
+
161
+ enc, _ = self.model(input_signal=buf.samples, input_signal_length=buf.context_size_batch.total())
162
+ enc = enc.transpose(1, 2) # [B, T, C]
163
+
164
+ enc_ctx = buf.context_size.subsample(factor=self.enc_f2a)
165
+ enc_ctx_b = buf.context_size_batch.subsample(factor=self.enc_f2a)
166
+
167
+ enc = enc[:, enc_ctx.left:] # drop left context before decoding
168
+
169
+ hyps, _, prev_state = self._decoding_computer(
170
+ x=enc, out_len=enc_ctx_b.chunk, prev_batched_state=prev_state
171
+ )
172
+
173
+ if cur_hyps is None:
174
+ cur_hyps = hyps
175
+ else:
176
+ cur_hyps.merge_(hyps)
177
+
178
+ remaining -= clen_b
179
  l = r
180
+ r = min(r + self.ctx_samp.chunk, a.shape[1])
181
+
182
+ outs = batched_hyps_to_hypotheses(cur_hyps, None, batch_size=self.cfg.batch_size) if cur_hyps is not None else []
183
+ for h in outs:
184
+ h.text = self.model.tokenizer.ids_to_text(h.y_sequence.tolist())
185
+
186
  return outs[0].text if outs else ""
187
 
188
+ # -------- public API for Gradio --------
189
+ def transcribe(self, stream: Optional[np.ndarray], new_chunk: Optional[Tuple[int, np.ndarray]]):
190
+ """Gradio callback. Maintains rolling buffer in `stream` and returns transcript.
191
+
192
+ Args:
193
+ stream: rolling buffer carried in gr.State()
194
+ new_chunk: tuple (sample_rate, np.ndarray) provided by gr.Audio with type='numpy'
195
+ """
196
+ if new_chunk is None:
197
+ return stream, ""
198
+
199
+ in_sr, data = new_chunk
200
+ y = self._to_mono(data)
201
+ y = self._resample_if_needed(y, int(in_sr))
202
+
203
+ if stream is None or len(stream) == 0:
204
+ a = y
205
+ else:
206
+ a = np.concatenate([stream, y])
207
+
208
+ if a.size > self.max_samples:
209
+ a = a[-self.max_samples:]
210
+
211
+ text = self._decode_buffer(a) if a.size else ""
212
+ return a, text
213
+
214
+
215
+ # ----------------------------
216
+ # UI
217
+ # ----------------------------
218
+ def build_demo(cfg: Optional[AppConfig] = None) -> gr.Interface:
219
+ cfg = cfg or AppConfig()
220
+ engine = ParakeetStreamer(cfg)
221
+
222
+ return gr.Interface(
223
+ fn=engine.transcribe,
224
+ inputs=[gr.State(), gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Mic")],
225
+ outputs=[gr.State(), gr.Textbox(label="Transcript", lines=3)],
226
+ title="Parakeet-TDT-0.6B-v3 — CPU streaming",
227
+ description="Multilingual buffered streaming (10-2-2) in memory",
228
+ live=True,
229
+ )
230
+
231
 
232
  if __name__ == "__main__":
233
+ demo = build_demo()
234
+ demo.launch()