DvD / datasets /utils /general_utils.py
hanquansanren's picture
Add application file
05fb4ab
from PIL import Image
from pathlib import Path
from typing import *
from einops import rearrange
import numpy as np
import torch
import os
import shutil
import cv2
# from util.image import scale_image, random_tight_crop, tight_crop_image,random_tight_crop_imgonly
# from util.load import load_array, load_image,load_npz
# from util.mapping import scale_map, tight_crop_map, tight_crop_map_docaligner
# from inv3d_util.mask import scale_mask, tight_crop_mask
from torchvision.utils import save_image as tv_save_image
import torch.nn.functional as F
def training_init(args):
local_rank = int(os.environ["LOCAL_RANK"])
torch.distributed.init_process_group('nccl', init_method='env://')
device = torch.device(f'cuda:{local_rank}')
torch.cuda.manual_seed_all(40)
# create model
torch.cuda.set_device(local_rank)
torch.cuda.empty_cache()
return local_rank
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def has_file_allowed_extension(filename, extensions):
return filename.lower().endswith(extensions)
def is_image_file(filename):
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def pil_loader_withHW(path):
with open(path, 'rb') as f:
img = Image.open(f)
width, height = img.size
return img.convert('RGB'), height, width
def cv2_loader_withHW(path):
# with open(path, 'rb') as f:
img = cv2.imread(path,flags=cv2.IMREAD_COLOR)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
height, width, channels = img.shape
return img, height, width
def collate_batch(batch_list):
data1 = [item[0] for item in batch_list]
data2 = [item[1] for item in batch_list]
labels = [item[2] for item in batch_list]
return data1, data2, labels
# data1 = [item[0] for item in batch_list]
# labels = [item[1] for item in batch_list]
# return data1, labels
# # features = torch.stack([sample[0] for sample in batch])
# # labels = torch.stack([sample[1] for sample in batch])
# # return features,labels
def select_max_region(mask):
nums, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
background = 0
for row in range(stats.shape[0]):
if stats[row, :][0] == 0 and stats[row, :][1] == 0:
background = row
stats_no_bg = np.delete(stats, background, axis=0)
max_idx = stats_no_bg[:, 4].argmax()
max_region = np.where(labels==max_idx+1, 1, 0)
return max_region
def docreg_bm_norm(file: str, resolution: Optional[int]):
file = Path(file)
bm = np.load(file)[file.name[:-4]][0] # (4032, 3024, 2) numpy
bm = ((bm+1)/2)*resolution # (0-288)
bm = scale_map(bm, resolution) # (288, 288, 2)
bm = rearrange(bm, "h w c -> c h w") # (2, 288, 288)
bm = torch.from_numpy(bm).float() # *resolution # tensor 0-288
return bm
def prepare_image(
image_file: Path, mask_file: Path, color_jitter: bool, **scale_settings
):
image = load_image(image_file) # (1770, 1327, 3)
# mask = load_image(mask_file)[..., :1] # (1770, 1327, 1)
mask = load_npz(mask_file)[..., :1].astype(np.uint8)
mask = select_max_region(mask)
H,W,_ = image.shape
# test point
# image = scale_image(image, **scale_settings) # (288, 288, 3)
# image = rearrange(image, "h w c -> c h w")
# image = image.astype("float32") / 255
# image = torch.from_numpy(image)
# flow = load_array(flow_file) # (1024, 1024, 2)
# flow = rearrange(flow, "h w c -> c h w")
# flow = torch.from_numpy(flow).float()
# B2A = spatial_trans(image, flow[None], 0)
# tv_save_image(image, "backup/test/ori.png")
# tv_save_image(B2A[0], "backup/test/ttt.png")
# assert image.shape[0]>0,print("input",image.shape)
# assert image.shape[1]>0,print("input",image.shape)
# image = tight_crop_image(image, mask.squeeze()) # (349, 245, 3)
image,t,b,l,r = random_tight_crop_imgonly(mask,image,H,W)
assert image.shape[0]>0, print("crop",image.shape)
assert image.shape[1]>0, print("crop",image.shape)
image = scale_image(image, **scale_settings) # (288, 288, 3)
# image = transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)(image) if color_jitter else image
image = rearrange(image, "h w c -> c h w")
image = image.astype("float32") / 255
image = torch.from_numpy(image)
# recon = scale_image(recon, **scale_settings) # (288, 288, 3)
# img_pil = Image.fromarray(image)
# image = transfrom(img_pil)
# t,b,l,r = None,None,None,None
return image,t,b,l,r,H,W
def prepare_masked_image(
image_file: Path, recon_file: Path, transfrom, uv_file: Path, color_jitter: bool, **scale_settings
):
mask = load_array(uv_file)[..., :1] # [448,448,1]
image = load_image(image_file) # (448, 448, 3)
recon = load_image(recon_file) # (448, 448, 3)
# assert image.shape[0]>0,print("input",image.shape)
# assert image.shape[1]>0,print("input",image.shape)
# image = tight_crop_image(image, mask.squeeze()) # (349, 245, 3)
image,recon,t,b,l,r = random_tight_crop(mask,image,recon)
assert image.shape[0]>0, print("crop",image.shape)
assert image.shape[1]>0, print("crop",image.shape)
# image = scale_image(image, **scale_settings) # (288, 288, 3)
# recon = scale_image(recon, **scale_settings) # (288, 288, 3)
# image = transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)(image) if color_jitter else image
img_pil = Image.fromarray(image)
recon_pil = Image.fromarray(recon)
image = transfrom(img_pil)
recon = transfrom(recon_pil)
# image = rearrange(image, "h w c -> c h w")
# image = image.astype("float32") / 255
# image = torch.from_numpy(image)
# recon = rearrange(recon, "h w c -> c h w")
# recon = recon.astype("float32") / 255
# recon = torch.from_numpy(recon)
return image,recon,t,b,l,r
def prepare_bm_inv3d(file: Path, resolution: Optional[int], t,b,l,r):
file = Path(file)
assert file.suffix in [".npz", ".mat", ".npy"]
bm = load_array(file).astype("float32")*1600.0 # (512, 512, 2) 0-1600
bm = tight_crop_map(bm,t,b,l,r) # (512, 512, 2) crop and 0-1
bm = scale_map(bm, resolution) # (288, 288, 2)
bm = np.roll(bm, shift=1, axis=-1)# ε˜ε›žε…ˆx后y
bm = rearrange(bm, "h w c -> c h w")
bm = torch.from_numpy(bm).float()*resolution # 0-288
# bm=(bm-0.5)*2
return bm
def prepare_bm_docregis(file: Path, resolution: Optional[int], t,b,l,r, H,W):
file = Path(file)
assert file.suffix in [".npz", ".mat", ".npy"]
bm = load_array(file).astype("float32") # numpy (512, 512, 2) (0,1)
# bm = (bm+1)/2 # (0,1)
bm[...,0] *= H
bm[...,1] *= W
bm = tight_crop_map_docaligner(bm,t,b,l,r,H,W) # (512, 512, 2) crop and 0-1
bm = scale_map(bm, resolution) # (288, 288, 2)
bm = np.roll(bm, shift=1, axis=-1)# ε˜ε›žε…ˆx后y
bm = rearrange(bm, "h w c -> c h w") # (2, 288, 288)
bm = torch.from_numpy(bm).float()*resolution # 0-288
# bm=(bm-0.5)*2
return bm
class Averager(object):
"""Compute average for torch.Tensor, used for loss average."""
def __init__(self):
self.reset()
def add(self, v):
count = v.data.numel()
v = v.data.sum()
self.n_count += count
self.sum += v
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.n_count += n
self.avg = self.sum / self.n_count
# class AverageMeter(object):
# """Computes and stores the average and current value"""
# def __init__(self):
# self.reset()
# def reset(self):
# self.val = 0
# self.avg = 0
# self.sum = 0
# self.count = 0
# def update(self, val, n=1):
# self.val = val
# self.sum += val * n
# self.count += n
# self.avg = self.sum / self.count
def save_checkpoint(state, is_best, epoch, checkpoint_name, filename='checkpoint.pth.tar'):
if os.path.exists("checkpoints/{}".format(checkpoint_name)) is False:
os.makedirs("checkpoints/{}".format(checkpoint_name), exist_ok=True)
torch.save(state, 'checkpoints/{}/'.format(checkpoint_name) + filename + '_latest.pth.tar')
if epoch%10 == 0:
# os.rename('checkpoint/' + filename + '_latest.pth.tar', 'checkpoint/' + filename + '_%d.pth.tar' % (epoch))
shutil.copyfile('checkpoints/{}/'.format(checkpoint_name) + filename + '_latest.pth.tar', 'checkpoints/{}/'.format(checkpoint_name) + filename + '_%d.pth.tar' % (epoch))
if is_best:
shutil.copyfile('checkpoints/{}/'.format(checkpoint_name) + filename + '_latest.pth.tar', 'checkpoints/{}/'.format(checkpoint_name) + filename + '_best.pth.tar')
class EarlyStopping:
def __init__(self, patience=7, verbose=False, delta=0):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
def __call__(self, val_loss, model, path, ret=None, opt=None):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, path,ret,opt)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, path,ret,opt)
self.counter = 0
def save_checkpoint(self, val_loss, model, path,ret,opt):
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
self.val_loss_min = val_loss