Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import random | |
| import argparse | |
| import soundfile as sf | |
| import numpy as np | |
| import torch | |
| from diffusers import DDPMScheduler | |
| from pico_model import PicoDiffusion, build_pretrained_models | |
| from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt | |
| class dotdict(dict): | |
| """dot.notation access to dictionary attributes""" | |
| __getattr__ = dict.get | |
| __setattr__ = dict.__setitem__ | |
| __delattr__ = dict.__delitem__ | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Inference for text to audio generation task.") | |
| parser.add_argument( | |
| "--text", '-t', type=str, default="spraying two times then gunshot three times.", | |
| help="free-text caption." | |
| ) | |
| parser.add_argument( | |
| "--timestamp_caption", '-c', type=str, | |
| default=None, | |
| #default="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.", | |
| help="timestamp caption, formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'." | |
| ) | |
| parser.add_argument( | |
| "--exp_path", '-exp', type=str, default="/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/pico_model", | |
| help="Path for experiment." | |
| ) | |
| parser.add_argument( | |
| "--freeze_text_encoder_ckpt", type=str, default='/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/laion_clap/630k-audioset-best.pt', | |
| help="Path for clap." | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=0, | |
| help="seed.", | |
| ) | |
| args = parser.parse_args() | |
| args.original_args = os.path.join(args.exp_path, "summary.jsonl") | |
| args.diffusion_pt = os.path.join(args.exp_path, "diffusion.pt") | |
| return args | |
| def main(): | |
| args = parse_args() | |
| train_args = dotdict(json.loads(open(args.original_args).readlines()[0])) | |
| seed = args.seed | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # Step1: preprocess via llm | |
| if args.timestamp_caption == None: | |
| #args.timestamp_caption = preprocess_gpt(args.text) | |
| args.timestamp_caption = preprocess_gemini(args.text) | |
| # Load Models # | |
| print("------Load model") | |
| name = "audioldm-s-full" | |
| vae, stft = build_pretrained_models(name) | |
| vae, stft = vae.cuda(), stft.cuda() | |
| model = PicoDiffusion( | |
| scheduler_name=train_args.scheduler_name, | |
| unet_model_config_path=train_args.unet_model_config, | |
| snr_gamma=train_args.snr_gamma, | |
| freeze_text_encoder_ckpt=args.freeze_text_encoder_ckpt, | |
| diffusion_pt=args.diffusion_pt, | |
| ).cuda().eval() | |
| scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler") | |
| # Generate # | |
| num_steps, guidance, num_samples, audio_len = 200, 3.0, 1, 16000 * 10 | |
| output_dir = os.path.join("/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/synthesized", | |
| f"huggingface_demo_steps-{num_steps}_guidance-{guidance}_samples-{num_samples}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| print("------Diffusion begin!") | |
| with torch.no_grad(): | |
| latents = model.demo_inference(args.timestamp_caption, scheduler, num_steps, guidance, num_samples, disable_progress=True) | |
| mel = vae.decode_first_stage(latents) | |
| wave = vae.decode_to_waveform(mel) | |
| sf.write(f"{output_dir}/{args.timestamp_caption}.wav", wave[0][:audio_len], samplerate=16000, subtype='PCM_16') | |
| print(f"------Write to files to {output_dir}/{args.timestamp_caption}.wav") | |
| if __name__ == "__main__": | |
| main() |