In this guide, we’ll create a complete 3D medical image segmentation system using MONAI to identify the spleen in CT scans from the Medical Segmentation Decathlon Task09 dataset. We’ll handle volumetric medical images by applying essential preprocessing steps like aligning image orientation, standardizing voxel spacing, adjusting intensity ranges, extracting foreground regions, and sampling patches. Then, we’ll train a 3D UNet model specifically designed for binary organ segmentation. Along the way, we’ll employ mixed-precision training for efficiency, DiceCE loss for optimization, sliding-window inference for accurate predictions, Dice score evaluation during validation, and visual comparisons to see how closely our model’s predictions match the expert annotations. Essentially, we’ll transform raw medical imaging volumes into a fully functional train–validate–visualize segmentation solution.
!pip install -q "monai[nibabel,tqdm,matplotlib]==1.5.2" 2>/dev/null
import os, time, glob, tempfile, warnings
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.amp import autocast, GradScaler
from monai.apps import DecathlonDataset
from monai.data import DataLoader, decollate_batch
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
Spacingd, ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
RandFlipd, RandRotate90d, RandShiftIntensityd, AsDiscrete,
)
warnings.filterwarnings("ignore")
First, we install MONAI along with its essential medical imaging and visualization packages. Next, we bring in PyTorch, NumPy, Matplotlib, and MONAI’s core components for handling datasets, data transformations, model training, performance metrics, and inference. We also silence warnings to maintain clean output as we concentrate on building our segmentation pipeline.
QUICK_RUN = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root_dir = tempfile.mkdtemp()
roi_size = (96, 96, 96)
num_samples = 4
batch_size = 2
max_epochs = 15 if QUICK_RUN else 200
val_every = 3
train_cache = 8 if QUICK_RUN else 24
val_cache = 2 if QUICK_RUN else 6
set_determinism(seed=0)
print(f"Device: {device} | epochs: {max_epochs} | data dir: {root_dir}")
train_transforms = Compose(common + [
image_key="image", image_threshold=0),
RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=1),
RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=2),
RandRotate90d(keys=["image", "label"], prob=0.2, max_k=3),
RandShiftIntensityd(keys=["image"], offsets=0.10, prob=0.5),
EnsureTyped(keys=["image", "label"]),
])
val_transforms = Compose(common + [EnsureTyped(keys=["image", "label"])])
Here, we configure key tutorial parameters such as the compute device (CUDA if available), temporary data storage location, region-of-interest dimensions, batch size, training epochs, and caching strategies. Next, we build our image preprocessing workflow for CT scans: loading volumes, standardizing orientation, resampling to consistent voxel spacing, normalizing intensity values, and extracting foreground areas. We then establish separate transformation pipelines for training and validation — the training version includes additional random augmentations like flips, rotations, and intensity shifts to improve model robustness.
train_ds = DecathlonDataset(
root_dir=root_dir, task="Task09_Spleen", section="training",
transform=train_transforms, download=True, val_frac=0.2,
cache_num=train_cache, num_workers=2, seed=0)
val_ds = DecathlonDataset(
root_dir=root_dir, task="Task09_Spleen", section="validation",
transform=val_transforms, download=False, val_frac=0.2,
cache_num=val_cache, num_workers=2, seed=0)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=2, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
num_workers=1, pin_memory=torch.cuda.is_available())
print(f"Train volumes: {len(train_ds)} | Val volumes: {len(val_ds)}")
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
scaler = GradScaler("cuda", enabled=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=False, reduction="mean")
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])
We load the Task09 Spleen dataset from the Medical Segmentation Decathlon using MONAI’s specialized DecathlonDataset class. The dataset is automatically split into training and validation subsets, each processed with their respective transformation pipelines and wrapped in PyTorch DataLoaders for efficient batch processing. Next, we initialize our 3D UNet architecture, define the DiceCE loss function (combining Dice and cross-entropy losses), configure the AdamW optimizer with weight decay, set up a cosine annealing learning-rate scheduler, prepare the mixed-precision gradient scaler, initialize the Dice metric for evaluation, and define post-processing transforms to convert model outputs and labels into comparable formats.
best_dice, best_epoch = -1.0, -1
loss_hist, dice_hist, dice_epochs = [], [], []
best_path = os.path.join(root_dir,
We execute the complete training cycle, where every epoch teaches the 3D UNet using cropped volumetric patches drawn from the spleen dataset. To conserve memory and accelerate computation on compatible GPUs, we employ automatic mixed precision. At set intervals, we assess the model's performance through sliding-window inference, monitor the Dice metric, and preserve the top-performing model weights.
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(range(1, len(loss_hist)+1), loss_hist, "-o", ms=3)
ax[0].set(title="Training loss", xlabel="epoch", ylabel="DiceCE loss")
ax[1].plot(dice_epochs, dice_hist, "-o", color="seagreen", ms=4)
ax[1].set(title="Validation mean Dice", xlabel="epoch", ylabel="Dice"); ax[1].set_ylim(0, 1)
plt.tight_layout(); plt.show()
model.load_state_dict(torch.load(best_path, map_location=device)); model.eval()
with torch.no_grad():
sample = next(iter(val_loader))
img = sample["image"].to(device)
with autocast("cuda", enabled=torch.cuda.is_available()):
pred = sliding_window_inference(img, roi_size, 4, model, overlap=0.5)
pred = torch.argmax(pred, dim=1).cpu().numpy()[0]
img_np, lab_np = img.cpu().numpy()[0, 0], sample["label"].numpy()[0, 0]
z = int(np.argmax(lab_np.sum(axis=(0, 1))))
fig, ax = plt.subplots(1, 3, figsize=(13, 5))
ax[0].imshow(img_np[:, :, z], cmap="gray"); ax[0].set_title("CT slice")
ax[1].imshow(lab_np[:, :, z], cmap="viridis"); ax[1].set_title("Ground truth")
ax[2].imshow(pred[:, :, z], cmap="viridis"); ax[2].set_title("Prediction")
for a in ax: a.axis("off")
plt.tight_layout(); plt.show()
Initially, we graph the training loss alongside the validation Dice score to observe the model's progression throughout the training period. Following this, we load the optimal model checkpoint and perform inference on a validation volume via sliding-window prediction. A side-by-side comparison of the CT slice, the actual mask, and the generated segmentation allows us to visually evaluate the model's effectiveness.
To summarize, we have successfully implemented a comprehensive MONAI-driven pipeline for 3D spleen segmentation utilizing a 3D UNet architecture. This involved curating the Medical Segmentation Decathlon dataset, applying transformations and augmentations to the CT volumes, training the model with DiceCE loss, conducting validation via sliding-window inference, and recording both loss and Dice metrics. Furthermore, we performed a visual inspection by juxtaposing the CT data, the true labels, and the model's predictions. This exercise provides a solid grasp of how MONAI facilitates medical imaging tasks, spanning from initial data handling and preprocessing to the final stages of training, evaluation, model saving, and qualitative review.
Explore the Full Codes with Notebook. Additionally, feel free to connect with us on Twitter and remember to subscribe to our Newsletter and join our 150k+ ML SubReddit. Are you on Telegram? You can now join our community there as well.
Interested in collaborating with us to showcase your GitHub Repository, Hugging Face page, Product Launch, or Webinar? Get in touch with us
The post A Coding Implementation on MONAI for End-to-End 3D Spleen Segmentation Using UNet on Medical CT Volumes appeared first on MarkTechPost.



