File size: 5,349 Bytes
05fb4ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import cv2
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.io import savemat
import os
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class register_model2(nn.Module):
    def __init__(self, img_size=(64, 1024, 1024), mode='bilinear'):
        super(register_model2, self).__init__()
        self.spatial_trans = SpatialTransformer2(img_size, mode)

    def forward(self, x):
        img = x[0] # [1, 3, 512, 512]
        flow = x[1] # [1, 2, 512, 512]
        out = self.spatial_trans(img, flow)
        return out


class SpatialTransformer2(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # # create sampling grid
        # vectors = [torch.arange(0, s) for s in size]
        # grids = torch.meshgrid(vectors,indexing="ij")
        # grid = torch.stack(grids) # 列序优先
        # grid = torch.unsqueeze(grid, 0)
        # grid = grid.type(torch.FloatTensor)

        # # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # # adds it to the state dict. this is annoying since everything in the state dict
        # # is included when saving weights to disk, so the model files are way bigger
        # # than they need to be. so far, there does not appear to be an elegant solution.
        # # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        # self.register_buffer('grid', grid) # 先wx后hy,列序优先

    def forward(self, src, flow):
        # new locations
        # new_locs = self.grid + flow #  [1, 2, 256, 256] [b, 2, 256, 256] # 广播机制
        # a = self.grid
        # new_locs = flow.clone() # 先hy后wx列序优先 [1, 2, 384, 512]
        # new_locs2 = new_locs[:, [1, 0],...]
        # shape = flow.shape[2:] # h,w
        # shape = shape[::-1] # [512, 384]
        # shape = [1024,1024]

        # need to normalize grid values to [-1, 1] for resampler
        # for i in range(len(shape)):
        #     flow[:, i, ...] = 2 * (flow[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        # if len(shape) == 2:
        flow = flow.permute(0, 2, 3, 1) # [46, 2, 256, 256]->[46, 256, 256, 2]
            # new_locs = new_locs[..., [1, 0]] # 改回行序优先
        # elif len(shape) == 3:
        #     new_locs = new_locs.permute(0, 2, 3, 4, 1)
        #     new_locs = new_locs[..., [2, 1, 0]]

        return F.grid_sample(src, flow, align_corners=True, mode=self.mode, padding_mode="zeros")



class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # # create sampling grid
        # vectors = [torch.arange(0, s) for s in size]
        # grids = torch.meshgrid(vectors,indexing="ij")
        # grid = torch.stack(grids) # 列序优先
        # grid = torch.unsqueeze(grid, 0)
        # grid = grid.type(torch.FloatTensor)

        # # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # # adds it to the state dict. this is annoying since everything in the state dict
        # # is included when saving weights to disk, so the model files are way bigger
        # # than they need to be. so far, there does not appear to be an elegant solution.
        # # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        # self.register_buffer('grid', grid) # 先wx后hy,列序优先

    def forward(self, src, flow):
        # new locations
        # new_locs = self.grid + flow #  [1, 2, 256, 256] [b, 2, 256, 256] # 广播机制
        # a = self.grid
        new_locs = flow # 先hy后wx列序优先 [1, 2, 384, 512]
        # new_locs2 = new_locs[:, [1, 0],...]
        shape = flow.shape[2:] # h,w
        # shape = shape[::-1] # [512, 384]
        # shape = [1024,1024]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1) # [46, 2, 256, 256]->[46, 256, 256, 2]
            # new_locs = new_locs[..., [1, 0]] # 改回行序优先
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode, padding_mode="zeros")









class register_model(nn.Module):
    def __init__(self, img_size=(64, 1024, 1024), mode='bilinear'):
        super(register_model, self).__init__()
        self.spatial_trans = SpatialTransformer(img_size, mode)

    def forward(self, x):
        img = x[0] # [1, 3, 512, 512]
        flow = x[1] # [1, 2, 512, 512]
        out = self.spatial_trans(img, flow)
        return out