Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| import librosa | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook | |
| from fam.llm.decoders import EncodecDecoder | |
| from fam.llm.fast_inference_utils import build_model, main | |
| from fam.llm.inference import ( | |
| EncodecDecoder, | |
| InferenceConfig, | |
| Model, | |
| TiltedEncodec, | |
| TrainedBPETokeniser, | |
| get_cached_embedding, | |
| get_cached_file, | |
| get_enhancer, | |
| ) | |
| from fam.llm.utils import ( | |
| check_audio_file, | |
| get_default_dtype, | |
| get_device, | |
| normalize_text, | |
| ) | |
| class TTS: | |
| END_OF_AUDIO_TOKEN = 1024 | |
| def __init__( | |
| self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs" | |
| ): | |
| """ | |
| model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio) | |
| """ | |
| # NOTE: this needs to come first so that we don't change global state when we want to use | |
| # the torch.compiled-model. | |
| self._dtype = get_default_dtype() | |
| self._device = get_device() | |
| self._model_dir = snapshot_download(repo_id=model_name, cache_dir = '/proj/afosr/metavoice/cache') | |
| self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN) | |
| self.output_dir = output_dir | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt" | |
| config_second_stage = InferenceConfig( | |
| ckpt_path=second_stage_ckpt_path, | |
| num_samples=1, | |
| seed=seed, | |
| device=self._device, | |
| dtype=self._dtype, | |
| compile=False, | |
| init_from="resume", | |
| output_dir=self.output_dir, | |
| ) | |
| data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN) | |
| self.llm_second_stage = Model( | |
| config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode | |
| ) | |
| self.enhancer = get_enhancer("df") | |
| self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype] | |
| self.model, self.tokenizer, self.smodel, self.model_size = build_model( | |
| precision=self.precision, | |
| checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"), | |
| spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"), | |
| device=self._device, | |
| compile=True, | |
| compile_prefill=True, | |
| ) | |
| def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str: | |
| """ | |
| text: Text to speak | |
| spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3 | |
| top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker | |
| guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style. | |
| temperature: Temperature for sampling applied to both LLMs (first & second stage) | |
| returns: path to speech .wav file | |
| """ | |
| text = normalize_text(text) | |
| spk_ref_path = get_cached_file(spk_ref_path) | |
| check_audio_file(spk_ref_path) | |
| spk_emb = get_cached_embedding( | |
| spk_ref_path, | |
| self.smodel, | |
| ).to(device=self._device, dtype=self.precision) | |
| start = time.time() | |
| # first stage LLM | |
| tokens = main( | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| model_size=self.model_size, | |
| prompt=text, | |
| spk_emb=spk_emb, | |
| top_p=torch.tensor(top_p, device=self._device, dtype=self.precision), | |
| guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision), | |
| temperature=torch.tensor(temperature, device=self._device, dtype=self.precision), | |
| ) | |
| _, extracted_audio_ids = self.first_stage_adapter.decode([tokens]) | |
| b_speaker_embs = spk_emb.unsqueeze(0) | |
| # second stage LLM + multi-band diffusion model | |
| wav_files = self.llm_second_stage( | |
| texts=[text], | |
| encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)], | |
| speaker_embs=b_speaker_embs, | |
| batch_size=1, | |
| guidance_scale=None, | |
| top_p=None, | |
| top_k=200, | |
| temperature=1.0, | |
| max_new_tokens=None, | |
| ) | |
| # enhance using deepfilternet | |
| wav_file = wav_files[0] | |
| with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp: | |
| self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name) | |
| shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav") | |
| print(f"\nSaved audio to {wav_file}.wav") | |
| # calculating real-time factor (RTF) | |
| time_to_synth_s = time.time() - start | |
| audio, sr = librosa.load(str(wav_file) + ".wav") | |
| duration_s = librosa.get_duration(y=audio, sr=sr) | |
| print(f"\nTotal time to synth (s): {time_to_synth_s}") | |
| print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}") | |
| return str(wav_file) + ".wav" | |
| if __name__ == "__main__": | |
| tts = TTS() | |