xlgeng commited on
Commit
aea4592
·
1 Parent(s): 13f013f

开始部署

Browse files
Files changed (1) hide show
  1. app.py +36 -12
app.py CHANGED
@@ -4,6 +4,8 @@ import datetime
4
  import json
5
  import logging
6
  import os
 
 
7
  import spaces
8
 
9
  import gradio as gr
@@ -12,6 +14,7 @@ import time
12
  import traceback
13
 
14
  import torch
 
15
 
16
  from common_utils.utils4infer import get_feat_from_wav_path, load_model_and_tokenizer, token_list2wav
17
 
@@ -53,7 +56,7 @@ cosyvoice_model_path="./CosyVoice-300M-25Hz"
53
 
54
 
55
 
56
- device = torch.device("cuda")
57
  print("开始加载模型 A...")
58
  model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
59
  model_a
@@ -131,6 +134,29 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
131
  if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
132
  print("音频信息未输入,且不是T2S或T2T任务")
133
  return "错误:需要音频输入"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # 通用初始化:模型设备设置
136
  start_time = time.time()
@@ -142,7 +168,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
142
  if input_prompt.endswith("_TTS"):
143
  text_for_tts = input_prompt.replace("_TTS", "")
144
  # T2S推理逻辑
145
- res_tensor = model_a.generate_tts(device=device, text=text_for_tts)[0]
146
  res_token_list = res_tensor.tolist()
147
  res_text = res_token_list[:-1]
148
  print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
@@ -151,16 +177,14 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
151
  elif input_prompt.endswith("_self_prompt"):
152
  prompt = input_prompt.replace("_self_prompt", "")
153
  # S2T推理逻辑
154
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
155
- print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
156
- if is_npu: torch_npu.npu.synchronize()
157
  res_text = model_a.generate(
158
  wavs=feat,
159
  wavs_len=feat_lens,
160
  prompt=prompt,
161
  cache_implementation="static"
162
  )[0]
163
- if is_npu: torch_npu.npu.synchronize()
164
  print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
165
 
166
  # 3. 处理T2T任务
@@ -170,7 +194,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
170
  print(f'开始t2t推理, question_txt: {question_txt}')
171
  if is_npu: torch_npu.npu.synchronize()
172
  res_text = model_a.generate_text2text(
173
- device=device,
174
  text=question_txt
175
  )[0]
176
  if is_npu: torch_npu.npu.synchronize()
@@ -181,7 +205,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
181
  "请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
182
  "s2s_no_think"]:
183
  # S2S推理逻辑
184
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
185
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
186
  if is_npu: torch_npu.npu.synchronize()
