xlgeng commited on
Commit
13f013f
·
1 Parent(s): 58a2540

开始部署

Browse files
Files changed (1) hide show
  1. app.py +25 -24
app.py CHANGED
@@ -30,6 +30,11 @@ except ImportError:
30
  print("torch_npu is not available. if you want to use npu, please install it.")
31
 
32
 
 
 
 
 
 
33
 
34
 
35
  from huggingface_hub import hf_hub_download
@@ -51,7 +56,7 @@ cosyvoice_model_path="./CosyVoice-300M-25Hz"
51
  device = torch.device("cuda")
52
  print("开始加载模型 A...")
53
  model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
54
- model_a.eval().cuda()
55
 
56
  print("\n开始加载模型 B...")
57
  if CHECKPOINT_PATH_B is not None:
@@ -61,15 +66,15 @@ else:
61
  model_b, tokenizer_b = None, None
62
 
63
  loaded_models = {
64
- NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
65
  NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
66
  } if model_b is not None else {
67
- NAME_A: {"model": model_a, "tokenizer": tokenizer_a},
68
  }
69
  print("\n所有模型已加载完毕。")
70
 
71
- cosyvoice = CosyVoice(cosyvoice_model_path)
72
- cosyvoice.eval().cuda()
73
 
74
  # 将图片转换为 Base64
75
  with open("./tts/assert/实验室.png", "rb") as image_file:
@@ -114,11 +119,6 @@ for item in prompt_audio_choices:
114
 
115
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
116
 
117
- import time
118
- import datetime
119
- import torch
120
- from common_utils.utils4infer import get_feat_from_wav_path, token_list2wav
121
-
122
 
123
  @spaces.GPU
124
  def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice, prompt_speech_data):
@@ -135,13 +135,14 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
135
  # 通用初始化:模型设备设置
136
  start_time = time.time()
137
  res_text = None
 
138
 
139
  try:
140
  # 1. 处理TTS任务
141
  if input_prompt.endswith("_TTS"):
142
  text_for_tts = input_prompt.replace("_TTS", "")
143
  # T2S推理逻辑
144
- res_tensor = model.generate_tts(device=device, text=text_for_tts)[0]
145
  res_token_list = res_tensor.tolist()
146
  res_text = res_token_list[:-1]
147
  print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
@@ -153,7 +154,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
153
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
154
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
155
  if is_npu: torch_npu.npu.synchronize()
