johnwang2026 commited on
Commit
6dbf71c
·
verified ·
1 Parent(s): 6a0aab9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -1,28 +1,33 @@
1
  import gradio as gr
2
- from transformers import AutoModel, AutoTokenizer # 替换AutoModelForTextToSpeech为AutoModel
3
  import soundfile as sf
4
  import torch
5
  import os
6
 
7
- # 加载模型和Tokenizer(自动下载SoulX模型)
8
  model_name = "Soul-AILab/SoulX-Podcast-1.7B"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModel.from_pretrained( # 这里用AutoModel替代AutoModelForTextToSpeech
11
  model_name,
12
- torch_dtype=torch.float16,
13
- device_map="auto"
14
  )
 
 
 
15
 
16
- # 语音生成函数(逻辑不变)
17
  def generate_speech(text):
18
  if not text.strip():
19
  return None, "错误:请输入有效文本!"
20
 
21
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
 
22
 
23
  with torch.no_grad():
24
  audio_output = model.generate(**inputs)
25
 
 
26
  output_path = "output.wav"
27
  sf.write(output_path, audio_output[0].cpu().numpy(), samplerate=24000)
28
 
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
  import soundfile as sf
4
  import torch
5
  import os
6
 
7
+ # 加载模型和Tokenizer(修复参数+移除device_map)
8
  model_name = "Soul-AILab/SoulX-Podcast-1.7B"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModel.from_pretrained(
11
  model_name,
12
+ dtype=torch.float16, # 替换 deprecated 的 torch_dtype
13
+ # 移除 device_map="auto",改用手动分配设备(兼容无accelerate环境)
14
  )
15
+ # 手动将模型移到GPU(无GPU自动用CPU)
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model = model.to(device)
18
 
19
+ # 语音生成函数(补充设备适配)
20
  def generate_speech(text):
21
  if not text.strip():
22
  return None, "错误:请输入有效文本!"
23
 
24
+ # 文本编码并移到对应设备
25
+ inputs = tokenizer(text, return_tensors="pt").to(device)
26
 
27
  with torch.no_grad():
28
  audio_output = model.generate(**inputs)
29
 
30
+ # 保存音频
31
  output_path = "output.wav"
32
  sf.write(output_path, audio_output[0].cpu().numpy(), samplerate=24000)
33