DvD / datasets /batch_processing.py
hanquansanren's picture
Add application file
05fb4ab
import torch
import torch.nn.functional as F
import torch.utils.data
from packaging import version
def pre_process_image_glunet(source_img, device, mean_vector=[0.485, 0.456, 0.406],
std_vector=[0.229, 0.224, 0.225]):
"""
Image is in range [0, 255}. Creates image at 256x256, and applies imagenet weights to both.
Args:
source_img: torch tensor, bx3xHxW in range [0, 255], not normalized yet
device:
mean_vector:
std_vector:
Returns:
image at original and 256x256 resolution
"""
# img has shape bx3xhxw
b, _, h_scale, w_scale = source_img.shape
source_img_copy = source_img.float().to(device).div(255.0)
# mean = torch.as_tensor(mean_vector, dtype=source_img_copy.dtype, device=source_img_copy.device)
# std = torch.as_tensor(std_vector, dtype=source_img_copy.dtype, device=source_img_copy.device)
# source_img_copy.sub_(mean[:, None, None]).div_(std[:, None, None])
# resolution 256x256
source_img_256 = torch.nn.functional.interpolate(input=source_img.float().to(device).div(255.0), size=(256, 256), mode='area')
# source_img_256 = source_img_256.float().div(255.0)
# source_img_256.sub_(mean[:, None, None]).div_(std[:, None, None])
return source_img_copy.to(device), source_img_256.to(device)
class CATsBatchPreprocessing:
""" Class responsible for processing the mini-batch to create the desired training inputs for GLU-Net based networks.
Particularly, from the source and target images at original resolution as well as the corresponding ground-truth
flow field, needs to create the source, target and flow at resolution 256x256 for training the L-Net.
"""
def __init__(self, settings, apply_mask=False, apply_mask_zero_borders=False, sparse_ground_truth=False,
mapping=False):
"""
Args:
settings: settings
apply_mask: apply ground-truth correspondence mask for loss computation?
apply_mask_zero_borders: apply mask zero borders (equal to 0 at black borders in target image) for loss
computation?
sparse_ground_truth: is ground-truth sparse? Important for downscaling/upscaling of the flow field
mapping: load correspondence map instead of flow field?
"""
self.apply_mask = apply_mask
self.apply_mask_zero_borders = apply_mask_zero_borders
self.sparse_ground_truth = sparse_ground_truth
self.device = getattr(settings, 'device', None)
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.mapping = mapping
def __call__(self, mini_batch, *args, **kwargs):
"""
args:
mini_batch: The input data, should contain the following fields:
'source_image', 'target_image', 'correspondence_mask'
'flow_map' if self.mapping is False else 'correspondence_map_pyro'
'mask_zero_borders' if self.apply_mask_zero_borders
returns:
TensorDict: output data block with following fields:
'source_image', 'target_image', 'source_image_256', 'target_image_256', flow_map',
'flow_map_256', 'mask', 'mask_256', 'correspondence_mask'
"""
source_image, source_image_256 = pre_process_image_glunet(mini_batch['source_image'], self.device)
target_image, target_image_256 = pre_process_image_glunet(mini_batch['target_image'], self.device)
# At original resolution
if self.sparse_ground_truth:
flow_gt_original = mini_batch['flow_map'][0].to(self.device)
flow_gt_256 = mini_batch['flow_map'][1].to(self.device)
if flow_gt_original.shape[1] != 2:
# shape is bxhxwx2
flow_gt_original = flow_gt_original.permute(0, 3, 1, 2)
if flow_gt_256.shape[1] != 2:
# shape is bxhxwx2
flow_gt_256 = flow_gt_256.permute(0, 3, 1, 2)
else:
if self.mapping:
mapping_gt_original = mini_batch['correspondence_map_pyro'].to(self.device)
# flow_gt_original = unormalise_and_convert_mapping_to_flow(mapping_gt_original.permute(0,3,1,2))
else:
flow_gt_original = mini_batch['flow_map'].to(self.device)
if flow_gt_original.shape[1] != 2:
# shape is bxhxwx2
flow_gt_original = flow_gt_original.permute(0, 3, 1, 2)
bs, _, h_original, w_original = flow_gt_original.shape
# now we have flow everywhere, at 256x256 resolution, b, _, 256, 256
flow_gt_256 = F.interpolate(flow_gt_original, (256, 256), mode='bilinear', align_corners=False)
flow_gt_256[:, 0, :, :] *= 256.0/float(w_original)
flow_gt_256[:, 1, :, :] *= 256.0/float(h_original)
bs, _, h_original, w_original = flow_gt_original.shape
bs, _, h_256, w_256 = flow_gt_256.shape
# defines the mask to use during training
mask = None
mask_256 = None
if self.apply_mask_zero_borders:
# make mask to remove all black borders
mask = mini_batch['mask_zero_borders'].to(self.device) # bxhxw, torch.uint8
mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear',
align_corners=False).squeeze(1).byte() # bx256x256, rounding
mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte()
elif self.apply_mask:
if self.sparse_ground_truth:
mask = mini_batch['correspondence_mask'][0].to(self.device)
mask_256 = mini_batch['correspondence_mask'][1].to(self.device)
else:
mask = mini_batch['correspondence_mask'].to(self.device) # bxhxw, torch.uint8
mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear',
align_corners=False).squeeze(1).byte() # bx256x256, rounding
mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte()
mini_batch['source_image'] = source_image
mini_batch['target_image'] = target_image
mini_batch['source_image_256'] = source_image_256
mini_batch['target_image_256'] = target_image_256
mini_batch['flow_map'] = flow_gt_original
mini_batch['flow_map_256'] = flow_gt_256
mini_batch['mask'] = mask
mini_batch['mask_256'] = mask_256
if self.sparse_ground_truth:
mini_batch['correspondence_mask'][0] = mini_batch['correspondence_mask'][0].to(self.device)
mini_batch['correspondence_mask'][1] = mini_batch['correspondence_mask'][1].to(self.device)
else:
mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device)
return mini_batch
class GLUNetBatchPreprocessing:
""" Class responsible for processing the mini-batch to create the desired training inputs for GLU-Net based networks.
Particularly, from the source and target images at original resolution as well as the corresponding ground-truth
flow field, needs to create the source, target and flow at resolution 256x256 for training the L-Net.
"""
def __init__(self, settings, apply_mask=False, apply_mask_zero_borders=False, sparse_ground_truth=False,
mapping=False, megadepth=False):
"""
Args:
settings: settings
apply_mask: apply ground-truth correspondence mask for loss computation?
apply_mask_zero_borders: apply mask zero borders (equal to 0 at black borders in target image) for loss
computation?
sparse_ground_truth: is ground-truth sparse? Important for downscaling/upscaling of the flow field
mapping: load correspondence map instead of flow field?
"""
self.apply_mask = apply_mask
self.apply_mask_zero_borders = apply_mask_zero_borders
self.sparse_ground_truth = sparse_ground_truth
self.megadepth = megadepth
self.device = getattr(settings, 'device', None)
# if self.device is None:
# self.device = torch.device("cuda" if torch.cuda.is_available() and settings.use_gpu else "cpu")
print("Using device: {}".format(self.device))
self.mapping = mapping
def __call__(self, mini_batch, *args, **kwargs):
"""
args:
mini_batch: The input data, should contain the following fields:
'source_image', 'target_image', 'correspondence_mask'
'flow_map' if self.mapping is False else 'correspondence_map_pyro'
'mask_zero_borders' if self.apply_mask_zero_borders
returns:
TensorDict: output data block with following fields:
'source_image', 'target_image', 'source_image_256', 'target_image_256', flow_map',
'flow_map_256', 'mask', 'mask_256', 'correspondence_mask'
"""
source_image, source_image_256 = pre_process_image_glunet(mini_batch['source_image'], self.device)
target_image, target_image_256 = pre_process_image_glunet(mini_batch['target_image'], self.device)
# At original resolution
if self.sparse_ground_truth:
flow_gt_original = mini_batch['flow_map'][0].to(self.device)
flow_gt_256 = mini_batch['flow_map'][1].to(self.device)
if flow_gt_original.shape[1] != 2:
# shape is bxhxwx2
flow_gt_original = flow_gt_original.permute(0, 3, 1, 2)
if flow_gt_256.shape[1] != 2:
# shape is bxhxwx2
flow_gt_256 = flow_gt_256.permute(0, 3, 1, 2)
else:
if self.mapping:
mapping_gt_original = mini_batch['correspondence_map_pyro'][-1].to(self.device)
# flow_gt_original = unormalise_and_convert_mapping_to_flow(mapping_gt_original.permute(0,3,1,2))
else:
if self.megadepth:
flow_gt_original = mini_batch['flow_map'][0].to(self.device)
else:
flow_gt_original = mini_batch['flow_map'].to(self.device)
src_vis = mini_batch['source_image']
trg_vis = mini_batch['target_image']
flow_vis = mini_batch['flow_map'][0]
if flow_gt_original.shape[1] != 2:
# shape is bxhxwx2
flow_gt_original = flow_gt_original.permute(0, 3, 1, 2)
bs, _, h_original, w_original = flow_gt_original.shape
# now we have flow everywhere, at 256x256 resolution, b, _, 256, 256
flow_gt_256 = F.interpolate(flow_gt_original, (256, 256), mode='bilinear', align_corners=False)
flow_gt_256[:, 0, :, :] *= 256.0/float(w_original)
flow_gt_256[:, 1, :, :] *= 256.0/float(h_original)
bs, _, h_original, w_original = flow_gt_original.shape
bs, _, h_256, w_256 = flow_gt_256.shape
# defines the mask to use during training
mask = None
mask_256 = None
if self.apply_mask_zero_borders:
# make mask to remove all black borders
mask = mini_batch['mask_zero_borders'].to(self.device) # bxhxw, torch.uint8
mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear',
align_corners=False).squeeze(1).byte() # bx256x256, rounding
mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte()
elif self.apply_mask:
if self.sparse_ground_truth:
mask = mini_batch['correspondence_mask'][0].to(self.device)
mask_256 = mini_batch['correspondence_mask'][1].to(self.device)
else:
if self.megadepth:
mask = mini_batch['correspondence_mask'][0].to(self.device) # bxhxw, torch.uint8
else:
mask = mini_batch['correspondence_mask'].to(self.device) # bxhxw, torch.uint8
mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear',
align_corners=False).squeeze(1).byte() # bx256x256, rounding
mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte()
mini_batch['source_image'] = source_image
mini_batch['target_image'] = target_image
mini_batch['source_image_256'] = source_image_256
mini_batch['target_image_256'] = target_image_256
mini_batch['flow_map'] = flow_gt_original
mini_batch['flow_map_256'] = flow_gt_256
mini_batch['mask'] = mask
mini_batch['mask_256'] = mask_256
if self.sparse_ground_truth:
if self.megadepth:
mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'][0].to(self.device)
else:
mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device)
else:
if self.megadepth:
mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'][0].to(self.device)
else:
mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device)
return mini_batch
class DocBatchPreprocessing:
""" Class responsible for processing the mini-batch to create the desired training inputs for GLU-Net based networks.
Particularly, from the source and target images at original resolution as well as the corresponding ground-truth
flow field, needs to create the source, target and flow at resolution 256x256 for training the L-Net.
"""
def __init__(self, settings, apply_mask=False, apply_mask_zero_borders=False, sparse_ground_truth=False,
mapping=False, megadepth=False):
"""
Args:
settings: settings
apply_mask: apply ground-truth correspondence mask for loss computation?
apply_mask_zero_borders: apply mask zero borders (equal to 0 at black borders in target image) for loss
computation?
sparse_ground_truth: is ground-truth sparse? Important for downscaling/upscaling of the flow field
mapping: load correspondence map instead of flow field?
"""
self.apply_mask = apply_mask
self.apply_mask_zero_borders = apply_mask_zero_borders
self.sparse_ground_truth = sparse_ground_truth
self.megadepth = megadepth
self.device = getattr(settings, 'device', None)
# if self.device is None:
# self.device = torch.device("cuda" if torch.cuda.is_available() and settings.use_gpu else "cpu")
print("Using device: {}".format(self.device))
self.mapping = mapping
def __call__(self, mini_batch, *args, **kwargs):
"""
args:
mini_batch: The input data, should contain the following fields:
'source_image', 'target_image', 'correspondence_mask'
'flow_map' if self.mapping is False else 'correspondence_map_pyro'
'mask_zero_borders' if self.apply_mask_zero_borders
returns:
TensorDict: output data block with following fields:
'source_image', 'target_image', 'source_image_256', 'target_image_256', flow_map',
'flow_map_256', 'mask', 'mask_256', 'correspondence_mask'
"""
source_image, source_image_256 = pre_process_image_glunet(mini_batch['source_image'], self.device) # [24, 3, 512, 512],[24, 3, 256, 256])
# target_image, target_image_256 = pre_process_image_glunet(mini_batch['target_image'], self.device)
target_image = None
# At original resolution
if self.sparse_ground_truth: # false
flow_gt_original = mini_batch['flow_map'][0].to(self.device)
flow_gt_256 = mini_batch['flow_map'][1].to(self.device)
if flow_gt_original.shape[1] != 2:
# shape is bxhxwx2
flow_gt_original = flow_gt_original.permute(0, 3, 1, 2)
if flow_gt_256.shape[1] != 2:
# shape is bxhxwx2
flow_gt_256 = flow_gt_256.permute(0, 3, 1, 2)
else: # true
if self.mapping: # false
mapping_gt_original = mini_batch['correspondence_map_pyro'][-1].to(self.device)
# flow_gt_original = unormalise_and_convert_mapping_to_flow(mapping_gt_original.permute(0,3,1,2))
else: # true
if self.megadepth: # false
flow_gt_original = mini_batch['flow_map'][0].to(self.device)
elif 'flow_map' in mini_batch:# true
flow_gt_original = mini_batch['flow_map'].to(self.device) # [24, 2, 512, 512]
else:
mini_batch['source_image'] = source_image
mini_batch['source_image_256'] = source_image_256
return mini_batch
# src_vis = mini_batch['source_image']
# trg_vis = mini_batch['target_image']
# flow_vis = mini_batch['flow_map'][0]
if flow_gt_original.shape[1] != 2:
# shape is bxhxwx2
flow_gt_original = flow_gt_original.permute(0, 3, 1, 2)
bs, _, h_original, w_original = flow_gt_original.shape
# now we have flow everywhere, at 256x256 resolution, b, _, 256, 256
flow_gt_256 = F.interpolate(flow_gt_original, (256, 256), mode='bilinear', align_corners=False)
flow_gt_256[:, 0, :, :] *= 256.0/float(w_original)
flow_gt_256[:, 1, :, :] *= 256.0/float(h_original)
bs, _, h_original, w_original = flow_gt_original.shape
bs, _, h_256, w_256 = flow_gt_256.shape
# defines the mask to use during training
mask = None
mask_256 = None
if self.apply_mask_zero_borders: # false
# make mask to remove all black borders
mask = mini_batch['mask_zero_borders'].to(self.device) # bxhxw, torch.uint8
mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear',
align_corners=False).squeeze(1).byte() # bx256x256, rounding
mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte()
elif self.apply_mask: # false
if self.sparse_ground_truth:
mask = mini_batch['correspondence_mask'][0].to(self.device)
mask_256 = mini_batch['correspondence_mask'][1].to(self.device)
else:
if self.megadepth:
mask = mini_batch['correspondence_mask'][0].to(self.device) # bxhxw, torch.uint8
else:
mask = mini_batch['correspondence_mask'].to(self.device) # bxhxw, torch.uint8
mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear',
align_corners=False).squeeze(1).byte() # bx256x256, rounding
mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte()
mini_batch['source_image'] = source_image
if 'target_image' in mini_batch:
mini_batch['target_image'] = target_image
else:
mini_batch['target_image'] = None
mini_batch['source_image_256'] = source_image_256
mini_batch['target_image_256'] = None
mini_batch['flow_map'] = flow_gt_original
mini_batch['flow_map_256'] = flow_gt_256
mini_batch['mask'] = mask
mini_batch['mask_256'] = mask_256
if self.sparse_ground_truth: # false
if self.megadepth: # false
mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'][0].to(self.device)
else:
mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device)
else: # true
if self.megadepth: # false
mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'][0].to(self.device)
else: # true
pass
# mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device)
return mini_batch