Spaces:
Build error
Build error
jason-on-salt-a40
commited on
Commit
·
78774ba
1
Parent(s):
c1908d8
hf model download
Browse files- app.py +11 -16
- models/voicecraft.py +8 -2
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -93,27 +93,22 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
|
|
| 93 |
transcribe_model = WhisperxModel(whisper_model_name, align_model)
|
| 94 |
|
| 95 |
voicecraft_name = f"{voicecraft_model_name}.pth"
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
encodec_fn = f"{MODELS_PATH}/encodec_4cb2048_giga.th"
|
| 98 |
-
if not os.path.exists(ckpt_fn):
|
| 99 |
-
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
|
| 100 |
-
os.system(f"mv {voicecraft_name}\?download\=true {MODELS_PATH}/{voicecraft_name}")
|
| 101 |
if not os.path.exists(encodec_fn):
|
| 102 |
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
|
| 103 |
-
os.system(f"mv encodec_4cb2048_giga.th {MODELS_PATH}/encodec_4cb2048_giga.th")
|
| 104 |
|
| 105 |
-
ckpt = torch.load(ckpt_fn, map_location="cpu")
|
| 106 |
-
model = voicecraft.VoiceCraft(ckpt["config"])
|
| 107 |
-
model.load_state_dict(ckpt["model"])
|
| 108 |
-
model.to(device)
|
| 109 |
-
model.eval()
|
| 110 |
voicecraft_model = {
|
| 111 |
-
"
|
|
|
|
| 112 |
"model": model,
|
| 113 |
"text_tokenizer": TextTokenizer(backend="espeak"),
|
| 114 |
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
|
| 115 |
}
|
| 116 |
-
|
| 117 |
return gr.Accordion()
|
| 118 |
|
| 119 |
|
|
@@ -255,8 +250,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
|
|
| 255 |
|
| 256 |
prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
|
| 257 |
_, gen_audio = inference_one_sample(voicecraft_model["model"],
|
| 258 |
-
voicecraft_model["
|
| 259 |
-
voicecraft_model["
|
| 260 |
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
|
| 261 |
audio_path, target_transcript, device, decode_config,
|
| 262 |
prompt_end_frame)
|
|
@@ -284,8 +279,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
|
|
| 284 |
mask_interval = torch.LongTensor(mask_interval)
|
| 285 |
|
| 286 |
_, gen_audio = inference_one_sample(voicecraft_model["model"],
|
| 287 |
-
voicecraft_model["
|
| 288 |
-
voicecraft_model["
|
| 289 |
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
|
| 290 |
audio_path, target_transcript, mask_interval, device, decode_config)
|
| 291 |
gen_audio = gen_audio[0].cpu()
|
|
|
|
| 93 |
transcribe_model = WhisperxModel(whisper_model_name, align_model)
|
| 94 |
|
| 95 |
voicecraft_name = f"{voicecraft_model_name}.pth"
|
| 96 |
+
model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
|
| 97 |
+
phn2num = model.args.phn2num
|
| 98 |
+
config = model.args
|
| 99 |
+
model.to(device)
|
| 100 |
+
|
| 101 |
encodec_fn = f"{MODELS_PATH}/encodec_4cb2048_giga.th"
|
|
|
|
|
|
|
|
|
|
| 102 |
if not os.path.exists(encodec_fn):
|
| 103 |
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
|
|
|
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
voicecraft_model = {
|
| 106 |
+
"config": config,
|
| 107 |
+
"phn2num": phn2num,
|
| 108 |
"model": model,
|
| 109 |
"text_tokenizer": TextTokenizer(backend="espeak"),
|
| 110 |
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
|
| 111 |
}
|
|
|
|
| 112 |
return gr.Accordion()
|
| 113 |
|
| 114 |
|
|
|
|
| 250 |
|
| 251 |
prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
|
| 252 |
_, gen_audio = inference_one_sample(voicecraft_model["model"],
|
| 253 |
+
voicecraft_model["config"],
|
| 254 |
+
voicecraft_model["phn2num"],
|
| 255 |
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
|
| 256 |
audio_path, target_transcript, device, decode_config,
|
| 257 |
prompt_end_frame)
|
|
|
|
| 279 |
mask_interval = torch.LongTensor(mask_interval)
|
| 280 |
|
| 281 |
_, gen_audio = inference_one_sample(voicecraft_model["model"],
|
| 282 |
+
voicecraft_model["config"],
|
| 283 |
+
voicecraft_model["phn2num"],
|
| 284 |
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
|
| 285 |
audio_path, target_transcript, mask_interval, device, decode_config)
|
| 286 |
gen_audio = gen_audio[0].cpu()
|
models/voicecraft.py
CHANGED
|
@@ -17,7 +17,8 @@ from .modules.transformer import (
|
|
| 17 |
TransformerEncoderLayer,
|
| 18 |
)
|
| 19 |
from .codebooks_patterns import DelayedPatternProvider
|
| 20 |
-
|
|
|
|
| 21 |
def top_k_top_p_filtering(
|
| 22 |
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
| 23 |
):
|
|
@@ -1403,4 +1404,9 @@ class VoiceCraft(nn.Module):
|
|
| 1403 |
res = res - int(self.args.n_special)
|
| 1404 |
flatten_gen = flatten_gen - int(self.args.n_special)
|
| 1405 |
|
| 1406 |
-
return res, flatten_gen[0].unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
TransformerEncoderLayer,
|
| 18 |
)
|
| 19 |
from .codebooks_patterns import DelayedPatternProvider
|
| 20 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 21 |
+
from argparse import Namespace
|
| 22 |
def top_k_top_p_filtering(
|
| 23 |
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
| 24 |
):
|
|
|
|
| 1404 |
res = res - int(self.args.n_special)
|
| 1405 |
flatten_gen = flatten_gen - int(self.args.n_special)
|
| 1406 |
|
| 1407 |
+
return res, flatten_gen[0].unsqueeze(0)
|
| 1408 |
+
|
| 1409 |
+
class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]):
|
| 1410 |
+
def __init__(self, config: dict):
|
| 1411 |
+
args = Namespace(**config)
|
| 1412 |
+
super().__init__(args)
|
requirements.txt
CHANGED
|
@@ -5,4 +5,5 @@ nltk>=3.8.1
|
|
| 5 |
openai-whisper>=20231117
|
| 6 |
spaces
|
| 7 |
aeneas==1.7.3.0
|
| 8 |
-
whisperx==3.1.1
|
|
|
|
|
|
| 5 |
openai-whisper>=20231117
|
| 6 |
spaces
|
| 7 |
aeneas==1.7.3.0
|
| 8 |
+
whisperx==3.1.1
|
| 9 |
+
huggingface-hub==0.22.2
|