187
  output_text, text_res, speech_res = model_a.generate_s2s_no_stream_with_repetition_penalty(
@@ -195,7 +219,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
195
  # 5. 处理S2S有思考任务
196
  elif input_prompt == "THINK":
197
  # S2S带思考推理逻辑
198
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
199
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
200
  if is_npu: torch_npu.npu.synchronize()
201
  output_text, text_res, speech_res = model_a.generate_s2s_no_stream_think_with_repetition_penalty(
@@ -209,7 +233,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
209
  # 6. 处理S2T4Chat无思考任务
210
  elif input_prompt == "s2t_no_think":
211
  # S2T4Chat推理逻辑
212
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
213
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
214
  if is_npu: torch_npu.npu.synchronize()
215
  res_text = model_a.generate4chat(
@@ -223,7 +247,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
223
  # 7. 处理S2T4Chat有思考任务
224
  elif input_prompt == "s2t_think":
225
  # S2T4Chat带思考推理逻辑
226
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
227
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
228
  if is_npu: torch_npu.npu.synchronize()
229
  res_text = model_a.generate4chat_think(
@@ -237,7 +261,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
237
  # 8. 处理默认S2T任务
238
  else:
239
  # 默认S2T推理逻辑
240
- feat, feat_lens = get_feat_from_wav_path(input_wav_path)
241
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
242
  if is_npu: torch_npu.npu.synchronize()
243
  res_text = model_a.generate(
 
4
  import json
5
  import logging
6
  import os
7
+
8
+ import librosa
9
  import spaces
10
 
11
  import gradio as gr
 
14
  import traceback
15
 
16
  import torch
17
+ import torchaudio
18
 
19
  from common_utils.utils4infer import get_feat_from_wav_path, load_model_and_tokenizer, token_list2wav
20
 
 
56
 
57
 
58
 
59
+
60
  print("开始加载模型 A...")
61
  model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
62
  model_a
 
134
  if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
135
  print("音频信息未输入,且不是T2S或T2T任务")
136
  return "错误:需要音频输入"
137
+ if input_wav_path is not None:
138
+ waveform, sample_rate = torchaudio.load(input_wav_path)
139
+ if waveform.shape[0] > 1:
140
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
141
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
142
+ waveform = resampler(waveform)
143
+ waveform = waveform.squeeze(0)
144
+ window = torch.hann_window(400)
145
+ stft = torch.stft(waveform, 400, 160, window=window, return_complex=True)
146
+ magnitudes = stft[..., :-1].abs() ** 2
147
+ filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
148
+ mel_spec = filters @ magnitudes
149
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
150
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
151
+ log_spec = (log_spec + 4.0) / 4.0
152
+ feat = log_spec.transpose(0, 1)
153
+ feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).cuda()
154
+ feat = feat.unsqueeze(0).cuda()
155
+ feat = feat.to(torch.bfloat16)
156
+ print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
157
+ else:
158
+ feat = None
159
+ feat_lens = None
160
 
161
  # 通用初始化:模型设备设置
162
  start_time = time.time()
 
168
  if input_prompt.endswith("_TTS"):
169
  text_for_tts = input_prompt.replace("_TTS", "")
170
  # T2S推理逻辑
171
+ res_tensor = model_a.generate_tts(device=torch.device("cuda"), text=text_for_tts)[0]
172
  res_token_list = res_tensor.tolist()
173
  res_text = res_token_list[:-1]
174
  print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
 
177
  elif input_prompt.endswith("_self_prompt"):
178
  prompt = input_prompt.replace("_self_prompt", "")
179
  # S2T推理逻辑
180
+ # feat, feat_lens = get_feat_from_wav_path(input_wav_path)
181
+ # waveform, sample_rate = do_resample(input_wav_path)
 
182
  res_text = model_a.generate(
183
  wavs=feat,
184
  wavs_len=feat_lens,
185
  prompt=prompt,
186
  cache_implementation="static"
187
  )[0]
 
188
  print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
189
 
190
  # 3. 处理T2T任务
 
194
  print(f'开始t2t推理, question_txt: {question_txt}')
195
  if is_npu: torch_npu.npu.synchronize()
196
  res_text = model_a.generate_text2text(
197
+ device=torch.device("cuda"),
198
  text=question_txt
199
  )[0]
200
  if is_npu: torch_npu.npu.synchronize()
 
205
  "请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
206
  "s2s_no_think"]:
207
  # S2S推理逻辑
208
+ # feat, feat_lens = get_feat_from_wav_path(input_wav_path)
209
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
210
  if is_npu: torch_npu.npu.synchronize()
211
  output_text, text_res, speech_res = model_a.generate_s2s_no_stream_with_repetition_penalty(
 
219
  # 5. 处理S2S有思考任务
220
  elif input_prompt == "THINK":
221
  # S2S带思考推理逻辑
222
+ # feat, feat_lens = get_feat_from_wav_path(input_wav_path)
223
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
224
  if is_npu: torch_npu.npu.synchronize()
225
  output_text, text_res, speech_res = model_a.generate_s2s_no_stream_think_with_repetition_penalty(
 
233
  # 6. 处理S2T4Chat无思考任务
234
  elif input_prompt == "s2t_no_think":
235
  # S2T4Chat推理逻辑
236
+ # feat, feat_lens = get_feat_from_wav_path(input_wav_path)
237
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
238
  if is_npu: torch_npu.npu.synchronize()
239
  res_text = model_a.generate4chat(
 
247
  # 7. 处理S2T4Chat有思考任务
248
  elif input_prompt == "s2t_think":
249
  # S2T4Chat带思考推理逻辑
250
+ # feat, feat_lens = get_feat_from_wav_path(input_wav_path)
251
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
252
  if is_npu: torch_npu.npu.synchronize()
253
  res_text = model_a.generate4chat_think(
 
261
  # 8. 处理默认S2T任务
262
  else:
263
  # 默认S2T推理逻辑
264
+ # feat, feat_lens = get_feat_from_wav_path(input_wav_path)
265
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
266
  if is_npu: torch_npu.npu.synchronize()
267
  res_text = model_a.generate(