Spaces:
Runtime error
Runtime error
Commit
·
9b15f17
1
Parent(s):
8e35048
Create server.py
Browse files
server.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from onnx_modules.V230_OnnxInference import OnnxInferenceSession
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from scipy.io.wavfile import write
|
| 5 |
+
from text import cleaned_text_to_sequence, get_bert
|
| 6 |
+
from text.cleaner import clean_text
|
| 7 |
+
import utils
|
| 8 |
+
import commons
|
| 9 |
+
import uuid
|
| 10 |
+
from flask import Flask, request, jsonify, render_template_string
|
| 11 |
+
from flask_cors import CORS
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import os
|
| 14 |
+
from threading import Thread
|
| 15 |
+
|
| 16 |
+
hps = utils.get_hparams_from_file('onnx/BangDreamApi.json')
|
| 17 |
+
device = 'cpu'
|
| 18 |
+
|
| 19 |
+
BandList = {
|
| 20 |
+
"PoppinParty":["香澄","有咲","たえ","りみ","沙綾"],
|
| 21 |
+
"Afterglow":["蘭","モカ","ひまり","巴","つぐみ"],
|
| 22 |
+
"HelloHappyWorld":["こころ","美咲","薫","花音","はぐみ"],
|
| 23 |
+
"PastelPalettes":["彩","日菜","千聖","イヴ","麻弥"],
|
| 24 |
+
"Roselia":["友希那","紗夜","リサ","燐子","あこ"],
|
| 25 |
+
"RaiseASuilen":["レイヤ","ロック","ますき","チュチュ","パレオ"],
|
| 26 |
+
"Morfonica":["ましろ","瑠唯","つくし","七深","透子"],
|
| 27 |
+
"MyGo":["燈","愛音","そよ","立希","楽奈"],
|
| 28 |
+
"AveMujica":["祥子","睦","海鈴","にゃむ","初華"],
|
| 29 |
+
"圣翔音乐学园":["華戀","光","香子","雙葉","真晝","純那","克洛迪娜","真矢","奈奈"],
|
| 30 |
+
"凛明馆女子学校":["珠緒","壘","文","悠悠子","一愛"],
|
| 31 |
+
"弗隆提亚艺术学校":["艾露","艾露露","菈樂菲","司","靜羽"],
|
| 32 |
+
"西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"]
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
Session = OnnxInferenceSession(
|
| 36 |
+
{
|
| 37 |
+
"enc" : "onnx/BangDreamApi/BangDreamApi_enc_p.onnx",
|
| 38 |
+
"emb_g" : "onnx/BangDreamApi/BangDreamApi_emb.onnx",
|
| 39 |
+
"dp" : "onnx/BangDreamApi/BangDreamApi_dp.onnx",
|
| 40 |
+
"sdp" : "onnx/BangDreamApi/BangDreamApi_sdp.onnx",
|
| 41 |
+
"flow" : "onnx/BangDreamApi/BangDreamApi_flow.onnx",
|
| 42 |
+
"dec" : "onnx/BangDreamApi/BangDreamApi_dec.onnx"
|
| 43 |
+
},
|
| 44 |
+
Providers = ["CPUExecutionProvider"]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
|
| 48 |
+
style_text = None if style_text == "" else style_text
|
| 49 |
+
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
| 50 |
+
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
| 51 |
+
|
| 52 |
+
if True:
|
| 53 |
+
phone = commons.intersperse(phone, 0)
|
| 54 |
+
tone = commons.intersperse(tone, 0)
|
| 55 |
+
language = commons.intersperse(language, 0)
|
| 56 |
+
for i in range(len(word2ph)):
|
| 57 |
+
word2ph[i] = word2ph[i] * 2
|
| 58 |
+
word2ph[0] += 1
|
| 59 |
+
bert_ori = get_bert(
|
| 60 |
+
norm_text, word2ph, language_str, device, style_text, style_weight
|
| 61 |
+
)
|
| 62 |
+
del word2ph
|
| 63 |
+
assert bert_ori.shape[-1] == len(phone), phone
|
| 64 |
+
|
| 65 |
+
if language_str == "ZH":
|
| 66 |
+
bert = bert_ori
|
| 67 |
+
ja_bert = torch.randn(1024, len(phone))
|
| 68 |
+
en_bert = torch.randn(1024, len(phone))
|
| 69 |
+
elif language_str == "JP":
|
| 70 |
+
bert = torch.randn(1024, len(phone))
|
| 71 |
+
ja_bert = bert_ori
|
| 72 |
+
en_bert = torch.randn(1024, len(phone))
|
| 73 |
+
elif language_str == "EN":
|
| 74 |
+
bert = torch.randn(1024, len(phone))
|
| 75 |
+
ja_bert = torch.randn(1024, len(phone))
|
| 76 |
+
en_bert = bert_ori
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError("language_str should be ZH, JP or EN")
|
| 79 |
+
|
| 80 |
+
assert bert.shape[-1] == len(
|
| 81 |
+
phone
|
| 82 |
+
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
|
| 83 |
+
|
| 84 |
+
phone = torch.LongTensor(phone)
|
| 85 |
+
tone = torch.LongTensor(tone)
|
| 86 |
+
language = torch.LongTensor(language)
|
| 87 |
+
return bert, ja_bert, en_bert, phone, tone, language
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def infer(
|
| 91 |
+
text,
|
| 92 |
+
sid,
|
| 93 |
+
style_text=None,
|
| 94 |
+
style_weight=0.7,
|
| 95 |
+
sdp_ratio=0.5,
|
| 96 |
+
noise_scale=0.6,
|
| 97 |
+
noise_scale_w=0.667,
|
| 98 |
+
length_scale=1,
|
| 99 |
+
unique_filename = 'temp.wav'
|
| 100 |
+
):
|
| 101 |
+
language= 'JP' if is_japanese(text) else 'ZH'
|
| 102 |
+
bert, ja_bert, en_bert, phones, tone, language = get_text(
|
| 103 |
+
text,
|
| 104 |
+
language,
|
| 105 |
+
hps,
|
| 106 |
+
device,
|
| 107 |
+
style_text=style_text,
|
| 108 |
+
style_weight=style_weight,
|
| 109 |
+
)
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
x_tst = phones.unsqueeze(0).to(device).numpy()
|
| 112 |
+
language = np.zeros_like(x_tst)
|
| 113 |
+
tone = np.zeros_like(x_tst)
|
| 114 |
+
bert = bert.to(device).transpose(0, 1).numpy()
|
| 115 |
+
ja_bert = ja_bert.to(device).transpose(0, 1).numpy()
|
| 116 |
+
en_bert = en_bert.to(device).transpose(0, 1).numpy()
|
| 117 |
+
del phones
|
| 118 |
+
sid = np.array([hps.spk2id[sid]])
|
| 119 |
+
audio = Session(
|
| 120 |
+
x_tst,
|
| 121 |
+
tone,
|
| 122 |
+
language,
|
| 123 |
+
bert,
|
| 124 |
+
ja_bert,
|
| 125 |
+
en_bert,
|
| 126 |
+
sid,
|
| 127 |
+
seed=114514,
|
| 128 |
+
seq_noise_scale=noise_scale_w,
|
| 129 |
+
sdp_noise_scale=noise_scale,
|
| 130 |
+
length_scale=length_scale,
|
| 131 |
+
sdp_ratio=sdp_ratio,
|
| 132 |
+
)
|
| 133 |
+
del x_tst, tone, language, bert, ja_bert, en_bert, sid
|
| 134 |
+
write(unique_filename, 44100, audio)
|
| 135 |
+
return (44100,gr.processing_utils.convert_to_16_bit_wav(audio))
|
| 136 |
+
|
| 137 |
+
def is_japanese(string):
|
| 138 |
+
for ch in string:
|
| 139 |
+
if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
|
| 140 |
+
return True
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
Flaskapp = Flask(__name__)
|
| 145 |
+
CORS(Flaskapp)
|
| 146 |
+
@Flaskapp.route('/')
|
| 147 |
+
|
| 148 |
+
def tts():
|
| 149 |
+
global last_text, last_model
|
| 150 |
+
speaker = request.args.get('speaker')
|
| 151 |
+
sdp_ratio = float(request.args.get('sdp_ratio', 0.2))
|
| 152 |
+
noise_scale = float(request.args.get('noise_scale', 0.6))
|
| 153 |
+
noise_scale_w = float(request.args.get('noise_scale_w', 0.8))
|
| 154 |
+
length_scale = float(request.args.get('length_scale', 1))
|
| 155 |
+
style_weight = float(request.args.get('style_weight', 0.7))
|
| 156 |
+
style_text = request.args.get('style_text', 'happy')
|
| 157 |
+
text = request.args.get('text')
|
| 158 |
+
is_chat = request.args.get('is_chat', 'false').lower() == 'true'
|
| 159 |
+
#model = request.args.get('model',modelPaths[-1])
|
| 160 |
+
|
| 161 |
+
if not speaker or not text:
|
| 162 |
+
return render_template_string("""
|
| 163 |
+
<!DOCTYPE html>
|
| 164 |
+
<html>
|
| 165 |
+
<head>
|
| 166 |
+
<title>TTS API Documentation</title>
|
| 167 |
+
</head>
|
| 168 |
+
<body>
|
| 169 |
+
<iframe src="https://mahiruoshi-bangdream-bert-vits2.hf.space" style="width:100%; height:100vh; border:none;"></iframe>
|
| 170 |
+
</body>
|
| 171 |
+
</html>
|
| 172 |
+
""")
|
| 173 |
+
'''
|
| 174 |
+
if model != last_model:
|
| 175 |
+
unique_filename = loadmodel(model)
|
| 176 |
+
last_model = model
|
| 177 |
+
'''
|
| 178 |
+
if is_chat and text == last_text:
|
| 179 |
+
# Generate 1 second of silence and return
|
| 180 |
+
unique_filename = 'blank.wav'
|
| 181 |
+
silence = np.zeros(44100, dtype=np.int16)
|
| 182 |
+
write(unique_filename , 44100, silence)
|
| 183 |
+
else:
|
| 184 |
+
last_text = text
|
| 185 |
+
unique_filename = f"temp{uuid.uuid4()}.wav"
|
| 186 |
+
infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, style_text=style_text, style_weight=style_weight,unique_filename=unique_filename)
|
| 187 |
+
with open(unique_filename ,'rb') as bit:
|
| 188 |
+
wav_bytes = bit.read()
|
| 189 |
+
os.remove(unique_filename)
|
| 190 |
+
headers = {
|
| 191 |
+
'Content-Type': 'audio/wav',
|
| 192 |
+
'Text': unique_filename .encode('utf-8')}
|
| 193 |
+
return wav_bytes, 200, headers
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
speaker_ids = hps.spk2id
|
| 197 |
+
speakers = list(speaker_ids.keys())
|
| 198 |
+
last_text = ""
|
| 199 |
+
Flaskapp.run(host="0.0.0.0", port=5000,debug=True)
|