A Comprehensive Guide to 3D Models for Medical Image Segmentation

This article introduces 3D segmentation, partitioning volumetric data into labeled regions for applications in medical imaging, robotics, and more. Focusing on 3D semantic segmentation, it uses the Swin UNETR architecture for brain tumor segmentation as an example. The article covers core concepts, training on the BraTS dataset including MRI normalization, input/output processing, computational challenges, and adapting Swin UNETR for 3D image classification.

Vasanth Durvasula
Editor

Introduction to 3D Segmentation

3D segmentation, the process of partitioning volumetric data into distinct, labeled regions or objects, finds wide application across numerous fields. This involves analyzing data representing a three-dimensional space, such as medical imaging datasets (CT or MRI), point clouds acquired from LiDAR sensors, or 3D models generated through computer-aided design (CAD). The objective is to assign a semantic label to each individual data point (voxel, point, or mesh element), delineating boundaries and identifying constituent parts. This blog explores the core concepts behind 3D segmentation, its applications, and how SwinUNETR — a U-shaped network using a Swin Transformer encoder and CNN decoder — enhances segmentation accuracy and robustness across different medical imaging modalities.

Where is 3D Voxel Segmentation Applied?

The applications of 3D voxel segmentation are surprisingly diverse, spanning industries and disciplines in ways you might not have initially imagined. Consider just a few examples: In robotics, it empowers robots to perceive and interact with objects and by facilitating product inspection through defect analysis. For example, robots can use 3D segmentation to identify and grasp parts on a conveyor belt or to detect flaws in manufactured goods. In geospatial and environmental analysis, 3D segmentation contributes to detailed 3D mapping for urban planning and environmental studies, as well as the analysis of geological structures for mining and oil exploration. Think of city planners using segmented 3D models to analyze building heights and shadows, or environmental scientists mapping forest canopies to assess biomass. This same core technology, enabling precise 3D understanding, extends to medical applications, where it plays a critical role in analyzing organ structures from medical scans, aiding in diagnosis and treatment planning.

3D Segmentation Categories

3D segmentation techniques fall into three main categories:

  • Semantic Segmentation:  Labels each 3D element (like a voxel or mesh vertex) with a predefined class.  For example, in medical imaging, this could mean classifying each point in a scan as either "tumor" or "background."  It answers the question: What is here?
  • Instance Segmentation:  Builds upon semantic segmentation by distinguishing individual instances of the same class. So, instead of just identifying "tumor" regions, it would differentiate between separate, distinct tumors. It answers the question: Where are the individual objects of this type?
  • Part Segmentation:  Focuses on segmenting the different sub-components of a single object.  For instance, it could involve breaking down an anatomical structure into its constituent parts. It answers the question: What are the parts of this object?

This article specifically explores 3D semantic segmentation for medical imaging applications.

Types of 3D Segmentation

Core Concepts in 3D Semantic Segmentation

Before getting into the specifics of 3D segmentation, it is important to familiarize yourself with two concepts: voxels and loss functions, both of which are essential to understanding not only 3D segmentation, but 3D data and neural networks, respectively.

Voxels

A voxel, short for "volume pixel," is the 3D equivalent of a pixel, representing a value in a discretized 3D space. In medical imaging, voxels correspond to volumetric elements in MRI or CT scans. By stacking 2D slices along a third, depth dimension, we obtain a volumetric dataset composed of these voxels. 

This structured, three-dimensional representation lends itself well to analysis, particularly through the application of convolutional operations. A convolution, at its core, is a way of combining two functions to produce a third. In the context of 3D voxels, we use a small 3D cube of numbers, called a kernel or filter, which acts as a feature detector. This kernel slides across the 3D volume, and at each location, it performs element-wise multiplication with the corresponding voxels. These products are then summed to produce a single value in the output, creating a feature map

