Spaces:
Runtime error
Runtime error
Commit
Β·
793fc05
1
Parent(s):
30a6bb8
update
Browse files- app.py +2 -2
- diffrhythm/infer/infer.py +5 -1
- diffrhythm/infer/infer_utils.py +1 -2
- requirements.txt +3 -2
app.py
CHANGED
|
@@ -51,7 +51,7 @@ def infer_music(lrc, ref_audio_path, steps, file_type, max_frames=2048):
|
|
| 51 |
start_time=start_time,
|
| 52 |
file_type=file_type
|
| 53 |
)
|
| 54 |
-
|
| 55 |
gc.collect()
|
| 56 |
print(">4")
|
| 57 |
|
|
@@ -207,7 +207,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 207 |
interactive=True,
|
| 208 |
elem_id="step_slider"
|
| 209 |
)
|
| 210 |
-
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="
|
| 211 |
|
| 212 |
|
| 213 |
|
|
|
|
| 51 |
start_time=start_time,
|
| 52 |
file_type=file_type
|
| 53 |
)
|
| 54 |
+
devicetorch.empty_cache(torch)
|
| 55 |
gc.collect()
|
| 56 |
print(">4")
|
| 57 |
|
|
|
|
| 207 |
interactive=True,
|
| 208 |
elem_id="step_slider"
|
| 209 |
)
|
| 210 |
+
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
|
| 211 |
|
| 212 |
|
| 213 |
|
diffrhythm/infer/infer.py
CHANGED
|
@@ -134,7 +134,11 @@ if __name__ == "__main__":
|
|
| 134 |
parser.add_argument('--output-dir', type=str, default="example/output")
|
| 135 |
args = parser.parse_args()
|
| 136 |
|
| 137 |
-
device =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
audio_length = args.audio_length
|
| 140 |
if audio_length == 95:
|
|
|
|
| 134 |
parser.add_argument('--output-dir', type=str, default="example/output")
|
| 135 |
args = parser.parse_args()
|
| 136 |
|
| 137 |
+
device = "cpu"
|
| 138 |
+
if torch.cuda.is_available():
|
| 139 |
+
device = "cuda"
|
| 140 |
+
elif torch.mps.is_available():
|
| 141 |
+
device = "mps"
|
| 142 |
|
| 143 |
audio_length = args.audio_length
|
| 144 |
if audio_length == 95:
|
diffrhythm/infer/infer_utils.py
CHANGED
|
@@ -169,8 +169,7 @@ def get_lrc_token(text, tokenizer, device):
|
|
| 169 |
return lrc_emb, normalized_start_time
|
| 170 |
|
| 171 |
def load_checkpoint(model, ckpt_path, device, use_ema=True):
|
| 172 |
-
|
| 173 |
-
model = model.half()
|
| 174 |
|
| 175 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 176 |
if ckpt_type == "safetensors":
|
|
|
|
| 169 |
return lrc_emb, normalized_start_time
|
| 170 |
|
| 171 |
def load_checkpoint(model, ckpt_path, device, use_ema=True):
|
| 172 |
+
model = model.half()
|
|
|
|
| 173 |
|
| 174 |
ckpt_type = ckpt_path.split(".")[-1]
|
| 175 |
if ckpt_type == "safetensors":
|
requirements.txt
CHANGED
|
@@ -12,7 +12,8 @@ pandas==2.2.3
|
|
| 12 |
pylance==0.23.2
|
| 13 |
ema-pytorch==0.7.7
|
| 14 |
prefigure==0.0.10
|
| 15 |
-
bitsandbytes==0.45.3
|
|
|
|
| 16 |
muq==0.1.0
|
| 17 |
mutagen==1.47.0
|
| 18 |
pyopenjtalk==0.4.0
|
|
@@ -21,7 +22,7 @@ jieba==0.42.1
|
|
| 21 |
cn2an==0.5.23
|
| 22 |
pypinyin==0.53.0
|
| 23 |
#onnxruntime==1.20.1
|
| 24 |
-
onnxruntime-gpu
|
| 25 |
Unidecode==1.3.8
|
| 26 |
phonemizer==3.3.0
|
| 27 |
LangSegment==0.3.5
|
|
|
|
| 12 |
pylance==0.23.2
|
| 13 |
ema-pytorch==0.7.7
|
| 14 |
prefigure==0.0.10
|
| 15 |
+
#bitsandbytes==0.45.3
|
| 16 |
+
bitsandbytes
|
| 17 |
muq==0.1.0
|
| 18 |
mutagen==1.47.0
|
| 19 |
pyopenjtalk==0.4.0
|
|
|
|
| 22 |
cn2an==0.5.23
|
| 23 |
pypinyin==0.53.0
|
| 24 |
#onnxruntime==1.20.1
|
| 25 |
+
#onnxruntime-gpu
|
| 26 |
Unidecode==1.3.8
|
| 27 |
phonemizer==3.3.0
|
| 28 |
LangSegment==0.3.5
|