Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- model/utils_infer.py +25 -7
model/utils_infer.py
CHANGED
|
@@ -19,8 +19,14 @@ from model.utils import (
|
|
| 19 |
convert_char_to_pinyin,
|
| 20 |
)
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 26 |
|
|
@@ -76,7 +82,9 @@ def chunk_text(text, max_chars=135):
|
|
| 76 |
|
| 77 |
|
| 78 |
# load vocoder
|
| 79 |
-
def load_vocoder(is_local=False, local_path="", device=
|
|
|
|
|
|
|
| 80 |
if is_local:
|
| 81 |
print(f"Load vocos from local path {local_path}")
|
| 82 |
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
|
|
@@ -94,8 +102,10 @@ def load_vocoder(is_local=False, local_path="", device=device):
|
|
| 94 |
asr_pipe = None
|
| 95 |
|
| 96 |
|
| 97 |
-
def initialize_asr_pipeline(device=
|
| 98 |
global asr_pipe
|
|
|
|
|
|
|
| 99 |
|
| 100 |
asr_pipe = pipeline(
|
| 101 |
"automatic-speech-recognition",
|
|
@@ -108,7 +118,9 @@ def initialize_asr_pipeline(device=device):
|
|
| 108 |
# load model for inference
|
| 109 |
|
| 110 |
|
| 111 |
-
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=
|
|
|
|
|
|
|
| 112 |
if vocab_file == "":
|
| 113 |
vocab_file = "Emilia_ZH_EN"
|
| 114 |
tokenizer = "pinyin"
|
|
@@ -141,7 +153,9 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
|
|
| 141 |
# preprocess reference audio and text
|
| 142 |
|
| 143 |
|
| 144 |
-
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=
|
|
|
|
|
|
|
| 145 |
show_info("Converting audio...")
|
| 146 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 147 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
@@ -243,7 +257,11 @@ def infer_batch_process(
|
|
| 243 |
sway_sampling_coef=-1,
|
| 244 |
speed=1,
|
| 245 |
fix_duration=None,
|
|
|
|
| 246 |
):
|
|
|
|
|
|
|
|
|
|
| 247 |
audio, sr = ref_audio
|
| 248 |
if audio.shape[0] > 1:
|
| 249 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
@@ -254,7 +272,7 @@ def infer_batch_process(
|
|
| 254 |
if sr != target_sample_rate:
|
| 255 |
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
| 256 |
audio = resampler(audio)
|
| 257 |
-
audio = audio.to(
|
| 258 |
|
| 259 |
generated_waves = []
|
| 260 |
spectrograms = []
|
|
|
|
| 19 |
convert_char_to_pinyin,
|
| 20 |
)
|
| 21 |
|
| 22 |
+
# get device
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_device():
|
| 26 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 27 |
+
# print(f"Using {device} device")
|
| 28 |
+
return device
|
| 29 |
+
|
| 30 |
|
| 31 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 32 |
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
# load vocoder
|
| 85 |
+
def load_vocoder(is_local=False, local_path="", device=None):
|
| 86 |
+
if device is None:
|
| 87 |
+
device = get_device()
|
| 88 |
if is_local:
|
| 89 |
print(f"Load vocos from local path {local_path}")
|
| 90 |
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
|
|
|
|
| 102 |
asr_pipe = None
|
| 103 |
|
| 104 |
|
| 105 |
+
def initialize_asr_pipeline(device=None):
|
| 106 |
global asr_pipe
|
| 107 |
+
if device is None:
|
| 108 |
+
device = get_device()
|
| 109 |
|
| 110 |
asr_pipe = pipeline(
|
| 111 |
"automatic-speech-recognition",
|
|
|
|
| 118 |
# load model for inference
|
| 119 |
|
| 120 |
|
| 121 |
+
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=None):
|
| 122 |
+
if device is None:
|
| 123 |
+
device = get_device()
|
| 124 |
if vocab_file == "":
|
| 125 |
vocab_file = "Emilia_ZH_EN"
|
| 126 |
tokenizer = "pinyin"
|
|
|
|
| 153 |
# preprocess reference audio and text
|
| 154 |
|
| 155 |
|
| 156 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=None):
|
| 157 |
+
device = get_device(device)
|
| 158 |
+
|
| 159 |
show_info("Converting audio...")
|
| 160 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 161 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
|
|
| 257 |
sway_sampling_coef=-1,
|
| 258 |
speed=1,
|
| 259 |
fix_duration=None,
|
| 260 |
+
device=None,
|
| 261 |
):
|
| 262 |
+
if device is None:
|
| 263 |
+
device = get_device()
|
| 264 |
+
|
| 265 |
audio, sr = ref_audio
|
| 266 |
if audio.shape[0] > 1:
|
| 267 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
|
| 272 |
if sr != target_sample_rate:
|
| 273 |
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
| 274 |
audio = resampler(audio)
|
| 275 |
+
audio = audio.to()
|
| 276 |
|
| 277 |
generated_waves = []
|
| 278 |
spectrograms = []
|