Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copyright (c) Facebook, Inc. and its affiliates. | |
| This source code is licensed under the MIT license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| from argparse import ArgumentParser | |
| import torch | |
| import fastmri | |
| from fastmri import transforms | |
| from ..varnet import VarNet | |
| import wandb | |
| from .mri_module import MriModule | |
| class VarNetModule(MriModule): | |
| """ | |
| VarNet training module. | |
| This can be used to train variational networks from the paper: | |
| A. Sriram et al. End-to-end variational networks for accelerated MRI | |
| reconstruction. In International Conference on Medical Image Computing and | |
| Computer-Assisted Intervention, 2020. | |
| which was inspired by the earlier paper: | |
| K. Hammernik et al. Learning a variational network for reconstruction of | |
| accelerated MRI data. Magnetic Resonance inMedicine, 79(6):3055–3071, 2018. | |
| """ | |
| def __init__( | |
| self, | |
| num_cascades: int = 12, | |
| pools: int = 4, | |
| chans: int = 18, | |
| sens_pools: int = 4, | |
| sens_chans: int = 8, | |
| lr: float = 0.0003, | |
| lr_step_size: int = 40, | |
| lr_gamma: float = 0.1, | |
| weight_decay: float = 0.0, | |
| **kwargs, | |
| ): | |
| """ | |
| Parameters | |
| ---------- | |
| num_cascades : int | |
| Number of cascades (i.e., layers) for the variational network. | |
| pools : int | |
| Number of downsampling and upsampling layers for the cascade U-Net. | |
| chans : int | |
| Number of channels for the cascade U-Net. | |
| sens_pools : int | |
| Number of downsampling and upsampling layers for the sensitivity map U-Net. | |
| sens_chans : int | |
| Number of channels for the sensitivity map U-Net. | |
| lr : float | |
| Learning rate. | |
| lr_step_size : int | |
| Learning rate step size. | |
| lr_gamma : float | |
| Learning rate gamma decay. | |
| weight_decay : float | |
| Parameter for penalizing weights norm. | |
| num_sense_lines : int, optional | |
| Number of low-frequency lines to use for sensitivity map computation. | |
| Must be even or `None`. Default `None` will automatically compute the number | |
| from masks. Default behavior may cause some slices to use more low-frequency | |
| lines than others, when used in conjunction with e.g. the EquispacedMaskFunc | |
| defaults. To prevent this, either set `num_sense_lines`, or set | |
| `skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc. | |
| Note that setting this value may lead to undesired behavior when training on | |
| multiple accelerations simultaneously. | |
| """ | |
| super().__init__(**kwargs) | |
| self.save_hyperparameters() | |
| self.num_cascades = num_cascades | |
| self.pools = pools | |
| self.chans = chans | |
| self.sens_pools = sens_pools | |
| self.sens_chans = sens_chans | |
| self.lr = lr | |
| self.lr_step_size = lr_step_size | |
| self.lr_gamma = lr_gamma | |
| self.weight_decay = weight_decay | |
| self.varnet = VarNet( | |
| num_cascades=self.num_cascades, | |
| sens_chans=self.sens_chans, | |
| sens_pools=self.sens_pools, | |
| chans=self.chans, | |
| pools=self.pools, | |
| ) | |
| self.criterion = fastmri.SSIMLoss() | |
| self.num_params = sum(p.numel() for p in self.parameters()) | |
| def forward(self, masked_kspace, mask, num_low_frequencies): | |
| return self.varnet(masked_kspace, mask, num_low_frequencies) | |
| def training_step(self, batch, batch_idx): | |
| output = self.forward( | |
| batch.masked_kspace, batch.mask, batch.num_low_frequencies | |
| ) | |
| target, output = transforms.center_crop_to_smallest(batch.target, output) | |
| loss = self.criterion( | |
| output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value | |
| ) | |
| self.log("train_loss", loss, on_step=True, on_epoch=True) | |
| self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
| dataloaders = self.trainer.val_dataloaders | |
| slug = list(dataloaders.keys())[dataloader_idx] | |
| # breakpoint() | |
| output = self.forward( | |
| batch.masked_kspace, batch.mask, batch.num_low_frequencies | |
| ) | |
| target, output = transforms.center_crop_to_smallest(batch.target, output) | |
| loss = self.criterion( | |
| output.unsqueeze(1), | |
| target.unsqueeze(1), | |
| data_range=batch.max_value, | |
| ) | |
| return { | |
| "slug": slug, | |
| "fname": batch.fname, | |
| "slice_num": batch.slice_num, | |
| "max_value": batch.max_value, | |
| "output": output, | |
| "target": target, | |
| "val_loss": loss, | |
| } | |
| def configure_optimizers(self): | |
| optim = torch.optim.Adam( | |
| self.parameters(), lr=self.lr, weight_decay=self.weight_decay | |
| ) | |
| scheduler = torch.optim.lr_scheduler.StepLR( | |
| optim, self.lr_step_size, self.lr_gamma | |
| ) | |
| return [optim], [scheduler] | |
| def add_model_specific_args(parent_parser): # pragma: no-cover | |
| """ | |
| Define parameters that only apply to this model | |
| """ | |
| parser = ArgumentParser(parents=[parent_parser], add_help=False) | |
| parser = MriModule.add_model_specific_args(parser) | |
| # network params | |
| parser.add_argument( | |
| "--num_cascades", | |
| default=12, | |
| type=int, | |
| help="Number of VarNet cascades", | |
| ) | |
| parser.add_argument( | |
| "--pools", | |
| default=4, | |
| type=int, | |
| help="Number of U-Net pooling layers in VarNet blocks", | |
| ) | |
| parser.add_argument( | |
| "--chans", | |
| default=18, | |
| type=int, | |
| help="Number of channels for U-Net in VarNet blocks", | |
| ) | |
| parser.add_argument( | |
| "--sens_pools", | |
| default=4, | |
| type=int, | |
| help=( | |
| "Number of pooling layers for sense map estimation U-Net in" " VarNet" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--sens_chans", | |
| default=8, | |
| type=float, | |
| help="Number of channels for sense map estimation U-Net in VarNet", | |
| ) | |
| # training params (opt) | |
| parser.add_argument( | |
| "--lr", default=0.0003, type=float, help="Adam learning rate" | |
| ) | |
| parser.add_argument( | |
| "--lr_step_size", | |
| default=40, | |
| type=int, | |
| help="Epoch at which to decrease step size", | |
| ) | |
| parser.add_argument( | |
| "--lr_gamma", | |
| default=0.1, | |
| type=float, | |
| help="Extent to which step size should be decreased", | |
| ) | |
| parser.add_argument( | |
| "--weight_decay", | |
| default=0.0, | |
| type=float, | |
| help="Strength of weight decay regularization", | |
| ) | |
| return parser | |