Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import cv2 | |
| import torch | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from scipy.io import savemat | |
| import os | |
| from torch import nn | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| class register_model2(nn.Module): | |
| def __init__(self, img_size=(64, 1024, 1024), mode='bilinear'): | |
| super(register_model2, self).__init__() | |
| self.spatial_trans = SpatialTransformer2(img_size, mode) | |
| def forward(self, x): | |
| img = x[0] # [1, 3, 512, 512] | |
| flow = x[1] # [1, 2, 512, 512] | |
| out = self.spatial_trans(img, flow) | |
| return out | |
| class SpatialTransformer2(nn.Module): | |
| """ | |
| N-D Spatial Transformer | |
| """ | |
| def __init__(self, size, mode='bilinear'): | |
| super().__init__() | |
| self.mode = mode | |
| # # create sampling grid | |
| # vectors = [torch.arange(0, s) for s in size] | |
| # grids = torch.meshgrid(vectors,indexing="ij") | |
| # grid = torch.stack(grids) # 列序优先 | |
| # grid = torch.unsqueeze(grid, 0) | |
| # grid = grid.type(torch.FloatTensor) | |
| # # registering the grid as a buffer cleanly moves it to the GPU, but it also | |
| # # adds it to the state dict. this is annoying since everything in the state dict | |
| # # is included when saving weights to disk, so the model files are way bigger | |
| # # than they need to be. so far, there does not appear to be an elegant solution. | |
| # # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict | |
| # self.register_buffer('grid', grid) # 先wx后hy,列序优先 | |
| def forward(self, src, flow): | |
| # new locations | |
| # new_locs = self.grid + flow # [1, 2, 256, 256] [b, 2, 256, 256] # 广播机制 | |
| # a = self.grid | |
| # new_locs = flow.clone() # 先hy后wx列序优先 [1, 2, 384, 512] | |
| # new_locs2 = new_locs[:, [1, 0],...] | |
| # shape = flow.shape[2:] # h,w | |
| # shape = shape[::-1] # [512, 384] | |
| # shape = [1024,1024] | |
| # need to normalize grid values to [-1, 1] for resampler | |
| # for i in range(len(shape)): | |
| # flow[:, i, ...] = 2 * (flow[:, i, ...] / (shape[i] - 1) - 0.5) | |
| # move channels dim to last position | |
| # also not sure why, but the channels need to be reversed | |
| # if len(shape) == 2: | |
| flow = flow.permute(0, 2, 3, 1) # [46, 2, 256, 256]->[46, 256, 256, 2] | |
| # new_locs = new_locs[..., [1, 0]] # 改回行序优先 | |
| # elif len(shape) == 3: | |
| # new_locs = new_locs.permute(0, 2, 3, 4, 1) | |
| # new_locs = new_locs[..., [2, 1, 0]] | |
| return F.grid_sample(src, flow, align_corners=True, mode=self.mode, padding_mode="zeros") | |
| class SpatialTransformer(nn.Module): | |
| """ | |
| N-D Spatial Transformer | |
| """ | |
| def __init__(self, size, mode='bilinear'): | |
| super().__init__() | |
| self.mode = mode | |
| # # create sampling grid | |
| # vectors = [torch.arange(0, s) for s in size] | |
| # grids = torch.meshgrid(vectors,indexing="ij") | |
| # grid = torch.stack(grids) # 列序优先 | |
| # grid = torch.unsqueeze(grid, 0) | |
| # grid = grid.type(torch.FloatTensor) | |
| # # registering the grid as a buffer cleanly moves it to the GPU, but it also | |
| # # adds it to the state dict. this is annoying since everything in the state dict | |
| # # is included when saving weights to disk, so the model files are way bigger | |
| # # than they need to be. so far, there does not appear to be an elegant solution. | |
| # # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict | |
| # self.register_buffer('grid', grid) # 先wx后hy,列序优先 | |
| def forward(self, src, flow): | |
| # new locations | |
| # new_locs = self.grid + flow # [1, 2, 256, 256] [b, 2, 256, 256] # 广播机制 | |
| # a = self.grid | |
| new_locs = flow # 先hy后wx列序优先 [1, 2, 384, 512] | |
| # new_locs2 = new_locs[:, [1, 0],...] | |
| shape = flow.shape[2:] # h,w | |
| # shape = shape[::-1] # [512, 384] | |
| # shape = [1024,1024] | |
| # need to normalize grid values to [-1, 1] for resampler | |
| for i in range(len(shape)): | |
| new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) | |
| # move channels dim to last position | |
| # also not sure why, but the channels need to be reversed | |
| if len(shape) == 2: | |
| new_locs = new_locs.permute(0, 2, 3, 1) # [46, 2, 256, 256]->[46, 256, 256, 2] | |
| # new_locs = new_locs[..., [1, 0]] # 改回行序优先 | |
| elif len(shape) == 3: | |
| new_locs = new_locs.permute(0, 2, 3, 4, 1) | |
| new_locs = new_locs[..., [2, 1, 0]] | |
| return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode, padding_mode="zeros") | |
| class register_model(nn.Module): | |
| def __init__(self, img_size=(64, 1024, 1024), mode='bilinear'): | |
| super(register_model, self).__init__() | |
| self.spatial_trans = SpatialTransformer(img_size, mode) | |
| def forward(self, x): | |
| img = x[0] # [1, 3, 512, 512] | |
| flow = x[1] # [1, 2, 512, 512] | |
| out = self.spatial_trans(img, flow) | |
| return out | |