|
|
from typing import List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import librosa |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from s3tokenizer.utils import padding |
|
|
from s3tokenizer.model_v2 import ( |
|
|
S3TokenizerV2, |
|
|
ModelConfig, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
S3_SR = 16_000 |
|
|
S3_HOP = 160 |
|
|
S3_TOKEN_HOP = 640 |
|
|
S3_TOKEN_RATE = 25 |
|
|
SPEECH_VOCAB_SIZE = 6561 |
|
|
|
|
|
|
|
|
class S3Tokenizer(S3TokenizerV2): |
|
|
""" |
|
|
s3tokenizer.S3TokenizerV2 with the following changes: |
|
|
- a more integrated `forward` |
|
|
- compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` |
|
|
""" |
|
|
|
|
|
ignore_state_dict_missing = ("_mel_filters", "window") |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str="speech_tokenizer_v2_25hz", |
|
|
config: ModelConfig = ModelConfig() |
|
|
): |
|
|
super().__init__(name) |
|
|
|
|
|
self.n_fft = 400 |
|
|
_mel_filters = librosa.filters.mel( |
|
|
sr=S3_SR, |
|
|
n_fft=self.n_fft, |
|
|
n_mels=config.n_mels |
|
|
) |
|
|
self.register_buffer( |
|
|
"_mel_filters", |
|
|
torch.FloatTensor(_mel_filters), |
|
|
) |
|
|
|
|
|
self.register_buffer( |
|
|
"window", |
|
|
torch.hann_window(self.n_fft), |
|
|
) |
|
|
|
|
|
def pad(self, wavs, sr) -> List[torch.Tensor]: |
|
|
""" |
|
|
Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). |
|
|
""" |
|
|
processed_wavs = [] |
|
|
for wav in wavs: |
|
|
if isinstance(wav, np.ndarray): |
|
|
wav = torch.from_numpy(wav) |
|
|
if wav.dim() == 1: |
|
|
wav = wav.unsqueeze(0) |
|
|
|
|
|
n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE |
|
|
n_tokens = np.ceil(n_tokens) |
|
|
intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) |
|
|
intended_wav_len = int(intended_wav_len) |
|
|
wav = torch.nn.functional.pad( |
|
|
wav, |
|
|
(0, intended_wav_len - wav.shape[-1]), |
|
|
mode="constant", |
|
|
value=0 |
|
|
) |
|
|
processed_wavs.append(wav) |
|
|
return processed_wavs |
|
|
|
|
|
def _prepare_audio(self, wavs): |
|
|
"""Prepare a list of audios for s3tokenizer processing.""" |
|
|
processed_wavs = [] |
|
|
for wav in wavs: |
|
|
if isinstance(wav, np.ndarray): |
|
|
wav = torch.from_numpy(wav) |
|
|
if wav.dim() == 1: |
|
|
wav = wav.unsqueeze(0) |
|
|
|
|
|
processed_wavs.append(wav) |
|
|
return processed_wavs |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward( |
|
|
self, |
|
|
wavs: torch.Tensor, |
|
|
accelerator: 'Accelerator'=None, |
|
|
max_len: int=None, |
|
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
|
""" |
|
|
NOTE: mel-spec has a hop size of 160 points (100 frame/sec). |
|
|
FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. |
|
|
|
|
|
Args |
|
|
---- |
|
|
- `wavs`: 16 kHz speech audio |
|
|
- `max_len` max length to truncate the output sequence to (25 token/sec). |
|
|
NOTE: please pad the waveform if longer sequence is needed. |
|
|
""" |
|
|
processed_wavs = self._prepare_audio(wavs) |
|
|
mels, mel_lens = [], [] |
|
|
for wav in processed_wavs: |
|
|
wav = wav.to(self.device) |
|
|
mel = self.log_mel_spectrogram(wav) |
|
|
if max_len is not None: |
|
|
mel = mel[..., :max_len * 4] |
|
|
mels.append(mel.squeeze(0)) |
|
|
|
|
|
mels, mel_lens = padding(mels) |
|
|
if accelerator is None: |
|
|
tokenizer = self |
|
|
else: |
|
|
tokenizer = accelerator.unwrap_model(self) |
|
|
|
|
|
speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) |
|
|
return ( |
|
|
speech_tokens.long().detach(), |
|
|
speech_token_lens.long().detach(), |
|
|
) |
|
|
|
|
|
def log_mel_spectrogram( |
|
|
self, |
|
|
audio: torch.Tensor, |
|
|
padding: int = 0, |
|
|
): |
|
|
""" |
|
|
Compute the log-Mel spectrogram of |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
audio: torch.Tensor, shape = (*) |
|
|
The path to audio or either a NumPy array or Tensor containing the |
|
|
audio waveform in 16 kHz |
|
|
|
|
|
padding: int |
|
|
Number of zero samples to pad to the right |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor, shape = (128, n_frames) |
|
|
A Tensor that contains the Mel spectrogram |
|
|
""" |
|
|
if not torch.is_tensor(audio): |
|
|
audio = torch.from_numpy(audio) |
|
|
|
|
|
audio = audio.to(self.device) |
|
|
if padding > 0: |
|
|
audio = F.pad(audio, (0, padding)) |
|
|
stft = torch.stft( |
|
|
audio, self.n_fft, S3_HOP, |
|
|
window=self.window.to(self.device), |
|
|
return_complex=True |
|
|
) |
|
|
magnitudes = stft[..., :-1].abs()**2 |
|
|
|
|
|
mel_spec = self._mel_filters.to(self.device) @ magnitudes |
|
|
|
|
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
|
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
|
|
log_spec = (log_spec + 4.0) / 4.0 |
|
|
return log_spec |
|
|
|