Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List, Tuple | |
| import numpy as np | |
| import librosa | |
| import torch | |
| import torch.nn.functional as F | |
| from s3tokenizer.utils import padding | |
| from s3tokenizer.model_v2 import ( | |
| S3TokenizerV2, | |
| ModelConfig, | |
| ) | |
| # Sampling rate of the inputs to S3TokenizerV2 | |
| S3_SR = 16_000 | |
| S3_HOP = 160 # 100 frames/sec | |
| S3_TOKEN_HOP = 640 # 25 tokens/sec | |
| S3_TOKEN_RATE = 25 | |
| SPEECH_VOCAB_SIZE = 6561 | |
| class S3Tokenizer(S3TokenizerV2): | |
| """ | |
| s3tokenizer.S3TokenizerV2 with the following changes: | |
| - a more integrated `forward` | |
| - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` | |
| """ | |
| ignore_state_dict_missing = ("_mel_filters", "window") | |
| def __init__( | |
| self, | |
| name: str="speech_tokenizer_v2_25hz", | |
| config: ModelConfig = ModelConfig() | |
| ): | |
| super().__init__(name) | |
| self.n_fft = 400 | |
| _mel_filters = librosa.filters.mel( | |
| sr=S3_SR, | |
| n_fft=self.n_fft, | |
| n_mels=config.n_mels | |
| ) | |
| self.register_buffer( | |
| "_mel_filters", | |
| torch.FloatTensor(_mel_filters), | |
| ) | |
| self.register_buffer( | |
| "window", | |
| torch.hann_window(self.n_fft), | |
| ) | |
| def pad(self, wavs, sr) -> List[torch.Tensor]: | |
| """ | |
| Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). | |
| """ | |
| processed_wavs = [] | |
| for wav in wavs: | |
| if isinstance(wav, np.ndarray): | |
| wav = torch.from_numpy(wav) | |
| if wav.dim() == 1: | |
| wav = wav.unsqueeze(0) | |
| n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE | |
| n_tokens = np.ceil(n_tokens) | |
| intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) | |
| intended_wav_len = int(intended_wav_len) | |
| wav = torch.nn.functional.pad( | |
| wav, | |
| (0, intended_wav_len - wav.shape[-1]), | |
| mode="constant", | |
| value=0 | |
| ) | |
| processed_wavs.append(wav) | |
| return processed_wavs | |
| def _prepare_audio(self, wavs): | |
| """Prepare a list of audios for s3tokenizer processing.""" | |
| processed_wavs = [] | |
| for wav in wavs: | |
| if isinstance(wav, np.ndarray): | |
| wav = torch.from_numpy(wav) | |
| if wav.dim() == 1: | |
| wav = wav.unsqueeze(0) | |
| processed_wavs.append(wav) | |
| return processed_wavs | |
| def forward( | |
| self, | |
| wavs: torch.Tensor, | |
| accelerator: 'Accelerator'=None, | |
| max_len: int=None, | |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: | |
| """ | |
| NOTE: mel-spec has a hop size of 160 points (100 frame/sec). | |
| FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. | |
| Args | |
| ---- | |
| - `wavs`: 16 kHz speech audio | |
| - `max_len` max length to truncate the output sequence to (25 token/sec). | |
| NOTE: please pad the waveform if longer sequence is needed. | |
| """ | |
| processed_wavs = self._prepare_audio(wavs) | |
| mels, mel_lens = [], [] | |
| for wav in processed_wavs: | |
| wav = wav.to(self.device) | |
| mel = self.log_mel_spectrogram(wav) # [B=1, F, T] | |
| if max_len is not None: | |
| mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens | |
| mels.append(mel.squeeze(0)) | |
| mels, mel_lens = padding(mels) | |
| if accelerator is None: | |
| tokenizer = self | |
| else: | |
| tokenizer = accelerator.unwrap_model(self) | |
| speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) | |
| return ( | |
| speech_tokens.long().detach(), | |
| speech_token_lens.long().detach(), | |
| ) | |
| def log_mel_spectrogram( | |
| self, | |
| audio: torch.Tensor, | |
| padding: int = 0, | |
| ): | |
| """ | |
| Compute the log-Mel spectrogram of | |
| Parameters | |
| ---------- | |
| audio: torch.Tensor, shape = (*) | |
| The path to audio or either a NumPy array or Tensor containing the | |
| audio waveform in 16 kHz | |
| padding: int | |
| Number of zero samples to pad to the right | |
| Returns | |
| ------- | |
| torch.Tensor, shape = (128, n_frames) | |
| A Tensor that contains the Mel spectrogram | |
| """ | |
| if not torch.is_tensor(audio): | |
| audio = torch.from_numpy(audio) | |
| audio = audio.to(self.device) | |
| if padding > 0: | |
| audio = F.pad(audio, (0, padding)) | |
| stft = torch.stft( | |
| audio, self.n_fft, S3_HOP, | |
| window=self.window.to(self.device), | |
| return_complex=True | |
| ) | |
| magnitudes = stft[..., :-1].abs()**2 | |
| mel_spec = self._mel_filters.to(self.device) @ magnitudes | |
| log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
| log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
| log_spec = (log_spec + 4.0) / 4.0 | |
| return log_spec | |