import os import argparse import torch from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything from src.tools.training_utils import get_restore_step from src.utilities.model.model_util import instantiate_from_config from src.tools.configuration import Configuration from src.tools.download_manager import get_checkpoint_path @torch.no_grad() def infer(prompt, configs, exp_group_name, exp_name, seed=0, n_cand=1, cfg_weight=3.5, ddim_steps=200): seed_everything(seed) use_ema = False if 'force_use_ema' in configs: use_ema = configs['force_use_ema'] if "precision" in configs['training'].keys(): torch.set_float32_matmul_precision( configs['training']["precision"] ) # highest, high, medium log_path = configs['logging']["log_directory"] config_reload_from_ckpt = configs.get("reload_from_ckpt", None) checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") if config_reload_from_ckpt is not None: resume_from_checkpoint = config_reload_from_ckpt print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) elif len(os.listdir(checkpoint_path)) > 0: print("Load checkpoint from path: %s" % checkpoint_path) restore_step, n_step = get_restore_step(checkpoint_path) resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) print("Resume from checkpoint", resume_from_checkpoint) else: raise f"[ERROR] no checkpoint was found at {checkpoint_path}, please provide a checkpoint" configs['model']['params']['ckpt_path'] = resume_from_checkpoint latent_diffusion = instantiate_from_config(configs["model"]) latent_diffusion.eval() latent_diffusion = latent_diffusion.cuda() saved_wav_path = latent_diffusion.text_to_audio( prompt=prompt, ddim_steps=ddim_steps, unconditional_guidance_scale=cfg_weight, n_gen=n_cand, use_ema=use_ema) print("[INFO] saved audio sample at:", saved_wav_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-p", "--prompt", type=str, required=True, help="model name", ) parser.add_argument( "-m", "--model", type=str, default='genau-l-full-hq-data', required=False, help="path to config .yaml file", ) parser.add_argument( "-c", "--config_yaml", type=str, required=False, default=None, help="path to config .yaml file", ) parser.add_argument( "-s", "--seed", type=int, default=0 ) parser.add_argument( "-cfg", "--cfg_weight", type=float, default=4.0 ) parser.add_argument( "--n_cand", type=int, default=1, help="number of candidates for clap reranking" ) parser.add_argument( "--ddim_steps", type=int, default=100, help="number of ddim steps for sampling" ) parser.add_argument( "-ckpt", "--reload_from_ckpt", type=str, required=False, help="the checkpoint path for the model. If not provided, the most recent checkpoint from the log folder of the provided caption will be used", ) args = parser.parse_args() assert torch.cuda.is_available(), "CUDA is not available" if args.config_yaml is None: args.config_yaml = get_checkpoint_path(f"{args.model}_config") if args.reload_from_ckpt is None: args.reload_from_ckpt = get_checkpoint_path(args.model) config_yaml = args.config_yaml exp_name = os.path.basename(config_yaml.split(".")[0]) exp_group_name = os.path.basename(os.path.dirname(config_yaml)) exp_name = os.path.basename(config_yaml.split(".")[0]) exp_group_name = os.path.basename(os.path.dirname(config_yaml)) configuration = Configuration(config_yaml) config_yaml = configuration.get_config() if args.reload_from_ckpt != None: config_yaml["reload_from_ckpt"] = args.reload_from_ckpt infer(prompt=args.prompt, configs=config_yaml, exp_name=exp_name, exp_group_name=exp_group_name, seed=args.seed, n_cand=args.n_cand, ddim_steps=args.ddim_steps, cfg_weight=args.cfg_weight)