Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.data | |
| from packaging import version | |
| def pre_process_image_glunet(source_img, device, mean_vector=[0.485, 0.456, 0.406], | |
| std_vector=[0.229, 0.224, 0.225]): | |
| """ | |
| Image is in range [0, 255}. Creates image at 256x256, and applies imagenet weights to both. | |
| Args: | |
| source_img: torch tensor, bx3xHxW in range [0, 255], not normalized yet | |
| device: | |
| mean_vector: | |
| std_vector: | |
| Returns: | |
| image at original and 256x256 resolution | |
| """ | |
| # img has shape bx3xhxw | |
| b, _, h_scale, w_scale = source_img.shape | |
| source_img_copy = source_img.float().to(device).div(255.0) | |
| # mean = torch.as_tensor(mean_vector, dtype=source_img_copy.dtype, device=source_img_copy.device) | |
| # std = torch.as_tensor(std_vector, dtype=source_img_copy.dtype, device=source_img_copy.device) | |
| # source_img_copy.sub_(mean[:, None, None]).div_(std[:, None, None]) | |
| # resolution 256x256 | |
| source_img_256 = torch.nn.functional.interpolate(input=source_img.float().to(device).div(255.0), size=(256, 256), mode='area') | |
| # source_img_256 = source_img_256.float().div(255.0) | |
| # source_img_256.sub_(mean[:, None, None]).div_(std[:, None, None]) | |
| return source_img_copy.to(device), source_img_256.to(device) | |
| class CATsBatchPreprocessing: | |
| """ Class responsible for processing the mini-batch to create the desired training inputs for GLU-Net based networks. | |
| Particularly, from the source and target images at original resolution as well as the corresponding ground-truth | |
| flow field, needs to create the source, target and flow at resolution 256x256 for training the L-Net. | |
| """ | |
| def __init__(self, settings, apply_mask=False, apply_mask_zero_borders=False, sparse_ground_truth=False, | |
| mapping=False): | |
| """ | |
| Args: | |
| settings: settings | |
| apply_mask: apply ground-truth correspondence mask for loss computation? | |
| apply_mask_zero_borders: apply mask zero borders (equal to 0 at black borders in target image) for loss | |
| computation? | |
| sparse_ground_truth: is ground-truth sparse? Important for downscaling/upscaling of the flow field | |
| mapping: load correspondence map instead of flow field? | |
| """ | |
| self.apply_mask = apply_mask | |
| self.apply_mask_zero_borders = apply_mask_zero_borders | |
| self.sparse_ground_truth = sparse_ground_truth | |
| self.device = getattr(settings, 'device', None) | |
| if self.device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.mapping = mapping | |
| def __call__(self, mini_batch, *args, **kwargs): | |
| """ | |
| args: | |
| mini_batch: The input data, should contain the following fields: | |
| 'source_image', 'target_image', 'correspondence_mask' | |
| 'flow_map' if self.mapping is False else 'correspondence_map_pyro' | |
| 'mask_zero_borders' if self.apply_mask_zero_borders | |
| returns: | |
| TensorDict: output data block with following fields: | |
| 'source_image', 'target_image', 'source_image_256', 'target_image_256', flow_map', | |
| 'flow_map_256', 'mask', 'mask_256', 'correspondence_mask' | |
| """ | |
| source_image, source_image_256 = pre_process_image_glunet(mini_batch['source_image'], self.device) | |
| target_image, target_image_256 = pre_process_image_glunet(mini_batch['target_image'], self.device) | |
| # At original resolution | |
| if self.sparse_ground_truth: | |
| flow_gt_original = mini_batch['flow_map'][0].to(self.device) | |
| flow_gt_256 = mini_batch['flow_map'][1].to(self.device) | |
| if flow_gt_original.shape[1] != 2: | |
| # shape is bxhxwx2 | |
| flow_gt_original = flow_gt_original.permute(0, 3, 1, 2) | |
| if flow_gt_256.shape[1] != 2: | |
| # shape is bxhxwx2 | |
| flow_gt_256 = flow_gt_256.permute(0, 3, 1, 2) | |
| else: | |
| if self.mapping: | |
| mapping_gt_original = mini_batch['correspondence_map_pyro'].to(self.device) | |
| # flow_gt_original = unormalise_and_convert_mapping_to_flow(mapping_gt_original.permute(0,3,1,2)) | |
| else: | |
| flow_gt_original = mini_batch['flow_map'].to(self.device) | |
| if flow_gt_original.shape[1] != 2: | |
| # shape is bxhxwx2 | |
| flow_gt_original = flow_gt_original.permute(0, 3, 1, 2) | |
| bs, _, h_original, w_original = flow_gt_original.shape | |
| # now we have flow everywhere, at 256x256 resolution, b, _, 256, 256 | |
| flow_gt_256 = F.interpolate(flow_gt_original, (256, 256), mode='bilinear', align_corners=False) | |
| flow_gt_256[:, 0, :, :] *= 256.0/float(w_original) | |
| flow_gt_256[:, 1, :, :] *= 256.0/float(h_original) | |
| bs, _, h_original, w_original = flow_gt_original.shape | |
| bs, _, h_256, w_256 = flow_gt_256.shape | |
| # defines the mask to use during training | |
| mask = None | |
| mask_256 = None | |
| if self.apply_mask_zero_borders: | |
| # make mask to remove all black borders | |
| mask = mini_batch['mask_zero_borders'].to(self.device) # bxhxw, torch.uint8 | |
| mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear', | |
| align_corners=False).squeeze(1).byte() # bx256x256, rounding | |
| mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte() | |
| elif self.apply_mask: | |
| if self.sparse_ground_truth: | |
| mask = mini_batch['correspondence_mask'][0].to(self.device) | |
| mask_256 = mini_batch['correspondence_mask'][1].to(self.device) | |
| else: | |
| mask = mini_batch['correspondence_mask'].to(self.device) # bxhxw, torch.uint8 | |
| mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear', | |
| align_corners=False).squeeze(1).byte() # bx256x256, rounding | |
| mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte() | |
| mini_batch['source_image'] = source_image | |
| mini_batch['target_image'] = target_image | |
| mini_batch['source_image_256'] = source_image_256 | |
| mini_batch['target_image_256'] = target_image_256 | |
| mini_batch['flow_map'] = flow_gt_original | |
| mini_batch['flow_map_256'] = flow_gt_256 | |
| mini_batch['mask'] = mask | |
| mini_batch['mask_256'] = mask_256 | |
| if self.sparse_ground_truth: | |
| mini_batch['correspondence_mask'][0] = mini_batch['correspondence_mask'][0].to(self.device) | |
| mini_batch['correspondence_mask'][1] = mini_batch['correspondence_mask'][1].to(self.device) | |
| else: | |
| mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device) | |
| return mini_batch | |
| class GLUNetBatchPreprocessing: | |
| """ Class responsible for processing the mini-batch to create the desired training inputs for GLU-Net based networks. | |
| Particularly, from the source and target images at original resolution as well as the corresponding ground-truth | |
| flow field, needs to create the source, target and flow at resolution 256x256 for training the L-Net. | |
| """ | |
| def __init__(self, settings, apply_mask=False, apply_mask_zero_borders=False, sparse_ground_truth=False, | |
| mapping=False, megadepth=False): | |
| """ | |
| Args: | |
| settings: settings | |
| apply_mask: apply ground-truth correspondence mask for loss computation? | |
| apply_mask_zero_borders: apply mask zero borders (equal to 0 at black borders in target image) for loss | |
| computation? | |
| sparse_ground_truth: is ground-truth sparse? Important for downscaling/upscaling of the flow field | |
| mapping: load correspondence map instead of flow field? | |
| """ | |
| self.apply_mask = apply_mask | |
| self.apply_mask_zero_borders = apply_mask_zero_borders | |
| self.sparse_ground_truth = sparse_ground_truth | |
| self.megadepth = megadepth | |
| self.device = getattr(settings, 'device', None) | |
| # if self.device is None: | |
| # self.device = torch.device("cuda" if torch.cuda.is_available() and settings.use_gpu else "cpu") | |
| print("Using device: {}".format(self.device)) | |
| self.mapping = mapping | |
| def __call__(self, mini_batch, *args, **kwargs): | |
| """ | |
| args: | |
| mini_batch: The input data, should contain the following fields: | |
| 'source_image', 'target_image', 'correspondence_mask' | |
| 'flow_map' if self.mapping is False else 'correspondence_map_pyro' | |
| 'mask_zero_borders' if self.apply_mask_zero_borders | |
| returns: | |
| TensorDict: output data block with following fields: | |
| 'source_image', 'target_image', 'source_image_256', 'target_image_256', flow_map', | |
| 'flow_map_256', 'mask', 'mask_256', 'correspondence_mask' | |
| """ | |
| source_image, source_image_256 = pre_process_image_glunet(mini_batch['source_image'], self.device) | |
| target_image, target_image_256 = pre_process_image_glunet(mini_batch['target_image'], self.device) | |
| # At original resolution | |
| if self.sparse_ground_truth: | |
| flow_gt_original = mini_batch['flow_map'][0].to(self.device) | |
| flow_gt_256 = mini_batch['flow_map'][1].to(self.device) | |
| if flow_gt_original.shape[1] != 2: | |
| # shape is bxhxwx2 | |
| flow_gt_original = flow_gt_original.permute(0, 3, 1, 2) | |
| if flow_gt_256.shape[1] != 2: | |
| # shape is bxhxwx2 | |
| flow_gt_256 = flow_gt_256.permute(0, 3, 1, 2) | |
| else: | |
| if self.mapping: | |
| mapping_gt_original = mini_batch['correspondence_map_pyro'][-1].to(self.device) | |
| # flow_gt_original = unormalise_and_convert_mapping_to_flow(mapping_gt_original.permute(0,3,1,2)) | |
| else: | |
| if self.megadepth: | |
| flow_gt_original = mini_batch['flow_map'][0].to(self.device) | |
| else: | |
| flow_gt_original = mini_batch['flow_map'].to(self.device) | |
| src_vis = mini_batch['source_image'] | |
| trg_vis = mini_batch['target_image'] | |
| flow_vis = mini_batch['flow_map'][0] | |
| if flow_gt_original.shape[1] != 2: | |
| # shape is bxhxwx2 | |
| flow_gt_original = flow_gt_original.permute(0, 3, 1, 2) | |
| bs, _, h_original, w_original = flow_gt_original.shape | |
| # now we have flow everywhere, at 256x256 resolution, b, _, 256, 256 | |
| flow_gt_256 = F.interpolate(flow_gt_original, (256, 256), mode='bilinear', align_corners=False) | |
| flow_gt_256[:, 0, :, :] *= 256.0/float(w_original) | |
| flow_gt_256[:, 1, :, :] *= 256.0/float(h_original) | |
| bs, _, h_original, w_original = flow_gt_original.shape | |
| bs, _, h_256, w_256 = flow_gt_256.shape | |
| # defines the mask to use during training | |
| mask = None | |
| mask_256 = None | |
| if self.apply_mask_zero_borders: | |
| # make mask to remove all black borders | |
| mask = mini_batch['mask_zero_borders'].to(self.device) # bxhxw, torch.uint8 | |
| mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear', | |
| align_corners=False).squeeze(1).byte() # bx256x256, rounding | |
| mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte() | |
| elif self.apply_mask: | |
| if self.sparse_ground_truth: | |
| mask = mini_batch['correspondence_mask'][0].to(self.device) | |
| mask_256 = mini_batch['correspondence_mask'][1].to(self.device) | |
| else: | |
| if self.megadepth: | |
| mask = mini_batch['correspondence_mask'][0].to(self.device) # bxhxw, torch.uint8 | |
| else: | |
| mask = mini_batch['correspondence_mask'].to(self.device) # bxhxw, torch.uint8 | |
| mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear', | |
| align_corners=False).squeeze(1).byte() # bx256x256, rounding | |
| mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte() | |
| mini_batch['source_image'] = source_image | |
| mini_batch['target_image'] = target_image | |
| mini_batch['source_image_256'] = source_image_256 | |
| mini_batch['target_image_256'] = target_image_256 | |
| mini_batch['flow_map'] = flow_gt_original | |
| mini_batch['flow_map_256'] = flow_gt_256 | |
| mini_batch['mask'] = mask | |
| mini_batch['mask_256'] = mask_256 | |
| if self.sparse_ground_truth: | |
| if self.megadepth: | |
| mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'][0].to(self.device) | |
| else: | |
| mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device) | |
| else: | |
| if self.megadepth: | |
| mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'][0].to(self.device) | |
| else: | |
| mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device) | |
| return mini_batch | |
| class DocBatchPreprocessing: | |
| """ Class responsible for processing the mini-batch to create the desired training inputs for GLU-Net based networks. | |
| Particularly, from the source and target images at original resolution as well as the corresponding ground-truth | |
| flow field, needs to create the source, target and flow at resolution 256x256 for training the L-Net. | |
| """ | |
| def __init__(self, settings, apply_mask=False, apply_mask_zero_borders=False, sparse_ground_truth=False, | |
| mapping=False, megadepth=False): | |
| """ | |
| Args: | |
| settings: settings | |
| apply_mask: apply ground-truth correspondence mask for loss computation? | |
| apply_mask_zero_borders: apply mask zero borders (equal to 0 at black borders in target image) for loss | |
| computation? | |
| sparse_ground_truth: is ground-truth sparse? Important for downscaling/upscaling of the flow field | |
| mapping: load correspondence map instead of flow field? | |
| """ | |
| self.apply_mask = apply_mask | |
| self.apply_mask_zero_borders = apply_mask_zero_borders | |
| self.sparse_ground_truth = sparse_ground_truth | |
| self.megadepth = megadepth | |
| self.device = getattr(settings, 'device', None) | |
| # if self.device is None: | |
| # self.device = torch.device("cuda" if torch.cuda.is_available() and settings.use_gpu else "cpu") | |
| print("Using device: {}".format(self.device)) | |
| self.mapping = mapping | |
| def __call__(self, mini_batch, *args, **kwargs): | |
| """ | |
| args: | |
| mini_batch: The input data, should contain the following fields: | |
| 'source_image', 'target_image', 'correspondence_mask' | |
| 'flow_map' if self.mapping is False else 'correspondence_map_pyro' | |
| 'mask_zero_borders' if self.apply_mask_zero_borders | |
| returns: | |
| TensorDict: output data block with following fields: | |
| 'source_image', 'target_image', 'source_image_256', 'target_image_256', flow_map', | |
| 'flow_map_256', 'mask', 'mask_256', 'correspondence_mask' | |
| """ | |
| source_image, source_image_256 = pre_process_image_glunet(mini_batch['source_image'], self.device) # [24, 3, 512, 512],[24, 3, 256, 256]) | |
| # target_image, target_image_256 = pre_process_image_glunet(mini_batch['target_image'], self.device) | |
| target_image = None | |
| # At original resolution | |
| if self.sparse_ground_truth: # false | |
| flow_gt_original = mini_batch['flow_map'][0].to(self.device) | |
| flow_gt_256 = mini_batch['flow_map'][1].to(self.device) | |
| if flow_gt_original.shape[1] != 2: | |
| # shape is bxhxwx2 | |
| flow_gt_original = flow_gt_original.permute(0, 3, 1, 2) | |
| if flow_gt_256.shape[1] != 2: | |
| # shape is bxhxwx2 | |
| flow_gt_256 = flow_gt_256.permute(0, 3, 1, 2) | |
| else: # true | |
| if self.mapping: # false | |
| mapping_gt_original = mini_batch['correspondence_map_pyro'][-1].to(self.device) | |
| # flow_gt_original = unormalise_and_convert_mapping_to_flow(mapping_gt_original.permute(0,3,1,2)) | |
| else: # true | |
| if self.megadepth: # false | |
| flow_gt_original = mini_batch['flow_map'][0].to(self.device) | |
| elif 'flow_map' in mini_batch:# true | |
| flow_gt_original = mini_batch['flow_map'].to(self.device) # [24, 2, 512, 512] | |
| else: | |
| mini_batch['source_image'] = source_image | |
| mini_batch['source_image_256'] = source_image_256 | |
| return mini_batch | |
| # src_vis = mini_batch['source_image'] | |
| # trg_vis = mini_batch['target_image'] | |
| # flow_vis = mini_batch['flow_map'][0] | |
| if flow_gt_original.shape[1] != 2: | |
| # shape is bxhxwx2 | |
| flow_gt_original = flow_gt_original.permute(0, 3, 1, 2) | |
| bs, _, h_original, w_original = flow_gt_original.shape | |
| # now we have flow everywhere, at 256x256 resolution, b, _, 256, 256 | |
| flow_gt_256 = F.interpolate(flow_gt_original, (256, 256), mode='bilinear', align_corners=False) | |
| flow_gt_256[:, 0, :, :] *= 256.0/float(w_original) | |
| flow_gt_256[:, 1, :, :] *= 256.0/float(h_original) | |
| bs, _, h_original, w_original = flow_gt_original.shape | |
| bs, _, h_256, w_256 = flow_gt_256.shape | |
| # defines the mask to use during training | |
| mask = None | |
| mask_256 = None | |
| if self.apply_mask_zero_borders: # false | |
| # make mask to remove all black borders | |
| mask = mini_batch['mask_zero_borders'].to(self.device) # bxhxw, torch.uint8 | |
| mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear', | |
| align_corners=False).squeeze(1).byte() # bx256x256, rounding | |
| mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte() | |
| elif self.apply_mask: # false | |
| if self.sparse_ground_truth: | |
| mask = mini_batch['correspondence_mask'][0].to(self.device) | |
| mask_256 = mini_batch['correspondence_mask'][1].to(self.device) | |
| else: | |
| if self.megadepth: | |
| mask = mini_batch['correspondence_mask'][0].to(self.device) # bxhxw, torch.uint8 | |
| else: | |
| mask = mini_batch['correspondence_mask'].to(self.device) # bxhxw, torch.uint8 | |
| mask_256 = F.interpolate(mask.unsqueeze(1).float(), (256, 256), mode='bilinear', | |
| align_corners=False).squeeze(1).byte() # bx256x256, rounding | |
| mask_256 = mask_256.bool() if version.parse(torch.__version__) >= version.parse("1.1") else mask_256.byte() | |
| mini_batch['source_image'] = source_image | |
| if 'target_image' in mini_batch: | |
| mini_batch['target_image'] = target_image | |
| else: | |
| mini_batch['target_image'] = None | |
| mini_batch['source_image_256'] = source_image_256 | |
| mini_batch['target_image_256'] = None | |
| mini_batch['flow_map'] = flow_gt_original | |
| mini_batch['flow_map_256'] = flow_gt_256 | |
| mini_batch['mask'] = mask | |
| mini_batch['mask_256'] = mask_256 | |
| if self.sparse_ground_truth: # false | |
| if self.megadepth: # false | |
| mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'][0].to(self.device) | |
| else: | |
| mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device) | |
| else: # true | |
| if self.megadepth: # false | |
| mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'][0].to(self.device) | |
| else: # true | |
| pass | |
| # mini_batch['correspondence_mask'] = mini_batch['correspondence_mask'].to(self.device) | |
| return mini_batch | |