Spaces:
Running
on
Zero
Running
on
Zero
开始部署
Browse files
app.py
CHANGED
|
@@ -30,6 +30,11 @@ except ImportError:
|
|
| 30 |
print("torch_npu is not available. if you want to use npu, please install it.")
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
from huggingface_hub import hf_hub_download
|
|
@@ -51,7 +56,7 @@ cosyvoice_model_path="./CosyVoice-300M-25Hz"
|
|
| 51 |
device = torch.device("cuda")
|
| 52 |
print("开始加载模型 A...")
|
| 53 |
model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
|
| 54 |
-
model_a
|
| 55 |
|
| 56 |
print("\n开始加载模型 B...")
|
| 57 |
if CHECKPOINT_PATH_B is not None:
|
|
@@ -61,15 +66,15 @@ else:
|
|
| 61 |
model_b, tokenizer_b = None, None
|
| 62 |
|
| 63 |
loaded_models = {
|
| 64 |
-
NAME_A: {"model":
|
| 65 |
NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
|
| 66 |
} if model_b is not None else {
|
| 67 |
-
NAME_A: {"model":
|
| 68 |
}
|
| 69 |
print("\n所有模型已加载完毕。")
|
| 70 |
|
| 71 |
-
cosyvoice = CosyVoice(cosyvoice_model_path)
|
| 72 |
-
cosyvoice.eval().cuda()
|
| 73 |
|
| 74 |
# 将图片转换为 Base64
|
| 75 |
with open("./tts/assert/实验室.png", "rb") as image_file:
|
|
@@ -114,11 +119,6 @@ for item in prompt_audio_choices:
|
|
| 114 |
|
| 115 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
|
| 116 |
|
| 117 |
-
import time
|
| 118 |
-
import datetime
|
| 119 |
-
import torch
|
| 120 |
-
from common_utils.utils4infer import get_feat_from_wav_path, token_list2wav
|
| 121 |
-
|
| 122 |
|
| 123 |
@spaces.GPU
|
| 124 |
def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice, prompt_speech_data):
|
|
@@ -135,13 +135,14 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 135 |
# 通用初始化:模型设备设置
|
| 136 |
start_time = time.time()
|
| 137 |
res_text = None
|
|
|
|
| 138 |
|
| 139 |
try:
|
| 140 |
# 1. 处理TTS任务
|
| 141 |
if input_prompt.endswith("_TTS"):
|
| 142 |
text_for_tts = input_prompt.replace("_TTS", "")
|
| 143 |
# T2S推理逻辑
|
| 144 |
-
res_tensor =
|
| 145 |
res_token_list = res_tensor.tolist()
|
| 146 |
res_text = res_token_list[:-1]
|
| 147 |
print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
|
@@ -153,7 +154,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 153 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 154 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 155 |
if is_npu: torch_npu.npu.synchronize()
|
| 156 |
-
res_text =
|
| 157 |
wavs=feat,
|
| 158 |
wavs_len=feat_lens,
|
| 159 |
prompt=prompt,
|
|
@@ -168,7 +169,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 168 |
# T2T推理逻辑
|
| 169 |
print(f'开始t2t推理, question_txt: {question_txt}')
|
| 170 |
if is_npu: torch_npu.npu.synchronize()
|
| 171 |
-
res_text =
|
| 172 |
device=device,
|
| 173 |
text=question_txt
|
| 174 |
)[0]
|
|
@@ -183,7 +184,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 183 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 184 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 185 |
if is_npu: torch_npu.npu.synchronize()
|
| 186 |
-
output_text, text_res, speech_res =
|
| 187 |
wavs=feat,
|
| 188 |
wavs_len=feat_lens,
|
| 189 |
)
|
|
@@ -197,7 +198,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 197 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 198 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 199 |
if is_npu: torch_npu.npu.synchronize()
|
| 200 |
-
output_text, text_res, speech_res =
|
| 201 |
wavs=feat,
|
| 202 |
wavs_len=feat_lens,
|
| 203 |
)
|
|
@@ -211,7 +212,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 211 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 212 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 213 |
if is_npu: torch_npu.npu.synchronize()
|
| 214 |
-
res_text =
|
| 215 |
wavs=feat,
|
| 216 |
wavs_len=feat_lens,
|
| 217 |
cache_implementation="static"
|
|
@@ -225,7 +226,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 225 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 226 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 227 |
if is_npu: torch_npu.npu.synchronize()
|
| 228 |
-
res_text =
|
| 229 |
wavs=feat,
|
| 230 |
wavs_len=feat_lens,
|
| 231 |
cache_implementation="static"
|
|
@@ -239,7 +240,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 239 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 240 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 241 |
if is_npu: torch_npu.npu.synchronize()
|
| 242 |
-
res_text =
|
| 243 |
wavs=feat,
|
| 244 |
wavs_len=feat_lens,
|
| 245 |
prompt=input_prompt,
|
|
@@ -260,20 +261,20 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 260 |
wav_path_output = input_wav_path
|
| 261 |
if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
|
| 262 |
if isinstance(output_res, list): # TTS case
|
| 263 |
-
cosyvoice.eval()
|
| 264 |
-
time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 265 |
-
wav_path = f"./tmp/{time_str}.wav"
|
| 266 |
-
wav_path_output = token_list2wav(output_res, prompt_speech_data, wav_path, cosyvoice)
|
| 267 |
# wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
|
| 268 |
output_res = "生成的token: " + str(output_res)
|
| 269 |
elif isinstance(output_res, str) and "|" in output_res: # S2S case
|
| 270 |
try:
|
| 271 |
text_res, token_list_str = output_res.split("|")
|
| 272 |
token_list = json.loads(token_list_str)
|
| 273 |
-
cosyvoice.eval()
|
| 274 |
time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 275 |
wav_path = f"./tmp/{time_str}.wav"
|
| 276 |
-
wav_path_output = token_list2wav(token_list, prompt_speech_data, wav_path, cosyvoice)
|
| 277 |
# wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
|
| 278 |
output_res = text_res
|
| 279 |
except (ValueError, json.JSONDecodeError) as e:
|
|
|
|
| 30 |
print("torch_npu is not available. if you want to use npu, please install it.")
|
| 31 |
|
| 32 |
|
| 33 |
+
import time
|
| 34 |
+
import datetime
|
| 35 |
+
import torch
|
| 36 |
+
from common_utils.utils4infer import get_feat_from_wav_path, token_list2wav
|
| 37 |
+
|
| 38 |
|
| 39 |
|
| 40 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 56 |
device = torch.device("cuda")
|
| 57 |
print("开始加载模型 A...")
|
| 58 |
model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
|
| 59 |
+
model_a
|
| 60 |
|
| 61 |
print("\n开始加载模型 B...")
|
| 62 |
if CHECKPOINT_PATH_B is not None:
|
|
|
|
| 66 |
model_b, tokenizer_b = None, None
|
| 67 |
|
| 68 |
loaded_models = {
|
| 69 |
+
NAME_A: {"model": model_b, "tokenizer": tokenizer_b},
|
| 70 |
NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
|
| 71 |
} if model_b is not None else {
|
| 72 |
+
NAME_A: {"model": model_b, "tokenizer": tokenizer_b},
|
| 73 |
}
|
| 74 |
print("\n所有模型已加载完毕。")
|
| 75 |
|
| 76 |
+
# cosyvoice = CosyVoice(cosyvoice_model_path)
|
| 77 |
+
# cosyvoice.eval().cuda()
|
| 78 |
|
| 79 |
# 将图片转换为 Base64
|
| 80 |
with open("./tts/assert/实验室.png", "rb") as image_file:
|
|
|
|
| 119 |
|
| 120 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
@spaces.GPU
|
| 124 |
def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice, prompt_speech_data):
|
|
|
|
| 135 |
# 通用初始化:模型设备设置
|
| 136 |
start_time = time.time()
|
| 137 |
res_text = None
|
| 138 |
+
model_a.eval().cuda()
|
| 139 |
|
| 140 |
try:
|
| 141 |
# 1. 处理TTS任务
|
| 142 |
if input_prompt.endswith("_TTS"):
|
| 143 |
text_for_tts = input_prompt.replace("_TTS", "")
|
| 144 |
# T2S推理逻辑
|
| 145 |
+
res_tensor = model_a.generate_tts(device=device, text=text_for_tts)[0]
|
| 146 |
res_token_list = res_tensor.tolist()
|
| 147 |
res_text = res_token_list[:-1]
|
| 148 |
print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
|
|
|
| 154 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 155 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 156 |
if is_npu: torch_npu.npu.synchronize()
|
| 157 |
+
res_text = model_a.generate(
|
| 158 |
wavs=feat,
|
| 159 |
wavs_len=feat_lens,
|
| 160 |
prompt=prompt,
|
|
|
|
| 169 |
# T2T推理逻辑
|
| 170 |
print(f'开始t2t推理, question_txt: {question_txt}')
|
| 171 |
if is_npu: torch_npu.npu.synchronize()
|
| 172 |
+
res_text = model_a.generate_text2text(
|
| 173 |
device=device,
|
| 174 |
text=question_txt
|
| 175 |
)[0]
|
|
|
|
| 184 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 185 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 186 |
if is_npu: torch_npu.npu.synchronize()
|
| 187 |
+
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_with_repetition_penalty(
|
| 188 |
wavs=feat,
|
| 189 |
wavs_len=feat_lens,
|
| 190 |
)
|
|
|
|
| 198 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 199 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 200 |
if is_npu: torch_npu.npu.synchronize()
|
| 201 |
+
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_think_with_repetition_penalty(
|
| 202 |
wavs=feat,
|
| 203 |
wavs_len=feat_lens,
|
| 204 |
)
|
|
|
|
| 212 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 213 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 214 |
if is_npu: torch_npu.npu.synchronize()
|
| 215 |
+
res_text = model_a.generate4chat(
|
| 216 |
wavs=feat,
|
| 217 |
wavs_len=feat_lens,
|
| 218 |
cache_implementation="static"
|
|
|
|
| 226 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 227 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 228 |
if is_npu: torch_npu.npu.synchronize()
|
| 229 |
+
res_text = model_a.generate4chat_think(
|
| 230 |
wavs=feat,
|
| 231 |
wavs_len=feat_lens,
|
| 232 |
cache_implementation="static"
|
|
|
|
| 240 |
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 241 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 242 |
if is_npu: torch_npu.npu.synchronize()
|
| 243 |
+
res_text = model_a.generate(
|
| 244 |
wavs=feat,
|
| 245 |
wavs_len=feat_lens,
|
| 246 |
prompt=input_prompt,
|
|
|
|
| 261 |
wav_path_output = input_wav_path
|
| 262 |
if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
|
| 263 |
if isinstance(output_res, list): # TTS case
|
| 264 |
+
# cosyvoice.eval()
|
| 265 |
+
# time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 266 |
+
# wav_path = f"./tmp/{time_str}.wav"
|
| 267 |
+
# wav_path_output = token_list2wav(output_res, prompt_speech_data, wav_path, cosyvoice)
|
| 268 |
# wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
|
| 269 |
output_res = "生成的token: " + str(output_res)
|
| 270 |
elif isinstance(output_res, str) and "|" in output_res: # S2S case
|
| 271 |
try:
|
| 272 |
text_res, token_list_str = output_res.split("|")
|
| 273 |
token_list = json.loads(token_list_str)
|
| 274 |
+
# cosyvoice.eval()
|
| 275 |
time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 276 |
wav_path = f"./tmp/{time_str}.wav"
|
| 277 |
+
# wav_path_output = token_list2wav(token_list, prompt_speech_data, wav_path, cosyvoice)
|
| 278 |
# wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
|
| 279 |
output_res = text_res
|
| 280 |
except (ValueError, json.JSONDecodeError) as e:
|