Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| from typing import List, Literal, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import fastmri | |
| from fastmri import transforms | |
| from models.udno import UDNO | |
| def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Calculates F (x sens_maps) | |
| Parameters | |
| ---------- | |
| x : ndarray | |
| Single-channel image of shape (..., H, W, 2) | |
| sens_maps : ndarray | |
| Sensitivity maps (image space) | |
| Returns | |
| ------- | |
| ndarray | |
| Result of the operation F (x sens_maps) | |
| """ | |
| return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) | |
| def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Calculates F^{-1}(k) * conj(sens_maps) | |
| where conj(sens_maps) is the element-wise applied complex conjugate | |
| Parameters | |
| ---------- | |
| k : ndarray | |
| Multi-channel k-space of shape (B, C, H, W, 2) | |
| sens_maps : ndarray | |
| Sensitivity maps (image space) | |
| Returns | |
| ------- | |
| ndarray | |
| Result of the operation F^{-1}(k) * conj(sens_maps) | |
| """ | |
| return fastmri.complex_mul(fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)).sum( | |
| dim=1, keepdim=True | |
| ) | |
| def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]: | |
| """Reshapes batched multi-channel samples into multiple single channel samples. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| x has shape (b, c, h, w, 2) | |
| Returns | |
| ------- | |
| Tuple[torch.Tensor, int] | |
| tensor of shape (b * c, 1, h, w, 2), b | |
| """ | |
| b, c, h, w, comp = x.shape | |
| return x.view(b * c, 1, h, w, comp), b | |
| def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor: | |
| """Reshapes batched independent samples into original multi-channel samples. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| tensor of shape (b * c, 1, h, w, 2) | |
| batch_size : int | |
| batch size | |
| Returns | |
| ------- | |
| torch.Tensor | |
| original multi-channel tensor of shape (b, c, h, w, 2) | |
| """ | |
| bc, _, h, w, comp = x.shape | |
| c = bc // batch_size | |
| return x.view(batch_size, c, h, w, comp) | |
| class NormUDNO(nn.Module): | |
| """ | |
| Normalized UDNO model. | |
| Inputs are normalized before the UDNO for numerically stable training. | |
| """ | |
| def __init__( | |
| self, | |
| chans: int, | |
| num_pool_layers: int, | |
| radius_cutoff: float, | |
| in_shape: Tuple[int, int], | |
| kernel_shape: Tuple[int, int], | |
| in_chans: int = 2, | |
| out_chans: int = 2, | |
| drop_prob: float = 0.0, | |
| ): | |
| """ | |
| Initialize the VarNet model. | |
| Parameters | |
| ---------- | |
| chans : int | |
| Number of output channels of the first convolution layer. | |
| num_pools : int | |
| Number of down-sampling and up-sampling layers. | |
| in_chans : int, optional | |
| Number of channels in the input to the U-Net model. Default is 2. | |
| out_chans : int, optional | |
| Number of channels in the output to the U-Net model. Default is 2. | |
| drop_prob : float, optional | |
| Dropout probability. Default is 0.0. | |
| """ | |
| super().__init__() | |
| self.udno = UDNO( | |
| in_chans=in_chans, | |
| out_chans=out_chans, | |
| radius_cutoff=radius_cutoff, | |
| chans=chans, | |
| num_pool_layers=num_pool_layers, | |
| drop_prob=drop_prob, | |
| in_shape=in_shape, | |
| kernel_shape=kernel_shape, | |
| ) | |
| def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: | |
| b, c, h, w, two = x.shape | |
| assert two == 2 | |
| return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) | |
| def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: | |
| b, c2, h, w = x.shape | |
| assert c2 % 2 == 0 | |
| c = c2 // 2 | |
| return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() | |
| def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # group norm | |
| b, c, h, w = x.shape | |
| x = x.view(b, 2, c // 2 * h * w) | |
| mean = x.mean(dim=2).view(b, 2, 1, 1) | |
| std = x.std(dim=2).view(b, 2, 1, 1) | |
| x = x.view(b, c, h, w) | |
| return (x - mean) / std, mean, std | |
| def norm_new( | |
| self, x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # FIXME: not working, wip | |
| # group norm | |
| b, c, h, w = x.shape | |
| num_groups = 2 | |
| assert ( | |
| c % num_groups == 0 | |
| ), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})." | |
| x = x.view(b, num_groups, c // num_groups * h * w) | |
| mean = x.mean(dim=2).view(b, num_groups, 1, 1) | |
| std = x.std(dim=2).view(b, num_groups, 1, 1) | |
| print(x.shape, mean.shape, std.shape) | |
| x = x.view(b, c, h, w) | |
| mean = ( | |
| mean.view(b, num_groups, 1, 1) | |
| .repeat(1, c // num_groups, h, w) | |
| .view(b, c, h, w) | |
| ) | |
| std = ( | |
| std.view(b, num_groups, 1, 1) | |
| .repeat(1, c // num_groups, h, w) | |
| .view(b, c, h, w) | |
| ) | |
| return (x - mean) / std, mean, std | |
| def unnorm( | |
| self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor | |
| ) -> torch.Tensor: | |
| return x * std + mean | |
| def pad( | |
| self, x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: | |
| _, _, h, w = x.shape | |
| w_mult = ((w - 1) | 15) + 1 | |
| h_mult = ((h - 1) | 15) + 1 | |
| w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] | |
| h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] | |
| # TODO: fix this type when PyTorch fixes theirs | |
| # the documentation lies - this actually takes a list | |
| # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457 | |
| # https://github.com/pytorch/pytorch/pull/16949 | |
| x = F.pad(x, w_pad + h_pad) | |
| return x, (h_pad, w_pad, h_mult, w_mult) | |
| def unpad( | |
| self, | |
| x: torch.Tensor, | |
| h_pad: List[int], | |
| w_pad: List[int], | |
| h_mult: int, | |
| w_mult: int, | |
| ) -> torch.Tensor: | |
| return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if not x.shape[-1] == 2: | |
| raise ValueError("Last dimension must be 2 for complex.") | |
| chans = x.shape[1] | |
| if chans == 2: | |
| # FIXME: hard coded skip norm/pad temporarily to avoid group norm bug | |
| x = self.complex_to_chan_dim(x) | |
| x = self.udno(x) | |
| return self.chan_complex_to_last_dim(x) | |
| # get shapes for unet and normalize | |
| x = self.complex_to_chan_dim(x) | |
| x, mean, std = self.norm(x) | |
| x, pad_sizes = self.pad(x) | |
| x = self.udno(x) | |
| # get shapes back and unnormalize | |
| x = self.unpad(x, *pad_sizes) | |
| x = self.unnorm(x, mean, std) | |
| x = self.chan_complex_to_last_dim(x) | |
| return x | |
| class SensitivityModel(nn.Module): | |
| """ | |
| Learn sensitivity maps | |
| """ | |
| def __init__( | |
| self, | |
| chans: int, | |
| num_pools: int, | |
| radius_cutoff: float, | |
| in_shape: Tuple[int, int], | |
| kernel_shape: Tuple[int, int], | |
| in_chans: int = 2, | |
| out_chans: int = 2, | |
| drop_prob: float = 0.0, | |
| mask_center: bool = True, | |
| ): | |
| """ | |
| Parameters | |
| ---------- | |
| chans : int | |
| Number of output channels of the first convolution layer. | |
| num_pools : int | |
| Number of down-sampling and up-sampling layers. | |
| in_chans : int, optional | |
| Number of channels in the input to the U-Net model. Default is 2. | |
| out_chans : int, optional | |
| Number of channels in the output to the U-Net model. Default is 2. | |
| drop_prob : float, optional | |
| Dropout probability. Default is 0.0. | |
| mask_center : bool, optional | |
| Whether to mask center of k-space for sensitivity map calculation. | |
| Default is True. | |
| """ | |
| super().__init__() | |
| self.mask_center = mask_center | |
| self.norm_udno = NormUDNO( | |
| chans, | |
| num_pools, | |
| radius_cutoff, | |
| in_shape, | |
| kernel_shape, | |
| in_chans=in_chans, | |
| out_chans=out_chans, | |
| drop_prob=drop_prob, | |
| ) | |
| def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: | |
| return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) | |
| def get_pad_and_num_low_freqs( | |
| self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if num_low_frequencies is None or any( | |
| torch.any(t == 0) for t in num_low_frequencies | |
| ): | |
| # get low frequency line locations and mask them out | |
| squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8) | |
| cent = squeezed_mask.shape[1] // 2 | |
| # running argmin returns the first non-zero | |
| left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1) | |
| right = torch.argmin(squeezed_mask[:, cent:], dim=1) | |
| num_low_frequencies_tensor = torch.max( | |
| 2 * torch.min(left, right), torch.ones_like(left) | |
| ) # force a symmetric center unless 1 | |
| else: | |
| num_low_frequencies_tensor = num_low_frequencies * torch.ones( | |
| mask.shape[0], dtype=mask.dtype, device=mask.device | |
| ) | |
| pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2 | |
| return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long) | |
| def forward( | |
| self, | |
| masked_kspace: torch.Tensor, | |
| mask: torch.Tensor, | |
| num_low_frequencies: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| if self.mask_center: | |
| pad, num_low_freqs = self.get_pad_and_num_low_freqs( | |
| mask, num_low_frequencies | |
| ) | |
| masked_kspace = transforms.batched_mask_center( | |
| masked_kspace, pad, pad + num_low_freqs | |
| ) | |
| # convert to image space | |
| images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace)) | |
| # estimate sensitivities | |
| return self.divide_root_sum_of_squares( | |
| batch_chans_to_chan_dim(self.norm_udno(images), batches) | |
| ) | |
| class VarNetBlock(nn.Module): | |
| """ | |
| Model block for iterative refinement of k-space data. | |
| This model applies a combination of soft data consistency with the input | |
| model as a regularizer. A series of these blocks can be stacked to form | |
| the full variational network. | |
| aka Refinement Module in Fig 1 | |
| """ | |
| def __init__(self, model: nn.Module): | |
| """ | |
| Args: | |
| model: Module for "regularization" component of variational | |
| network. | |
| """ | |
| super().__init__() | |
| self.model = model | |
| self.dc_weight = nn.Parameter(torch.ones(1)) | |
| def forward( | |
| self, | |
| current_kspace: torch.Tensor, | |
| ref_kspace: torch.Tensor, | |
| mask: torch.Tensor, | |
| sens_maps: torch.Tensor, | |
| use_dc_term: bool = True, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| current_kspace: The current k-space data (frequency domain data) | |
| being processed by the network. (torch.Tensor) | |
| ref_kspace: Original subsampled k-space data (from which we are | |
| reconstrucintg the image (reference k-space). (torch.Tensor) | |
| mask: A binary mask indicating the locations in k-space where | |
| data consistency should be enforced. (torch.Tensor) | |
| sens_maps: Sensitivity maps for the different coils in parallel | |
| imaging. (torch.Tensor) | |
| """ | |
| # model-term see orange box of Fig 1 in E2E-VarNet paper! | |
| # multi channel k-space -> single channel image-space | |
| b, c, h, w, _ = current_kspace.shape | |
| if c == 30: | |
| # get kspace and inpainted kspace | |
| kspace = current_kspace[:, :15, :, :, :] | |
| in_kspace = current_kspace[:, 15:, :, :, :] | |
| # convert to image space | |
| image = sens_reduce(kspace, sens_maps) | |
| in_image = sens_reduce(in_kspace, sens_maps) | |
| # concatenate both onto each other | |
| reduced_image = torch.cat([image, in_image], dim=1) | |
| else: | |
| reduced_image = sens_reduce(current_kspace, sens_maps) | |
| # single channel image-space | |
| refined_image = self.model(reduced_image) | |
| # single channel image-space -> multi channel k-space | |
| model_term = sens_expand(refined_image, sens_maps) | |
| # only use first 15 channels (masked_kspace) in the update | |
| # current_kspace = current_kspace[:, :15, :, :, :] | |
| if not use_dc_term: | |
| return current_kspace - model_term | |
| """ | |
| Soft data consistency term: | |
| - Calculates the difference between current k-space and reference k-space where the mask is true. | |
| - Multiplies this difference by the data consistency weight. | |
| """ | |
| # dc_term: see green box of Fig 1 in E2E-VarNet paper! | |
| zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace) | |
| soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight | |
| return current_kspace - soft_dc - model_term | |
| class NOVarnet(nn.Module): | |
| """ | |
| Neural Operator model for MRI reconstruction. | |
| Uses a variational architecture (iterative updates) with a learned sensitivity | |
| model. All operations are resolution invariant employing neural operator | |
| modules (GNO, UDNO). | |
| """ | |
| def __init__( | |
| self, | |
| num_cascades: int = 12, | |
| sens_chans: int = 8, | |
| sens_pools: int = 4, | |
| chans: int = 18, | |
| pools: int = 4, | |
| gno_chans: int = 16, | |
| gno_pools: int = 4, | |
| gno_radius_cutoff: float = 0.02, | |
| gno_kernel_shape: Tuple[int, int] = (6, 7), | |
| radius_cutoff: float = 0.01, | |
| kernel_shape: Tuple[int, int] = (3, 4), | |
| in_shape: Tuple[int, int] = (640, 320), | |
| mask_center: bool = True, | |
| use_dc_term: bool = True, | |
| reduction_method: Literal["batch", "rss"] = "rss", | |
| skip_method: Literal["replace", "add", "add_inv", "concat"] = "add", | |
| ): | |
| """ | |
| Parameters | |
| ---------- | |
| num_cascades : int | |
| Number of cascades (i.e., layers) for variational network. | |
| sens_chans : int | |
| Number of channels for sensitivity map U-Net. | |
| sens_pools : int | |
| Number of downsampling and upsampling layers for sensitivity map U-Net. | |
| chans : int | |
| Number of channels for cascade U-Net. | |
| pools : int | |
| Number of downsampling and upsampling layers for cascade U-Net. | |
| mask_center : bool | |
| Whether to mask center of k-space for sensitivity map calculation. | |
| use_dc_term : bool | |
| Whether to use the data consistency term. | |
| reduction_method : "batch" or "rss" | |
| Method for reducing sensitivity maps to single channel. | |
| "batch" reduces to single channel by stacking channels. | |
| "rss" reduces to single channel by root sum of squares. | |
| skip_method : "replace" or "add" or "add_inv" or "concat" | |
| "replace" replaces the input with the output of the GNO | |
| "add" adds the output of the GNO to the input | |
| "add_inv" adds the output of the GNO to the input (only where samples are missing) | |
| "concat" concatenates the output of the GNO to the input | |
| """ | |
| super().__init__() | |
| self.sens_net = SensitivityModel( | |
| sens_chans, | |
| sens_pools, | |
| radius_cutoff, | |
| in_shape, | |
| kernel_shape, | |
| mask_center=mask_center, | |
| ) | |
| self.gno = NormUDNO( | |
| gno_chans, | |
| gno_pools, | |
| in_shape=in_shape, | |
| radius_cutoff=radius_cutoff, | |
| kernel_shape=kernel_shape, | |
| # radius_cutoff=gno_radius_cutoff, | |
| # kernel_shape=gno_kernel_shape, | |
| in_chans=2, | |
| out_chans=2, | |
| ) | |
| self.cascades = nn.ModuleList( | |
| [ | |
| VarNetBlock( | |
| NormUDNO( | |
| chans, | |
| pools, | |
| radius_cutoff, | |
| in_shape, | |
| kernel_shape, | |
| in_chans=( | |
| 4 if skip_method == "concat" and cascade_idx == 0 else 2 | |
| ), | |
| out_chans=2, | |
| ) | |
| ) | |
| for cascade_idx in range(num_cascades) | |
| ] | |
| ) | |
| self.use_dc_term = use_dc_term | |
| self.reduction_method = reduction_method | |
| self.skip_method = skip_method | |
| def forward( | |
| self, | |
| masked_kspace: torch.Tensor, | |
| mask: torch.Tensor, | |
| num_low_frequencies: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| # (B, C, X, Y, 2) | |
| sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) | |
| # reduce before inpainting | |
| if self.reduction_method == "rss": | |
| # (B, 1, H, W, 2) single channel image space | |
| x_reduced = sens_reduce(masked_kspace, sens_maps) | |
| # (B, 1, H, W, 2) | |
| k_reduced = fastmri.fft2c(x_reduced) | |
| elif self.reduction_method == "batch": | |
| k_reduced, b = chans_to_batch_dim(masked_kspace) | |
| # inpainting | |
| if self.skip_method == "replace": | |
| kspace_pred = self.gno(k_reduced) | |
| elif self.skip_method == "add_inv": | |
| # FIXME: this is not correct (mask has shape B, 1, H, W, 2 and self.gno(k_reduced) has shape B*C, 1, H, W, 2) | |
| kspace_pred = k_reduced.clone() + (~mask * self.gno(k_reduced)) | |
| elif self.skip_method == "add": | |
| kspace_pred = k_reduced.clone() + self.gno(k_reduced) | |
| elif self.skip_method == "concat": | |
| kspace_pred = torch.cat([k_reduced.clone(), self.gno(k_reduced)], dim=1) | |
| else: | |
| raise NotImplementedError("skip_method not implemented") | |
| # expand after inpainting | |
| if self.reduction_method == "rss": | |
| if self.skip_method == "concat": | |
| # kspace_pred is (B, 2, H, W, 2) | |
| kspace = kspace_pred[:, :1, :, :, :] | |
| in_kspace = kspace_pred[:, 1:, :, :, :] | |
| # B, 2C, H, W, 2 | |
| kspace_pred = torch.cat( | |
| [sens_expand(kspace, sens_maps), sens_expand(in_kspace, sens_maps)], | |
| dim=1, | |
| ) | |
| else: | |
| # (B, 1, H, W, 2) -> (B, C, H, W, 2) multi-channel k space | |
| kspace_pred = sens_expand(kspace_pred, sens_maps) | |
| elif self.reduction_method == "batch": | |
| # (B, C, H, W, 2) multi-channel k space | |
| if self.skip_method == "concat": | |
| kspace = kspace_pred[:, :1, :, :, :] | |
| in_kspace = kspace_pred[:, 1:, :, :, :] | |
| # B, 2C, H, W, 2 | |
| kspace_pred = torch.cat( | |
| [ | |
| batch_chans_to_chan_dim(kspace, b), | |
| batch_chans_to_chan_dim(in_kspace, b), | |
| ], | |
| dim=1, | |
| ) | |
| else: | |
| kspace_pred = batch_chans_to_chan_dim(kspace_pred, b) | |
| # iterative update | |
| for cascade in self.cascades: | |
| kspace_pred = cascade( | |
| kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term | |
| ) | |
| spatial_pred = fastmri.ifft2c(kspace_pred) | |
| spatial_pred_abs = fastmri.complex_abs(spatial_pred) | |
| combined_spatial = fastmri.rss(spatial_pred_abs, dim=1) | |
| return combined_spatial | |