The values within the kernel determine what kind of 3D features it's sensitive to, such as shapes, boundaries, or patterns. Just as 2D convolutions extract features from images, 3D convolutional voxels extend this principle into the volumetric realm, enabling the detection and analysis of complex 3D structures within medical scans.

Convolutional 3D Voxel Visualization

Loss Functions

In 3D segmentation, accurately measuring the difference between the predicted segmentation and the ground truth is crucial for training effective models. To quantify this difference, or error, during the training process, various loss functions are employed. Among these, Dice Loss and Intersection over Union (IoU) Loss are two commonly used choices:

Dice Loss

  • The Dice Loss is derived from the Dice Coefficient, a measure of similarity between two sets (A and B). The Dice Coefficient ranges from 0 (no overlap) to 1 (identical sets). The Dice Loss is simply calculated as 1 minus the Dice Coefficient.
Dice Coefficient Formula and Visualization

So in our case, it can be used to compare the pixel-wise agreement between a predicted segmentation and its corresponding ground truth. The Dice Loss can be viewed as a differentiable Dice Coefficient:

Dice Loss Formula

Where:

  • N is the total number of voxels
  • predi is the predicted probability of voxel i belonging to the foreground class. 
  • yi is the ground truth for voxel i. 
  • ϵ is the smoothing factor (e.g. 1e-6) that prevents division by zero when both the prediction and ground truth are all empty (all zeros)

Intersection over Union (IoU)

  • The IoU Loss is derived from the Intersection over Union (IoU) metric, also known as the Jaccard Index. IoU measures the overlap between two sets – in our case, a predicted segmentation and its corresponding ground truth.
  • IoU Score: The IoU score ranges from 0 to 1.
    • 0 indicates no overlap between the predicted and ground truth segmentation.
    • 1 indicates perfect overlap, meaning the prediction is identical to the ground truth.
Intersection over Union Formula and Visualization

In practice, Dice Loss is most commonly used to train and optimize model performance because it is differentiable, whereas IoU is more so used as a metric to simply evaluate the model during training or in validation phases. As such, we focus on using Dice Loss.

Brain Tumor Segmentation Using Swin UNETR

Swin UNETR is a Transformer-based architecture designed for the 3D semantic segmentation of brain tumors in multi-modal MRI images.  It was proposed by a team from NVIDIA and Vanderbilt University in 2022.  The architecture takes its inspiration from the U-Net architecture, which is known for its effectiveness in medical image segmentation. The Swin UNETR architecture comprises an encoder and a decoder.  The encoder is based on the Swin Transformer, which employs a shifted windowing technique for efficient computation of attention across the input data.

Comparison between ground truth and prediction using Swin UNETR

The decoder is a CNN-based structure that upsamples the encoded features to generate the final segmentation output. The Swin UNETR model has demonstrated competitive results in the BraTS 2021 challenge, achieving top-ranking performance in the validation phase.  It has been shown to be effective in segmenting different sub-regions of brain tumors, including ET, WT, and TC.

image.png
Swin UNETR's Architecture (Source: Tsang, 2023)

Achieving Scan Invariance with Swin UNETR

In medical imaging, scan invariance is crucial for robust performance. It refers to a model's ability to handle images from different scanners, despite variations in intensity ranges, noise patterns, and resolutions. Swin UNETR possesses several architectural advantages that make it well-suited for achieving scan invariance:  

  1. Global Context with Swin Transformer: The Swin Transformer encoder captures global dependencies for robust segmentation across scanners. 
  2. Patch-Based Processing: The network processes input volumes as fixed-size, non-overlapping patches. This approach simplifies handling varying image resolutions or sizes, and importantly, reduces computational demands without significantly impacting performance.
  3. Skip Connections for Fine-Grained Detail: Skip connections enable convolutional layers to bypass intermediate layers, preserving high-resolution anatomical features and fine-grained details. 

Although Swin UNETR is not inherently immune to variations between scanners, its design makes it readily adaptable to achieve scan invariance. The combination of its Transformer encoder, patch-based processing, and U-Net-like architecture enables the effective use of preprocessing and training techniques to enhance its performance across diverse scanning environments.

