|
|
|
|
|
import argparse
|
|
|
import binascii
|
|
|
import os
|
|
|
import os.path as osp
|
|
|
|
|
|
import imageio
|
|
|
import torch
|
|
|
import torchvision
|
|
|
from sys import argv
|
|
|
|
|
|
__all__ = ['cache_video', 'cache_image', 'str2bool']
|
|
|
|
|
|
|
|
|
def get_arguments(args=argv[1:]):
|
|
|
parser = get_argument_parser()
|
|
|
args = parser.parse_args(args)
|
|
|
|
|
|
|
|
|
if getattr(args, "local_rank", -1) == -1:
|
|
|
env_lr = os.environ.get("LOCAL_RANK") or os.environ.get("SLURM_LOCALID")
|
|
|
try:
|
|
|
if env_lr is not None:
|
|
|
args.local_rank = int(env_lr)
|
|
|
except ValueError:
|
|
|
pass
|
|
|
|
|
|
|
|
|
args.no_cuda = False
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available() and getattr(args, "local_rank", -1) >= 0:
|
|
|
try:
|
|
|
torch.cuda.set_device(args.local_rank % torch.cuda.device_count())
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
def get_argument_parser():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--config-file",
|
|
|
type=str,
|
|
|
default="ovi/configs/inference/inference_fusion.yaml")
|
|
|
parser.add_argument("--local_rank",
|
|
|
type=int,
|
|
|
default=-1,
|
|
|
help="local_rank for distributed training on gpus")
|
|
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
def rand_name(length=8, suffix=''):
|
|
|
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
|
|
if suffix:
|
|
|
if not suffix.startswith('.'):
|
|
|
suffix = '.' + suffix
|
|
|
name += suffix
|
|
|
return name
|
|
|
|
|
|
|
|
|
def cache_video(tensor,
|
|
|
save_file=None,
|
|
|
fps=30,
|
|
|
suffix='.mp4',
|
|
|
nrow=8,
|
|
|
normalize=True,
|
|
|
value_range=(-1, 1),
|
|
|
retry=5):
|
|
|
|
|
|
cache_file = osp.join('/tmp', rand_name(
|
|
|
suffix=suffix)) if save_file is None else save_file
|
|
|
|
|
|
|
|
|
error = None
|
|
|
for _ in range(retry):
|
|
|
try:
|
|
|
|
|
|
tensor = tensor.clamp(min(value_range), max(value_range))
|
|
|
tensor = torch.stack([
|
|
|
torchvision.utils.make_grid(
|
|
|
u, nrow=nrow, normalize=normalize, value_range=value_range)
|
|
|
for u in tensor.unbind(2)
|
|
|
],
|
|
|
dim=1).permute(1, 2, 3, 0)
|
|
|
tensor = (tensor * 255).type(torch.uint8).cpu()
|
|
|
|
|
|
|
|
|
writer = imageio.get_writer(
|
|
|
cache_file, fps=fps, codec='libx264', quality=8)
|
|
|
for frame in tensor.numpy():
|
|
|
writer.append_data(frame)
|
|
|
writer.close()
|
|
|
return cache_file
|
|
|
except Exception as e:
|
|
|
error = e
|
|
|
continue
|
|
|
else:
|
|
|
print(f'cache_video failed, error: {error}', flush=True)
|
|
|
return None
|
|
|
|
|
|
|
|
|
def cache_image(tensor,
|
|
|
save_file,
|
|
|
nrow=8,
|
|
|
normalize=True,
|
|
|
value_range=(-1, 1),
|
|
|
retry=5):
|
|
|
|
|
|
suffix = osp.splitext(save_file)[1]
|
|
|
if suffix.lower() not in [
|
|
|
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
|
|
|
]:
|
|
|
suffix = '.png'
|
|
|
|
|
|
|
|
|
error = None
|
|
|
for _ in range(retry):
|
|
|
try:
|
|
|
tensor = tensor.clamp(min(value_range), max(value_range))
|
|
|
torchvision.utils.save_image(
|
|
|
tensor,
|
|
|
save_file,
|
|
|
nrow=nrow,
|
|
|
normalize=normalize,
|
|
|
value_range=value_range)
|
|
|
return save_file
|
|
|
except Exception as e:
|
|
|
error = e
|
|
|
continue
|
|
|
|
|
|
|
|
|
def str2bool(v):
|
|
|
"""
|
|
|
Convert a string to a boolean.
|
|
|
|
|
|
Supported true values: 'yes', 'true', 't', 'y', '1'
|
|
|
Supported false values: 'no', 'false', 'f', 'n', '0'
|
|
|
|
|
|
Args:
|
|
|
v (str): String to convert.
|
|
|
|
|
|
Returns:
|
|
|
bool: Converted boolean value.
|
|
|
|
|
|
Raises:
|
|
|
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
|
|
|
"""
|
|
|
if isinstance(v, bool):
|
|
|
return v
|
|
|
v_lower = v.lower()
|
|
|
if v_lower in ('yes', 'true', 't', 'y', '1'):
|
|
|
return True
|
|
|
elif v_lower in ('no', 'false', 'f', 'n', '0'):
|
|
|
return False
|
|
|
else:
|
|
|
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
|
|
|
|