Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copyright (c) Facebook, Inc. and its affiliates. | |
| This source code is licensed under the MIT license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import random | |
| from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import fastmri | |
| from .subsample import MaskFunc | |
| def to_tensor(data: np.ndarray) -> torch.Tensor: | |
| """ | |
| Convert numpy array to PyTorch tensor. | |
| For complex arrays, the real and imaginary parts are stacked along the last | |
| dimension. | |
| Args: | |
| data: Input numpy array. | |
| Returns: | |
| PyTorch version of data. | |
| """ | |
| if np.iscomplexobj(data): | |
| data = np.stack((data.real, data.imag), axis=-1) | |
| return torch.from_numpy(data) | |
| def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray: | |
| """ | |
| Converts a complex torch tensor to numpy array. | |
| Args: | |
| data: Input data to be converted to numpy. | |
| Returns: | |
| Complex numpy version of data. | |
| """ | |
| return torch.view_as_complex(data).numpy() | |
| def apply_mask( | |
| data: torch.Tensor, | |
| mask_func: MaskFunc, | |
| offset: Optional[int] = None, | |
| seed: Optional[Union[int, Tuple[int, ...]]] = None, | |
| padding: Optional[Sequence[int]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
| """ | |
| Subsample given k-space by multiplying with a mask. | |
| Args: | |
| data: The input k-space data. This should have at least 3 dimensions, | |
| where dimensions -3 and -2 are the spatial dimensions, and the | |
| final dimension has size 2 (for complex values). | |
| mask_func: A function that takes a shape (tuple of ints) and a random | |
| number seed and returns a mask. | |
| seed: Seed for the random number generator. | |
| padding: Padding value to apply for mask. | |
| Returns: | |
| tuple containing: | |
| masked data: Subsampled k-space data. | |
| mask: The generated mask. | |
| num_low_frequencies: The number of low-resolution frequency samples | |
| in the mask. | |
| """ | |
| shape = (1,) * len(data.shape[:-3]) + tuple(data.shape[-3:]) | |
| mask, num_low_frequencies = mask_func(shape, offset, seed) | |
| if padding is not None: | |
| mask[..., : padding[0], :] = 0 | |
| mask[..., padding[1] :, :] = ( | |
| 0 # padding value inclusive on right of zeros | |
| ) | |
| masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros | |
| return masked_data, mask, num_low_frequencies | |
| def mask_center(x: torch.Tensor, mask_from: int, mask_to: int) -> torch.Tensor: | |
| """ | |
| Initializes a mask with the center filled in. | |
| Args: | |
| mask_from: Part of center to start filling. | |
| mask_to: Part of center to end filling. | |
| Returns: | |
| A mask with the center filled. | |
| """ | |
| mask = torch.zeros_like(x) | |
| mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] | |
| return mask | |
| def batched_mask_center( | |
| x: torch.Tensor, mask_from: torch.Tensor, mask_to: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Initializes a mask with the center filled in. | |
| Can operate with different masks for each batch element. | |
| Args: | |
| mask_from: Part of center to start filling. | |
| mask_to: Part of center to end filling. | |
| Returns: | |
| A mask with the center filled. | |
| """ | |
| if not mask_from.shape == mask_to.shape: | |
| raise ValueError("mask_from and mask_to must match shapes.") | |
| if not mask_from.ndim == 1: | |
| raise ValueError("mask_from and mask_to must have 1 dimension.") | |
| if not mask_from.shape[0] == 1: | |
| if (not x.shape[0] == mask_from.shape[0]) or ( | |
| not x.shape[0] == mask_to.shape[0] | |
| ): | |
| raise ValueError( | |
| "mask_from and mask_to must have batch_size length." | |
| ) | |
| if mask_from.shape[0] == 1: | |
| mask = mask_center(x, int(mask_from), int(mask_to)) | |
| else: | |
| mask = torch.zeros_like(x) | |
| for i, (start, end) in enumerate(zip(mask_from, mask_to)): | |
| mask[i, :, :, start:end] = x[i, :, :, start:end] | |
| return mask | |
| def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: | |
| """ | |
| Apply a center crop to the input real image or batch of real images. | |
| Args: | |
| data: The input tensor to be center cropped. It should | |
| have at least 2 dimensions and the cropping is applied along the | |
| last two dimensions. | |
| shape: The output shape. The shape should be smaller | |
| than the corresponding dimensions of data. | |
| Returns: | |
| The center cropped image. | |
| """ | |
| if not (0 < shape[0] <= data.shape[-2] and 0 < shape[1] <= data.shape[-1]): | |
| raise ValueError("Invalid shapes.") | |
| w_from = (data.shape[-2] - shape[0]) // 2 | |
| h_from = (data.shape[-1] - shape[1]) // 2 | |
| w_to = w_from + shape[0] | |
| h_to = h_from + shape[1] | |
| return data[..., w_from:w_to, h_from:h_to] | |
| def complex_center_crop( | |
| data: torch.Tensor, shape: Tuple[int, int] | |
| ) -> torch.Tensor: | |
| """ | |
| Apply a center crop to the input image or batch of complex images. | |
| Args: | |
| data: The complex input tensor to be center cropped. It should have at | |
| least 3 dimensions and the cropping is applied along dimensions -3 | |
| and -2 and the last dimensions should have a size of 2. | |
| shape: The output shape. The shape should be smaller than the | |
| corresponding dimensions of data. | |
| Returns: | |
| The center cropped image | |
| """ | |
| if not (0 < shape[0] <= data.shape[-3] and 0 < shape[1] <= data.shape[-2]): | |
| raise ValueError("Invalid shapes.") | |
| w_from = (data.shape[-3] - shape[0]) // 2 | |
| h_from = (data.shape[-2] - shape[1]) // 2 | |
| w_to = w_from + shape[0] | |
| h_to = h_from + shape[1] | |
| return data[..., w_from:w_to, h_from:h_to, :] | |
| def center_crop_to_smallest( | |
| x: torch.Tensor, y: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Apply a center crop on the larger image to the size of the smaller. | |
| The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at | |
| dim=-1 and y is smaller than x at dim=-2, then the returned dimension will | |
| be a mixture of the two. | |
| Args: | |
| x: The first image. | |
| y: The second image. | |
| Returns: | |
| tuple of tensors x and y, each cropped to the minimim size. | |
| """ | |
| smallest_width = min(x.shape[-1], y.shape[-1]) | |
| smallest_height = min(x.shape[-2], y.shape[-2]) | |
| x = center_crop(x, (smallest_height, smallest_width)) | |
| y = center_crop(y, (smallest_height, smallest_width)) | |
| return x, y | |
| def normalize( | |
| data: torch.Tensor, | |
| mean: Union[float, torch.Tensor], | |
| stddev: Union[float, torch.Tensor], | |
| eps: Union[float, torch.Tensor] = 0.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Normalize the given tensor. | |
| Applies the formula (data - mean) / (stddev + eps). | |
| Args: | |
| data: Input data to be normalized. | |
| mean: Mean value. | |
| stddev: Standard deviation. | |
| eps: Added to stddev to prevent dividing by zero. | |
| Returns: | |
| Normalized tensor. | |
| """ | |
| return (data - mean) / (stddev + eps) | |
| def normalize_instance( | |
| data: torch.Tensor, eps: Union[float, torch.Tensor] = 0.0 | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Normalize the given tensor with instance norm/ | |
| Applies the formula (data - mean) / (stddev + eps), where mean and stddev | |
| are computed from the data itself. | |
| Args: | |
| data: Input data to be normalized | |
| eps: Added to stddev to prevent dividing by zero. | |
| Returns: | |
| torch.Tensor: Normalized tensor | |
| """ | |
| mean = data.mean() | |
| std = data.std() | |
| return normalize(data, mean, std, eps), mean, std | |
| class UnetSample(NamedTuple): | |
| """ | |
| A subsampled image for U-Net reconstruction. | |
| Args: | |
| image: Subsampled image after inverse FFT. | |
| target: The target image (if applicable). | |
| mean: Per-channel mean values used for normalization. | |
| std: Per-channel standard deviations used for normalization. | |
| fname: File name. | |
| slice_num: The slice index. | |
| max_value: Maximum image value. | |
| """ | |
| image: torch.Tensor | |
| target: torch.Tensor | |
| mean: torch.Tensor | |
| std: torch.Tensor | |
| fname: str | |
| slice_num: int | |
| max_value: float | |
| class UnetDataTransform: | |
| """ | |
| Data Transformer for training U-Net models. | |
| """ | |
| def __init__( | |
| self, | |
| which_challenge: str, | |
| mask_func: Optional[MaskFunc] = None, | |
| use_seed: bool = True, | |
| ): | |
| """ | |
| Args: | |
| which_challenge: Challenge from ("singlecoil", "multicoil"). | |
| mask_func: Optional; A function that can create a mask of | |
| appropriate shape. | |
| use_seed: If true, this class computes a pseudo random number | |
| generator seed from the filename. This ensures that the same | |
| mask is used for all the slices of a given volume every time. | |
| """ | |
| if which_challenge not in ("singlecoil", "multicoil"): | |
| raise ValueError( | |
| "Challenge should either be 'singlecoil' or 'multicoil'" | |
| ) | |
| self.mask_func = mask_func | |
| self.which_challenge = which_challenge | |
| self.use_seed = use_seed | |
| def __call__( | |
| self, | |
| kspace: np.ndarray, | |
| mask: np.ndarray, | |
| target: np.ndarray, | |
| attrs: Dict, | |
| fname: str, | |
| slice_num: int, | |
| ) -> Tuple[ | |
| torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float | |
| ]: | |
| """ | |
| Args: | |
| kspace: Input k-space of shape (num_coils, rows, cols) for | |
| multi-coil data or (rows, cols) for single coil data. | |
| mask: Mask from the test dataset. | |
| target: Target image. | |
| attrs: Acquisition related information stored in the HDF5 object. | |
| fname: File name. | |
| slice_num: Serial number of the slice. | |
| Returns: | |
| A tuple containing, zero-filled input image, the reconstruction | |
| target, the mean used for normalization, the standard deviations | |
| used for normalization, the filename, and the slice number. | |
| """ | |
| kspace_torch = to_tensor(kspace) | |
| # check for max value | |
| max_value = attrs["max"] if "max" in attrs.keys() else 0.0 | |
| # apply mask | |
| if self.mask_func: | |
| seed = None if not self.use_seed else tuple(map(ord, fname)) | |
| # we only need first element, which is k-space after masking | |
| masked_kspace = apply_mask(kspace_torch, self.mask_func, seed=seed)[ | |
| 0 | |
| ] | |
| else: | |
| masked_kspace = kspace_torch | |
| # inverse Fourier transform to get zero filled solution | |
| image = fastmri.ifft2c(masked_kspace) | |
| # crop input to correct size | |
| if target is not None: | |
| crop_size = (target.shape[-2], target.shape[-1]) | |
| else: | |
| crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
| # check for FLAIR 203 | |
| if image.shape[-2] < crop_size[1]: | |
| crop_size = (image.shape[-2], image.shape[-2]) | |
| image = complex_center_crop(image, crop_size) | |
| # absolute value | |
| image = fastmri.complex_abs(image) | |
| # apply Root-Sum-of-Squares if multicoil data | |
| if self.which_challenge == "multicoil": | |
| image = fastmri.rss(image) | |
| # normalize input | |
| image, mean, std = normalize_instance(image, eps=1e-11) | |
| image = image.clamp(-6, 6) | |
| # normalize target | |
| if target is not None: | |
| target_torch = to_tensor(target) | |
| target_torch = center_crop(target_torch, crop_size) | |
| target_torch = normalize(target_torch, mean, std, eps=1e-11) | |
| target_torch = target_torch.clamp(-6, 6) | |
| else: | |
| target_torch = torch.Tensor([0]) | |
| return UnetSample( | |
| image=image, | |
| target=target_torch, | |
| mean=mean, | |
| std=std, | |
| fname=fname, | |
| slice_num=slice_num, | |
| max_value=max_value, | |
| ) | |
| class VarNetSample(NamedTuple): | |
| """ | |
| A sample of masked k-space for variational network reconstruction. | |
| Args: | |
| masked_kspace: k-space after applying sampling mask. | |
| mask: The applied sampling mask. | |
| num_low_frequencies: The number of samples for the densely-sampled | |
| center. | |
| target: The target image (if applicable). | |
| fname: File name. | |
| slice_num: The slice index. | |
| max_value: Maximum image value. | |
| crop_size: The size to crop the final image. | |
| """ | |
| masked_kspace: torch.Tensor | |
| mask: torch.Tensor | |
| num_low_frequencies: Optional[int] | |
| target: torch.Tensor | |
| fname: str | |
| slice_num: int | |
| max_value: float | |
| crop_size: Tuple[int, int] | |
| class VarNetDataTransform: | |
| """ | |
| Data Transformer for training VarNet models. | |
| """ | |
| def __init__( | |
| self, mask_func: Optional[MaskFunc] = None, use_seed: bool = True | |
| ): | |
| """ | |
| Args: | |
| mask_func: Optional; A function that can create a mask of | |
| appropriate shape. Defaults to None. | |
| use_seed: If True, this class computes a pseudo random number | |
| generator seed from the filename. This ensures that the same | |
| mask is used for all the slices of a given volume every time. | |
| """ | |
| self.mask_func = mask_func | |
| self.use_seed = use_seed | |
| def __call__( | |
| self, | |
| kspace: np.ndarray, | |
| mask: np.ndarray, | |
| target: Optional[np.ndarray], | |
| attrs: Dict, | |
| fname: str, | |
| slice_num: int, | |
| ) -> VarNetSample: | |
| """ | |
| Args: | |
| kspace: Input k-space of shape (num_coils, rows, cols) for | |
| multi-coil data. | |
| mask: Mask from the test dataset. | |
| target: Target image. | |
| attrs: Acquisition related information stored in the HDF5 object. | |
| fname: File name. | |
| slice_num: Serial number of the slice. | |
| Returns: | |
| A VarNetSample with the masked k-space, sampling mask, target | |
| image, the filename, the slice number, the maximum image value | |
| (from target), the target crop size, and the number of low | |
| frequency lines sampled. | |
| """ | |
| if target is not None: | |
| target_torch = to_tensor(target) | |
| max_value = attrs["max"] | |
| else: | |
| target_torch = torch.tensor(0) | |
| max_value = 0.0 | |
| kspace_torch = to_tensor(kspace) | |
| seed = None if not self.use_seed else tuple(map(ord, fname)) | |
| acq_start = attrs["padding_left"] | |
| acq_end = attrs["padding_right"] | |
| crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
| if self.mask_func is not None: | |
| masked_kspace, mask_torch, num_low_frequencies = apply_mask( | |
| kspace_torch, | |
| self.mask_func, | |
| seed=seed, | |
| padding=(acq_start, acq_end), | |
| ) | |
| sample = VarNetSample( | |
| masked_kspace=masked_kspace, | |
| mask=mask_torch.to(torch.bool), | |
| num_low_frequencies=num_low_frequencies, | |
| target=target_torch, | |
| fname=fname, | |
| slice_num=slice_num, | |
| max_value=max_value, | |
| crop_size=crop_size, | |
| ) | |
| else: | |
| masked_kspace = kspace_torch | |
| shape = np.array(kspace_torch.shape) | |
| num_cols = shape[-2] | |
| shape[:-3] = 1 | |
| mask_shape = [1] * len(shape) | |
| mask_shape[-2] = num_cols | |
| mask_torch = torch.from_numpy( | |
| mask.reshape(*mask_shape).astype(np.float32) | |
| ) | |
| mask_torch = mask_torch.reshape(*mask_shape) | |
| mask_torch[:, :, :acq_start] = 0 | |
| mask_torch[:, :, acq_end:] = 0 | |
| sample = VarNetSample( | |
| masked_kspace=masked_kspace, | |
| mask=mask_torch.to(torch.bool), | |
| num_low_frequencies=0, | |
| target=target_torch, | |
| fname=fname, | |
| slice_num=slice_num, | |
| max_value=max_value, | |
| crop_size=crop_size, | |
| ) | |
| # whether to crop samples for batch processing | |
| batch_crop = False | |
| def save_img(x, fname): | |
| slice_kspace2 = x | |
| slice_image = fastmri.ifft2c( | |
| slice_kspace2 | |
| ) # Apply Inverse Fourier Transform to get the complex image | |
| slice_image_abs = fastmri.complex_abs( | |
| slice_image | |
| ) # Compute absolute value to get a real image | |
| slice_image_rss = fastmri.rss(slice_image_abs, dim=0) | |
| plt.imsave(f"{fname}.png", torch.abs(slice_image_rss), cmap="gray") | |
| def save_raw_img(x, fname): | |
| # slice_kspace2 = x | |
| # slice_image = fastmri.ifft2c( | |
| # slice_kspace2 | |
| # ) # Apply Inverse Fourier Transform to get the complex image | |
| # slice_image_abs = fastmri.complex_abs( | |
| # slice_image | |
| # ) # Compute absolute value to get a real image | |
| x = fastmri.rss(x, dim=0)[:, :, 0] | |
| plt.imsave(f"{fname}.png", torch.abs(x)) | |
| if batch_crop: | |
| # crop kspace data to minx, miny size (640, 320 cols) | |
| square_crop = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
| # print(square_crop) | |
| cropped_kspace = fastmri.fft2c( | |
| complex_center_crop( | |
| fastmri.ifft2c(sample.masked_kspace), square_crop | |
| ) | |
| ) | |
| cropped_kspace = complex_center_crop(cropped_kspace, (320, 320)) | |
| # print(cropped_kspace.shape) | |
| # exit(0) | |
| # CHECK: debugging purposes | |
| # save_img(sample.masked_kspace, "og") | |
| # save_img(cropped_kspace, "cropped") | |
| # save_raw_img(sample.masked_kspace, "og_kspace") | |
| # save_raw_img(cropped_kspace, "cropped_kspace") | |
| # exit(0) | |
| # crop mask shape | |
| h_from = (mask_torch.shape[-2] - 320) // 2 | |
| h_to = h_from + 320 | |
| cropped_mask = mask_torch[..., :, h_from:h_to, :] | |
| sample = VarNetSample( | |
| masked_kspace=cropped_kspace, | |
| mask=cropped_mask.to(torch.bool), | |
| num_low_frequencies=0, | |
| target=target_torch, | |
| fname=fname, | |
| slice_num=slice_num, | |
| max_value=max_value, | |
| crop_size=crop_size, | |
| ) | |
| return sample | |
| class EnhancedVarNetDataTransform(VarNetDataTransform): | |
| """ | |
| Enhanced Data Transformer for training VarNet models with additional functionality. | |
| - allows for training on multiple patterns | |
| """ | |
| def __init__( | |
| self, mask_funcs: List[MaskFunc] = None, use_seed: bool = True | |
| ): | |
| self.mask_funcs = mask_funcs | |
| self.use_seed = use_seed | |
| def __call__( | |
| self, | |
| kspace: np.ndarray, | |
| mask: np.ndarray, | |
| target: Optional[np.ndarray], | |
| attrs: Dict, | |
| fname: str, | |
| slice_num: int, | |
| ) -> VarNetSample: | |
| """ | |
| Args: | |
| kspace: Input k-space of shape (num_coils, rows, cols) for | |
| multi-coil data. | |
| mask: Mask from the test dataset. | |
| use mask for test data see og VarNetDataTransform __call__ | |
| target: Target image. | |
| attrs: Acquisition related information stored in the HDF5 object. | |
| fname: File name. | |
| slice_num: Serial number of the slice. | |
| Returns: | |
| A VarNetSample with the masked k-space, sampling mask, target | |
| image, the filename, the slice number, the maximum image value | |
| (from target), the target crop size, and the number of low | |
| frequency lines sampled. | |
| """ | |
| if target is not None: | |
| target_torch = to_tensor(target) | |
| max_value = attrs["max"] | |
| else: | |
| target_torch = torch.tensor(0) | |
| max_value = 0.0 | |
| kspace_torch = to_tensor(kspace) | |
| seed = None if not self.use_seed else tuple(map(ord, fname)) | |
| acq_start = attrs["padding_left"] | |
| acq_end = attrs["padding_right"] | |
| crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
| # choose one of the masking functions provided randomly | |
| mask_func = random.choice(self.mask_funcs) | |
| masked_kspace, mask_torch, num_low_frequencies = apply_mask( | |
| kspace_torch, | |
| mask_func, | |
| seed=seed, | |
| padding=(acq_start, acq_end), | |
| ) | |
| # print(masked_kspace.shape) | |
| # print(mask_torch.shape) | |
| # torch.save(masked_kspace, f"masked_kspace_{slice_num}.pkl") | |
| # torch.save(mask_torch, f"mask_torch_{slice_num}.pkl") | |
| sample = VarNetSample( | |
| masked_kspace=masked_kspace, | |
| mask=mask_torch.to(torch.bool), | |
| num_low_frequencies=num_low_frequencies, | |
| target=target_torch, | |
| fname=fname, | |
| slice_num=slice_num, | |
| max_value=max_value, | |
| crop_size=crop_size, | |
| ) | |
| # whether to crop samples for batch processing | |
| batch_crop = False | |
| if batch_crop: | |
| # crop kspace data to minx, miny size (640, 320 cols) | |
| square_crop = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
| # print(square_crop) | |
| cropped_kspace = fastmri.fft2c( | |
| complex_center_crop( | |
| fastmri.ifft2c(sample.masked_kspace), square_crop | |
| ) | |
| ) | |
| # cropped_kspace = complex_center_crop(cropped_kspace, (640, 320)) | |
| # exit(0) | |
| # crop mask shape | |
| h_from = (mask_torch.shape[-2] - 320) // 2 | |
| h_to = h_from + 320 | |
| cropped_mask = mask_torch[..., :, h_from:h_to, :] | |
| sample = VarNetSample( | |
| masked_kspace=cropped_kspace, | |
| mask=cropped_mask.to(torch.bool), | |
| num_low_frequencies=0, | |
| target=target_torch, | |
| fname=fname, | |
| slice_num=slice_num, | |
| max_value=max_value, | |
| crop_size=crop_size, | |
| ) | |
| return sample | |
| class MiniCoilSample(NamedTuple): | |
| """ | |
| A sample of masked coil-compressed k-space for reconstruction. | |
| Args: | |
| kspace: the original k-space before masking. | |
| masked_kspace: k-space after applying sampling mask. | |
| mask: The applied sampling mask. | |
| num_low_frequencies: The number of samples for the densely-sampled | |
| center. | |
| target: The target image (if applicable). | |
| fname: File name. | |
| slice_num: The slice index. | |
| max_value: Maximum image value. | |
| crop_size: The size to crop the final image. | |
| """ | |
| kspace: torch.Tensor | |
| masked_kspace: torch.Tensor | |
| mask: torch.Tensor | |
| target: torch.Tensor | |
| fname: str | |
| slice_num: int | |
| max_value: float | |
| crop_size: Tuple[int, int] | |
| class MiniCoilTransform: | |
| """ | |
| Multi-coil compressed transform, for faster prototyping. | |
| """ | |
| def __init__( | |
| self, | |
| mask_func: Optional[MaskFunc] = None, | |
| use_seed: Optional[bool] = True, | |
| crop_size: Optional[tuple] = None, | |
| num_compressed_coils: Optional[int] = None, | |
| ): | |
| """ | |
| Args: | |
| mask_func: Optional; A function that can create a mask of | |
| appropriate shape. Defaults to None. | |
| use_seed: If True, this class computes a pseudo random number | |
| generator seed from the filename. This ensures that the same | |
| mask is used for all the slices of a given volume every time. | |
| crop_size: Image dimensions for mini MR images. | |
| num_compressed_coils: Number of coils to output from coil | |
| compression. | |
| """ | |
| self.mask_func = mask_func | |
| self.use_seed = use_seed | |
| self.crop_size = crop_size | |
| self.num_compressed_coils = num_compressed_coils | |
| def __call__(self, kspace, mask, target, attrs, fname, slice_num): | |
| """ | |
| Args: | |
| kspace: Input k-space of shape (num_coils, rows, cols) for | |
| multi-coil data. | |
| mask: Mask from the test dataset. Not used if mask_func is defined. | |
| target: Target image. | |
| attrs: Acquisition related information stored in the HDF5 object. | |
| fname: File name. | |
| slice_num: Serial number of the slice. | |
| Returns: | |
| tuple containing: | |
| kspace: original kspace (used for active acquisition only). | |
| masked_kspace: k-space after applying sampling mask. If there | |
| is no mask or mask_func, returns same as kspace. | |
| mask: The applied sampling mask | |
| target: The target image (if applicable). The target is built | |
| from the RSS opp of all coils pre-compression. | |
| fname: File name. | |
| slice_num: The slice index. | |
| max_value: Maximum image value. | |
| crop_size: The size to crop the final image. | |
| """ | |
| if target is not None: | |
| target = to_tensor(target) | |
| max_value = attrs["max"] | |
| else: | |
| target = torch.tensor(0) | |
| max_value = 0.0 | |
| if self.crop_size is None: | |
| crop_size = torch.tensor( | |
| [attrs["recon_size"][0], attrs["recon_size"][1]] | |
| ) | |
| else: | |
| if isinstance(self.crop_size, tuple) or isinstance( | |
| self.crop_size, list | |
| ): | |
| assert len(self.crop_size) == 2 | |
| if self.crop_size[0] is None or self.crop_size[1] is None: | |
| crop_size = torch.tensor( | |
| [attrs["recon_size"][0], attrs["recon_size"][1]] | |
| ) | |
| else: | |
| crop_size = torch.tensor(self.crop_size) | |
| elif isinstance(self.crop_size, int): | |
| crop_size = torch.tensor((self.crop_size, self.crop_size)) | |
| else: | |
| raise ValueError( | |
| "`crop_size` should be None, tuple, list, or int, not:" | |
| f" {type(self.crop_size)}" | |
| ) | |
| if self.num_compressed_coils is None: | |
| num_compressed_coils = kspace.shape[0] | |
| else: | |
| num_compressed_coils = self.num_compressed_coils | |
| seed = None if not self.use_seed else tuple(map(ord, fname)) | |
| acq_start = 0 | |
| acq_end = crop_size[1] | |
| # new cropping section | |
| square_crop = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
| kspace = fastmri.fft2c( | |
| complex_center_crop(fastmri.ifft2c(to_tensor(kspace)), square_crop) | |
| ).numpy() | |
| kspace = complex_center_crop(kspace, crop_size) | |
| # we calculate the target before coil compression. This causes the mini | |
| # simulation to be one where we have a 15-coil, low-resolution image | |
| # and our reconstructor has an SVD coil approximation. This is a little | |
| # bit more realistic than doing the target after SVD compression | |
| target = fastmri.rss_complex(fastmri.ifft2c(to_tensor(kspace))) | |
| max_value = target.max() | |
| # apply coil compression | |
| new_shape = (num_compressed_coils,) + kspace.shape[1:] | |
| kspace = np.reshape(kspace, (kspace.shape[0], -1)) | |
| left_vec, _, _ = np.linalg.svd( | |
| kspace, compute_uv=True, full_matrices=False | |
| ) | |
| kspace = np.reshape( | |
| np.array(np.matrix(left_vec[:, :num_compressed_coils]).H @ kspace), | |
| new_shape, | |
| ) | |
| kspace = to_tensor(kspace) | |
| # Mask kspace | |
| if self.mask_func: | |
| masked_kspace, mask, _ = apply_mask( | |
| kspace, self.mask_func, seed, (acq_start, acq_end) | |
| ) | |
| mask = mask.byte() | |
| elif mask is not None: | |
| masked_kspace = kspace | |
| shape = np.array(kspace.shape) | |
| num_cols = shape[-2] | |
| shape[:-3] = 1 | |
| mask_shape = [1] * len(shape) | |
| mask_shape[-2] = num_cols | |
| mask = torch.from_numpy( | |
| mask.reshape(*mask_shape).astype(np.float32) | |
| ) | |
| mask = mask.reshape(*mask_shape) | |
| mask = mask.byte() | |
| else: | |
| masked_kspace = kspace | |
| shape = np.array(kspace.shape) | |
| num_cols = shape[-2] | |
| return MiniCoilSample( | |
| kspace, | |
| masked_kspace, | |
| mask, | |
| target, | |
| fname, | |
| slice_num, | |
| max_value, | |
| crop_size, | |
| ) | |
| """ | |
| sens maps & feature transformations | |
| - expand | |
| - reduce | |
| - batch -> chan | |
| - chan -> batch | |
| """ | |
| def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Calculates F (x sens_maps) | |
| Parameters | |
| ---------- | |
| x : ndarray | |
| Single-channel image of shape (..., H, W, 2) | |
| sens_maps : ndarray | |
| Sensitivity maps (image space) | |
| Returns | |
| ------- | |
| ndarray | |
| Result of the operation F (x sens_maps) | |
| """ | |
| return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) | |
| def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Calculates F^{-1}(k) * conj(sens_maps) | |
| where conj(sens_maps) is the element-wise applied complex conjugate | |
| Parameters | |
| ---------- | |
| k : ndarray | |
| Multi-channel k-space of shape (B, C, H, W, 2) | |
| sens_maps : ndarray | |
| Sensitivity maps (image space) | |
| Returns | |
| ------- | |
| ndarray | |
| Result of the operation F^{-1}(k) * conj(sens_maps) | |
| """ | |
| return fastmri.complex_mul( | |
| fastmri.ifft2c(k), fastmri.complex_conj(sens_maps) | |
| ).sum(dim=1, keepdim=True) | |
| def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]: | |
| """Reshapes batched multi-channel samples into multiple single channel samples. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| x has shape (b, c, h, w, 2) | |
| Returns | |
| ------- | |
| Tuple[torch.Tensor, int] | |
| tensor of shape (b * c, 1, h, w, 2), b | |
| """ | |
| b, c, h, w, comp = x.shape | |
| return x.view(b * c, 1, h, w, comp), b | |
| def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor: | |
| """Reshapes batched independent samples into original multi-channel samples. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| tensor of shape (b * c, 1, h, w, 2) | |
| batch_size : int | |
| batch size | |
| Returns | |
| ------- | |
| torch.Tensor | |
| original multi-channel tensor of shape (b, c, h, w, 2) | |
| """ | |
| bc, _, h, w, comp = x.shape | |
| c = bc // batch_size | |
| return x.view(batch_size, c, h, w, comp) | |