Input Processing

Swin UNETR's robust input processing pipeline is structured as follows:

  1. Input Acquisition: The system ingests a 3D MRI volume, typically visualized in three standard anatomical planes: axial, sagittal, and coronal.
  2. Patch Generation: This 3D volume is systematically decomposed into non-overlapping patches.
  3. Swin Transformer Encoding: These patches are then passed to the Swin Transformer encoder.
    • The encoder performs multi-scale analysis on these patches, producing multi-scale feature representations which hold information about scans at differing resolution sizes.
    • It employs window-based attention mechanisms to efficiently manage the large 3D volume.
  4. Skip Connections: To ensure the preservation of fine-grained spatial information, skip connections relay intermediate feature maps from the encoder to the decoder.
  5. CNN-Based Decoding:
    • The decoder upsamples the feature representations to their original resolution. A CNN is then employed to extract local features within these upsampled representations. These local features are strategically merged with the multi-scale features from the encoder, conveyed through skip connections, to refine voxel-level predictions. While transformers could be considered for the decoder as well, CNNs are well-suited for capturing local patterns and offer computational efficiency in this context.
Planes of the MRI

Training SwinUNETR on the BraTS Dataset

To effectively train Swin UNETR for accurate brain tumor segmentation, we leverage the Brain Tumor Segmentation (BraTS) dataset, which focuses on segmenting brain tumors from MRI scans into three primary regions of interest:

  • Edema (ED): Surrounding swelling caused by the tumor.  
  • Enhancing Tumor (ET): The core tumor mass, typically showing active growth and increased vascularity.  
  • Non-Enhancing Tumor (NET): The tumor portion with less blood flow, often indicating areas of necrosis or less active growth.  

Each class is assigned a specific label for training purposes: 0 for ED, 1 for ET, and 2 for NET.  

However, MRI intensities, which typically range from 0 to ~2000, need to be normalized before being fed into the model. Unlike modalities like Computed Tomography (CT) where voxel intensities correspond to standardized physical units (Hounsfield units), MRI intensities lack a consistent scale. This inconsistency arises because MRI signal strength is affected by:  

  • Scanner Differences: Different MRI machines may produce varying intensity ranges due to hardware discrepancies.
  • Acquisition Parameters: Settings like echo time (TE), repetition time (TR), and flip angle can alter signal intensities.
  • Patient Factors: Individual patient characteristics can influence the signal received.

Consequently, the same tissue type might exhibit different intensity values across scans, even within the same patient or using the same scanner. Normalization is crucial for several reasons:

  • Comparative Analysis: When comparing images from different sessions or patients, consistent intensity values are essential for accurate assessment.
  • Automated Processing: Many image analysis algorithms, including segmentation and registration tools, assume a standardized intensity range. Variations can lead to errors or reduced performance.
  • Quantitative Studies: Research involving quantitative measurements relies on consistent intensity values to draw valid conclusions.

Several methods have been developed to standardize MRI intensities:

Z-Score Normalization

Adjusts the image based on the mean and standard deviation of its intensities. 

Z-Score Normalisation Visualization

Python code snippet for Z-Score Normalization:

Histogram Matching

Aligns the histogram of an image to a reference histogram, ensuring a consistent intensity distribution.

Histogram Matching Normalisation Visualization

Python code snippet for Histogram Matching:

Training Logic

During the training process, an Adam optimizer with a learning rate of 0.0001 was used. Visualization tools like 3D Slicer or ParaView aid in visualizing the volumes and ground-truth labels, helping researchers visualize the model's performance and identify areas for improvement

Input

The input tensor to the Swin UNETR model is a 4D tensor representing volumetric medical images. Here's the breakdown:

  • Shape: (B, C, H, W, D)
    • B: Batch size (input parameter, determined by batch_size).
    • C: Number of input channels (here 4, corresponding to 4 MRI modalities: FLAIR, T1w, t1gd, T2w).
    • H, W, D: Spatial dimensions of the 3D image (in this case (96, 96, 96), as set by the input parameter roi_size).

The input images are cropped to a fixed region-of-interest (ROI) size, to focus on the most relevant areas and manage computational resources. The network architecture is designed to work with this fixed input size, often requiring the dimensions to be multiples of 32 to facilitate downsampling and upsampling operations. 

Example Input Shape: For a single batch:(1, 4, 96, 96, 96)

Context: The model processes volumetric data from 3D MRI scans, where each voxel has intensity values for 4 different modalities, providing rich spatial and anatomical information for segmentation.

Pre-processing part

Before training the model, each image undergoes a series of preprocessing steps to standardize intensities and enhance the model's generalizability. First, the images are z-normalized to ensure that intensity values across all BraTS images have a consistent scale. This standardization is crucial for mitigating variations in image acquisition and ensuring that the model learns from a unified intensity distribution.

In addition to normalization, various data augmentations are applied to improve the model's robustness and prevent overfitting. These augmentations introduce variability in the training data, helping the model generalize better to unseen examples. You can refer to a more comprehensive list of them here.

SwinUNETR lacks intrinsic rotation and scale invariance due to its window-based attention and patch tokenization. Hence these augmentations help the model generalize by exposing it to diverse orientations and feature distributions.

Output

The output tensor from the SwinUNETR model is a 4D tensor representing the segmentation map of the input image. Here's the breakdown:

  • Shape: (B, C_out, H, W, D)
    • B: Batch size.
    • C_out: Number of output channels.  In this case 3, corresponding to the 3 tumor classes: edema, non-enhancing tumor and enhancing tumor)
    • H, W, D: Spatial dimensions of the segmentation map (same as the input, (96, 96, 96), ensuring alignment with the input image). 

The spatial dimension of the segmentation map is the same as the input image as the goal of segmentation is to assign a label or probability to each voxel in the input image. By keeping the spatial dimensions identical, every voxel in the original volume has a corresponding prediction in the segmentation map. 

Example Output Shape: For a single batch:(1, 3, 96, 96, 96)

Context:

  • Each channel in the output corresponds to a specific tumor class:
    • Channel 0: Edema segmentation probability map.
    • Channel 1: Non-Enhancing Tumor segmentation probability map.
    • Channel 2: Enhancing Tumor (et) segmentation probability map.
  • After applying a sigmoid activate function, the values in each voxel of the output tensor range between [0, 1], representing the probability that the voxel belongs to the corresponding class.

Post-Processing

  1. Thresholding: The output probabilities are converted into binary masks using a threshold of 0.5. This is done using:

hard_outputs = (torch.sigmoid(outputs) > 0.5).float()

After this step, each voxel will have a value of 1 (belonging to the tumor class) or 0 (background).

  1. Per-Class Metrics: The hard_outputs are used to calculate the Dice score for each tumor class.

Summary

  • Input Shape: (1, 4, 96, 96, 96) — a single 3D image with 4 modalities.
  • Output Shape: (1, 3, 96, 96, 96) — segmentation maps for 3 tumor classes.
  • Interpretation:
    • Input provides volumetric spatial and anatomical data across modalities.
    • Output provides voxel-wise segmentation probabilities for the tumor regions.

Training Results

Training Results

The model demonstrates effective learning, as evidenced by the increasing trend in both training and validation mean Dice scores over the epochs. Within just 10 epochs, the model achieves a training mean Dice score of 0.82 and a validation mean Dice score of 0.78. The scores plateau in later epochs, suggesting that the model is approaching its optimal performance. Our Dice scores (0.82 training, 0.78 validation) are lower than the original Swin UNETR paper's (0.964), likely due to our limited 10-epoch training. We did not perform hyperparameter tuning or cross-validation, which also impacts results. Future work will explore these areas to better match reported performance. The close alignment between the training and validation curves indicates good generalization to unseen data without significant overfitting. Notably, the model reaches this performance level within only 10 epochs.

