Code updates
Browse files- .DS_Store +0 -0
- __pycache__/brain2vec.cpython-310.pyc +0 -0
- model.py → inference_brain2vec.py +140 -11
- requirements.txt +8 -4
- brain2vec.py → train_brain2vec.py +21 -114
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
__pycache__/brain2vec.cpython-310.pyc
DELETED
|
Binary file (18.8 kB)
|
|
|
model.py → inference_brain2vec.py
RENAMED
|
@@ -1,9 +1,34 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
|
|
| 7 |
from monai.transforms import (
|
| 8 |
Compose,
|
| 9 |
CopyItemsD,
|
|
@@ -14,12 +39,12 @@ from monai.transforms import (
|
|
| 14 |
ScaleIntensityD,
|
| 15 |
)
|
| 16 |
from generative.networks.nets import AutoencoderKL
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
# Constants for your typical config
|
| 19 |
RESOLUTION = 2
|
| 20 |
INPUT_SHAPE_AE = (80, 96, 80)
|
| 21 |
|
| 22 |
-
# Define the exact transform pipeline for input MRI
|
| 23 |
transforms_fn = Compose([
|
| 24 |
CopyItemsD(keys={'image_path'}, names=['image']),
|
| 25 |
LoadImageD(image_only=True, keys=['image']),
|
|
@@ -29,15 +54,23 @@ transforms_fn = Compose([
|
|
| 29 |
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
| 30 |
])
|
| 31 |
|
|
|
|
| 32 |
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
|
| 33 |
"""
|
| 34 |
Preprocess an MRI using MONAI transforms to produce
|
| 35 |
-
a 5D tensor (batch=1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
"""
|
| 37 |
data_dict = {"image_path": image_path}
|
| 38 |
output_dict = transforms_fn(data_dict)
|
| 39 |
image_tensor = output_dict["image"] # shape: (1, D, H, W)
|
| 40 |
-
image_tensor = image_tensor.unsqueeze(0) # => (
|
| 41 |
return image_tensor.to(device)
|
| 42 |
|
| 43 |
|
|
@@ -63,11 +96,11 @@ class Brain2vec(AutoencoderKL):
|
|
| 63 |
Otherwise, return an uninitialized model.
|
| 64 |
|
| 65 |
Args:
|
| 66 |
-
checkpoint_path (Optional[str]):
|
| 67 |
device (str): "cpu", "cuda", "mps", etc.
|
| 68 |
|
| 69 |
Returns:
|
| 70 |
-
nn.Module:
|
| 71 |
"""
|
| 72 |
model = Brain2vec(
|
| 73 |
spatial_dims=3,
|
|
@@ -90,5 +123,101 @@ class Brain2vec(AutoencoderKL):
|
|
| 90 |
model.load_state_dict(state_dict)
|
| 91 |
|
| 92 |
model.to(device)
|
| 93 |
-
model.eval()
|
| 94 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
inference_brain2vec.py
|
| 5 |
+
|
| 6 |
+
Loads a pretrained Brain2vec VAE (AutoencoderKL) model and performs inference
|
| 7 |
+
on one or more MRI images, generating reconstructions and latent parameters
|
| 8 |
+
(z_mu, z_sigma).
|
| 9 |
|
| 10 |
+
Example usage:
|
| 11 |
+
|
| 12 |
+
# 1) Multiple file paths
|
| 13 |
+
python inference_brain2vec.py \
|
| 14 |
+
--checkpoint_path /path/to/autoencoder_checkpoint.pth \
|
| 15 |
+
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
| 16 |
+
--output_dir ./vae_inference_outputs \
|
| 17 |
+
--device cuda
|
| 18 |
+
|
| 19 |
+
# 2) Use a CSV containing image paths
|
| 20 |
+
python inference_brain2vec.py \
|
| 21 |
+
--checkpoint_path /path/to/autoencoder_checkpoint.pth \
|
| 22 |
+
--csv_input /path/to/images.csv \
|
| 23 |
+
--output_dir ./vae_inference_outputs
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import argparse
|
| 28 |
+
import numpy as np
|
| 29 |
import torch
|
| 30 |
import torch.nn as nn
|
| 31 |
+
from typing import Optional
|
| 32 |
from monai.transforms import (
|
| 33 |
Compose,
|
| 34 |
CopyItemsD,
|
|
|
|
| 39 |
ScaleIntensityD,
|
| 40 |
)
|
| 41 |
from generative.networks.nets import AutoencoderKL
|
| 42 |
+
import pandas as pd
|
| 43 |
+
|
| 44 |
|
|
|
|
| 45 |
RESOLUTION = 2
|
| 46 |
INPUT_SHAPE_AE = (80, 96, 80)
|
| 47 |
|
|
|
|
| 48 |
transforms_fn = Compose([
|
| 49 |
CopyItemsD(keys={'image_path'}, names=['image']),
|
| 50 |
LoadImageD(image_only=True, keys=['image']),
|
|
|
|
| 54 |
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
| 55 |
])
|
| 56 |
|
| 57 |
+
|
| 58 |
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
|
| 59 |
"""
|
| 60 |
Preprocess an MRI using MONAI transforms to produce
|
| 61 |
+
a 5D tensor (batch=1, channel=1, D, H, W) for inference.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
image_path (str): Path to the MRI (e.g. .nii.gz).
|
| 65 |
+
device (str): Device to place the tensor on.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
torch.Tensor: Shape (1, 1, D, H, W).
|
| 69 |
"""
|
| 70 |
data_dict = {"image_path": image_path}
|
| 71 |
output_dict = transforms_fn(data_dict)
|
| 72 |
image_tensor = output_dict["image"] # shape: (1, D, H, W)
|
| 73 |
+
image_tensor = image_tensor.unsqueeze(0) # => (1, 1, D, H, W)
|
| 74 |
return image_tensor.to(device)
|
| 75 |
|
| 76 |
|
|
|
|
| 96 |
Otherwise, return an uninitialized model.
|
| 97 |
|
| 98 |
Args:
|
| 99 |
+
checkpoint_path (Optional[str]): Path to a .pth checkpoint file.
|
| 100 |
device (str): "cpu", "cuda", "mps", etc.
|
| 101 |
|
| 102 |
Returns:
|
| 103 |
+
nn.Module: The loaded Brain2vec model on the chosen device.
|
| 104 |
"""
|
| 105 |
model = Brain2vec(
|
| 106 |
spatial_dims=3,
|
|
|
|
| 123 |
model.load_state_dict(state_dict)
|
| 124 |
|
| 125 |
model.to(device)
|
| 126 |
+
model.eval()
|
| 127 |
+
return model
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def main() -> None:
|
| 131 |
+
"""
|
| 132 |
+
Main function to parse command-line arguments and run inference
|
| 133 |
+
with a pretrained Brain2vec model.
|
| 134 |
+
"""
|
| 135 |
+
parser = argparse.ArgumentParser(
|
| 136 |
+
description="Inference script for a Brain2vec (VAE) model."
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--checkpoint_path", type=str, required=True,
|
| 140 |
+
help="Path to the .pth checkpoint of the pretrained Brain2vec model."
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--output_dir", type=str, default="./vae_inference_outputs",
|
| 144 |
+
help="Directory to save reconstructions and latent parameters."
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--device", type=str, default="cpu",
|
| 148 |
+
help="Device to run inference on ('cpu', 'cuda', etc.)."
|
| 149 |
+
)
|
| 150 |
+
# Two ways to supply images: multiple file paths or a CSV
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--input_images", type=str, nargs="*",
|
| 153 |
+
help="One or more MRI file paths (e.g. .nii.gz)."
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--csv_input", type=str,
|
| 157 |
+
help="Path to a CSV file with an 'image_path' column."
|
| 158 |
+
)
|
| 159 |
+
args = parser.parse_args()
|
| 160 |
+
|
| 161 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 162 |
+
|
| 163 |
+
# Load the model
|
| 164 |
+
model = Brain2vec.from_pretrained(
|
| 165 |
+
checkpoint_path=args.checkpoint_path,
|
| 166 |
+
device=args.device
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Gather image paths
|
| 170 |
+
if args.csv_input:
|
| 171 |
+
df = pd.read_csv(args.csv_input)
|
| 172 |
+
if "image_path" not in df.columns:
|
| 173 |
+
raise ValueError("CSV must contain a column named 'image_path'.")
|
| 174 |
+
image_paths = df["image_path"].tolist()
|
| 175 |
+
else:
|
| 176 |
+
if not args.input_images:
|
| 177 |
+
raise ValueError("Must provide either --csv_input or --input_images.")
|
| 178 |
+
image_paths = args.input_images
|
| 179 |
+
|
| 180 |
+
# Lists for stacking latent parameters later
|
| 181 |
+
all_z_mu = []
|
| 182 |
+
all_z_sigma = []
|
| 183 |
+
|
| 184 |
+
# Inference on each image
|
| 185 |
+
for i, img_path in enumerate(image_paths):
|
| 186 |
+
if not os.path.exists(img_path):
|
| 187 |
+
raise FileNotFoundError(f"Image not found: {img_path}")
|
| 188 |
+
|
| 189 |
+
print(f"[INFO] Processing image {i}: {img_path}")
|
| 190 |
+
img_tensor = preprocess_mri(img_path, device=args.device)
|
| 191 |
+
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
recon, z_mu, z_sigma = model.forward(img_tensor)
|
| 194 |
+
|
| 195 |
+
# Convert to NumPy
|
| 196 |
+
recon_np = recon.detach().cpu().numpy() # shape: (1, 1, D, H, W)
|
| 197 |
+
z_mu_np = z_mu.detach().cpu().numpy() # shape: (1, latent_channels, ...)
|
| 198 |
+
z_sigma_np = z_sigma.detach().cpu().numpy()
|
| 199 |
+
|
| 200 |
+
# Save each reconstruction (per image) as .npy
|
| 201 |
+
recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
|
| 202 |
+
np.save(recon_path, recon_np)
|
| 203 |
+
print(f"[INFO] Saved reconstruction to {recon_path}")
|
| 204 |
+
|
| 205 |
+
# Store latent parameters for optional combined saving
|
| 206 |
+
all_z_mu.append(z_mu_np)
|
| 207 |
+
all_z_sigma.append(z_sigma_np)
|
| 208 |
+
|
| 209 |
+
# Combine latent parameters from all images and save
|
| 210 |
+
stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
|
| 211 |
+
stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
|
| 212 |
+
|
| 213 |
+
mu_path = os.path.join(args.output_dir, "all_z_mu.npy")
|
| 214 |
+
sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
|
| 215 |
+
np.save(mu_path, stacked_mu)
|
| 216 |
+
np.save(sigma_path, stacked_sigma)
|
| 217 |
+
|
| 218 |
+
print(f"[INFO] Saved z_mu of shape {stacked_mu.shape} to {mu_path}")
|
| 219 |
+
print(f"[INFO] Saved z_sigma of shape {stacked_sigma.shape} to {sigma_path}")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
main()
|
requirements.txt
CHANGED
|
@@ -1,12 +1,15 @@
|
|
| 1 |
# requirements.txt
|
| 2 |
|
| 3 |
-
# PyTorch (CUDA or CPU version).
|
| 4 |
torch>=1.12
|
| 5 |
|
| 6 |
-
# MONAI
|
| 7 |
-
monai-weekly
|
| 8 |
monai-generative
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# For perceptual losses in MONAI's generative module.
|
| 11 |
lpips
|
| 12 |
|
|
@@ -17,4 +20,5 @@ nibabel
|
|
| 17 |
tqdm
|
| 18 |
tensorboard
|
| 19 |
matplotlib
|
| 20 |
-
datasets
|
|
|
|
|
|
| 1 |
# requirements.txt
|
| 2 |
|
| 3 |
+
# PyTorch (CUDA or CPU version).
|
| 4 |
torch>=1.12
|
| 5 |
|
| 6 |
+
# Install MONAI Generative first
|
|
|
|
| 7 |
monai-generative
|
| 8 |
|
| 9 |
+
# Now force reinstall MONAI Weekly so its (newer) MONAI version takes precedence
|
| 10 |
+
--force-reinstall
|
| 11 |
+
monai-weekly
|
| 12 |
+
|
| 13 |
# For perceptual losses in MONAI's generative module.
|
| 14 |
lpips
|
| 15 |
|
|
|
|
| 20 |
tqdm
|
| 21 |
tensorboard
|
| 22 |
matplotlib
|
| 23 |
+
datasets
|
| 24 |
+
scikit-learn
|
brain2vec.py → train_brain2vec.py
RENAMED
|
@@ -1,35 +1,20 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
-
# SOFTWARE.
|
| 22 |
-
|
| 23 |
-
# Forked from: https://github.com/LemuelPuglisi/BrLP
|
| 24 |
-
|
| 25 |
-
# @inproceedings{puglisi2024enhancing,
|
| 26 |
-
# title={Enhancing spatiotemporal disease progression models via latent diffusion and prior knowledge},
|
| 27 |
-
# author={Puglisi, Lemuel and Alexander, Daniel C and Rav{\`\i}, Daniele},
|
| 28 |
-
# booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
|
| 29 |
-
# pages={173--183},
|
| 30 |
-
# year={2024},
|
| 31 |
-
# organization={Springer}
|
| 32 |
-
# }
|
| 33 |
|
| 34 |
import os
|
| 35 |
os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
|
|
@@ -37,7 +22,6 @@ from typing import Optional, Union
|
|
| 37 |
import pandas as pd
|
| 38 |
import argparse
|
| 39 |
import numpy as np
|
| 40 |
-
|
| 41 |
import warnings
|
| 42 |
import torch
|
| 43 |
import torch.nn as nn
|
|
@@ -47,7 +31,6 @@ from torch.nn import L1Loss
|
|
| 47 |
from torch.utils.data import DataLoader
|
| 48 |
from torch.amp import autocast
|
| 49 |
from torch.amp import GradScaler
|
| 50 |
-
|
| 51 |
from generative.networks.nets import (
|
| 52 |
AutoencoderKL,
|
| 53 |
PatchDiscriminator,
|
|
@@ -65,13 +48,11 @@ torch.serialization.add_safe_globals([_reconstruct])
|
|
| 65 |
torch.serialization.add_safe_globals([MetaTensor])
|
| 66 |
torch.serialization.add_safe_globals([ndarray])
|
| 67 |
torch.serialization.add_safe_globals([dtype])
|
| 68 |
-
|
| 69 |
from tqdm import tqdm
|
| 70 |
import matplotlib.pyplot as plt
|
| 71 |
-
|
| 72 |
from torch.utils.tensorboard import SummaryWriter
|
| 73 |
|
| 74 |
-
#
|
| 75 |
RESOLUTION = 2
|
| 76 |
|
| 77 |
# shape of the MNI152 (1mm^3) template
|
|
@@ -101,10 +82,7 @@ def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
|
|
| 101 |
"""
|
| 102 |
if checkpoints_path is not None:
|
| 103 |
assert os.path.exists(checkpoints_path), 'Invalid path'
|
| 104 |
-
# Using context manager to allow MetaTensor
|
| 105 |
-
#with torch.serialization.safe_globals([MetaTensor]):
|
| 106 |
network.load_state_dict(torch.load(checkpoints_path))
|
| 107 |
-
#network.load_state_dict(torch.load(checkpoints_path, map_location='cpu'))
|
| 108 |
return network
|
| 109 |
|
| 110 |
|
|
@@ -140,7 +118,7 @@ def init_patch_discriminator(checkpoints_path: Optional[str] = None) -> nn.Modul
|
|
| 140 |
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
|
| 141 |
|
| 142 |
Returns:
|
| 143 |
-
nn.Module: the
|
| 144 |
"""
|
| 145 |
patch_discriminator = PatchDiscriminator(spatial_dims=3,
|
| 146 |
num_layers_d=3,
|
|
@@ -387,22 +365,6 @@ def train(
|
|
| 387 |
train_df = dataset_df[dataset_df.split == 'train']
|
| 388 |
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
| 389 |
|
| 390 |
-
print(f"[DEBUG] Using cache_dir={cache_dir}")
|
| 391 |
-
print(f"[DEBUG] trainset length={len(trainset)}")
|
| 392 |
-
|
| 393 |
-
try:
|
| 394 |
-
sample_debug = trainset[0] # Force a transform on the first record
|
| 395 |
-
print("[DEBUG] Successfully loaded sample 0 from trainset.")
|
| 396 |
-
except Exception as e:
|
| 397 |
-
print("[DEBUG] Error loading sample 0:", e)
|
| 398 |
-
|
| 399 |
-
import glob
|
| 400 |
-
|
| 401 |
-
hashfiles = glob.glob(os.path.join(cache_dir, "*.pt"))
|
| 402 |
-
print(f"[DEBUG] Found {len(hashfiles)} cached .pt files in {cache_dir}")
|
| 403 |
-
if hashfiles:
|
| 404 |
-
print("[DEBUG] Example cache file:", hashfiles[0])
|
| 405 |
-
|
| 406 |
train_loader = DataLoader(
|
| 407 |
dataset=trainset,
|
| 408 |
num_workers=num_workers,
|
|
@@ -523,60 +485,11 @@ def train(
|
|
| 523 |
print("Training completed and models saved.")
|
| 524 |
|
| 525 |
|
| 526 |
-
def inference(
|
| 527 |
-
dataset_csv: str,
|
| 528 |
-
aekl_ckpt: str,
|
| 529 |
-
output_dir: str,
|
| 530 |
-
device: str = ('cuda' if torch.cuda.is_available() else
|
| 531 |
-
'cpu'),
|
| 532 |
-
) -> None:
|
| 533 |
-
"""
|
| 534 |
-
Perform inference to encode images into latent space.
|
| 535 |
-
|
| 536 |
-
Args:
|
| 537 |
-
dataset_csv (str): Path to the dataset CSV file.
|
| 538 |
-
aekl_ckpt (str): Path to the autoencoder checkpoint.
|
| 539 |
-
output_dir (str): Directory to save latent representations.
|
| 540 |
-
device (str, optional): Device to run the inference on. Defaults to 'cuda' if available.
|
| 541 |
-
"""
|
| 542 |
-
DEVICE = device
|
| 543 |
-
|
| 544 |
-
autoencoder = init_autoencoder(aekl_ckpt).to(DEVICE).eval()
|
| 545 |
-
|
| 546 |
-
transforms_fn = transforms.Compose([
|
| 547 |
-
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
| 548 |
-
transforms.LoadImageD(image_only=True, keys=['image']),
|
| 549 |
-
transforms.EnsureChannelFirstD(keys=['image']),
|
| 550 |
-
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
|
| 551 |
-
transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
| 552 |
-
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image'])
|
| 553 |
-
])
|
| 554 |
-
|
| 555 |
-
df = pd.read_csv(dataset_csv)
|
| 556 |
-
|
| 557 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 558 |
-
|
| 559 |
-
with torch.no_grad():
|
| 560 |
-
for image_path in tqdm(df.image_path, total=len(df)):
|
| 561 |
-
destpath = os.path.join(
|
| 562 |
-
output_dir,
|
| 563 |
-
os.path.basename(image_path).replace('.nii.gz', '_embeddings.npz').replace('.nii', '_embeddings.npz')
|
| 564 |
-
)
|
| 565 |
-
if os.path.exists(destpath):
|
| 566 |
-
continue
|
| 567 |
-
mri_tensor = transforms_fn({'image_path': image_path})['image'].to(DEVICE)
|
| 568 |
-
mri_latent, _ = autoencoder.encode(mri_tensor.unsqueeze(0))
|
| 569 |
-
mri_latent = mri_latent.cpu().squeeze(0).numpy()
|
| 570 |
-
np.savez_compressed(destpath, data=mri_latent)
|
| 571 |
-
|
| 572 |
-
print("Inference completed and latent representations saved.")
|
| 573 |
-
|
| 574 |
-
|
| 575 |
def main():
|
| 576 |
"""
|
| 577 |
-
Main function to parse command-line arguments and execute training
|
| 578 |
"""
|
| 579 |
-
parser = argparse.ArgumentParser(description="brain2vec Training
|
| 580 |
|
| 581 |
subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands: train or infer')
|
| 582 |
|
|
@@ -594,12 +507,6 @@ def main():
|
|
| 594 |
train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
| 595 |
train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
| 596 |
|
| 597 |
-
# Inference Subparser
|
| 598 |
-
infer_parser = subparsers.add_parser('inference', help='Run inference to encode images.')
|
| 599 |
-
infer_parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
|
| 600 |
-
infer_parser.add_argument('--aekl_ckpt', type=str, required=True, help='Path to the autoencoder checkpoint.')
|
| 601 |
-
infer_parser.add_argument('--output_dir', type=str, required=True, help='Directory to save latent representations.')
|
| 602 |
-
|
| 603 |
args = parser.parse_args()
|
| 604 |
|
| 605 |
if args.command == 'train':
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
train_brain2vec.py
|
| 5 |
+
|
| 6 |
+
Trains a 3D VAE-based Brain2Vec model using MONAI. This script implements
|
| 7 |
+
autoencoder training with adversarial loss (via a patch discriminator),
|
| 8 |
+
a perceptual loss, and KL divergence regularization for robust latent
|
| 9 |
+
representations.
|
| 10 |
+
|
| 11 |
+
Example usage:
|
| 12 |
+
python train_brain2vec.py train \
|
| 13 |
+
--dataset_csv /path/to/dataset.csv \
|
| 14 |
+
--cache_dir /path/to/cache \
|
| 15 |
+
--output_dir /path/to/output_dir \
|
| 16 |
+
--n_epochs 10
|
| 17 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
import os
|
| 20 |
os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
|
|
|
|
| 22 |
import pandas as pd
|
| 23 |
import argparse
|
| 24 |
import numpy as np
|
|
|
|
| 25 |
import warnings
|
| 26 |
import torch
|
| 27 |
import torch.nn as nn
|
|
|
|
| 31 |
from torch.utils.data import DataLoader
|
| 32 |
from torch.amp import autocast
|
| 33 |
from torch.amp import GradScaler
|
|
|
|
| 34 |
from generative.networks.nets import (
|
| 35 |
AutoencoderKL,
|
| 36 |
PatchDiscriminator,
|
|
|
|
| 48 |
torch.serialization.add_safe_globals([MetaTensor])
|
| 49 |
torch.serialization.add_safe_globals([ndarray])
|
| 50 |
torch.serialization.add_safe_globals([dtype])
|
|
|
|
| 51 |
from tqdm import tqdm
|
| 52 |
import matplotlib.pyplot as plt
|
|
|
|
| 53 |
from torch.utils.tensorboard import SummaryWriter
|
| 54 |
|
| 55 |
+
# voxel resolution
|
| 56 |
RESOLUTION = 2
|
| 57 |
|
| 58 |
# shape of the MNI152 (1mm^3) template
|
|
|
|
| 82 |
"""
|
| 83 |
if checkpoints_path is not None:
|
| 84 |
assert os.path.exists(checkpoints_path), 'Invalid path'
|
|
|
|
|
|
|
| 85 |
network.load_state_dict(torch.load(checkpoints_path))
|
|
|
|
| 86 |
return network
|
| 87 |
|
| 88 |
|
|
|
|
| 118 |
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
|
| 119 |
|
| 120 |
Returns:
|
| 121 |
+
nn.Module: the patch discriminator
|
| 122 |
"""
|
| 123 |
patch_discriminator = PatchDiscriminator(spatial_dims=3,
|
| 124 |
num_layers_d=3,
|
|
|
|
| 365 |
train_df = dataset_df[dataset_df.split == 'train']
|
| 366 |
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
| 367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
train_loader = DataLoader(
|
| 369 |
dataset=trainset,
|
| 370 |
num_workers=num_workers,
|
|
|
|
| 485 |
print("Training completed and models saved.")
|
| 486 |
|
| 487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
def main():
|
| 489 |
"""
|
| 490 |
+
Main function to parse command-line arguments and execute training.
|
| 491 |
"""
|
| 492 |
+
parser = argparse.ArgumentParser(description="brain2vec Training Script")
|
| 493 |
|
| 494 |
subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands: train or infer')
|
| 495 |
|
|
|
|
| 507 |
train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
| 508 |
train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
| 509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
args = parser.parse_args()
|
| 511 |
|
| 512 |
if args.command == 'train':
|