Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| from einops import rearrange | |
| import argparse | |
| import os | |
| import time | |
| import random | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from einops import rearrange | |
| import io | |
| import pydub | |
| from diffrhythm.infer.infer_utils import ( | |
| decode_audio, | |
| get_lrc_token, | |
| get_negative_style_prompt, | |
| get_reference_latent, | |
| get_style_prompt, | |
| prepare_model, | |
| eval_song, | |
| ) | |
| def inference( | |
| cfm_model, | |
| vae_model, | |
| eval_model, | |
| eval_muq, | |
| cond, | |
| text, | |
| duration, | |
| style_prompt, | |
| negative_style_prompt, | |
| steps, | |
| cfg_strength, | |
| sway_sampling_coef, | |
| start_time, | |
| file_type, | |
| vocal_flag, | |
| odeint_method, | |
| pred_frames, | |
| batch_infer_num, | |
| chunked=True, | |
| ): | |
| with torch.inference_mode(): | |
| latents, _ = cfm_model.sample( | |
| cond=cond, | |
| text=text, | |
| duration=duration, | |
| style_prompt=style_prompt, | |
| negative_style_prompt=negative_style_prompt, | |
| steps=steps, | |
| cfg_strength=cfg_strength, | |
| sway_sampling_coef=sway_sampling_coef, | |
| start_time=start_time, | |
| vocal_flag=vocal_flag, | |
| odeint_method=odeint_method, | |
| latent_pred_segments=pred_frames, | |
| batch_infer_num=batch_infer_num | |
| ) | |
| outputs = [] | |
| for latent in latents: | |
| latent = latent.to(torch.float32) | |
| latent = latent.transpose(1, 2) # [b d t] | |
| output = decode_audio(latent, vae_model, chunked=chunked) | |
| # Rearrange audio batch to a single sequence | |
| output = rearrange(output, "b d n -> d (b n)") | |
| outputs.append(output) | |
| if batch_infer_num > 1: | |
| generated_song = eval_song(eval_model, eval_muq, outputs) | |
| else: | |
| generated_song = outputs[0] | |
| output_tensor = generated_song.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu() | |
| output_np = output_tensor.numpy().T.astype(np.float32) | |
| if file_type == 'wav': | |
| return (44100, output_np) | |
| else: | |
| buffer = io.BytesIO() | |
| output_np = np.int16(output_np * 2**15) | |
| song = pydub.AudioSegment(output_np.tobytes(), frame_rate=44100, sample_width=2, channels=2) | |
| if file_type == 'mp3': | |
| song.export(buffer, format="mp3", bitrate="320k") | |
| else: | |
| song.export(buffer, format="ogg", bitrate="320k") | |
| return buffer.getvalue() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--lrc-path", | |
| type=str, | |
| help="lyrics of target song", | |
| ) # lyrics of target song | |
| parser.add_argument( | |
| "--ref-prompt", | |
| type=str, | |
| help="reference prompt as style prompt for target song", | |
| required=False, | |
| ) # reference prompt as style prompt for target song | |
| parser.add_argument( | |
| "--ref-audio-path", | |
| type=str, | |
| help="reference audio as style prompt for target song", | |
| required=False, | |
| ) # reference audio as style prompt for target song | |
| parser.add_argument( | |
| "--chunked", | |
| action="store_true", | |
| help="whether to use chunked decoding", | |
| ) # whether to use chunked decoding | |
| parser.add_argument( | |
| "--audio-length", | |
| type=int, | |
| default=95, | |
| choices=[95, 285], | |
| help="length of generated song", | |
| ) # length of target song | |
| parser.add_argument( | |
| "--repo-id", type=str, default="ASLP-lab/DiffRhythm-base", help="target model" | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default="infer/example/output", | |
| help="output directory fo generated song", | |
| ) # output directory of target song | |
| parser.add_argument( | |
| "--edit", | |
| action="store_true", | |
| help="whether to open edit mode", | |
| ) # edit flag | |
| parser.add_argument( | |
| "--ref-song", | |
| type=str, | |
| required=False, | |
| help="reference prompt as latent prompt for editing", | |
| ) # reference prompt as latent prompt for editing | |
| parser.add_argument( | |
| "--edit-segments", | |
| type=str, | |
| required=False, | |
| help="edit segments o target song", | |
| ) # edit segments o target song | |
| args = parser.parse_args() | |
| assert ( | |
| args.ref_prompt or args.ref_audio_path | |
| ), "either ref_prompt or ref_audio_path should be provided" | |
| assert not ( | |
| args.ref_prompt and args.ref_audio_path | |
| ), "only one of them should be provided" | |
| if args.edit: | |
| assert ( | |
| args.ref_song and args.edit_segments | |
| ), "reference song and edit segments should be provided for editing" | |
| device = "cpu" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.mps.is_available(): | |
| device = "mps" | |
| audio_length = args.audio_length | |
| if audio_length == 95: | |
| max_frames = 2048 | |
| elif audio_length == 285: | |
| max_frames = 6144 | |
| cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(max_frames, device, repo_id=args.repo_id) | |
| if args.lrc_path: | |
| with open(args.lrc_path, "r", encoding='utf-8') as f: | |
| lrc = f.read() | |
| else: | |
| lrc = "" | |
| lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device) | |
| if args.ref_audio_path: | |
| style_prompt = get_style_prompt(muq, args.ref_audio_path) | |
| else: | |
| style_prompt = get_style_prompt(muq, prompt=args.ref_prompt) | |
| negative_style_prompt = get_negative_style_prompt(device) | |
| latent_prompt, pred_frames = get_reference_latent(device, max_frames, args.edit, args.edit_segments, args.ref_song, vae) | |
| s_t = time.time() | |
| generated_songs = inference( | |
| cfm_model=cfm, | |
| vae_model=vae, | |
| cond=latent_prompt, | |
| text=lrc_prompt, | |
| duration=max_frames, | |
| style_prompt=style_prompt, | |
| negative_style_prompt=negative_style_prompt, | |
| start_time=start_time, | |
| pred_frames=pred_frames, | |
| chunked=args.chunked, | |
| ) | |
| generated_song = eval_song(eval_model, eval_muq, generated_songs) | |
| # Peak normalize, clip, convert to int16, and save to file | |
| generated_song = ( | |
| generated_song.to(torch.float32) | |
| .div(torch.max(torch.abs(generated_song))) | |
| .clamp(-1, 1) | |
| .mul(32767) | |
| .to(torch.int16) | |
| .cpu() | |
| ) | |
| e_t = time.time() - s_t | |
| print(f"inference cost {e_t:.2f} seconds") | |
| output_dir = args.output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_path = os.path.join(output_dir, "output.wav") | |
| torchaudio.save(output_path, generated_song, sample_rate=44100) | |