Performance and Computation Analysis

Training 3D models like Swin UNETR is computationally expensive because they process volumetric data (height × width × depth), unlike 2D models that handle single slices. This impacts memory usage in several ways:

  • Increased Input Size: 3D volumes contain many slices, adding an extra dimension to the input data. Even moderate-sized volumes have significantly more voxels (3D pixels) than a 2D image.
  • More Features to Extract: The network needs to learn patterns in three dimensions, requiring more complex architectures with increased parameters and memory usage.
  • GPU Memory Limits: GPUs with limited VRAM can struggle to store the large 3D feature maps and intermediate activations generated during processing.

These factors can constrain batch size and volume size. Smaller batch sizes can impact training stability, while downsampling volumes can limit the model's ability to capture broader anatomical context.

Potential Solutions:

  • Patch-Based Training: Processing smaller sub-volumes, or patches, reduces memory usage while still exposing the model to diverse image regions.
  • Gradient Checkpointing: This technique recomputes activations during training, reducing the need to store them all in memory.
  • Mixed Precision Training: Using lower precision for calculations, such as FP16/BFloat16 instead of Float32, etc., decreases memory usage without significant accuracy loss.
  • Distributed Training: Distributing the data and model across multiple GPUs increases the effective memory capacity.

Adapting Swin UNETR for 3D Image Classification

Originally designed for the challenging task of 3D medical image segmentation, Swin UNETR's unique architecture and training approach offer a robust foundation for various vision applications. In this paper, we investigate this potential by adapting Swin UNETR to the distinct demands of image classification.

Architectural Changes:

  1. Remove the Decoder: The U-Net-like decoder, designed for upsampling and generating segmentation maps, is not needed for classification.
  2. Global Average Pooling: Apply global average pooling to the feature maps generated by the Swin Transformer encoder. This aggregates the spatial information into a single feature vector.
  3. Fully Connected Layer: Add a fully connected layer after the global pooling. This layer maps the aggregated features to class probabilities, producing the final classification output.

Training Adaptations:

  • Dataset: Use a classification dataset instead of a segmentation dataset. For example, 3D MNIST, a dataset of volumetric digits (0-9), could be used.
  • Loss Function and Metrics: Use a cross-entropy loss function, which is suitable for classification tasks. Evaluate performance with metrics like accuracy, precision, recall, and F1-score.

By implementing these modifications, Swin UNETR can be effectively repurposed for 3D image classification tasks. In our training results, the training accuracy and validation accuracy steadily increased over the epochs and seem to be plateauing after around 40 epochs.  In 38 epochs, it achieved a training accuracy of 96% and a validation accuracy of 80% on 3D MNIST Data.

Image Classification Training Results

What’s Next?

Datature Nexus provides a robust platform for developing and deploying proprietary AI models for 3D medical image analysis. Its intuitive interface and comprehensive features enable efficient and precise analysis, improving diagnostic accuracy and streamlining workflows.

Datature's Nexus 3D Image Annotator

Continuous Innovation in Medical Imaging

Datature is committed to advancing medical imaging technology. Our R&D focuses on integrating cutting-edge models to maximize platform accuracy and throughput. We are exploring next-generation models for 3D image segmentation and their application to diverse diagnostic needs. The upcoming SegResNet integration will further enhance Datature Nexus capabilities. Details will be available in our next publication.

Partner with Datature

Explore how Datature Nexus can empower your medical imaging initiatives. Contact us here to develop custom AI models tailored to your requirements.

Our Developer’s Roadmap

If you have questions, feel free to join our Community Slack to post your questions or contact us if you wish to learn more about training a SwinUNETR model on Datature Nexus. For more detailed information about the SwinUNETR architecture, customization options, or answers to any common questions you might have, read more on our Developer Portal.

Build models with the best tools.

develop ml models in minutes with datature

START A PROJECT