156
- res_text = model.generate(
157
  wavs=feat,
158
  wavs_len=feat_lens,
159
  prompt=prompt,
@@ -168,7 +169,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
168
  # T2T推理逻辑
169
  print(f'开始t2t推理, question_txt: {question_txt}')
170
  if is_npu: torch_npu.npu.synchronize()
171
- res_text = model.generate_text2text(
172
  device=device,
173
  text=question_txt
174
  )[0]
@@ -183,7 +184,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
183
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
184
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
185
  if is_npu: torch_npu.npu.synchronize()
186
- output_text, text_res, speech_res = model.generate_s2s_no_stream_with_repetition_penalty(
187
  wavs=feat,
188
  wavs_len=feat_lens,
189
  )
@@ -197,7 +198,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
197
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
198
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
199
  if is_npu: torch_npu.npu.synchronize()
200
- output_text, text_res, speech_res = model.generate_s2s_no_stream_think_with_repetition_penalty(
201
  wavs=feat,
202
  wavs_len=feat_lens,
203
  )
@@ -211,7 +212,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
211
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
212
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
213
  if is_npu: torch_npu.npu.synchronize()
214
- res_text = model.generate4chat(
215
  wavs=feat,
216
  wavs_len=feat_lens,
217
  cache_implementation="static"
@@ -225,7 +226,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
225
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
226
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
227
  if is_npu: torch_npu.npu.synchronize()
228
- res_text = model.generate4chat_think(
229
  wavs=feat,
230
  wavs_len=feat_lens,
231
  cache_implementation="static"
@@ -239,7 +240,7 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
239
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
240
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
241
  if is_npu: torch_npu.npu.synchronize()
242
- res_text = model.generate(
243
  wavs=feat,
244
  wavs_len=feat_lens,
245
  prompt=input_prompt,
@@ -260,20 +261,20 @@ def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice,
260
  wav_path_output = input_wav_path
261
  if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
262
  if isinstance(output_res, list): # TTS case
263
- cosyvoice.eval()
264
- time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
265
- wav_path = f"./tmp/{time_str}.wav"
266
- wav_path_output = token_list2wav(output_res, prompt_speech_data, wav_path, cosyvoice)
267
  # wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
268
  output_res = "生成的token: " + str(output_res)
269
  elif isinstance(output_res, str) and "|" in output_res: # S2S case
270
  try:
271
  text_res, token_list_str = output_res.split("|")
272
  token_list = json.loads(token_list_str)
273
- cosyvoice.eval()
274
  time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
275
  wav_path = f"./tmp/{time_str}.wav"
276
- wav_path_output = token_list2wav(token_list, prompt_speech_data, wav_path, cosyvoice)
277
  # wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
278
  output_res = text_res
279
  except (ValueError, json.JSONDecodeError) as e:
 
30
  print("torch_npu is not available. if you want to use npu, please install it.")
31
 
32
 
33
+ import time
34
+ import datetime
35
+ import torch
36
+ from common_utils.utils4infer import get_feat_from_wav_path, token_list2wav
37
+
38
 
39
 
40
  from huggingface_hub import hf_hub_download
 
56
  device = torch.device("cuda")
57
  print("开始加载模型 A...")
58
  model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
59
+ model_a
60
 
61
  print("\n开始加载模型 B...")
62
  if CHECKPOINT_PATH_B is not None:
 
66
  model_b, tokenizer_b = None, None
67
 
68
  loaded_models = {
69
+ NAME_A: {"model": model_b, "tokenizer": tokenizer_b},
70
  NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
71
  } if model_b is not None else {
72
+ NAME_A: {"model": model_b, "tokenizer": tokenizer_b},
73
  }
74
  print("\n所有模型已加载完毕。")
75
 
76
+ # cosyvoice = CosyVoice(cosyvoice_model_path)
77
+ # cosyvoice.eval().cuda()
78
 
79
  # 将图片转换为 Base64
80
  with open("./tts/assert/实验室.png", "rb") as image_file:
 
119
 
120
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
121
 
 
 
 
 
 
122
 
123
  @spaces.GPU
124
  def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice, prompt_speech_data):
 
135
  # 通用初始化:模型设备设置
136
  start_time = time.time()
137
  res_text = None
138
+ model_a.eval().cuda()
139
 
140
  try:
141
  # 1. 处理TTS任务
142
  if input_prompt.endswith("_TTS"):
143
  text_for_tts = input_prompt.replace("_TTS", "")
144
  # T2S推理逻辑
145
+ res_tensor = model_a.generate_tts(device=device, text=text_for_tts)[0]
146
  res_token_list = res_tensor.tolist()
147
  res_text = res_token_list[:-1]
148
  print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
 
154
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
155
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
156
  if is_npu: torch_npu.npu.synchronize()
157
+ res_text = model_a.generate(
158
  wavs=feat,
159
  wavs_len=feat_lens,
160
  prompt=prompt,
 
169
  # T2T推理逻辑
170
  print(f'开始t2t推理, question_txt: {question_txt}')
171
  if is_npu: torch_npu.npu.synchronize()
172
+ res_text = model_a.generate_text2text(
173
  device=device,
174
  text=question_txt
175
  )[0]
 
184
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
185
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
186
  if is_npu: torch_npu.npu.synchronize()
187
+ output_text, text_res, speech_res = model_a.generate_s2s_no_stream_with_repetition_penalty(
188
  wavs=feat,
189
  wavs_len=feat_lens,
190
  )
 
198
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
199
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
200
  if is_npu: torch_npu.npu.synchronize()
201
+ output_text, text_res, speech_res = model_a.generate_s2s_no_stream_think_with_repetition_penalty(
202
  wavs=feat,
203
  wavs_len=feat_lens,
204
  )
 
212
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
213
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
214
  if is_npu: torch_npu.npu.synchronize()
215
+ res_text = model_a.generate4chat(
216
  wavs=feat,
217
  wavs_len=feat_lens,
218
  cache_implementation="static"
 
226
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
227
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
228
  if is_npu: torch_npu.npu.synchronize()
229
+ res_text = model_a.generate4chat_think(
230
  wavs=feat,
231
  wavs_len=feat_lens,
232
  cache_implementation="static"
 
240
  feat, feat_lens = get_feat_from_wav_path(input_wav_path)
241
  print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
242
  if is_npu: torch_npu.npu.synchronize()
243
+ res_text = model_a.generate(
244
  wavs=feat,
245
  wavs_len=feat_lens,
246
  prompt=input_prompt,
 
261
  wav_path_output = input_wav_path
262
  if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
263
  if isinstance(output_res, list): # TTS case
264
+ # cosyvoice.eval()
265
+ # time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
266
+ # wav_path = f"./tmp/{time_str}.wav"
267
+ # wav_path_output = token_list2wav(output_res, prompt_speech_data, wav_path, cosyvoice)
268
  # wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
269
  output_res = "生成的token: " + str(output_res)
270
  elif isinstance(output_res, str) and "|" in output_res: # S2S case
271
  try:
272
  text_res, token_list_str = output_res.split("|")
273
  token_list = json.loads(token_list_str)
274
+ # cosyvoice.eval()
275
  time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
276
  wav_path = f"./tmp/{time_str}.wav"
277
+ # wav_path_output = token_list2wav(token_list, prompt_speech_data, wav_path, cosyvoice)
278
  # wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
279
  output_res = text_res
280
  except (ValueError, json.JSONDecodeError) as e: