Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copyright (c) Facebook, Inc. and its affiliates. | |
| This source code is licensed under the MIT license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import math | |
| import os | |
| from typing import List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import fastmri | |
| from fastmri import transforms | |
| from models.unet import Unet | |
| class NormUnet(nn.Module): | |
| """ | |
| Normalized U-Net model. | |
| This is the same as a regular U-Net, but with normalization applied to the | |
| input before the U-Net. This keeps the values more numerically stable | |
| during training. | |
| """ | |
| def __init__( | |
| self, | |
| chans: int, | |
| num_pools: int, | |
| in_chans: int = 2, | |
| out_chans: int = 2, | |
| drop_prob: float = 0.0, | |
| ): | |
| """ | |
| Initialize the VarNet model. | |
| Parameters | |
| ---------- | |
| chans : int | |
| Number of output channels of the first convolution layer. | |
| num_pools : int | |
| Number of down-sampling and up-sampling layers. | |
| in_chans : int, optional | |
| Number of channels in the input to the U-Net model. Default is 2. | |
| out_chans : int, optional | |
| Number of channels in the output to the U-Net model. Default is 2. | |
| drop_prob : float, optional | |
| Dropout probability. Default is 0.0. | |
| """ | |
| super().__init__() | |
| self.unet = Unet( | |
| in_chans=in_chans, | |
| out_chans=out_chans, | |
| chans=chans, | |
| num_pool_layers=num_pools, | |
| drop_prob=drop_prob, | |
| ) | |
| def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: | |
| b, c, h, w, two = x.shape | |
| assert two == 2 | |
| return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) | |
| def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: | |
| b, c2, h, w = x.shape | |
| assert c2 % 2 == 0 | |
| c = c2 // 2 | |
| return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() | |
| def norm( | |
| self, x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # group norm | |
| b, c, h, w = x.shape | |
| x = x.view(b, 2, c // 2 * h * w) | |
| mean = x.mean(dim=2).view(b, 2, 1, 1) | |
| std = x.std(dim=2).view(b, 2, 1, 1) | |
| x = x.view(b, c, h, w) | |
| return (x - mean) / std, mean, std | |
| def unnorm( | |
| self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor | |
| ) -> torch.Tensor: | |
| return x * std + mean | |
| def pad( | |
| self, x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: | |
| _, _, h, w = x.shape | |
| w_mult = ((w - 1) | 15) + 1 | |
| h_mult = ((h - 1) | 15) + 1 | |
| w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] | |
| h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] | |
| # TODO: fix this type when PyTorch fixes theirs | |
| # the documentation lies - this actually takes a list | |
| # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457 | |
| # https://github.com/pytorch/pytorch/pull/16949 | |
| x = F.pad(x, w_pad + h_pad) | |
| return x, (h_pad, w_pad, h_mult, w_mult) | |
| def unpad( | |
| self, | |
| x: torch.Tensor, | |
| h_pad: List[int], | |
| w_pad: List[int], | |
| h_mult: int, | |
| w_mult: int, | |
| ) -> torch.Tensor: | |
| return x[ | |
| ..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1] | |
| ] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if not x.shape[-1] == 2: | |
| raise ValueError("Last dimension must be 2 for complex.") | |
| # get shapes for unet and normalize | |
| x = self.complex_to_chan_dim(x) | |
| x, mean, std = self.norm(x) | |
| x, pad_sizes = self.pad(x) | |
| x = self.unet(x) | |
| # get shapes back and unnormalize | |
| x = self.unpad(x, *pad_sizes) | |
| x = self.unnorm(x, mean, std) | |
| x = self.chan_complex_to_last_dim(x) | |
| return x | |
| class SensitivityModel(nn.Module): | |
| """ | |
| Model for learning sensitivity estimation from k-space data. | |
| This model applies an IFFT to multichannel k-space data and then a U-Net | |
| to the coil images to estimate coil sensitivities. It can be used with the | |
| end-to-end variational network. | |
| Input: multi-coil k-space data | |
| Output: multi-coil spatial domain sensitivity maps | |
| """ | |
| def __init__( | |
| self, | |
| chans: int, | |
| num_pools: int, | |
| in_chans: int = 2, | |
| out_chans: int = 2, | |
| drop_prob: float = 0.0, | |
| mask_center: bool = True, | |
| ): | |
| """ | |
| Parameters | |
| ---------- | |
| chans : int | |
| Number of output channels of the first convolution layer. | |
| num_pools : int | |
| Number of down-sampling and up-sampling layers. | |
| in_chans : int, optional | |
| Number of channels in the input to the U-Net model. Default is 2. | |
| out_chans : int, optional | |
| Number of channels in the output to the U-Net model. Default is 2. | |
| drop_prob : float, optional | |
| Dropout probability. Default is 0.0. | |
| mask_center : bool, optional | |
| Whether to mask center of k-space for sensitivity map calculation. | |
| Default is True. | |
| """ | |
| super().__init__() | |
| self.mask_center = mask_center | |
| self.norm_unet = NormUnet( | |
| chans, | |
| num_pools, | |
| in_chans=in_chans, | |
| out_chans=out_chans, | |
| drop_prob=drop_prob, | |
| ) | |
| def chans_to_batch_dim(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: | |
| b, c, h, w, comp = x.shape | |
| return x.view(b * c, 1, h, w, comp), b | |
| def batch_chans_to_chan_dim( | |
| self, | |
| x: torch.Tensor, | |
| batch_size: int, | |
| ) -> torch.Tensor: | |
| bc, _, h, w, comp = x.shape | |
| c = bc // batch_size | |
| return x.view(batch_size, c, h, w, comp) | |
| def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: | |
| return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) | |
| def get_pad_and_num_low_freqs( | |
| self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if num_low_frequencies is None or any( | |
| torch.any(t == 0) for t in num_low_frequencies | |
| ): | |
| # get low frequency line locations and mask them out | |
| squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8) | |
| cent = squeezed_mask.shape[1] // 2 | |
| # running argmin returns the first non-zero | |
| left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1) | |
| right = torch.argmin(squeezed_mask[:, cent:], dim=1) | |
| num_low_frequencies_tensor = torch.max( | |
| 2 * torch.min(left, right), torch.ones_like(left) | |
| ) # force a symmetric center unless 1 | |
| else: | |
| num_low_frequencies_tensor = num_low_frequencies * torch.ones( | |
| mask.shape[0], dtype=mask.dtype, device=mask.device | |
| ) | |
| pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2 | |
| return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long) | |
| def forward( | |
| self, | |
| masked_kspace: torch.Tensor, | |
| mask: torch.Tensor, | |
| num_low_frequencies: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| if self.mask_center: | |
| pad, num_low_freqs = self.get_pad_and_num_low_freqs( | |
| mask, num_low_frequencies | |
| ) | |
| masked_kspace = transforms.batched_mask_center( | |
| masked_kspace, pad, pad + num_low_freqs | |
| ) | |
| # convert to image space | |
| images, batches = self.chans_to_batch_dim(fastmri.ifft2c(masked_kspace)) | |
| # estimate sensitivities | |
| return self.divide_root_sum_of_squares( | |
| self.batch_chans_to_chan_dim(self.norm_unet(images), batches) | |
| ) | |
| class VarNet(nn.Module): | |
| """ | |
| A full variational network model. | |
| This model applies a combination of soft data consistency with a U-Net | |
| regularizer. To use non-U-Net regularizers, use VarNetBlock. | |
| Input: multi-channel k-space data | |
| Output: single-channel RSS reconstructed image | |
| """ | |
| def __init__( | |
| self, | |
| num_cascades: int = 12, | |
| sens_chans: int = 8, | |
| sens_pools: int = 4, | |
| chans: int = 18, | |
| pools: int = 4, | |
| mask_center: bool = True, | |
| ): | |
| """ | |
| Parameters | |
| ---------- | |
| num_cascades : int | |
| Number of cascades (i.e., layers) for variational network. | |
| sens_chans : int | |
| Number of channels for sensitivity map U-Net. | |
| sens_pools : int | |
| Number of downsampling and upsampling layers for sensitivity map U-Net. | |
| chans : int | |
| Number of channels for cascade U-Net. | |
| pools : int | |
| Number of downsampling and upsampling layers for cascade U-Net. | |
| mask_center : bool | |
| Whether to mask center of k-space for sensitivity map calculation. | |
| """ | |
| super().__init__() | |
| self.sens_net = SensitivityModel( | |
| chans=sens_chans, | |
| num_pools=sens_pools, | |
| mask_center=mask_center, | |
| ) | |
| self.cascades = nn.ModuleList( | |
| [VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)] | |
| ) | |
| def forward( | |
| self, | |
| masked_kspace: torch.Tensor, | |
| mask: torch.Tensor, | |
| num_low_frequencies: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) | |
| kspace_pred = masked_kspace.clone() | |
| for cascade in self.cascades: | |
| kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) | |
| spatial_pred = fastmri.ifft2c(kspace_pred) | |
| # ---------> FIXME: CHANGE FOR MVUE MODE | |
| if self.training and os.getenv("MVUE") in ["yes", "1", "true", "True"]: | |
| combined_spatial = fastmri.mvue(spatial_pred, sens_maps, dim=1) | |
| else: | |
| spatial_pred_abs = fastmri.complex_abs(spatial_pred) | |
| combined_spatial = fastmri.rss(spatial_pred_abs, dim=1) | |
| return combined_spatial | |
| class VarNetBlock(nn.Module): | |
| """ | |
| Model block for end-to-end variational network (refinemnt module) | |
| This model applies a combination of soft data consistency with the input | |
| model as a regularizer. A series of these blocks can be stacked to form | |
| the full variational network. | |
| Input: multi-channel k-space data | |
| Output: multi-channel k-space data | |
| """ | |
| def __init__(self, model: nn.Module): | |
| """ | |
| Parameters | |
| ---------- | |
| model : nn.Module | |
| Module for "regularization" component of variational network. | |
| """ | |
| super().__init__() | |
| self.model = model | |
| self.dc_weight = nn.Parameter(torch.ones(1)) | |
| def sens_expand( | |
| self, x: torch.Tensor, sens_maps: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Calculates F (x sens_maps) | |
| """ | |
| return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) | |
| def sens_reduce( | |
| self, x: torch.Tensor, sens_maps: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Calculates F^{-1}(x) \overline{sens_maps} | |
| where \overline{sens_maps} is the element-wise applied complex conjugate | |
| """ | |
| return fastmri.complex_mul( | |
| fastmri.ifft2c(x), fastmri.complex_conj(sens_maps) | |
| ).sum(dim=1, keepdim=True) | |
| def forward( | |
| self, | |
| current_kspace: torch.Tensor, | |
| ref_kspace: torch.Tensor, | |
| mask: torch.Tensor, | |
| sens_maps: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Parameters | |
| ---------- | |
| current_kspace : torch.Tensor | |
| The current k-space data (frequency domain data) being processed by the network. | |
| ref_kspace : torch.Tensor | |
| The reference k-space data (measured data) used for data consistency. | |
| mask : torch.Tensor | |
| A binary mask indicating the locations in k-space where data consistency should be enforced. | |
| sens_maps : torch.Tensor | |
| Sensitivity maps for the different coils in parallel imaging. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| The output k-space data after applying the variational network block. | |
| """ | |
| """ | |
| Model term: | |
| - Reduces the current k-space data using the sensitivity maps (inverse Fourier transform followed by element-wise multiplication and summation). | |
| - Applies the neural network model to the reduced data. | |
| - Expands the output of the model using the sensitivity maps (element-wise multiplication followed by Fourier transform). | |
| """ | |
| model_term = self.sens_expand( | |
| self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps | |
| ) | |
| """ | |
| Soft data consistency term: | |
| - Calculates the difference between current k-space and reference k-space where the mask is true. | |
| - Multiplies this difference by the data consistency weight. | |
| """ | |
| zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace) | |
| soft_dc = ( | |
| torch.where(mask, current_kspace - ref_kspace, zero) | |
| * self.dc_weight | |
| ) | |
| # with data consistency term (removed for single cascade experiments) | |
| return current_kspace - soft_dc - model_term | |