Spaces:
Running
Running
zhzluke96
commited on
Commit
·
d2b7e94
1
Parent(s):
9d9fe0d
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.webui +2 -2
- README.md +2 -2
- launch.py +11 -10
- modules/ChatTTS/ChatTTS/__init__.py +1 -1
- modules/ChatTTS/ChatTTS/core.py +7 -7
- modules/ChatTTS/ChatTTS/infer/api.py +1 -0
- modules/ChatTTS/ChatTTS/model/dvae.py +79 -48
- modules/ChatTTS/ChatTTS/model/gpt.py +167 -87
- modules/ChatTTS/ChatTTS/utils/infer_utils.py +1 -0
- modules/ChatTTS/ChatTTS/utils/io_utils.py +6 -6
- modules/Denoiser/AudioDenoiser.py +5 -3
- modules/Denoiser/AudioNosiseModel.py +2 -3
- modules/Enhancer/ResembleEnhance.py +6 -9
- modules/SentenceSplitter.py +1 -0
- modules/SynthesizeSegments.py +24 -18
- modules/api/Api.py +3 -5
- modules/api/api_setup.py +11 -13
- modules/api/impl/google_api.py +2 -6
- modules/api/impl/handler/AudioHandler.py +2 -1
- modules/api/impl/handler/SSMLHandler.py +3 -3
- modules/api/impl/handler/TTSHandler.py +2 -2
- modules/api/impl/model/enhancer_model.py +1 -0
- modules/api/impl/models_api.py +1 -1
- modules/api/impl/openai_api.py +6 -11
- modules/api/impl/ping_api.py +1 -2
- modules/api/impl/refiner_api.py +0 -3
- modules/api/impl/speaker_api.py +3 -2
- modules/api/impl/ssml_api.py +3 -8
- modules/api/impl/style_api.py +1 -1
- modules/api/impl/tts_api.py +3 -7
- modules/api/impl/xtts_v2_api.py +5 -7
- modules/api/utils.py +3 -7
- modules/api/worker.py +2 -1
- modules/config.py +2 -2
- modules/data.py +0 -1
- modules/denoise.py +3 -5
- modules/devices/devices.py +4 -3
- modules/devices/mac_devices.py +3 -2
- modules/ffmpeg_env.py +2 -1
- modules/finetune/train_speaker.py +8 -5
- modules/finetune/utils/dataset.py +6 -6
- modules/finetune/utils/logger.py +3 -4
- modules/generate_audio.py +7 -10
- modules/models.py +5 -5
- modules/normalization.py +5 -3
- modules/prompts/news_oral_prompt.txt +23 -4
- modules/refiner.py +1 -2
- modules/repos_static/resemble_enhance/common.py +3 -1
- modules/repos_static/resemble_enhance/data/dataset.py +21 -7
- modules/repos_static/resemble_enhance/data/distorter/base.py +1 -1
.env.webui
CHANGED
|
@@ -14,9 +14,9 @@ DEBUG_GENERATE=True
|
|
| 14 |
PRELOAD_MODELS=True
|
| 15 |
|
| 16 |
# Text-to-Speech (TTS) configuration
|
| 17 |
-
TTS_MAX_LEN=
|
| 18 |
SSML_MAX_LEN=3000
|
| 19 |
MAX_BATCH_SIZE=12
|
| 20 |
|
| 21 |
-
V_GIT_TAG="🤗hf(0.6.1
|
| 22 |
V_GIT_COMMIT=main
|
|
|
|
| 14 |
PRELOAD_MODELS=True
|
| 15 |
|
| 16 |
# Text-to-Speech (TTS) configuration
|
| 17 |
+
TTS_MAX_LEN=2000
|
| 18 |
SSML_MAX_LEN=3000
|
| 19 |
MAX_BATCH_SIZE=12
|
| 20 |
|
| 21 |
+
V_GIT_TAG="🤗hf(0.6.1)"
|
| 22 |
V_GIT_COMMIT=main
|
README.md
CHANGED
|
@@ -16,7 +16,7 @@ sdk_version: 4.36.1
|
|
| 16 |
|
| 17 |
| 类型 | 最大字符数 |
|
| 18 |
|------|-----------|
|
| 19 |
-
| TTS |
|
| 20 |
| SSML | 3000 字符(不计算 SSML 标签,只计算文本) |
|
| 21 |
|
| 22 |
# HuggingFace Space Limit
|
|
@@ -25,7 +25,7 @@ Due to the runtime limit for GPU usage on HuggingFace, extremely long tasks will
|
|
| 25 |
|
| 26 |
| Type | Maximum Characters |
|
| 27 |
|------|---------------------|
|
| 28 |
-
| TTS |
|
| 29 |
| SSML | 3000 characters (excluding SSML tags, only counting text) |
|
| 30 |
|
| 31 |
# 🗣️ ChatTTS-Forge
|
|
|
|
| 16 |
|
| 17 |
| 类型 | 最大字符数 |
|
| 18 |
|------|-----------|
|
| 19 |
+
| TTS | 2000 字符 |
|
| 20 |
| SSML | 3000 字符(不计算 SSML 标签,只计算文本) |
|
| 21 |
|
| 22 |
# HuggingFace Space Limit
|
|
|
|
| 25 |
|
| 26 |
| Type | Maximum Characters |
|
| 27 |
|------|---------------------|
|
| 28 |
+
| TTS | 2000 characters |
|
| 29 |
| SSML | 3000 characters (excluding SSML tags, only counting text) |
|
| 30 |
|
| 31 |
# 🗣️ ChatTTS-Forge
|
launch.py
CHANGED
|
@@ -1,23 +1,24 @@
|
|
| 1 |
-
import os
|
| 2 |
import logging
|
|
|
|
| 3 |
|
| 4 |
-
from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
|
| 5 |
from modules.ffmpeg_env import setup_ffmpeg_path
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
)
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
import argparse
|
|
|
|
| 14 |
import uvicorn
|
| 15 |
|
| 16 |
-
from modules import
|
| 17 |
from modules.utils import env
|
| 18 |
|
| 19 |
-
from fastapi import FastAPI
|
| 20 |
-
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
| 23 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import os
|
| 3 |
|
|
|
|
| 4 |
from modules.ffmpeg_env import setup_ffmpeg_path
|
| 5 |
|
| 6 |
+
try:
|
| 7 |
+
setup_ffmpeg_path()
|
| 8 |
+
logging.basicConfig(
|
| 9 |
+
level=os.getenv("LOG_LEVEL", "INFO"),
|
| 10 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 11 |
+
)
|
| 12 |
+
except BaseException:
|
| 13 |
+
pass
|
| 14 |
|
| 15 |
import argparse
|
| 16 |
+
|
| 17 |
import uvicorn
|
| 18 |
|
| 19 |
+
from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
|
| 20 |
from modules.utils import env
|
| 21 |
|
|
|
|
|
|
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
if __name__ == "__main__":
|
modules/ChatTTS/ChatTTS/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
from .core import Chat
|
|
|
|
| 1 |
+
from .core import Chat
|
modules/ChatTTS/ChatTTS/core.py
CHANGED
|
@@ -1,21 +1,21 @@
|
|
| 1 |
-
import os
|
| 2 |
import logging
|
| 3 |
-
|
| 4 |
|
| 5 |
import torch
|
|
|
|
|
|
|
| 6 |
from vocos import Vocos
|
|
|
|
|
|
|
| 7 |
from .model.dvae import DVAE
|
| 8 |
from .model.gpt import GPT_warpper
|
| 9 |
from .utils.infer_utils import (
|
| 10 |
-
count_invalid_characters,
|
| 11 |
-
detect_language,
|
| 12 |
apply_character_map,
|
| 13 |
apply_half2full_map,
|
|
|
|
|
|
|
| 14 |
)
|
| 15 |
from .utils.io_utils import get_latest_modified_file
|
| 16 |
-
from .infer.api import refine_text, infer_code
|
| 17 |
-
|
| 18 |
-
from huggingface_hub import snapshot_download
|
| 19 |
|
| 20 |
logging.basicConfig(level=logging.INFO)
|
| 21 |
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import os
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
+
from omegaconf import OmegaConf
|
| 7 |
from vocos import Vocos
|
| 8 |
+
|
| 9 |
+
from .infer.api import infer_code, refine_text
|
| 10 |
from .model.dvae import DVAE
|
| 11 |
from .model.gpt import GPT_warpper
|
| 12 |
from .utils.infer_utils import (
|
|
|
|
|
|
|
| 13 |
apply_character_map,
|
| 14 |
apply_half2full_map,
|
| 15 |
+
count_invalid_characters,
|
| 16 |
+
detect_language,
|
| 17 |
)
|
| 18 |
from .utils.io_utils import get_latest_modified_file
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
logging.basicConfig(level=logging.INFO)
|
| 21 |
|
modules/ChatTTS/ChatTTS/infer/api.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
|
|
|
| 4 |
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
|
| 5 |
|
| 6 |
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
| 4 |
+
|
| 5 |
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
|
| 6 |
|
| 7 |
|
modules/ChatTTS/ChatTTS/model/dvae.py
CHANGED
|
@@ -1,28 +1,36 @@
|
|
| 1 |
import math
|
| 2 |
-
from einops import rearrange
|
| 3 |
-
from vector_quantize_pytorch import GroupedResidualFSQ
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class ConvNeXtBlock(nn.Module):
|
| 10 |
def __init__(
|
| 11 |
self,
|
| 12 |
dim: int,
|
| 13 |
intermediate_dim: int,
|
| 14 |
-
kernel,
|
|
|
|
| 15 |
layer_scale_init_value: float = 1e-6,
|
| 16 |
):
|
| 17 |
# ConvNeXt Block copied from Vocos.
|
| 18 |
super().__init__()
|
| 19 |
-
self.dwconv = nn.Conv1d(
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 25 |
-
self.pwconv1 = nn.Linear(
|
|
|
|
|
|
|
| 26 |
self.act = nn.GELU()
|
| 27 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 28 |
self.gamma = (
|
|
@@ -31,7 +39,7 @@ class ConvNeXtBlock(nn.Module):
|
|
| 31 |
else None
|
| 32 |
)
|
| 33 |
|
| 34 |
-
def forward(self, x: torch.Tensor, cond
|
| 35 |
residual = x
|
| 36 |
x = self.dwconv(x)
|
| 37 |
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
|
@@ -45,14 +53,11 @@ class ConvNeXtBlock(nn.Module):
|
|
| 45 |
|
| 46 |
x = residual + x
|
| 47 |
return x
|
| 48 |
-
|
| 49 |
|
| 50 |
|
| 51 |
class GFSQ(nn.Module):
|
| 52 |
|
| 53 |
-
def __init__(self,
|
| 54 |
-
dim, levels, G, R, eps=1e-5, transpose = True
|
| 55 |
-
):
|
| 56 |
super(GFSQ, self).__init__()
|
| 57 |
self.quantizer = GroupedResidualFSQ(
|
| 58 |
dim=dim,
|
|
@@ -65,50 +70,74 @@ class GFSQ(nn.Module):
|
|
| 65 |
self.transpose = transpose
|
| 66 |
self.G = G
|
| 67 |
self.R = R
|
| 68 |
-
|
| 69 |
def _embed(self, x):
|
| 70 |
if self.transpose:
|
| 71 |
-
x = x.transpose(1,2)
|
| 72 |
x = rearrange(
|
| 73 |
-
x,
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
| 75 |
feat = self.quantizer.get_output_from_indices(x)
|
| 76 |
-
return feat.transpose(1,2) if self.transpose else feat
|
| 77 |
-
|
| 78 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
| 79 |
if self.transpose:
|
| 80 |
-
x = x.transpose(1,2)
|
| 81 |
feat, ind = self.quantizer(x)
|
| 82 |
ind = rearrange(
|
| 83 |
-
ind,
|
| 84 |
-
|
|
|
|
| 85 |
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
|
| 86 |
-
e_mean = torch.mean(embed_onehot, dim=[0,1])
|
| 87 |
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
|
| 88 |
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
|
| 89 |
-
|
| 90 |
return (
|
| 91 |
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
|
| 92 |
-
feat.transpose(1,2) if self.transpose else feat,
|
| 93 |
perplexity,
|
| 94 |
None,
|
| 95 |
-
ind.transpose(1,2) if self.transpose else ind,
|
| 96 |
)
|
| 97 |
-
|
|
|
|
| 98 |
class DVAEDecoder(nn.Module):
|
| 99 |
-
def __init__(
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
super().__init__()
|
| 104 |
self.up = up
|
| 105 |
self.conv_in = nn.Sequential(
|
| 106 |
-
nn.Conv1d(idim, bn_dim, 3, 1, 1),
|
| 107 |
-
nn.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
)
|
| 109 |
-
self.decoder_block = nn.ModuleList([
|
| 110 |
-
ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
|
| 111 |
-
for _ in range(n_layer)])
|
| 112 |
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
| 113 |
|
| 114 |
def forward(self, input, conditioning=None):
|
|
@@ -117,17 +146,15 @@ class DVAEDecoder(nn.Module):
|
|
| 117 |
x = self.conv_in(x)
|
| 118 |
for f in self.decoder_block:
|
| 119 |
x = f(x, conditioning)
|
| 120 |
-
|
| 121 |
x = self.conv_out(x)
|
| 122 |
return x.transpose(1, 2)
|
| 123 |
-
|
| 124 |
|
| 125 |
class DVAE(nn.Module):
|
| 126 |
-
def __init__(
|
| 127 |
-
self, decoder_config, vq_config, dim=512
|
| 128 |
-
):
|
| 129 |
super().__init__()
|
| 130 |
-
self.register_buffer(
|
| 131 |
|
| 132 |
self.decoder = DVAEDecoder(**decoder_config)
|
| 133 |
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
|
|
@@ -142,10 +169,14 @@ class DVAE(nn.Module):
|
|
| 142 |
vq_feats = self.vq_layer._embed(inp)
|
| 143 |
else:
|
| 144 |
vq_feats = inp.detach().clone()
|
| 145 |
-
|
| 146 |
-
vq_feats =
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
vq_feats = vq_feats.transpose(1, 2)
|
| 151 |
dec_out = self.decoder(input=vq_feats)
|
|
|
|
| 1 |
import math
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from vector_quantize_pytorch import GroupedResidualFSQ
|
| 8 |
+
|
| 9 |
|
| 10 |
class ConvNeXtBlock(nn.Module):
|
| 11 |
def __init__(
|
| 12 |
self,
|
| 13 |
dim: int,
|
| 14 |
intermediate_dim: int,
|
| 15 |
+
kernel,
|
| 16 |
+
dilation,
|
| 17 |
layer_scale_init_value: float = 1e-6,
|
| 18 |
):
|
| 19 |
# ConvNeXt Block copied from Vocos.
|
| 20 |
super().__init__()
|
| 21 |
+
self.dwconv = nn.Conv1d(
|
| 22 |
+
dim,
|
| 23 |
+
dim,
|
| 24 |
+
kernel_size=kernel,
|
| 25 |
+
padding=dilation * (kernel // 2),
|
| 26 |
+
dilation=dilation,
|
| 27 |
+
groups=dim,
|
| 28 |
+
) # depthwise conv
|
| 29 |
+
|
| 30 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 31 |
+
self.pwconv1 = nn.Linear(
|
| 32 |
+
dim, intermediate_dim
|
| 33 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 34 |
self.act = nn.GELU()
|
| 35 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 36 |
self.gamma = (
|
|
|
|
| 39 |
else None
|
| 40 |
)
|
| 41 |
|
| 42 |
+
def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
|
| 43 |
residual = x
|
| 44 |
x = self.dwconv(x)
|
| 45 |
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
|
|
|
| 53 |
|
| 54 |
x = residual + x
|
| 55 |
return x
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
class GFSQ(nn.Module):
|
| 59 |
|
| 60 |
+
def __init__(self, dim, levels, G, R, eps=1e-5, transpose=True):
|
|
|
|
|
|
|
| 61 |
super(GFSQ, self).__init__()
|
| 62 |
self.quantizer = GroupedResidualFSQ(
|
| 63 |
dim=dim,
|
|
|
|
| 70 |
self.transpose = transpose
|
| 71 |
self.G = G
|
| 72 |
self.R = R
|
| 73 |
+
|
| 74 |
def _embed(self, x):
|
| 75 |
if self.transpose:
|
| 76 |
+
x = x.transpose(1, 2)
|
| 77 |
x = rearrange(
|
| 78 |
+
x,
|
| 79 |
+
"b t (g r) -> g b t r",
|
| 80 |
+
g=self.G,
|
| 81 |
+
r=self.R,
|
| 82 |
+
)
|
| 83 |
feat = self.quantizer.get_output_from_indices(x)
|
| 84 |
+
return feat.transpose(1, 2) if self.transpose else feat
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
x,
|
| 89 |
+
):
|
| 90 |
if self.transpose:
|
| 91 |
+
x = x.transpose(1, 2)
|
| 92 |
feat, ind = self.quantizer(x)
|
| 93 |
ind = rearrange(
|
| 94 |
+
ind,
|
| 95 |
+
"g b t r ->b t (g r)",
|
| 96 |
+
)
|
| 97 |
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
|
| 98 |
+
e_mean = torch.mean(embed_onehot, dim=[0, 1])
|
| 99 |
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
|
| 100 |
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
|
| 101 |
+
|
| 102 |
return (
|
| 103 |
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
|
| 104 |
+
feat.transpose(1, 2) if self.transpose else feat,
|
| 105 |
perplexity,
|
| 106 |
None,
|
| 107 |
+
ind.transpose(1, 2) if self.transpose else ind,
|
| 108 |
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
class DVAEDecoder(nn.Module):
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
idim,
|
| 115 |
+
odim,
|
| 116 |
+
n_layer=12,
|
| 117 |
+
bn_dim=64,
|
| 118 |
+
hidden=256,
|
| 119 |
+
kernel=7,
|
| 120 |
+
dilation=2,
|
| 121 |
+
up=False,
|
| 122 |
+
):
|
| 123 |
super().__init__()
|
| 124 |
self.up = up
|
| 125 |
self.conv_in = nn.Sequential(
|
| 126 |
+
nn.Conv1d(idim, bn_dim, 3, 1, 1),
|
| 127 |
+
nn.GELU(),
|
| 128 |
+
nn.Conv1d(bn_dim, hidden, 3, 1, 1),
|
| 129 |
+
)
|
| 130 |
+
self.decoder_block = nn.ModuleList(
|
| 131 |
+
[
|
| 132 |
+
ConvNeXtBlock(
|
| 133 |
+
hidden,
|
| 134 |
+
hidden * 4,
|
| 135 |
+
kernel,
|
| 136 |
+
dilation,
|
| 137 |
+
)
|
| 138 |
+
for _ in range(n_layer)
|
| 139 |
+
]
|
| 140 |
)
|
|
|
|
|
|
|
|
|
|
| 141 |
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
| 142 |
|
| 143 |
def forward(self, input, conditioning=None):
|
|
|
|
| 146 |
x = self.conv_in(x)
|
| 147 |
for f in self.decoder_block:
|
| 148 |
x = f(x, conditioning)
|
| 149 |
+
|
| 150 |
x = self.conv_out(x)
|
| 151 |
return x.transpose(1, 2)
|
| 152 |
+
|
| 153 |
|
| 154 |
class DVAE(nn.Module):
|
| 155 |
+
def __init__(self, decoder_config, vq_config, dim=512):
|
|
|
|
|
|
|
| 156 |
super().__init__()
|
| 157 |
+
self.register_buffer("coef", torch.randn(1, 100, 1))
|
| 158 |
|
| 159 |
self.decoder = DVAEDecoder(**decoder_config)
|
| 160 |
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
|
|
|
|
| 169 |
vq_feats = self.vq_layer._embed(inp)
|
| 170 |
else:
|
| 171 |
vq_feats = inp.detach().clone()
|
| 172 |
+
|
| 173 |
+
vq_feats = (
|
| 174 |
+
vq_feats.view(
|
| 175 |
+
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
|
| 176 |
+
)
|
| 177 |
+
.permute(0, 2, 3, 1)
|
| 178 |
+
.flatten(2)
|
| 179 |
+
)
|
| 180 |
|
| 181 |
vq_feats = vq_feats.transpose(1, 2)
|
| 182 |
dec_out = self.decoder(input=vq_feats)
|
modules/ChatTTS/ChatTTS/model/gpt.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 3 |
|
| 4 |
import logging
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
from einops import rearrange
|
| 7 |
-
from transformers.cache_utils import Cache
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
import torch.nn.utils.parametrize as P
|
|
|
|
| 13 |
from torch.nn.utils.parametrizations import weight_norm
|
| 14 |
-
from
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
class LlamaMLP(nn.Module):
|
| 18 |
def __init__(self, hidden_size, intermediate_size):
|
| 19 |
super().__init__()
|
|
@@ -27,70 +28,106 @@ class LlamaMLP(nn.Module):
|
|
| 27 |
def forward(self, x):
|
| 28 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 29 |
return down_proj
|
| 30 |
-
|
| 31 |
-
|
| 32 |
class GPT_warpper(nn.Module):
|
| 33 |
def __init__(
|
| 34 |
-
self,
|
| 35 |
-
gpt_config,
|
| 36 |
num_audio_tokens,
|
| 37 |
num_text_tokens,
|
| 38 |
num_vq=4,
|
| 39 |
**kwargs,
|
| 40 |
-
|
| 41 |
super().__init__()
|
| 42 |
|
| 43 |
self.logger = logging.getLogger(__name__)
|
| 44 |
self.gpt = self.build_model(gpt_config)
|
| 45 |
-
self.model_dim = self.gpt.config.hidden_size
|
| 46 |
|
| 47 |
self.num_vq = num_vq
|
| 48 |
-
self.emb_code = nn.ModuleList(
|
|
|
|
|
|
|
| 49 |
self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
|
| 50 |
-
self.head_text = weight_norm(
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def build_model(self, config):
|
| 54 |
-
|
| 55 |
configuration = LlamaConfig(**config)
|
| 56 |
model = LlamaModel(configuration)
|
| 57 |
del model.embed_tokens
|
| 58 |
-
|
| 59 |
return model
|
| 60 |
-
|
| 61 |
def get_emb(self, input_ids, text_mask, **kwargs):
|
| 62 |
|
| 63 |
emb_text = self.emb_text(input_ids[text_mask][:, 0])
|
| 64 |
-
|
| 65 |
-
emb_code = [
|
|
|
|
|
|
|
| 66 |
emb_code = torch.stack(emb_code, 2).sum(2)
|
| 67 |
-
|
| 68 |
-
emb = torch.zeros(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
emb[text_mask] = emb_text
|
| 70 |
emb[~text_mask] = emb_code.to(emb.dtype)
|
| 71 |
-
|
| 72 |
return emb
|
| 73 |
-
|
| 74 |
def prepare_inputs_for_generation(
|
| 75 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
):
|
| 77 |
# With static cache, the `past_key_values` is None
|
| 78 |
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
| 79 |
has_static_cache = False
|
| 80 |
if past_key_values is None:
|
| 81 |
-
past_key_values = getattr(
|
|
|
|
|
|
|
| 82 |
has_static_cache = past_key_values is not None
|
| 83 |
|
| 84 |
past_length = 0
|
| 85 |
if past_key_values is not None:
|
| 86 |
if isinstance(past_key_values, Cache):
|
| 87 |
-
past_length =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
max_cache_length = (
|
| 89 |
-
torch.tensor(
|
|
|
|
|
|
|
| 90 |
if past_key_values.get_max_length() is not None
|
| 91 |
else None
|
| 92 |
)
|
| 93 |
-
cache_length =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
| 95 |
else:
|
| 96 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
|
@@ -100,7 +137,10 @@ class GPT_warpper(nn.Module):
|
|
| 100 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 101 |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 102 |
# input)
|
| 103 |
-
if
|
|
|
|
|
|
|
|
|
|
| 104 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 105 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 106 |
# input_ids based on the past_length.
|
|
@@ -133,9 +173,13 @@ class GPT_warpper(nn.Module):
|
|
| 133 |
# TODO: use `next_tokens` directly instead.
|
| 134 |
model_inputs = {"input_ids": input_ids.contiguous()}
|
| 135 |
|
| 136 |
-
input_length =
|
|
|
|
|
|
|
| 137 |
if cache_position is None:
|
| 138 |
-
cache_position = torch.arange(
|
|
|
|
|
|
|
| 139 |
else:
|
| 140 |
cache_position = cache_position[-input_length:]
|
| 141 |
|
|
@@ -152,118 +196,154 @@ class GPT_warpper(nn.Module):
|
|
| 152 |
}
|
| 153 |
)
|
| 154 |
return model_inputs
|
| 155 |
-
|
| 156 |
def generate(
|
| 157 |
-
self,
|
| 158 |
-
emb,
|
| 159 |
-
inputs_ids,
|
| 160 |
-
temperature,
|
| 161 |
-
eos_token,
|
| 162 |
-
attention_mask
|
| 163 |
-
max_new_token
|
| 164 |
-
min_new_token
|
| 165 |
-
LogitsWarpers
|
| 166 |
-
LogitsProcessors
|
| 167 |
infer_text=False,
|
| 168 |
return_attn=False,
|
| 169 |
return_hidden=False,
|
| 170 |
-
disable_tqdm=False
|
| 171 |
):
|
| 172 |
if disable_tqdm:
|
| 173 |
tqdm = lambda x: x
|
| 174 |
else:
|
| 175 |
from tqdm import tqdm
|
| 176 |
-
|
| 177 |
-
with torch.no_grad():
|
| 178 |
-
|
| 179 |
attentions = []
|
| 180 |
hiddens = []
|
| 181 |
-
|
| 182 |
-
start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
|
|
|
|
|
|
|
| 183 |
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
|
| 184 |
-
|
| 185 |
temperature = temperature[None].expand(inputs_ids.shape[0], -1)
|
| 186 |
temperature = rearrange(temperature, "b n -> (b n) 1")
|
| 187 |
|
| 188 |
-
attention_mask_cache = torch.ones(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
if attention_mask is not None:
|
| 190 |
-
attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
|
| 191 |
-
|
| 192 |
for i in tqdm(range(max_new_token)):
|
| 193 |
if finish.all():
|
| 194 |
continue
|
| 195 |
-
|
| 196 |
-
model_input = self.prepare_inputs_for_generation(
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
| 200 |
if i == 0:
|
| 201 |
-
model_input[
|
| 202 |
else:
|
| 203 |
if infer_text:
|
| 204 |
-
model_input[
|
|
|
|
|
|
|
| 205 |
else:
|
| 206 |
-
code_emb = [
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
| 210 |
outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
|
| 211 |
attentions.append(outputs.attentions)
|
| 212 |
-
hidden_states = outputs[0]
|
| 213 |
if return_hidden:
|
| 214 |
hiddens.append(hidden_states[:, -1])
|
| 215 |
|
| 216 |
with P.cached():
|
| 217 |
if infer_text:
|
| 218 |
-
logits = self.head_text(hidden_states)
|
| 219 |
else:
|
| 220 |
-
logits = torch.stack(
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
logits = logits[:, -1].float()
|
| 223 |
|
| 224 |
if not infer_text:
|
| 225 |
logits = rearrange(logits, "b c n -> (b n) c")
|
| 226 |
-
logits_token = rearrange(
|
|
|
|
|
|
|
| 227 |
else:
|
| 228 |
logits_token = inputs_ids[:, start_idx:, 0]
|
| 229 |
-
|
| 230 |
logits = logits / temperature
|
| 231 |
-
|
| 232 |
for logitsProcessors in LogitsProcessors:
|
| 233 |
logits = logitsProcessors(logits_token, logits)
|
| 234 |
-
|
| 235 |
for logitsWarpers in LogitsWarpers:
|
| 236 |
logits = logitsWarpers(logits_token, logits)
|
| 237 |
-
|
| 238 |
if i < min_new_token:
|
| 239 |
logits[:, eos_token] = -torch.inf
|
| 240 |
-
|
| 241 |
scores = F.softmax(logits, dim=-1)
|
| 242 |
-
|
| 243 |
idx_next = torch.multinomial(scores, num_samples=1)
|
| 244 |
-
|
| 245 |
if not infer_text:
|
| 246 |
idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
| 247 |
finish = finish | (idx_next == eos_token).any(1)
|
| 248 |
inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
|
| 249 |
else:
|
| 250 |
finish = finish | (idx_next == eos_token).any(1)
|
| 251 |
-
inputs_ids = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
end_idx = end_idx + (~finish).int()
|
| 254 |
-
|
| 255 |
-
inputs_ids = [
|
|
|
|
|
|
|
|
|
|
| 256 |
inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
|
| 257 |
-
|
| 258 |
if return_hidden:
|
| 259 |
hiddens = torch.stack(hiddens, 1)
|
| 260 |
hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
|
| 261 |
-
|
| 262 |
if not finish.all():
|
| 263 |
-
self.logger.warn(
|
| 264 |
-
|
|
|
|
|
|
|
| 265 |
return {
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
}
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 4 |
|
| 5 |
import logging
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
import torch.nn.functional as F
|
| 10 |
import torch.nn.utils.parametrize as P
|
| 11 |
+
from einops import rearrange
|
| 12 |
from torch.nn.utils.parametrizations import weight_norm
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from transformers import LlamaConfig, LlamaModel
|
| 15 |
+
from transformers.cache_utils import Cache
|
| 16 |
+
|
| 17 |
+
|
| 18 |
class LlamaMLP(nn.Module):
|
| 19 |
def __init__(self, hidden_size, intermediate_size):
|
| 20 |
super().__init__()
|
|
|
|
| 28 |
def forward(self, x):
|
| 29 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 30 |
return down_proj
|
| 31 |
+
|
| 32 |
+
|
| 33 |
class GPT_warpper(nn.Module):
|
| 34 |
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
gpt_config,
|
| 37 |
num_audio_tokens,
|
| 38 |
num_text_tokens,
|
| 39 |
num_vq=4,
|
| 40 |
**kwargs,
|
| 41 |
+
):
|
| 42 |
super().__init__()
|
| 43 |
|
| 44 |
self.logger = logging.getLogger(__name__)
|
| 45 |
self.gpt = self.build_model(gpt_config)
|
| 46 |
+
self.model_dim = self.gpt.config.hidden_size
|
| 47 |
|
| 48 |
self.num_vq = num_vq
|
| 49 |
+
self.emb_code = nn.ModuleList(
|
| 50 |
+
[nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)]
|
| 51 |
+
)
|
| 52 |
self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
|
| 53 |
+
self.head_text = weight_norm(
|
| 54 |
+
nn.Linear(self.model_dim, num_text_tokens, bias=False), name="weight"
|
| 55 |
+
)
|
| 56 |
+
self.head_code = nn.ModuleList(
|
| 57 |
+
[
|
| 58 |
+
weight_norm(
|
| 59 |
+
nn.Linear(self.model_dim, num_audio_tokens, bias=False),
|
| 60 |
+
name="weight",
|
| 61 |
+
)
|
| 62 |
+
for i in range(self.num_vq)
|
| 63 |
+
]
|
| 64 |
+
)
|
| 65 |
|
| 66 |
def build_model(self, config):
|
| 67 |
+
|
| 68 |
configuration = LlamaConfig(**config)
|
| 69 |
model = LlamaModel(configuration)
|
| 70 |
del model.embed_tokens
|
| 71 |
+
|
| 72 |
return model
|
| 73 |
+
|
| 74 |
def get_emb(self, input_ids, text_mask, **kwargs):
|
| 75 |
|
| 76 |
emb_text = self.emb_text(input_ids[text_mask][:, 0])
|
| 77 |
+
|
| 78 |
+
emb_code = [
|
| 79 |
+
self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)
|
| 80 |
+
]
|
| 81 |
emb_code = torch.stack(emb_code, 2).sum(2)
|
| 82 |
+
|
| 83 |
+
emb = torch.zeros(
|
| 84 |
+
(input_ids.shape[:-1]) + (emb_text.shape[-1],),
|
| 85 |
+
device=emb_text.device,
|
| 86 |
+
dtype=emb_text.dtype,
|
| 87 |
+
)
|
| 88 |
emb[text_mask] = emb_text
|
| 89 |
emb[~text_mask] = emb_code.to(emb.dtype)
|
| 90 |
+
|
| 91 |
return emb
|
| 92 |
+
|
| 93 |
def prepare_inputs_for_generation(
|
| 94 |
+
self,
|
| 95 |
+
input_ids,
|
| 96 |
+
past_key_values=None,
|
| 97 |
+
attention_mask=None,
|
| 98 |
+
inputs_embeds=None,
|
| 99 |
+
cache_position=None,
|
| 100 |
+
**kwargs,
|
| 101 |
):
|
| 102 |
# With static cache, the `past_key_values` is None
|
| 103 |
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
| 104 |
has_static_cache = False
|
| 105 |
if past_key_values is None:
|
| 106 |
+
past_key_values = getattr(
|
| 107 |
+
self.gpt.layers[0].self_attn, "past_key_value", None
|
| 108 |
+
)
|
| 109 |
has_static_cache = past_key_values is not None
|
| 110 |
|
| 111 |
past_length = 0
|
| 112 |
if past_key_values is not None:
|
| 113 |
if isinstance(past_key_values, Cache):
|
| 114 |
+
past_length = (
|
| 115 |
+
cache_position[0]
|
| 116 |
+
if cache_position is not None
|
| 117 |
+
else past_key_values.get_seq_length()
|
| 118 |
+
)
|
| 119 |
max_cache_length = (
|
| 120 |
+
torch.tensor(
|
| 121 |
+
past_key_values.get_max_length(), device=input_ids.device
|
| 122 |
+
)
|
| 123 |
if past_key_values.get_max_length() is not None
|
| 124 |
else None
|
| 125 |
)
|
| 126 |
+
cache_length = (
|
| 127 |
+
past_length
|
| 128 |
+
if max_cache_length is None
|
| 129 |
+
else torch.min(max_cache_length, past_length)
|
| 130 |
+
)
|
| 131 |
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
| 132 |
else:
|
| 133 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
|
|
|
| 137 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 138 |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 139 |
# input)
|
| 140 |
+
if (
|
| 141 |
+
attention_mask is not None
|
| 142 |
+
and attention_mask.shape[1] > input_ids.shape[1]
|
| 143 |
+
):
|
| 144 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 145 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 146 |
# input_ids based on the past_length.
|
|
|
|
| 173 |
# TODO: use `next_tokens` directly instead.
|
| 174 |
model_inputs = {"input_ids": input_ids.contiguous()}
|
| 175 |
|
| 176 |
+
input_length = (
|
| 177 |
+
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
| 178 |
+
)
|
| 179 |
if cache_position is None:
|
| 180 |
+
cache_position = torch.arange(
|
| 181 |
+
past_length, past_length + input_length, device=input_ids.device
|
| 182 |
+
)
|
| 183 |
else:
|
| 184 |
cache_position = cache_position[-input_length:]
|
| 185 |
|
|
|
|
| 196 |
}
|
| 197 |
)
|
| 198 |
return model_inputs
|
| 199 |
+
|
| 200 |
def generate(
|
| 201 |
+
self,
|
| 202 |
+
emb,
|
| 203 |
+
inputs_ids,
|
| 204 |
+
temperature,
|
| 205 |
+
eos_token,
|
| 206 |
+
attention_mask=None,
|
| 207 |
+
max_new_token=2048,
|
| 208 |
+
min_new_token=0,
|
| 209 |
+
LogitsWarpers=[],
|
| 210 |
+
LogitsProcessors=[],
|
| 211 |
infer_text=False,
|
| 212 |
return_attn=False,
|
| 213 |
return_hidden=False,
|
| 214 |
+
disable_tqdm=False,
|
| 215 |
):
|
| 216 |
if disable_tqdm:
|
| 217 |
tqdm = lambda x: x
|
| 218 |
else:
|
| 219 |
from tqdm import tqdm
|
| 220 |
+
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
|
| 223 |
attentions = []
|
| 224 |
hiddens = []
|
| 225 |
+
|
| 226 |
+
start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
|
| 227 |
+
inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
|
| 228 |
+
)
|
| 229 |
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
|
| 230 |
+
|
| 231 |
temperature = temperature[None].expand(inputs_ids.shape[0], -1)
|
| 232 |
temperature = rearrange(temperature, "b n -> (b n) 1")
|
| 233 |
|
| 234 |
+
attention_mask_cache = torch.ones(
|
| 235 |
+
(
|
| 236 |
+
inputs_ids.shape[0],
|
| 237 |
+
inputs_ids.shape[1] + max_new_token,
|
| 238 |
+
),
|
| 239 |
+
dtype=torch.bool,
|
| 240 |
+
device=inputs_ids.device,
|
| 241 |
+
)
|
| 242 |
if attention_mask is not None:
|
| 243 |
+
attention_mask_cache[:, : attention_mask.shape[1]] = attention_mask
|
| 244 |
+
|
| 245 |
for i in tqdm(range(max_new_token)):
|
| 246 |
if finish.all():
|
| 247 |
continue
|
| 248 |
+
|
| 249 |
+
model_input = self.prepare_inputs_for_generation(
|
| 250 |
+
inputs_ids,
|
| 251 |
+
outputs.past_key_values if i != 0 else None,
|
| 252 |
+
attention_mask_cache[:, : inputs_ids.shape[1]],
|
| 253 |
+
use_cache=True,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
if i == 0:
|
| 257 |
+
model_input["inputs_embeds"] = emb
|
| 258 |
else:
|
| 259 |
if infer_text:
|
| 260 |
+
model_input["inputs_embeds"] = self.emb_text(
|
| 261 |
+
model_input["input_ids"][:, :, 0]
|
| 262 |
+
)
|
| 263 |
else:
|
| 264 |
+
code_emb = [
|
| 265 |
+
self.emb_code[i](model_input["input_ids"][:, :, i])
|
| 266 |
+
for i in range(self.num_vq)
|
| 267 |
+
]
|
| 268 |
+
model_input["inputs_embeds"] = torch.stack(code_emb, 3).sum(3)
|
| 269 |
+
|
| 270 |
+
model_input["input_ids"] = None
|
| 271 |
outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
|
| 272 |
attentions.append(outputs.attentions)
|
| 273 |
+
hidden_states = outputs[0] # 🐻
|
| 274 |
if return_hidden:
|
| 275 |
hiddens.append(hidden_states[:, -1])
|
| 276 |
|
| 277 |
with P.cached():
|
| 278 |
if infer_text:
|
| 279 |
+
logits = self.head_text(hidden_states)
|
| 280 |
else:
|
| 281 |
+
logits = torch.stack(
|
| 282 |
+
[
|
| 283 |
+
self.head_code[i](hidden_states)
|
| 284 |
+
for i in range(self.num_vq)
|
| 285 |
+
],
|
| 286 |
+
3,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
logits = logits[:, -1].float()
|
| 290 |
|
| 291 |
if not infer_text:
|
| 292 |
logits = rearrange(logits, "b c n -> (b n) c")
|
| 293 |
+
logits_token = rearrange(
|
| 294 |
+
inputs_ids[:, start_idx:], "b c n -> (b n) c"
|
| 295 |
+
)
|
| 296 |
else:
|
| 297 |
logits_token = inputs_ids[:, start_idx:, 0]
|
| 298 |
+
|
| 299 |
logits = logits / temperature
|
| 300 |
+
|
| 301 |
for logitsProcessors in LogitsProcessors:
|
| 302 |
logits = logitsProcessors(logits_token, logits)
|
| 303 |
+
|
| 304 |
for logitsWarpers in LogitsWarpers:
|
| 305 |
logits = logitsWarpers(logits_token, logits)
|
| 306 |
+
|
| 307 |
if i < min_new_token:
|
| 308 |
logits[:, eos_token] = -torch.inf
|
| 309 |
+
|
| 310 |
scores = F.softmax(logits, dim=-1)
|
| 311 |
+
|
| 312 |
idx_next = torch.multinomial(scores, num_samples=1)
|
| 313 |
+
|
| 314 |
if not infer_text:
|
| 315 |
idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
| 316 |
finish = finish | (idx_next == eos_token).any(1)
|
| 317 |
inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
|
| 318 |
else:
|
| 319 |
finish = finish | (idx_next == eos_token).any(1)
|
| 320 |
+
inputs_ids = torch.cat(
|
| 321 |
+
[
|
| 322 |
+
inputs_ids,
|
| 323 |
+
idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq),
|
| 324 |
+
],
|
| 325 |
+
1,
|
| 326 |
+
)
|
| 327 |
|
| 328 |
end_idx = end_idx + (~finish).int()
|
| 329 |
+
|
| 330 |
+
inputs_ids = [
|
| 331 |
+
inputs_ids[idx, start_idx : start_idx + i]
|
| 332 |
+
for idx, i in enumerate(end_idx.int())
|
| 333 |
+
]
|
| 334 |
inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
|
| 335 |
+
|
| 336 |
if return_hidden:
|
| 337 |
hiddens = torch.stack(hiddens, 1)
|
| 338 |
hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
|
| 339 |
+
|
| 340 |
if not finish.all():
|
| 341 |
+
self.logger.warn(
|
| 342 |
+
f"Incomplete result. hit max_new_token: {max_new_token}"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
return {
|
| 346 |
+
"ids": inputs_ids,
|
| 347 |
+
"attentions": attentions,
|
| 348 |
+
"hiddens": hiddens,
|
| 349 |
+
}
|
modules/ChatTTS/ChatTTS/utils/infer_utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import re
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
|
|
|
| 1 |
import re
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
|
modules/ChatTTS/ChatTTS/utils/io_utils.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
-
|
| 2 |
-
import os
|
| 3 |
import logging
|
|
|
|
|
|
|
| 4 |
|
| 5 |
def get_latest_modified_file(directory):
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
-
|
| 8 |
-
files = [os.path.join(directory, f) for f in os.listdir(directory)]
|
| 9 |
if not files:
|
| 10 |
-
logger.log(logging.WARNING, f
|
| 11 |
return None
|
| 12 |
latest_file = max(files, key=os.path.getmtime)
|
| 13 |
|
| 14 |
-
return latest_file
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
|
| 5 |
def get_latest_modified_file(directory):
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
files = [os.path.join(directory, f) for f in os.listdir(directory)]
|
| 9 |
if not files:
|
| 10 |
+
logger.log(logging.WARNING, f"No files found in the directory: {directory}")
|
| 11 |
return None
|
| 12 |
latest_file = max(files, key=os.path.getmtime)
|
| 13 |
|
| 14 |
+
return latest_file
|
modules/Denoiser/AudioDenoiser.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
| 1 |
import logging
|
| 2 |
import math
|
| 3 |
from typing import Union
|
|
|
|
| 4 |
import torch
|
| 5 |
import torchaudio
|
| 6 |
-
from torch import nn
|
| 7 |
-
from audio_denoiser.helpers.torch_helper import batched_apply
|
| 8 |
-
from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
|
| 9 |
from audio_denoiser.helpers.audio_helper import (
|
| 10 |
create_spectrogram,
|
| 11 |
reconstruct_from_spectrogram,
|
| 12 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
_expected_t_std = 0.23
|
| 15 |
_recommended_backend = "soundfile"
|
|
|
|
| 1 |
import logging
|
| 2 |
import math
|
| 3 |
from typing import Union
|
| 4 |
+
|
| 5 |
import torch
|
| 6 |
import torchaudio
|
|
|
|
|
|
|
|
|
|
| 7 |
from audio_denoiser.helpers.audio_helper import (
|
| 8 |
create_spectrogram,
|
| 9 |
reconstruct_from_spectrogram,
|
| 10 |
)
|
| 11 |
+
from audio_denoiser.helpers.torch_helper import batched_apply
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
|
| 15 |
|
| 16 |
_expected_t_std = 0.23
|
| 17 |
_recommended_backend = "soundfile"
|
modules/Denoiser/AudioNosiseModel.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
|
| 4 |
from audio_denoiser.modules.Permute import Permute
|
| 5 |
from audio_denoiser.modules.SimpleRoberta import SimpleRoberta
|
| 6 |
from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler
|
| 7 |
|
| 8 |
-
import json
|
| 9 |
-
|
| 10 |
|
| 11 |
class AudioNoiseModel(nn.Module):
|
| 12 |
def __init__(self, config: dict):
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
| 5 |
from audio_denoiser.modules.Permute import Permute
|
| 6 |
from audio_denoiser.modules.SimpleRoberta import SimpleRoberta
|
| 7 |
from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class AudioNoiseModel(nn.Module):
|
| 11 |
def __init__(self, config: dict):
|
modules/Enhancer/ResembleEnhance.py
CHANGED
|
@@ -1,20 +1,17 @@
|
|
| 1 |
import gc
|
|
|
|
|
|
|
|
|
|
| 2 |
from typing import Literal
|
| 3 |
|
| 4 |
import numpy as np
|
|
|
|
|
|
|
| 5 |
from modules.devices import devices
|
| 6 |
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
|
| 7 |
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
|
| 8 |
from modules.repos_static.resemble_enhance.inference import inference
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
from modules.utils.constants import MODELS_DIR
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
|
| 15 |
-
from threading import Lock
|
| 16 |
-
|
| 17 |
-
import logging
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
|
@@ -155,8 +152,8 @@ def apply_audio_enhance(
|
|
| 155 |
|
| 156 |
|
| 157 |
if __name__ == "__main__":
|
| 158 |
-
import torchaudio
|
| 159 |
import gradio as gr
|
|
|
|
| 160 |
|
| 161 |
device = torch.device("cuda")
|
| 162 |
|
|
|
|
| 1 |
import gc
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from threading import Lock
|
| 5 |
from typing import Literal
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
from modules.devices import devices
|
| 11 |
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
|
| 12 |
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
|
| 13 |
from modules.repos_static.resemble_enhance.inference import inference
|
|
|
|
|
|
|
|
|
|
| 14 |
from modules.utils.constants import MODELS_DIR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
| 152 |
|
| 153 |
|
| 154 |
if __name__ == "__main__":
|
|
|
|
| 155 |
import gradio as gr
|
| 156 |
+
import torchaudio
|
| 157 |
|
| 158 |
device = torch.device("cuda")
|
| 159 |
|
modules/SentenceSplitter.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import re
|
|
|
|
| 2 |
import zhon
|
| 3 |
|
| 4 |
|
|
|
|
| 1 |
import re
|
| 2 |
+
|
| 3 |
import zhon
|
| 4 |
|
| 5 |
|
modules/SynthesizeSegments.py
CHANGED
|
@@ -1,31 +1,37 @@
|
|
| 1 |
import copy
|
|
|
|
|
|
|
| 2 |
import re
|
|
|
|
|
|
|
|
|
|
| 3 |
from box import Box
|
| 4 |
from pydub import AudioSegment
|
| 5 |
-
|
| 6 |
-
from scipy.io.wavfile import write
|
| 7 |
-
import io
|
| 8 |
-
from modules.SentenceSplitter import SentenceSplitter
|
| 9 |
-
from modules.api.utils import calc_spk_style
|
| 10 |
-
from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
|
| 11 |
-
from modules.utils import rng
|
| 12 |
-
from modules.utils.audio import time_stretch, pitch_shift
|
| 13 |
from modules import generate_audio
|
|
|
|
| 14 |
from modules.normalization import text_normalize
|
| 15 |
-
import
|
| 16 |
-
import
|
| 17 |
-
|
| 18 |
-
from modules.
|
|
|
|
| 19 |
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
|
| 23 |
-
def audio_data_to_segment(audio_data, sr):
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
|
|
|
|
| 1 |
import copy
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
import re
|
| 5 |
+
from typing import List, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
from box import Box
|
| 9 |
from pydub import AudioSegment
|
| 10 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from modules import generate_audio
|
| 12 |
+
from modules.api.utils import calc_spk_style
|
| 13 |
from modules.normalization import text_normalize
|
| 14 |
+
from modules.SentenceSplitter import SentenceSplitter
|
| 15 |
+
from modules.speaker import Speaker
|
| 16 |
+
from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment
|
| 17 |
+
from modules.utils import rng
|
| 18 |
+
from modules.utils.audio import pitch_shift, time_stretch
|
| 19 |
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
|
| 23 |
+
def audio_data_to_segment(audio_data: np.ndarray, sr: int):
|
| 24 |
+
"""
|
| 25 |
+
optimize: https://github.com/lenML/ChatTTS-Forge/issues/57
|
| 26 |
+
"""
|
| 27 |
+
audio_data = (audio_data * 32767).astype(np.int16)
|
| 28 |
+
audio_segment = AudioSegment(
|
| 29 |
+
audio_data.tobytes(),
|
| 30 |
+
frame_rate=sr,
|
| 31 |
+
sample_width=audio_data.dtype.itemsize,
|
| 32 |
+
channels=1,
|
| 33 |
+
)
|
| 34 |
+
return audio_segment
|
| 35 |
|
| 36 |
|
| 37 |
def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
|
modules/api/Api.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
-
|
| 4 |
import logging
|
| 5 |
|
|
|
|
|
|
|
| 6 |
from fastapi.staticfiles import StaticFiles
|
| 7 |
|
| 8 |
-
import fnmatch
|
| 9 |
-
|
| 10 |
|
| 11 |
def is_excluded(path, exclude_patterns):
|
| 12 |
"""
|
|
|
|
| 1 |
+
import fnmatch
|
|
|
|
|
|
|
| 2 |
import logging
|
| 3 |
|
| 4 |
+
from fastapi import FastAPI
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
from fastapi.staticfiles import StaticFiles
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def is_excluded(path, exclude_patterns):
|
| 10 |
"""
|
modules/api/api_setup.py
CHANGED
|
@@ -1,26 +1,24 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
from modules.Enhancer.ResembleEnhance import load_enhancer
|
| 3 |
-
from modules.devices import devices
|
| 4 |
import argparse
|
|
|
|
| 5 |
|
| 6 |
-
from modules import config
|
| 7 |
-
from modules.models import load_chat_tts
|
| 8 |
-
from modules.utils import env
|
| 9 |
-
from modules import generate_audio
|
| 10 |
from modules.api.Api import APIManager
|
| 11 |
-
|
| 12 |
from modules.api.impl import (
|
| 13 |
-
style_api,
|
| 14 |
-
tts_api,
|
| 15 |
-
ssml_api,
|
| 16 |
google_api,
|
|
|
|
| 17 |
openai_api,
|
|
|
|
| 18 |
refiner_api,
|
| 19 |
speaker_api,
|
| 20 |
-
|
| 21 |
-
|
|
|
|
| 22 |
xtts_v2_api,
|
| 23 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import logging
|
| 3 |
|
| 4 |
+
from modules import config, generate_audio
|
|
|
|
|
|
|
|
|
|
| 5 |
from modules.api.Api import APIManager
|
|
|
|
| 6 |
from modules.api.impl import (
|
|
|
|
|
|
|
|
|
|
| 7 |
google_api,
|
| 8 |
+
models_api,
|
| 9 |
openai_api,
|
| 10 |
+
ping_api,
|
| 11 |
refiner_api,
|
| 12 |
speaker_api,
|
| 13 |
+
ssml_api,
|
| 14 |
+
style_api,
|
| 15 |
+
tts_api,
|
| 16 |
xtts_v2_api,
|
| 17 |
)
|
| 18 |
+
from modules.devices import devices
|
| 19 |
+
from modules.Enhancer.ResembleEnhance import load_enhancer
|
| 20 |
+
from modules.models import load_chat_tts
|
| 21 |
+
from modules.utils import env
|
| 22 |
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
modules/api/impl/google_api.py
CHANGED
|
@@ -1,22 +1,18 @@
|
|
| 1 |
from typing import Union
|
| 2 |
-
from fastapi import HTTPException
|
| 3 |
|
|
|
|
| 4 |
from pydantic import BaseModel
|
| 5 |
|
| 6 |
-
|
| 7 |
from modules.api.Api import APIManager
|
| 8 |
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
| 9 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
| 10 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 11 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 12 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 13 |
-
|
| 14 |
from modules.speaker import Speaker, speaker_mgr
|
| 15 |
|
| 16 |
|
| 17 |
-
from modules.api import utils as api_utils
|
| 18 |
-
|
| 19 |
-
|
| 20 |
class SynthesisInput(BaseModel):
|
| 21 |
text: Union[str, None] = None
|
| 22 |
ssml: Union[str, None] = None
|
|
|
|
| 1 |
from typing import Union
|
|
|
|
| 2 |
|
| 3 |
+
from fastapi import HTTPException
|
| 4 |
from pydantic import BaseModel
|
| 5 |
|
| 6 |
+
from modules.api import utils as api_utils
|
| 7 |
from modules.api.Api import APIManager
|
| 8 |
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
| 9 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
| 10 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 11 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 12 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
|
| 13 |
from modules.speaker import Speaker, speaker_mgr
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
class SynthesisInput(BaseModel):
|
| 17 |
text: Union[str, None] = None
|
| 18 |
ssml: Union[str, None] = None
|
modules/api/impl/handler/AudioHandler.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
import base64
|
| 2 |
import io
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import soundfile as sf
|
| 5 |
|
| 6 |
-
from modules.api.impl.model.audio_model import AudioFormat
|
| 7 |
from modules.api import utils as api_utils
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class AudioHandler:
|
|
|
|
| 1 |
import base64
|
| 2 |
import io
|
| 3 |
+
|
| 4 |
import numpy as np
|
| 5 |
import soundfile as sf
|
| 6 |
|
|
|
|
| 7 |
from modules.api import utils as api_utils
|
| 8 |
+
from modules.api.impl.model.audio_model import AudioFormat
|
| 9 |
|
| 10 |
|
| 11 |
class AudioHandler:
|
modules/api/impl/handler/SSMLHandler.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
-
from fastapi import HTTPException
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
|
| 4 |
-
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
| 5 |
-
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
| 6 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
| 7 |
from modules.api.impl.model.audio_model import AdjustConfig
|
| 8 |
from modules.api.impl.model.chattts_model import InferConfig
|
| 9 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
|
| 10 |
from modules.normalization import text_normalize
|
| 11 |
from modules.ssml_parser.SSMLParser import create_ssml_parser
|
|
|
|
| 12 |
from modules.utils import audio
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
from fastapi import HTTPException
|
| 3 |
|
|
|
|
|
|
|
| 4 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
| 5 |
from modules.api.impl.model.audio_model import AdjustConfig
|
| 6 |
from modules.api.impl.model.chattts_model import InferConfig
|
| 7 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 8 |
+
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
| 9 |
from modules.normalization import text_normalize
|
| 10 |
from modules.ssml_parser.SSMLParser import create_ssml_parser
|
| 11 |
+
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
| 12 |
from modules.utils import audio
|
| 13 |
|
| 14 |
|
modules/api/impl/handler/TTSHandler.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
import numpy as np
|
| 2 |
-
|
| 3 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
| 4 |
from modules.api.impl.model.audio_model import AdjustConfig
|
| 5 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 6 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
|
| 7 |
from modules.normalization import text_normalize
|
| 8 |
from modules.speaker import Speaker
|
| 9 |
from modules.synthesize_audio import synthesize_audio
|
| 10 |
-
|
| 11 |
from modules.utils.audio import apply_prosody_to_audio_data
|
| 12 |
|
| 13 |
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
|
| 3 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
| 4 |
from modules.api.impl.model.audio_model import AdjustConfig
|
| 5 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 6 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 7 |
+
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
| 8 |
from modules.normalization import text_normalize
|
| 9 |
from modules.speaker import Speaker
|
| 10 |
from modules.synthesize_audio import synthesize_audio
|
|
|
|
| 11 |
from modules.utils.audio import apply_prosody_to_audio_data
|
| 12 |
|
| 13 |
|
modules/api/impl/model/enhancer_model.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from typing import Literal
|
|
|
|
| 2 |
from pydantic import BaseModel
|
| 3 |
|
| 4 |
|
|
|
|
| 1 |
from typing import Literal
|
| 2 |
+
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
| 5 |
|
modules/api/impl/models_api.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
from modules.Enhancer.ResembleEnhance import reload_enhancer, unload_enhancer
|
| 2 |
from modules.api import utils as api_utils
|
| 3 |
from modules.api.Api import APIManager
|
|
|
|
| 4 |
from modules.models import reload_chat_tts, unload_chat_tts
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
| 1 |
from modules.api import utils as api_utils
|
| 2 |
from modules.api.Api import APIManager
|
| 3 |
+
from modules.Enhancer.ResembleEnhance import reload_enhancer, unload_enhancer
|
| 4 |
from modules.models import reload_chat_tts, unload_chat_tts
|
| 5 |
|
| 6 |
|
modules/api/impl/openai_api.py
CHANGED
|
@@ -1,23 +1,18 @@
|
|
| 1 |
-
from
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from numpy import clip
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
-
from fastapi.responses import StreamingResponse
|
| 6 |
-
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
| 9 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 10 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 11 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
from typing import List, Optional
|
| 15 |
-
|
| 16 |
-
from modules.api import utils as api_utils
|
| 17 |
-
from modules.api.Api import APIManager
|
| 18 |
-
|
| 19 |
-
from modules.speaker import Speaker, speaker_mgr
|
| 20 |
from modules.data import styles_mgr
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class AudioSpeechRequest(BaseModel):
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
|
| 3 |
+
from fastapi import Body, File, Form, HTTPException, UploadFile
|
| 4 |
+
from fastapi.responses import StreamingResponse
|
| 5 |
from numpy import clip
|
| 6 |
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
from modules.api import utils as api_utils
|
| 9 |
+
from modules.api.Api import APIManager
|
| 10 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
| 11 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 12 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 13 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from modules.data import styles_mgr
|
| 15 |
+
from modules.speaker import Speaker, speaker_mgr
|
| 16 |
|
| 17 |
|
| 18 |
class AudioSpeechRequest(BaseModel):
|
modules/api/impl/ping_api.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
|
|
| 1 |
from modules.api import utils as api_utils
|
| 2 |
from modules.api.Api import APIManager
|
| 3 |
|
| 4 |
-
from modules import config
|
| 5 |
-
|
| 6 |
|
| 7 |
def setup(app: APIManager):
|
| 8 |
@app.get("/v1/ping", response_model=api_utils.BaseResponse)
|
|
|
|
| 1 |
+
from modules import config
|
| 2 |
from modules.api import utils as api_utils
|
| 3 |
from modules.api.Api import APIManager
|
| 4 |
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def setup(app: APIManager):
|
| 7 |
@app.get("/v1/ping", response_model=api_utils.BaseResponse)
|
modules/api/impl/refiner_api.py
CHANGED
|
@@ -1,10 +1,7 @@
|
|
| 1 |
from fastapi import HTTPException
|
| 2 |
-
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
| 5 |
-
|
| 6 |
from modules import refiner
|
| 7 |
-
|
| 8 |
from modules.api import utils as api_utils
|
| 9 |
from modules.api.Api import APIManager
|
| 10 |
from modules.normalization import text_normalize
|
|
|
|
| 1 |
from fastapi import HTTPException
|
|
|
|
| 2 |
from pydantic import BaseModel
|
| 3 |
|
|
|
|
| 4 |
from modules import refiner
|
|
|
|
| 5 |
from modules.api import utils as api_utils
|
| 6 |
from modules.api.Api import APIManager
|
| 7 |
from modules.normalization import text_normalize
|
modules/api/impl/speaker_api.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
|
|
| 1 |
from fastapi import HTTPException
|
| 2 |
from pydantic import BaseModel
|
| 3 |
-
|
| 4 |
-
from modules.speaker import speaker_mgr
|
| 5 |
from modules.api import utils as api_utils
|
| 6 |
from modules.api.Api import APIManager
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class CreateSpeaker(BaseModel):
|
|
|
|
| 1 |
+
import torch
|
| 2 |
from fastapi import HTTPException
|
| 3 |
from pydantic import BaseModel
|
| 4 |
+
|
|
|
|
| 5 |
from modules.api import utils as api_utils
|
| 6 |
from modules.api.Api import APIManager
|
| 7 |
+
from modules.speaker import speaker_mgr
|
| 8 |
|
| 9 |
|
| 10 |
class CreateSpeaker(BaseModel):
|
modules/api/impl/ssml_api.py
CHANGED
|
@@ -1,19 +1,14 @@
|
|
| 1 |
-
from fastapi import
|
| 2 |
-
from fastapi.responses import StreamingResponse
|
| 3 |
-
|
| 4 |
from pydantic import BaseModel
|
| 5 |
-
from fastapi.responses import FileResponse
|
| 6 |
-
|
| 7 |
|
|
|
|
| 8 |
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
| 9 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 10 |
from modules.api.impl.model.chattts_model import InferConfig
|
| 11 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 12 |
|
| 13 |
|
| 14 |
-
from modules.api.Api import APIManager
|
| 15 |
-
|
| 16 |
-
|
| 17 |
class SSMLRequest(BaseModel):
|
| 18 |
ssml: str
|
| 19 |
format: AudioFormat = "mp3"
|
|
|
|
| 1 |
+
from fastapi import Body, HTTPException
|
| 2 |
+
from fastapi.responses import FileResponse, StreamingResponse
|
|
|
|
| 3 |
from pydantic import BaseModel
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
from modules.api.Api import APIManager
|
| 6 |
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
| 7 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 8 |
from modules.api.impl.model.chattts_model import InferConfig
|
| 9 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
class SSMLRequest(BaseModel):
|
| 13 |
ssml: str
|
| 14 |
format: AudioFormat = "mp3"
|
modules/api/impl/style_api.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
from modules.data import styles_mgr
|
| 2 |
from modules.api import utils as api_utils
|
| 3 |
from modules.api.Api import APIManager
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
async def list_styles():
|
|
|
|
|
|
|
| 1 |
from modules.api import utils as api_utils
|
| 2 |
from modules.api.Api import APIManager
|
| 3 |
+
from modules.data import styles_mgr
|
| 4 |
|
| 5 |
|
| 6 |
async def list_styles():
|
modules/api/impl/tts_api.py
CHANGED
|
@@ -1,17 +1,13 @@
|
|
| 1 |
from fastapi import Depends, HTTPException, Query
|
| 2 |
-
from fastapi.responses import StreamingResponse
|
| 3 |
-
|
| 4 |
from pydantic import BaseModel
|
| 5 |
-
from fastapi.responses import FileResponse
|
| 6 |
-
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
| 9 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 10 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 11 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
| 12 |
-
|
| 13 |
-
from modules.api import utils as api_utils
|
| 14 |
-
from modules.api.Api import APIManager
|
| 15 |
from modules.speaker import Speaker
|
| 16 |
|
| 17 |
|
|
|
|
| 1 |
from fastapi import Depends, HTTPException, Query
|
| 2 |
+
from fastapi.responses import FileResponse, StreamingResponse
|
|
|
|
| 3 |
from pydantic import BaseModel
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
from modules.api import utils as api_utils
|
| 6 |
+
from modules.api.Api import APIManager
|
| 7 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
| 8 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
| 9 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
| 10 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
|
|
|
|
|
|
|
| 11 |
from modules.speaker import Speaker
|
| 12 |
|
| 13 |
|
modules/api/impl/xtts_v2_api.py
CHANGED
|
@@ -1,19 +1,17 @@
|
|
| 1 |
import io
|
|
|
|
|
|
|
|
|
|
| 2 |
from fastapi import HTTPException
|
| 3 |
from fastapi.responses import StreamingResponse
|
| 4 |
from pydantic import BaseModel
|
| 5 |
-
from modules.api import utils as api_utils
|
| 6 |
-
from modules.api.Api import APIManager
|
| 7 |
-
|
| 8 |
-
import soundfile as sf
|
| 9 |
|
| 10 |
from modules import config
|
|
|
|
|
|
|
| 11 |
from modules.normalization import text_normalize
|
| 12 |
from modules.speaker import speaker_mgr
|
| 13 |
from modules.synthesize_audio import synthesize_audio
|
| 14 |
-
|
| 15 |
-
import logging
|
| 16 |
-
|
| 17 |
from modules.utils.audio import apply_prosody_to_audio_data
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
|
|
|
| 1 |
import io
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import soundfile as sf
|
| 5 |
from fastapi import HTTPException
|
| 6 |
from fastapi.responses import StreamingResponse
|
| 7 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
from modules import config
|
| 10 |
+
from modules.api import utils as api_utils
|
| 11 |
+
from modules.api.Api import APIManager
|
| 12 |
from modules.normalization import text_normalize
|
| 13 |
from modules.speaker import speaker_mgr
|
| 14 |
from modules.synthesize_audio import synthesize_audio
|
|
|
|
|
|
|
|
|
|
| 15 |
from modules.utils.audio import apply_prosody_to_audio_data
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
modules/api/utils.py
CHANGED
|
@@ -1,14 +1,10 @@
|
|
| 1 |
-
from pydantic import BaseModel
|
| 2 |
from typing import Any, Union
|
| 3 |
|
| 4 |
-
|
| 5 |
-
from modules.speaker import speaker_mgr
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
from modules.data import styles_mgr
|
| 9 |
-
|
| 10 |
from pydub import AudioSegment
|
| 11 |
|
|
|
|
|
|
|
| 12 |
from modules.ssml import merge_prompt
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
| 1 |
from typing import Any, Union
|
| 2 |
|
| 3 |
+
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from pydub import AudioSegment
|
| 5 |
|
| 6 |
+
from modules.data import styles_mgr
|
| 7 |
+
from modules.speaker import speaker_mgr
|
| 8 |
from modules.ssml import merge_prompt
|
| 9 |
|
| 10 |
|
modules/api/worker.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import argparse
|
| 2 |
import logging
|
| 3 |
import os
|
|
|
|
| 4 |
import dotenv
|
| 5 |
from fastapi import FastAPI
|
| 6 |
|
|
@@ -12,6 +13,7 @@ logging.basicConfig(
|
|
| 12 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 13 |
)
|
| 14 |
|
|
|
|
| 15 |
from modules.api.api_setup import (
|
| 16 |
process_api_args,
|
| 17 |
process_model_args,
|
|
@@ -20,7 +22,6 @@ from modules.api.api_setup import (
|
|
| 20 |
setup_uvicon_args,
|
| 21 |
)
|
| 22 |
from modules.api.app_config import app_description, app_title, app_version
|
| 23 |
-
from modules import config
|
| 24 |
from modules.utils.torch_opt import configure_torch_optimizations
|
| 25 |
|
| 26 |
dotenv.load_dotenv(
|
|
|
|
| 1 |
import argparse
|
| 2 |
import logging
|
| 3 |
import os
|
| 4 |
+
|
| 5 |
import dotenv
|
| 6 |
from fastapi import FastAPI
|
| 7 |
|
|
|
|
| 13 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 14 |
)
|
| 15 |
|
| 16 |
+
from modules import config
|
| 17 |
from modules.api.api_setup import (
|
| 18 |
process_api_args,
|
| 19 |
process_model_args,
|
|
|
|
| 22 |
setup_uvicon_args,
|
| 23 |
)
|
| 24 |
from modules.api.app_config import app_description, app_title, app_version
|
|
|
|
| 25 |
from modules.utils.torch_opt import configure_torch_optimizations
|
| 26 |
|
| 27 |
dotenv.load_dotenv(
|
modules/config.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import sys
|
| 2 |
|
| 3 |
import torch
|
| 4 |
-
from modules.utils.JsonObject import JsonObject
|
| 5 |
|
| 6 |
-
from modules.utils import
|
|
|
|
| 7 |
|
| 8 |
# TODO impl RuntimeEnvVars() class
|
| 9 |
runtime_env_vars = JsonObject({})
|
|
|
|
| 1 |
import sys
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 4 |
|
| 5 |
+
from modules.utils import ffmpeg, git
|
| 6 |
+
from modules.utils.JsonObject import JsonObject
|
| 7 |
|
| 8 |
# TODO impl RuntimeEnvVars() class
|
| 9 |
runtime_env_vars = JsonObject({})
|
modules/data.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from modules.utils.CsvMgr import BaseManager
|
| 2 |
|
| 3 |
-
|
| 4 |
# speakers_mgr = BaseManager("./data/speakers.csv")
|
| 5 |
styles_mgr = BaseManager("./data/styles.csv")
|
| 6 |
|
|
|
|
| 1 |
from modules.utils.CsvMgr import BaseManager
|
| 2 |
|
|
|
|
| 3 |
# speakers_mgr = BaseManager("./data/speakers.csv")
|
| 4 |
styles_mgr = BaseManager("./data/styles.csv")
|
| 5 |
|
modules/denoise.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
| 1 |
import os
|
| 2 |
from typing import Union
|
| 3 |
|
|
|
|
| 4 |
import torch
|
| 5 |
import torchaudio
|
| 6 |
-
from modules.Denoiser.AudioDenoiser import AudioDenoiser
|
| 7 |
-
|
| 8 |
-
from modules.utils.constants import MODELS_DIR
|
| 9 |
|
|
|
|
| 10 |
from modules.devices import devices
|
| 11 |
-
|
| 12 |
-
import soundfile as sf
|
| 13 |
|
| 14 |
ad: Union[AudioDenoiser, None] = None
|
| 15 |
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import Union
|
| 3 |
|
| 4 |
+
import soundfile as sf
|
| 5 |
import torch
|
| 6 |
import torchaudio
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
from modules.Denoiser.AudioDenoiser import AudioDenoiser
|
| 9 |
from modules.devices import devices
|
| 10 |
+
from modules.utils.constants import MODELS_DIR
|
|
|
|
| 11 |
|
| 12 |
ad: Union[AudioDenoiser, None] = None
|
| 13 |
|
modules/devices/devices.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
-
|
| 2 |
import sys
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
-
from modules import config
|
| 5 |
|
| 6 |
-
import
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import sys
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
|
| 5 |
import torch
|
|
|
|
| 6 |
|
| 7 |
+
from modules import config
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
modules/devices/mac_devices.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
-
import torch
|
| 2 |
import logging
|
| 3 |
-
|
|
|
|
| 4 |
import torch.backends
|
| 5 |
import torch.backends.mps
|
|
|
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
import torch.backends
|
| 5 |
import torch.backends.mps
|
| 6 |
+
from packaging import version
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
modules/ffmpeg_env.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
from modules.utils.constants import ROOT_DIR
|
| 3 |
-
import logging
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import os
|
| 3 |
+
|
| 4 |
from modules.utils.constants import ROOT_DIR
|
|
|
|
| 5 |
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
|
modules/finetune/train_speaker.py
CHANGED
|
@@ -3,9 +3,10 @@ import torch.nn.functional as F
|
|
| 3 |
import transformers
|
| 4 |
|
| 5 |
from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
|
| 6 |
-
from modules.finetune.utils.output import get_ansi_len, output_iter
|
| 7 |
-
|
| 8 |
from .utils.dataset import AudioCollator, XzListTar
|
|
|
|
| 9 |
from .utils.model import quantize
|
| 10 |
|
| 11 |
IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
|
|
@@ -201,11 +202,13 @@ def train_speaker_embeddings(
|
|
| 201 |
if __name__ == "__main__":
|
| 202 |
import argparse
|
| 203 |
import os
|
| 204 |
-
import numpy as np
|
| 205 |
import pathlib
|
| 206 |
-
|
| 207 |
-
|
|
|
|
| 208 |
from modules import config
|
|
|
|
|
|
|
| 209 |
from modules.speaker import Speaker
|
| 210 |
|
| 211 |
config.runtime_env_vars.no_half = True
|
|
|
|
| 3 |
import transformers
|
| 4 |
|
| 5 |
from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
|
| 6 |
+
from modules.finetune.utils.output import ansi, get_ansi_len, output_iter
|
| 7 |
+
|
| 8 |
from .utils.dataset import AudioCollator, XzListTar
|
| 9 |
+
from .utils.logger import MetricLogger
|
| 10 |
from .utils.model import quantize
|
| 11 |
|
| 12 |
IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
|
|
|
|
| 202 |
if __name__ == "__main__":
|
| 203 |
import argparse
|
| 204 |
import os
|
|
|
|
| 205 |
import pathlib
|
| 206 |
+
|
| 207 |
+
import numpy as np
|
| 208 |
+
|
| 209 |
from modules import config
|
| 210 |
+
from modules.devices import devices
|
| 211 |
+
from modules.models import load_chat_tts
|
| 212 |
from modules.speaker import Speaker
|
| 213 |
|
| 214 |
config.runtime_env_vars.no_half = True
|
modules/finetune/utils/dataset.py
CHANGED
|
@@ -1,21 +1,21 @@
|
|
| 1 |
-
import
|
| 2 |
import functools
|
| 3 |
-
import json
|
| 4 |
-
import tarfile
|
| 5 |
import io
|
|
|
|
| 6 |
import logging
|
| 7 |
-
import
|
|
|
|
| 8 |
import typing
|
| 9 |
|
| 10 |
import torch.utils.data
|
| 11 |
import torchaudio
|
| 12 |
-
from torchvision.datasets.utils import download_url
|
| 13 |
import transformers
|
| 14 |
import vocos
|
|
|
|
| 15 |
|
| 16 |
from modules.ChatTTS.ChatTTS.utils.infer_utils import (
|
| 17 |
-
count_invalid_characters,
|
| 18 |
apply_character_map,
|
|
|
|
| 19 |
)
|
| 20 |
|
| 21 |
|
|
|
|
| 1 |
+
import abc
|
| 2 |
import functools
|
|
|
|
|
|
|
| 3 |
import io
|
| 4 |
+
import json
|
| 5 |
import logging
|
| 6 |
+
import os
|
| 7 |
+
import tarfile
|
| 8 |
import typing
|
| 9 |
|
| 10 |
import torch.utils.data
|
| 11 |
import torchaudio
|
|
|
|
| 12 |
import transformers
|
| 13 |
import vocos
|
| 14 |
+
from torchvision.datasets.utils import download_url
|
| 15 |
|
| 16 |
from modules.ChatTTS.ChatTTS.utils.infer_utils import (
|
|
|
|
| 17 |
apply_character_map,
|
| 18 |
+
count_invalid_characters,
|
| 19 |
)
|
| 20 |
|
| 21 |
|
modules/finetune/utils/logger.py
CHANGED
|
@@ -3,15 +3,14 @@
|
|
| 3 |
import statistics
|
| 4 |
import time
|
| 5 |
from collections import defaultdict, deque
|
| 6 |
-
from tqdm import tqdm as tqdm_class
|
| 7 |
-
|
| 8 |
from typing import Generator, Iterable, TypeVar
|
| 9 |
-
from typing_extensions import Self
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import torch.distributed as dist
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
from .output import ansi,
|
| 15 |
|
| 16 |
__all__ = ["SmoothedValue", "MetricLogger"]
|
| 17 |
|
|
|
|
| 3 |
import statistics
|
| 4 |
import time
|
| 5 |
from collections import defaultdict, deque
|
|
|
|
|
|
|
| 6 |
from typing import Generator, Iterable, TypeVar
|
|
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.distributed as dist
|
| 10 |
+
from tqdm import tqdm as tqdm_class
|
| 11 |
+
from typing_extensions import Self
|
| 12 |
|
| 13 |
+
from .output import ansi, get_ansi_len, prints
|
| 14 |
|
| 15 |
__all__ = ["SmoothedValue", "MetricLogger"]
|
| 16 |
|
modules/generate_audio.py
CHANGED
|
@@ -1,18 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
|
| 4 |
-
from modules
|
| 5 |
-
from modules.utils.SeedContext import SeedContext
|
| 6 |
-
|
| 7 |
-
from modules import models, config
|
| 8 |
-
|
| 9 |
-
import logging
|
| 10 |
-
import gc
|
| 11 |
-
|
| 12 |
from modules.devices import devices
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
from modules.utils.cache import conditional_cache
|
|
|
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
from modules import config, models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from modules.devices import devices
|
| 10 |
+
from modules.speaker import Speaker
|
|
|
|
| 11 |
from modules.utils.cache import conditional_cache
|
| 12 |
+
from modules.utils.SeedContext import SeedContext
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
modules/models.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
|
|
|
|
|
| 1 |
import threading
|
|
|
|
| 2 |
import torch
|
| 3 |
-
|
| 4 |
from modules import config
|
|
|
|
| 5 |
from modules.devices import devices
|
| 6 |
|
| 7 |
-
import logging
|
| 8 |
-
import gc
|
| 9 |
-
|
| 10 |
-
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
chat_tts = None
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
import threading
|
| 4 |
+
|
| 5 |
import torch
|
| 6 |
+
|
| 7 |
from modules import config
|
| 8 |
+
from modules.ChatTTS import ChatTTS
|
| 9 |
from modules.devices import devices
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
chat_tts = None
|
modules/normalization.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
|
|
| 1 |
from functools import lru_cache
|
| 2 |
-
|
| 3 |
import emojiswitch
|
| 4 |
-
|
| 5 |
from modules import models
|
| 6 |
-
import
|
|
|
|
| 7 |
|
| 8 |
# 是否关闭 unk token 检查
|
| 9 |
# NOTE: 单测的时候用于跳过模型加载
|
|
|
|
| 1 |
+
import re
|
| 2 |
from functools import lru_cache
|
| 3 |
+
|
| 4 |
import emojiswitch
|
| 5 |
+
|
| 6 |
from modules import models
|
| 7 |
+
from modules.utils.markdown import markdown_to_text
|
| 8 |
+
from modules.utils.zh_normalization.text_normlization import *
|
| 9 |
|
| 10 |
# 是否关闭 unk token 检查
|
| 11 |
# NOTE: 单测的时候用于跳过模型加载
|
modules/prompts/news_oral_prompt.txt
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
|
| 4 |
-
|
| 5 |
同时,适当的添加一些 附语言 标签为文本增加多样性
|
| 6 |
|
| 7 |
目前可以使用的附语言标签如下:
|
|
@@ -10,5 +10,24 @@
|
|
| 10 |
- `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
|
| 11 |
- `[lbreak]`: 表示一个长停顿一般表示段落结束
|
| 12 |
|
| 13 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
{{USER_INPUT}}
|
|
|
|
| 1 |
+
#任务要求
|
| 2 |
+
任务:新闻稿口播化
|
| 3 |
|
| 4 |
+
你需要将一个新闻稿改写为口语化的口播文本,以提供给新闻主播在晚间新闻节目中播报
|
| 5 |
同时,适当的添加一些 附语言 标签为文本增加多样性
|
| 6 |
|
| 7 |
目前可以使用的附语言标签如下:
|
|
|
|
| 10 |
- `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
|
| 11 |
- `[lbreak]`: 表示一个长停顿一般表示段落结束
|
| 12 |
|
| 13 |
+
# examples
|
| 14 |
+
## case 1
|
| 15 |
+
- input: `天气预报显示,今天会有小雨,请大家出门时记得带伞。降温的天气也提醒我们要适时添衣保暖`
|
| 16 |
+
- output: `天气预报显示,今天会有小雨,请大家出门时记得带伞[uv_break]。那降温的天气[uv_break]也提醒我们要适时添衣保暖[lbreak]`
|
| 17 |
+
|
| 18 |
+
## case 2
|
| 19 |
+
- input: `请注意,电梯将在下午两点进行例行维护,预计需要一个小时的时间,请大家在此期间使用楼梯`
|
| 20 |
+
- output: `请注意啊,这个电梯将在下午两点进行[uv_break]例行维护[uv_break],预计需要一个小时的时间[uv_break],请大家在此期间使用楼梯[lbreak]`
|
| 21 |
+
|
| 22 |
+
## case 3
|
| 23 |
+
- input: `它的任务是简化记者编辑的工作流程。记者写稿时可以用标签来标明关键词、标题或主题。随着时间推移,数据积累到一定程度后,机器编辑就能自动识别这些标签`
|
| 24 |
+
- output: `它的任务呢是简化记者编辑的工作流程[uv_break]。记者写稿时呢可以用标签来标明关键词啊、标题啊或主题[uv_break]。那随着时间推移呢,数据积累到一定程度后[uv_break],机器编辑就能自动识别这些标签[uv_break]`
|
| 25 |
+
|
| 26 |
+
## case 4
|
| 27 |
+
- input: `有一天,小明问他爸爸:“爸爸,我是不是傻孩子啊?”
|
| 28 |
+
|
| 29 |
+
爸爸说:“傻孩子,你怎么会是傻孩子呢?”`
|
| 30 |
+
- output: `然后有一天呢,小明问他[uv_break]爸爸[uv_break],爸爸,我是不是傻孩[uv_break]子啊?爸爸说,傻孩[laugh]子啊,你怎么会是傻孩子呢[laugh]?`
|
| 31 |
+
|
| 32 |
+
# 用户输入
|
| 33 |
{{USER_INPUT}}
|
modules/refiner.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
|
|
|
|
| 4 |
from modules.utils.SeedContext import SeedContext
|
| 5 |
|
| 6 |
-
from modules import models, config
|
| 7 |
-
|
| 8 |
|
| 9 |
@torch.inference_mode()
|
| 10 |
def refine_text(
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
|
| 4 |
+
from modules import config, models
|
| 5 |
from modules.utils.SeedContext import SeedContext
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
| 8 |
@torch.inference_mode()
|
| 9 |
def refine_text(
|
modules/repos_static/resemble_enhance/common.py
CHANGED
|
@@ -42,7 +42,9 @@ class Normalizer(nn.Module):
|
|
| 42 |
self.running_var_unsafe = x.var()
|
| 43 |
else:
|
| 44 |
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
|
| 45 |
-
self.running_var_unsafe = self._ema(
|
|
|
|
|
|
|
| 46 |
|
| 47 |
def forward(self, x: Tensor, update=True):
|
| 48 |
if self.training and update:
|
|
|
|
| 42 |
self.running_var_unsafe = x.var()
|
| 43 |
else:
|
| 44 |
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
|
| 45 |
+
self.running_var_unsafe = self._ema(
|
| 46 |
+
self.running_var_unsafe, (x - self.running_mean).pow(2).mean()
|
| 47 |
+
)
|
| 48 |
|
| 49 |
def forward(self, x: Tensor, update=True):
|
| 50 |
if self.training and update:
|
modules/repos_static/resemble_enhance/data/dataset.py
CHANGED
|
@@ -44,7 +44,9 @@ def praat_augment(wav, sr):
|
|
| 44 |
sound = parselmouth.Sound(wav, sr)
|
| 45 |
formant_shift_ratio = random.uniform(1.1, 1.5)
|
| 46 |
pitch_range_factor = random.uniform(0.5, 2.0)
|
| 47 |
-
sound = parselmouth.praat.call(
|
|
|
|
|
|
|
| 48 |
wav = np.array(sound.values)[0].astype(np.float32)
|
| 49 |
return wav
|
| 50 |
|
|
@@ -73,7 +75,9 @@ class Dataset(DatasetBase):
|
|
| 73 |
if len(self.bg_paths) == 0:
|
| 74 |
raise ValueError(f"No background audio files found in {hp.bg_dir}")
|
| 75 |
|
| 76 |
-
logger.info(
|
|
|
|
|
|
|
| 77 |
|
| 78 |
self.training = training
|
| 79 |
self.max_retries = max_retries
|
|
@@ -121,7 +125,9 @@ class Dataset(DatasetBase):
|
|
| 121 |
fg_path = self.fg_paths[index]
|
| 122 |
|
| 123 |
if self.training and random.random() < self.silent_fg_prob:
|
| 124 |
-
fg_wav = np.zeros(
|
|
|
|
|
|
|
| 125 |
else:
|
| 126 |
fg_wav = self._load_wav(fg_path)
|
| 127 |
if random.random() < self.hp.praat_augment_prob and self.training:
|
|
@@ -132,14 +138,20 @@ class Dataset(DatasetBase):
|
|
| 132 |
fg_dwav = None
|
| 133 |
bg_dwav = None
|
| 134 |
else:
|
| 135 |
-
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(
|
|
|
|
|
|
|
| 136 |
if self.training:
|
| 137 |
bg_path = random.choice(self.bg_paths)
|
| 138 |
else:
|
| 139 |
# Deterministic for validation
|
| 140 |
bg_path = self.bg_paths[index % len(self.bg_paths)]
|
| 141 |
-
bg_wav = self._load_wav(
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
return dict(
|
| 145 |
fg_wav=fg_wav,
|
|
@@ -154,7 +166,9 @@ class Dataset(DatasetBase):
|
|
| 154 |
return self._getitem_unsafe(index)
|
| 155 |
except Exception as e:
|
| 156 |
if i == self.max_retries - 1:
|
| 157 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 158 |
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
|
| 159 |
index = np.random.randint(0, len(self))
|
| 160 |
|
|
|
|
| 44 |
sound = parselmouth.Sound(wav, sr)
|
| 45 |
formant_shift_ratio = random.uniform(1.1, 1.5)
|
| 46 |
pitch_range_factor = random.uniform(0.5, 2.0)
|
| 47 |
+
sound = parselmouth.praat.call(
|
| 48 |
+
sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0
|
| 49 |
+
)
|
| 50 |
wav = np.array(sound.values)[0].astype(np.float32)
|
| 51 |
return wav
|
| 52 |
|
|
|
|
| 75 |
if len(self.bg_paths) == 0:
|
| 76 |
raise ValueError(f"No background audio files found in {hp.bg_dir}")
|
| 77 |
|
| 78 |
+
logger.info(
|
| 79 |
+
f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files"
|
| 80 |
+
)
|
| 81 |
|
| 82 |
self.training = training
|
| 83 |
self.max_retries = max_retries
|
|
|
|
| 125 |
fg_path = self.fg_paths[index]
|
| 126 |
|
| 127 |
if self.training and random.random() < self.silent_fg_prob:
|
| 128 |
+
fg_wav = np.zeros(
|
| 129 |
+
int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32
|
| 130 |
+
)
|
| 131 |
else:
|
| 132 |
fg_wav = self._load_wav(fg_path)
|
| 133 |
if random.random() < self.hp.praat_augment_prob and self.training:
|
|
|
|
| 138 |
fg_dwav = None
|
| 139 |
bg_dwav = None
|
| 140 |
else:
|
| 141 |
+
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(
|
| 142 |
+
np.float32
|
| 143 |
+
)
|
| 144 |
if self.training:
|
| 145 |
bg_path = random.choice(self.bg_paths)
|
| 146 |
else:
|
| 147 |
# Deterministic for validation
|
| 148 |
bg_path = self.bg_paths[index % len(self.bg_paths)]
|
| 149 |
+
bg_wav = self._load_wav(
|
| 150 |
+
bg_path, length=len(fg_wav), random_crop=self.training
|
| 151 |
+
)
|
| 152 |
+
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(
|
| 153 |
+
np.float32
|
| 154 |
+
)
|
| 155 |
|
| 156 |
return dict(
|
| 157 |
fg_wav=fg_wav,
|
|
|
|
| 166 |
return self._getitem_unsafe(index)
|
| 167 |
except Exception as e:
|
| 168 |
if i == self.max_retries - 1:
|
| 169 |
+
raise RuntimeError(
|
| 170 |
+
f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries"
|
| 171 |
+
) from e
|
| 172 |
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
|
| 173 |
index = np.random.randint(0, len(self))
|
| 174 |
|
modules/repos_static/resemble_enhance/data/distorter/base.py
CHANGED
|
@@ -2,8 +2,8 @@ import itertools
|
|
| 2 |
import os
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
-
from typing import Union
|
| 6 |
import warnings
|
|
|
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
|
|
|
|
| 2 |
import os
|
| 3 |
import random
|
| 4 |
import time
|
|
|
|
| 5 |
import warnings
|
| 6 |
+
from typing import Union
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
|