Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,349 Bytes
05fb4ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import numpy as np
import cv2
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.io import savemat
import os
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class register_model2(nn.Module):
def __init__(self, img_size=(64, 1024, 1024), mode='bilinear'):
super(register_model2, self).__init__()
self.spatial_trans = SpatialTransformer2(img_size, mode)
def forward(self, x):
img = x[0] # [1, 3, 512, 512]
flow = x[1] # [1, 2, 512, 512]
out = self.spatial_trans(img, flow)
return out
class SpatialTransformer2(nn.Module):
"""
N-D Spatial Transformer
"""
def __init__(self, size, mode='bilinear'):
super().__init__()
self.mode = mode
# # create sampling grid
# vectors = [torch.arange(0, s) for s in size]
# grids = torch.meshgrid(vectors,indexing="ij")
# grid = torch.stack(grids) # 列序优先
# grid = torch.unsqueeze(grid, 0)
# grid = grid.type(torch.FloatTensor)
# # registering the grid as a buffer cleanly moves it to the GPU, but it also
# # adds it to the state dict. this is annoying since everything in the state dict
# # is included when saving weights to disk, so the model files are way bigger
# # than they need to be. so far, there does not appear to be an elegant solution.
# # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
# self.register_buffer('grid', grid) # 先wx后hy,列序优先
def forward(self, src, flow):
# new locations
# new_locs = self.grid + flow # [1, 2, 256, 256] [b, 2, 256, 256] # 广播机制
# a = self.grid
# new_locs = flow.clone() # 先hy后wx列序优先 [1, 2, 384, 512]
# new_locs2 = new_locs[:, [1, 0],...]
# shape = flow.shape[2:] # h,w
# shape = shape[::-1] # [512, 384]
# shape = [1024,1024]
# need to normalize grid values to [-1, 1] for resampler
# for i in range(len(shape)):
# flow[:, i, ...] = 2 * (flow[:, i, ...] / (shape[i] - 1) - 0.5)
# move channels dim to last position
# also not sure why, but the channels need to be reversed
# if len(shape) == 2:
flow = flow.permute(0, 2, 3, 1) # [46, 2, 256, 256]->[46, 256, 256, 2]
# new_locs = new_locs[..., [1, 0]] # 改回行序优先
# elif len(shape) == 3:
# new_locs = new_locs.permute(0, 2, 3, 4, 1)
# new_locs = new_locs[..., [2, 1, 0]]
return F.grid_sample(src, flow, align_corners=True, mode=self.mode, padding_mode="zeros")
class SpatialTransformer(nn.Module):
"""
N-D Spatial Transformer
"""
def __init__(self, size, mode='bilinear'):
super().__init__()
self.mode = mode
# # create sampling grid
# vectors = [torch.arange(0, s) for s in size]
# grids = torch.meshgrid(vectors,indexing="ij")
# grid = torch.stack(grids) # 列序优先
# grid = torch.unsqueeze(grid, 0)
# grid = grid.type(torch.FloatTensor)
# # registering the grid as a buffer cleanly moves it to the GPU, but it also
# # adds it to the state dict. this is annoying since everything in the state dict
# # is included when saving weights to disk, so the model files are way bigger
# # than they need to be. so far, there does not appear to be an elegant solution.
# # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
# self.register_buffer('grid', grid) # 先wx后hy,列序优先
def forward(self, src, flow):
# new locations
# new_locs = self.grid + flow # [1, 2, 256, 256] [b, 2, 256, 256] # 广播机制
# a = self.grid
new_locs = flow # 先hy后wx列序优先 [1, 2, 384, 512]
# new_locs2 = new_locs[:, [1, 0],...]
shape = flow.shape[2:] # h,w
# shape = shape[::-1] # [512, 384]
# shape = [1024,1024]
# need to normalize grid values to [-1, 1] for resampler
for i in range(len(shape)):
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
# move channels dim to last position
# also not sure why, but the channels need to be reversed
if len(shape) == 2:
new_locs = new_locs.permute(0, 2, 3, 1) # [46, 2, 256, 256]->[46, 256, 256, 2]
# new_locs = new_locs[..., [1, 0]] # 改回行序优先
elif len(shape) == 3:
new_locs = new_locs.permute(0, 2, 3, 4, 1)
new_locs = new_locs[..., [2, 1, 0]]
return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode, padding_mode="zeros")
class register_model(nn.Module):
def __init__(self, img_size=(64, 1024, 1024), mode='bilinear'):
super(register_model, self).__init__()
self.spatial_trans = SpatialTransformer(img_size, mode)
def forward(self, x):
img = x[0] # [1, 3, 512, 512]
flow = x[1] # [1, 2, 512, 512]
out = self.spatial_trans(img, flow)
return out
|