genau-demo / GenAU /scripts /inference_file.py
Moayed's picture
add demo files
cef9e84
import shutil
import os
import argparse
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import seed_everything
from pytorch_lightning import seed_everything
from src.tools.training_utils import get_restore_step
from src.utilities.model.model_util import instantiate_from_config
from src.tools.training_utils import build_dataset_json_from_list
from src.tools.configuration import Configuration
from src.utilities.data.videoaudio_dataset import VideoAudioDataset, custom_collate_fn
from src.tools.download_manager import get_checkpoint_path
@torch.no_grad()
def infer(dataset_json, configs, config_yaml_path, exp_group_name, exp_name, seed=0,
n_cand=1, cfg_weight=3.5, ddim_steps=200, strategy='wo_ema'):
seed_everything(seed)
if "precision" in configs['training'].keys():
torch.set_float32_matmul_precision(
configs['training']["precision"]
) # highest, high, medium
log_path = configs['logging']["log_directory"]
if "dataloader_add_ons" in configs["data"].keys():
dataloader_add_ons = configs["data"]["dataloader_add_ons"]
else:
dataloader_add_ons = []
val_dataset = VideoAudioDataset(
config=configs,
split='test',
add_ons=dataloader_add_ons,
dataset_json=dataset_json,
load_audio=False,
load_video=False
)
val_loader = DataLoader(
val_dataset,
batch_size=42,
collate_fn=custom_collate_fn
)
config_reload_from_ckpt = configs.get("reload_from_ckpt", None)
checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints")
wandb_path = os.path.join(log_path, exp_group_name, exp_name)
os.makedirs(checkpoint_path, exist_ok=True)
shutil.copy(config_yaml_path, wandb_path)
if config_reload_from_ckpt is not None:
resume_from_checkpoint = config_reload_from_ckpt
print("Reload ckpt specified in the config file %s" % resume_from_checkpoint)
elif len(os.listdir(checkpoint_path)) > 0:
print("Load checkpoint from path: %s" % checkpoint_path)
restore_step, n_step = get_restore_step(checkpoint_path)
resume_from_checkpoint = os.path.join(checkpoint_path, restore_step)
print("Resume from checkpoint", resume_from_checkpoint)
else:
raise "Please specify a pre-trained checkpoint"
configs['model']['params']['ckpt_path'] = resume_from_checkpoint
latent_diffusion = instantiate_from_config(configs["model"])
latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name)
latent_diffusion.eval()
latent_diffusion = latent_diffusion.cuda()
latent_diffusion.generate_sample(
val_loader,
unconditional_guidance_scale=cfg_weight,
ddim_steps=ddim_steps,
n_gen=n_cand,
use_ema=(strategy != 'wo_ema')
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config_yaml",
type=str,
required=False,
default=None,
help="path to config .yaml file",
)
parser.add_argument(
"-m",
"--model",
type=str,
default='genau-l-full-hq-data',
required=False,
help="Model name",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=0
)
parser.add_argument(
"-cfg",
"--cfg_weight",
type=float,
default=4.0
)
parser.add_argument(
"--n_cand",
type=int,
default=3,
help="number of candidates for clap reranking"
)
parser.add_argument(
"--ddim_steps",
type=int,
default=200,
help="number of ddim steps for sampling"
)
parser.add_argument(
"-l",
"--list_inference",
type=str,
default=None,
required=True,
help="The filelist that contain captions (and optionally filenames)",
)
parser.add_argument(
"-ckpt",
"--reload_from_ckpt",
type=str,
required=False,
help="the checkpoint path for the model. If not provided, the most recent checkpoint from the log folder of the provided caption will be used",
)
parser.add_argument(
"--strategy",
type=str,
required=False,
default='ema',
help="The strategy of combining weights from different checkpoint: wo_ema, avg_ckpt, or ema",
)
args = parser.parse_args()
assert torch.cuda.is_available(), "CUDA is not available"
if args.config_yaml is None:
args.config_yaml = get_checkpoint_path(f"{args.model}_config")
if args.reload_from_ckpt is None:
args.reload_from_ckpt = get_checkpoint_path(args.model)
config_yaml = args.config_yaml
if args.list_inference is not None:
dataset_json = build_dataset_json_from_list(args.list_inference)
else:
dataset_json = None
exp_name = os.path.basename(config_yaml.split(".")[0])
exp_group_name = os.path.basename(os.path.dirname(config_yaml))
configuration = Configuration(config_yaml)
configs = configuration.get_config()
if args.reload_from_ckpt != None:
configs["reload_from_ckpt"] = args.reload_from_ckpt
infer(dataset_json=dataset_json,
configs=configs,
config_yaml_path=args.config_yaml,
exp_group_name=exp_group_name,
exp_name=exp_name,
seed=args.seed,
n_cand=args.n_cand,
ddim_steps=args.ddim_steps,
cfg_weight=args.cfg_weight,
strategy=args.strategy)