Spaces:
Running
on
Zero
Running
on
Zero
| from PIL import Image | |
| from pathlib import Path | |
| from typing import * | |
| from einops import rearrange | |
| import numpy as np | |
| import torch | |
| import os | |
| import shutil | |
| import cv2 | |
| # from util.image import scale_image, random_tight_crop, tight_crop_image,random_tight_crop_imgonly | |
| # from util.load import load_array, load_image,load_npz | |
| # from util.mapping import scale_map, tight_crop_map, tight_crop_map_docaligner | |
| # from inv3d_util.mask import scale_mask, tight_crop_mask | |
| from torchvision.utils import save_image as tv_save_image | |
| import torch.nn.functional as F | |
| def training_init(args): | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| torch.distributed.init_process_group('nccl', init_method='env://') | |
| device = torch.device(f'cuda:{local_rank}') | |
| torch.cuda.manual_seed_all(40) | |
| # create model | |
| torch.cuda.set_device(local_rank) | |
| torch.cuda.empty_cache() | |
| return local_rank | |
| IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') | |
| def has_file_allowed_extension(filename, extensions): | |
| return filename.lower().endswith(extensions) | |
| def is_image_file(filename): | |
| return has_file_allowed_extension(filename, IMG_EXTENSIONS) | |
| def pil_loader(path): | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('RGB') | |
| def pil_loader_withHW(path): | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| width, height = img.size | |
| return img.convert('RGB'), height, width | |
| def cv2_loader_withHW(path): | |
| # with open(path, 'rb') as f: | |
| img = cv2.imread(path,flags=cv2.IMREAD_COLOR) | |
| img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) | |
| height, width, channels = img.shape | |
| return img, height, width | |
| def collate_batch(batch_list): | |
| data1 = [item[0] for item in batch_list] | |
| data2 = [item[1] for item in batch_list] | |
| labels = [item[2] for item in batch_list] | |
| return data1, data2, labels | |
| # data1 = [item[0] for item in batch_list] | |
| # labels = [item[1] for item in batch_list] | |
| # return data1, labels | |
| # # features = torch.stack([sample[0] for sample in batch]) | |
| # # labels = torch.stack([sample[1] for sample in batch]) | |
| # # return features,labels | |
| def select_max_region(mask): | |
| nums, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) | |
| background = 0 | |
| for row in range(stats.shape[0]): | |
| if stats[row, :][0] == 0 and stats[row, :][1] == 0: | |
| background = row | |
| stats_no_bg = np.delete(stats, background, axis=0) | |
| max_idx = stats_no_bg[:, 4].argmax() | |
| max_region = np.where(labels==max_idx+1, 1, 0) | |
| return max_region | |
| def docreg_bm_norm(file: str, resolution: Optional[int]): | |
| file = Path(file) | |
| bm = np.load(file)[file.name[:-4]][0] # (4032, 3024, 2) numpy | |
| bm = ((bm+1)/2)*resolution # (0-288) | |
| bm = scale_map(bm, resolution) # (288, 288, 2) | |
| bm = rearrange(bm, "h w c -> c h w") # (2, 288, 288) | |
| bm = torch.from_numpy(bm).float() # *resolution # tensor 0-288 | |
| return bm | |
| def prepare_image( | |
| image_file: Path, mask_file: Path, color_jitter: bool, **scale_settings | |
| ): | |
| image = load_image(image_file) # (1770, 1327, 3) | |
| # mask = load_image(mask_file)[..., :1] # (1770, 1327, 1) | |
| mask = load_npz(mask_file)[..., :1].astype(np.uint8) | |
| mask = select_max_region(mask) | |
| H,W,_ = image.shape | |
| # test point | |
| # image = scale_image(image, **scale_settings) # (288, 288, 3) | |
| # image = rearrange(image, "h w c -> c h w") | |
| # image = image.astype("float32") / 255 | |
| # image = torch.from_numpy(image) | |
| # flow = load_array(flow_file) # (1024, 1024, 2) | |
| # flow = rearrange(flow, "h w c -> c h w") | |
| # flow = torch.from_numpy(flow).float() | |
| # B2A = spatial_trans(image, flow[None], 0) | |
| # tv_save_image(image, "backup/test/ori.png") | |
| # tv_save_image(B2A[0], "backup/test/ttt.png") | |
| # assert image.shape[0]>0,print("input",image.shape) | |
| # assert image.shape[1]>0,print("input",image.shape) | |
| # image = tight_crop_image(image, mask.squeeze()) # (349, 245, 3) | |
| image,t,b,l,r = random_tight_crop_imgonly(mask,image,H,W) | |
| assert image.shape[0]>0, print("crop",image.shape) | |
| assert image.shape[1]>0, print("crop",image.shape) | |
| image = scale_image(image, **scale_settings) # (288, 288, 3) | |
| # image = transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)(image) if color_jitter else image | |
| image = rearrange(image, "h w c -> c h w") | |
| image = image.astype("float32") / 255 | |
| image = torch.from_numpy(image) | |
| # recon = scale_image(recon, **scale_settings) # (288, 288, 3) | |
| # img_pil = Image.fromarray(image) | |
| # image = transfrom(img_pil) | |
| # t,b,l,r = None,None,None,None | |
| return image,t,b,l,r,H,W | |
| def prepare_masked_image( | |
| image_file: Path, recon_file: Path, transfrom, uv_file: Path, color_jitter: bool, **scale_settings | |
| ): | |
| mask = load_array(uv_file)[..., :1] # [448,448,1] | |
| image = load_image(image_file) # (448, 448, 3) | |
| recon = load_image(recon_file) # (448, 448, 3) | |
| # assert image.shape[0]>0,print("input",image.shape) | |
| # assert image.shape[1]>0,print("input",image.shape) | |
| # image = tight_crop_image(image, mask.squeeze()) # (349, 245, 3) | |
| image,recon,t,b,l,r = random_tight_crop(mask,image,recon) | |
| assert image.shape[0]>0, print("crop",image.shape) | |
| assert image.shape[1]>0, print("crop",image.shape) | |
| # image = scale_image(image, **scale_settings) # (288, 288, 3) | |
| # recon = scale_image(recon, **scale_settings) # (288, 288, 3) | |
| # image = transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)(image) if color_jitter else image | |
| img_pil = Image.fromarray(image) | |
| recon_pil = Image.fromarray(recon) | |
| image = transfrom(img_pil) | |
| recon = transfrom(recon_pil) | |
| # image = rearrange(image, "h w c -> c h w") | |
| # image = image.astype("float32") / 255 | |
| # image = torch.from_numpy(image) | |
| # recon = rearrange(recon, "h w c -> c h w") | |
| # recon = recon.astype("float32") / 255 | |
| # recon = torch.from_numpy(recon) | |
| return image,recon,t,b,l,r | |
| def prepare_bm_inv3d(file: Path, resolution: Optional[int], t,b,l,r): | |
| file = Path(file) | |
| assert file.suffix in [".npz", ".mat", ".npy"] | |
| bm = load_array(file).astype("float32")*1600.0 # (512, 512, 2) 0-1600 | |
| bm = tight_crop_map(bm,t,b,l,r) # (512, 512, 2) crop and 0-1 | |
| bm = scale_map(bm, resolution) # (288, 288, 2) | |
| bm = np.roll(bm, shift=1, axis=-1)# εεε xεy | |
| bm = rearrange(bm, "h w c -> c h w") | |
| bm = torch.from_numpy(bm).float()*resolution # 0-288 | |
| # bm=(bm-0.5)*2 | |
| return bm | |
| def prepare_bm_docregis(file: Path, resolution: Optional[int], t,b,l,r, H,W): | |
| file = Path(file) | |
| assert file.suffix in [".npz", ".mat", ".npy"] | |
| bm = load_array(file).astype("float32") # numpy (512, 512, 2) (0,1) | |
| # bm = (bm+1)/2 # (0,1) | |
| bm[...,0] *= H | |
| bm[...,1] *= W | |
| bm = tight_crop_map_docaligner(bm,t,b,l,r,H,W) # (512, 512, 2) crop and 0-1 | |
| bm = scale_map(bm, resolution) # (288, 288, 2) | |
| bm = np.roll(bm, shift=1, axis=-1)# εεε xεy | |
| bm = rearrange(bm, "h w c -> c h w") # (2, 288, 288) | |
| bm = torch.from_numpy(bm).float()*resolution # 0-288 | |
| # bm=(bm-0.5)*2 | |
| return bm | |
| class Averager(object): | |
| """Compute average for torch.Tensor, used for loss average.""" | |
| def __init__(self): | |
| self.reset() | |
| def add(self, v): | |
| count = v.data.numel() | |
| v = v.data.sum() | |
| self.n_count += count | |
| self.sum += v | |
| def reset(self): | |
| self.n_count = 0 | |
| self.sum = 0 | |
| def val(self): | |
| res = 0 | |
| if self.n_count != 0: | |
| res = self.sum / float(self.n_count) | |
| return res | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.n_count += n | |
| self.avg = self.sum / self.n_count | |
| # class AverageMeter(object): | |
| # """Computes and stores the average and current value""" | |
| # def __init__(self): | |
| # self.reset() | |
| # def reset(self): | |
| # self.val = 0 | |
| # self.avg = 0 | |
| # self.sum = 0 | |
| # self.count = 0 | |
| # def update(self, val, n=1): | |
| # self.val = val | |
| # self.sum += val * n | |
| # self.count += n | |
| # self.avg = self.sum / self.count | |
| def save_checkpoint(state, is_best, epoch, checkpoint_name, filename='checkpoint.pth.tar'): | |
| if os.path.exists("checkpoints/{}".format(checkpoint_name)) is False: | |
| os.makedirs("checkpoints/{}".format(checkpoint_name), exist_ok=True) | |
| torch.save(state, 'checkpoints/{}/'.format(checkpoint_name) + filename + '_latest.pth.tar') | |
| if epoch%10 == 0: | |
| # os.rename('checkpoint/' + filename + '_latest.pth.tar', 'checkpoint/' + filename + '_%d.pth.tar' % (epoch)) | |
| shutil.copyfile('checkpoints/{}/'.format(checkpoint_name) + filename + '_latest.pth.tar', 'checkpoints/{}/'.format(checkpoint_name) + filename + '_%d.pth.tar' % (epoch)) | |
| if is_best: | |
| shutil.copyfile('checkpoints/{}/'.format(checkpoint_name) + filename + '_latest.pth.tar', 'checkpoints/{}/'.format(checkpoint_name) + filename + '_best.pth.tar') | |
| class EarlyStopping: | |
| def __init__(self, patience=7, verbose=False, delta=0): | |
| self.patience = patience | |
| self.verbose = verbose | |
| self.counter = 0 | |
| self.best_score = None | |
| self.early_stop = False | |
| self.val_loss_min = np.Inf | |
| self.delta = delta | |
| def __call__(self, val_loss, model, path, ret=None, opt=None): | |
| score = -val_loss | |
| if self.best_score is None: | |
| self.best_score = score | |
| self.save_checkpoint(val_loss, model, path,ret,opt) | |
| elif score < self.best_score + self.delta: | |
| self.counter += 1 | |
| print(f'EarlyStopping counter: {self.counter} out of {self.patience}') | |
| if self.counter >= self.patience: | |
| self.early_stop = True | |
| else: | |
| self.best_score = score | |
| self.save_checkpoint(val_loss, model, path,ret,opt) | |
| self.counter = 0 | |
| def save_checkpoint(self, val_loss, model, path,ret,opt): | |
| if self.verbose: | |
| print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') | |
| torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') | |
| self.val_loss_min = val_loss | |