Spaces:
Running
on
Zero
Running
on
Zero
开始部署
Browse files
app.py
CHANGED
|
@@ -4,6 +4,8 @@ import datetime
|
|
| 4 |
import json
|
| 5 |
import logging
|
| 6 |
import os
|
|
|
|
|
|
|
| 7 |
import spaces
|
| 8 |
|
| 9 |
import gradio as gr
|
|
@@ -12,6 +14,7 @@ import time
|
|
| 12 |
import traceback
|
| 13 |
|
| 14 |
import torch
|
|
|
|
| 15 |
|
| 16 |
from common_utils.utils4infer import get_feat_from_wav_path, load_model_and_tokenizer, token_list2wav
|
| 17 |
|
|
@@ -53,7 +56,7 @@ cosyvoice_model_path="./CosyVoice-300M-25Hz"
|
|
| 53 |
|
| 54 |
|
| 55 |
|
| 56 |
-
|
| 57 |
print("开始加载模型 A...")
|
| 58 |
model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
|
| 59 |
model_a
|
|
@@ -131,6 +134,29 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 131 |
if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
|
| 132 |
print("音频信息未输入,且不是T2S或T2T任务")
|
| 133 |
return "错误:需要音频输入"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
# 通用初始化:模型设备设置
|
| 136 |
start_time = time.time()
|
|
@@ -142,7 +168,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 142 |
if input_prompt.endswith("_TTS"):
|
| 143 |
text_for_tts = input_prompt.replace("_TTS", "")
|
| 144 |
# T2S推理逻辑
|
| 145 |
-
res_tensor = model_a.generate_tts(device=device, text=text_for_tts)[0]
|
| 146 |
res_token_list = res_tensor.tolist()
|
| 147 |
res_text = res_token_list[:-1]
|
| 148 |
print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
|
@@ -151,16 +177,14 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 151 |
elif input_prompt.endswith("_self_prompt"):
|
| 152 |
prompt = input_prompt.replace("_self_prompt", "")
|
| 153 |
# S2T推理逻辑
|
| 154 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 155 |
-
|
| 156 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 157 |
res_text = model_a.generate(
|
| 158 |
wavs=feat,
|
| 159 |
wavs_len=feat_lens,
|
| 160 |
prompt=prompt,
|
| 161 |
cache_implementation="static"
|
| 162 |
)[0]
|
| 163 |
-
if is_npu: torch_npu.npu.synchronize()
|
| 164 |
print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 165 |
|
| 166 |
# 3. 处理T2T任务
|
|
@@ -170,7 +194,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 170 |
print(f'开始t2t推理, question_txt: {question_txt}')
|
| 171 |
if is_npu: torch_npu.npu.synchronize()
|
| 172 |
res_text = model_a.generate_text2text(
|
| 173 |
-
device=device,
|
| 174 |
text=question_txt
|
| 175 |
)[0]
|
| 176 |
if is_npu: torch_npu.npu.synchronize()
|
|
@@ -181,7 +205,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 181 |
"请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
|
| 182 |
"s2s_no_think"]:
|
| 183 |
# S2S推理逻辑
|
| 184 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 185 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 186 |
if is_npu: torch_npu.npu.synchronize()
|
| 187 |
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_with_repetition_penalty(
|
|
@@ -195,7 +219,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 195 |
# 5. 处理S2S有思考任务
|
| 196 |
elif input_prompt == "THINK":
|
| 197 |
# S2S带思考推理逻辑
|
| 198 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 199 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 200 |
if is_npu: torch_npu.npu.synchronize()
|
| 201 |
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_think_with_repetition_penalty(
|
|
@@ -209,7 +233,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 209 |
# 6. 处理S2T4Chat无思考任务
|
| 210 |
elif input_prompt == "s2t_no_think":
|
| 211 |
# S2T4Chat推理逻辑
|
| 212 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 213 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 214 |
if is_npu: torch_npu.npu.synchronize()
|
| 215 |
res_text = model_a.generate4chat(
|
|
@@ -223,7 +247,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 223 |
# 7. 处理S2T4Chat有思考任务
|
| 224 |
elif input_prompt == "s2t_think":
|
| 225 |
# S2T4Chat带思考推理逻辑
|
| 226 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 227 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 228 |
if is_npu: torch_npu.npu.synchronize()
|
| 229 |
res_text = model_a.generate4chat_think(
|
|
@@ -237,7 +261,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
|
|
| 237 |
# 8. 处理默认S2T任务
|
| 238 |
else:
|
| 239 |
# 默认S2T推理逻辑
|
| 240 |
-
feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 241 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 242 |
if is_npu: torch_npu.npu.synchronize()
|
| 243 |
res_text = model_a.generate(
|
|
|
|
| 4 |
import json
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
+
|
| 8 |
+
import librosa
|
| 9 |
import spaces
|
| 10 |
|
| 11 |
import gradio as gr
|
|
|
|
| 14 |
import traceback
|
| 15 |
|
| 16 |
import torch
|
| 17 |
+
import torchaudio
|
| 18 |
|
| 19 |
from common_utils.utils4infer import get_feat_from_wav_path, load_model_and_tokenizer, token_list2wav
|
| 20 |
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
|
| 59 |
+
|
| 60 |
print("开始加载模型 A...")
|
| 61 |
model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
|
| 62 |
model_a
|
|
|
|
| 134 |
if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
|
| 135 |
print("音频信息未输入,且不是T2S或T2T任务")
|
| 136 |
return "错误:需要音频输入"
|
| 137 |
+
if input_wav_path is not None:
|
| 138 |
+
waveform, sample_rate = torchaudio.load(input_wav_path)
|
| 139 |
+
if waveform.shape[0] > 1:
|
| 140 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 141 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
|
| 142 |
+
waveform = resampler(waveform)
|
| 143 |
+
waveform = waveform.squeeze(0)
|
| 144 |
+
window = torch.hann_window(400)
|
| 145 |
+
stft = torch.stft(waveform, 400, 160, window=window, return_complex=True)
|
| 146 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 147 |
+
filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
|
| 148 |
+
mel_spec = filters @ magnitudes
|
| 149 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 150 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 151 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 152 |
+
feat = log_spec.transpose(0, 1)
|
| 153 |
+
feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).cuda()
|
| 154 |
+
feat = feat.unsqueeze(0).cuda()
|
| 155 |
+
feat = feat.to(torch.bfloat16)
|
| 156 |
+
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 157 |
+
else:
|
| 158 |
+
feat = None
|
| 159 |
+
feat_lens = None
|
| 160 |
|
| 161 |
# 通用初始化:模型设备设置
|
| 162 |
start_time = time.time()
|
|
|
|
| 168 |
if input_prompt.endswith("_TTS"):
|
| 169 |
text_for_tts = input_prompt.replace("_TTS", "")
|
| 170 |
# T2S推理逻辑
|
| 171 |
+
res_tensor = model_a.generate_tts(device=torch.device("cuda"), text=text_for_tts)[0]
|
| 172 |
res_token_list = res_tensor.tolist()
|
| 173 |
res_text = res_token_list[:-1]
|
| 174 |
print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
|
|
|
| 177 |
elif input_prompt.endswith("_self_prompt"):
|
| 178 |
prompt = input_prompt.replace("_self_prompt", "")
|
| 179 |
# S2T推理逻辑
|
| 180 |
+
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 181 |
+
# waveform, sample_rate = do_resample(input_wav_path)
|
|
|
|
| 182 |
res_text = model_a.generate(
|
| 183 |
wavs=feat,
|
| 184 |
wavs_len=feat_lens,
|
| 185 |
prompt=prompt,
|
| 186 |
cache_implementation="static"
|
| 187 |
)[0]
|
|
|
|
| 188 |
print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
|
| 189 |
|
| 190 |
# 3. 处理T2T任务
|
|
|
|
| 194 |
print(f'开始t2t推理, question_txt: {question_txt}')
|
| 195 |
if is_npu: torch_npu.npu.synchronize()
|
| 196 |
res_text = model_a.generate_text2text(
|
| 197 |
+
device=torch.device("cuda"),
|
| 198 |
text=question_txt
|
| 199 |
)[0]
|
| 200 |
if is_npu: torch_npu.npu.synchronize()
|
|
|
|
| 205 |
"请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
|
| 206 |
"s2s_no_think"]:
|
| 207 |
# S2S推理逻辑
|
| 208 |
+
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 209 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 210 |
if is_npu: torch_npu.npu.synchronize()
|
| 211 |
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_with_repetition_penalty(
|
|
|
|
| 219 |
# 5. 处理S2S有思考任务
|
| 220 |
elif input_prompt == "THINK":
|
| 221 |
# S2S带思考推理逻辑
|
| 222 |
+
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 223 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 224 |
if is_npu: torch_npu.npu.synchronize()
|
| 225 |
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_think_with_repetition_penalty(
|
|
|
|
| 233 |
# 6. 处理S2T4Chat无思考任务
|
| 234 |
elif input_prompt == "s2t_no_think":
|
| 235 |
# S2T4Chat推理逻辑
|
| 236 |
+
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 237 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 238 |
if is_npu: torch_npu.npu.synchronize()
|
| 239 |
res_text = model_a.generate4chat(
|
|
|
|
| 247 |
# 7. 处理S2T4Chat有思考任务
|
| 248 |
elif input_prompt == "s2t_think":
|
| 249 |
# S2T4Chat带思考推理逻辑
|
| 250 |
+
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 251 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 252 |
if is_npu: torch_npu.npu.synchronize()
|
| 253 |
res_text = model_a.generate4chat_think(
|
|
|
|
| 261 |
# 8. 处理默认S2T任务
|
| 262 |
else:
|
| 263 |
# 默认S2T推理逻辑
|
| 264 |
+
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
|
| 265 |
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
|
| 266 |
if is_npu: torch_npu.npu.synchronize()
|
| 267 |
res_text = model_a.generate(
|