Spaces:
Running
on
T4
Running
on
T4
fix cpu and add hub download
Browse files- tortoise/api.py +6 -38
tortoise/api.py
CHANGED
|
@@ -38,45 +38,13 @@ MODELS = {
|
|
| 38 |
'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth',
|
| 39 |
}
|
| 40 |
|
| 41 |
-
def download_models(specific_models=None):
|
| 42 |
-
"""
|
| 43 |
-
Call to download all the models that Tortoise uses.
|
| 44 |
-
"""
|
| 45 |
-
os.makedirs(MODELS_DIR, exist_ok=True)
|
| 46 |
-
|
| 47 |
-
def show_progress(block_num, block_size, total_size):
|
| 48 |
-
global pbar
|
| 49 |
-
if pbar is None:
|
| 50 |
-
pbar = progressbar.ProgressBar(maxval=total_size)
|
| 51 |
-
pbar.start()
|
| 52 |
-
|
| 53 |
-
downloaded = block_num * block_size
|
| 54 |
-
if downloaded < total_size:
|
| 55 |
-
pbar.update(downloaded)
|
| 56 |
-
else:
|
| 57 |
-
pbar.finish()
|
| 58 |
-
pbar = None
|
| 59 |
-
for model_name, url in MODELS.items():
|
| 60 |
-
if specific_models is not None and model_name not in specific_models:
|
| 61 |
-
continue
|
| 62 |
-
model_path = os.path.join(MODELS_DIR, model_name)
|
| 63 |
-
if os.path.exists(model_path):
|
| 64 |
-
continue
|
| 65 |
-
print(f'Downloading {model_name} from {url}...')
|
| 66 |
-
request.urlretrieve(url, model_path, show_progress)
|
| 67 |
-
# hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=MODELS_DIR)
|
| 68 |
-
print('Done.')
|
| 69 |
-
|
| 70 |
-
|
| 71 |
def get_model_path(model_name, models_dir=MODELS_DIR):
|
| 72 |
"""
|
| 73 |
Get path to given model, download it if it doesn't exist.
|
| 74 |
"""
|
| 75 |
if model_name not in MODELS:
|
| 76 |
raise ValueError(f'Model {model_name} not found in available models.')
|
| 77 |
-
model_path =
|
| 78 |
-
if not os.path.exists(model_path) and models_dir == MODELS_DIR:
|
| 79 |
-
download_models([model_name])
|
| 80 |
return model_path
|
| 81 |
|
| 82 |
|
|
@@ -243,14 +211,14 @@ class TextToSpeech:
|
|
| 243 |
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
| 244 |
model_dim=1024,
|
| 245 |
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
| 246 |
-
train_solo_embeddings=False).
|
| 247 |
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
|
| 248 |
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
|
| 249 |
|
| 250 |
self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1",
|
| 251 |
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11],
|
| 252 |
upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2],
|
| 253 |
-
cond_channels=1024).
|
| 254 |
hifi_model = torch.load(get_model_path('hifidecoder.pth'))
|
| 255 |
self.hifi_decoder.load_state_dict(hifi_model, strict=False)
|
| 256 |
# Random latent generators (RLGs) are loaded lazily.
|
|
@@ -309,7 +277,7 @@ class TextToSpeech:
|
|
| 309 |
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
| 310 |
for audio_frame in self.tts(text, **settings):
|
| 311 |
yield audio_frame
|
| 312 |
-
|
| 313 |
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
| 314 |
"""Handle chunk formatting in streaming mode"""
|
| 315 |
wav_chunk = wav_gen[:-overlap_len]
|
|
@@ -413,7 +381,7 @@ class TextToSpeech:
|
|
| 413 |
wav_gen_prev = None
|
| 414 |
wav_overlap = None
|
| 415 |
is_end = False
|
| 416 |
-
first_buffer =
|
| 417 |
while not is_end:
|
| 418 |
try:
|
| 419 |
with torch.autocast(
|
|
@@ -428,7 +396,7 @@ class TextToSpeech:
|
|
| 428 |
if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)):
|
| 429 |
first_buffer = 0
|
| 430 |
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
| 431 |
-
wav_gen = self.hifi_decoder.inference(gpt_latents.
|
| 432 |
wav_gen = wav_gen.squeeze()
|
| 433 |
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
| 434 |
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
|
|
|
| 38 |
'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth',
|
| 39 |
}
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def get_model_path(model_name, models_dir=MODELS_DIR):
|
| 42 |
"""
|
| 43 |
Get path to given model, download it if it doesn't exist.
|
| 44 |
"""
|
| 45 |
if model_name not in MODELS:
|
| 46 |
raise ValueError(f'Model {model_name} not found in available models.')
|
| 47 |
+
model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=MODELS_DIR)
|
|
|
|
|
|
|
| 48 |
return model_path
|
| 49 |
|
| 50 |
|
|
|
|
| 211 |
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
| 212 |
model_dim=1024,
|
| 213 |
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
| 214 |
+
train_solo_embeddings=False).to(self.device).eval()
|
| 215 |
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
|
| 216 |
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
|
| 217 |
|
| 218 |
self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1",
|
| 219 |
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11],
|
| 220 |
upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2],
|
| 221 |
+
cond_channels=1024).to(self.device).eval()
|
| 222 |
hifi_model = torch.load(get_model_path('hifidecoder.pth'))
|
| 223 |
self.hifi_decoder.load_state_dict(hifi_model, strict=False)
|
| 224 |
# Random latent generators (RLGs) are loaded lazily.
|
|
|
|
| 277 |
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
| 278 |
for audio_frame in self.tts(text, **settings):
|
| 279 |
yield audio_frame
|
| 280 |
+
|
| 281 |
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
| 282 |
"""Handle chunk formatting in streaming mode"""
|
| 283 |
wav_chunk = wav_gen[:-overlap_len]
|
|
|
|
| 381 |
wav_gen_prev = None
|
| 382 |
wav_overlap = None
|
| 383 |
is_end = False
|
| 384 |
+
first_buffer = 40
|
| 385 |
while not is_end:
|
| 386 |
try:
|
| 387 |
with torch.autocast(
|
|
|
|
| 396 |
if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)):
|
| 397 |
first_buffer = 0
|
| 398 |
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
| 399 |
+
wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning)
|
| 400 |
wav_gen = wav_gen.squeeze()
|
| 401 |
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
| 402 |
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|