A Comprehensive Guide to 3D Models for Medical Image Segmentation

This article introduces 3D segmentation, partitioning volumetric data into labeled regions for applications in medical imaging, robotics, and more. Focusing on 3D semantic segmentation, it uses the Swin UNETR architecture for brain tumor segmentation as an example. The article covers core concepts, training on the BraTS dataset including MRI normalization, input/output processing, computational challenges, and adapting Swin UNETR for 3D image classification.

Marcus Neo
Editor

Introduction to 3D Segmentation

3D segmentation, the process of partitioning volumetric data into distinct, labeled regions or objects, finds wide application across numerous fields. This involves analyzing data representing a three-dimensional space, such as medical imaging datasets (CT or MRI), point clouds acquired from LiDAR sensors, or 3D models generated through computer-aided design (CAD). The objective is to assign a semantic label to each individual data point (voxel, point, or mesh element), delineating boundaries and identifying constituent parts. In medical imaging and healthcare, this enables tumor identification and surgical planning, along with precise measurement and monitoring of anatomical structures. 

Where is 3D Segmentation Applied?

Robotics and industrial automation benefit by empowering robots to perceive and interact with objects and by facilitating product inspection through defect analysis. For autonomous vehicles, segmenting LIDAR data is crucial for distinguishing pedestrians, vehicles, and obstacles, enabling accurate path planning by identifying drivable areas and hazards. Finally, in geospatial and environmental analysis, 3D segmentation contributes to detailed 3D mapping for urban planning and environmental studies, as well as the analysis of geological structures for mining and oil exploration.

3D Segmentation Categories

3D segmentation techniques fall into three main categories:

  • Semantic Segmentation:  Labels each 3D element (like a voxel or mesh vertex) with a predefined class.  For example, in medical imaging, this could mean classifying each point in a scan as either "tumor" or "background."  It answers the question: What is this?
  • Instance Segmentation:  Builds upon semantic segmentation by distinguishing individual instances of the same class. So, instead of just identifying "tumor" regions, it would differentiate between separate, distinct tumors. It answers the question: Where are the individual objects of this type?
  • Part Segmentation:  Focuses on segmenting the different sub-components of a single object.  For instance, it could involve breaking down an anatomical structure into its constituent parts. It answers the question: What are the parts of this object?

This article specifically explores 3D semantic segmentation for medical imaging applications.

Core Concepts in 3D Semantic Segmentation

Voxels

A voxel (short for "volume pixel") is the 3D equivalent of a pixel, representing a value in a discretized 3D space. In medical imaging, voxels correspond to volumetric elements in MRI or CT scans. By stacking 2D slices along a third dimension (depth), we obtain a volumetric dataset composed of these voxels.

Voxel Visualization

Loss Functions

Dice Loss and Intersection over Union (IoU) Loss are two commonly used losses in 3D segmentation:

  • Dice Coefficient
    • Measure of similarity between two sets A and B. The coefficients range from 0 to 1 where 1 indicates that the two sets are identical and 0 indicates that they have no overlap. Dice coefficient = 2 * |A ∩ B| / (|A| + |B|).
    • So in our case, it can be used to compare the pixel-wise agreement between a predicted segmentation and its corresponding ground truth. 
  • Intersection over Union (IoU)
    • Quantifies the overlap between the segmented region and the ground truth
    • IOU = |A ∩ B| / (|A| U |B|). 

Dice Loss is used as a loss function as it is differentiable while IoU is not.

Dice Loss Visualization

Brain Tumor Segmentation Using Swin UNETR

Swin UNETR is a Transformer-based architecture designed for the 3D semantic segmentation of brain tumors in multi-modal MRI images.  It was proposed by a team from NVIDIA and Vanderbilt University in 2022.  The architecture takes its inspiration from the U-Net architecture, which is known for its effectiveness in medical image segmentation. The Swin UNETR architecture comprises an encoder and a decoder.  The encoder is based on the Swin Transformer, which leverages a shifted windowing technique to capture both local and global contexts within the input data.  The decoder is a CNN-based structure that upsamples the encoded features to generate the final segmentation output. The Swin UNETR model has demonstrated competitive results in the BraTS 2021 challenge, achieving top-ranking performance in the validation phase.  It has been shown to be effective in segmenting different sub-regions of brain tumors, including ET, WT, and TC.

image.png
Swin UNETR's Architecture

Achieving Scan Invariance with Swin UNETRIn medical imaging, scan invariance is crucial for robust performance. It refers to a model's ability to handle images from different scanners, despite variations in intensity ranges, noise patterns, and resolutions. Swin UNETR possesses several architectural advantages that make it well-suited for achieving scan invariance.  

  1. Global Context with Swin Transformer: The Swin Transformer encoder captures global dependencies across the entire image, reducing reliance on localized features that may vary across scanners.  
  2. Patch-Based Processing: The network processes input volumes as non-overlapping patches, making it easier to handle inconsistencies in image resolution or size.  
  3. Skip Connections for Fine-Grained Detail: Skip connections preserve high-resolution features, ensuring that subtle anatomical structures are retained despite variations in intensity or noise levels between scanners.  
  4. Easy Integration with Preprocessing Pipelines: Swin UNETR's modular design allows for seamless integration with preprocessing steps like intensity normalization and denoising, further enhancing scan invariance.  

Although Swin UNETR is not inherently immune to variations between scanners, its design makes it readily adaptable to achieve scan invariance. The combination of its Transformer backbone, patch-based processing, and U-Net-like architecture enables the effective use of preprocessing and training techniques to enhance its performance across diverse scanning environments.Input ProcessingSwin UNETR's robust input processing pipeline is structured as follows:

  1. Input Acquisition: The system ingests a 3D MRI volume, typically visualized in three standard anatomical planes: axial, sagittal, and coronal.
  2. Patch Generation: This 3D volume is systematically decomposed into non-overlapping patches.
  3. Swin Transformer Encoding: These patches are then passed to the Swin Transformer encoder.
    • The encoder performs multi-scale analysis on these patches.
    • It employs window-based attention mechanisms to efficiently manage the large 3D volume.
    • This process yields hierarchical feature representations.
  4. Skip Connections: To ensure the preservation of fine-grained spatial information, skip connections relay intermediate feature maps from the encoder to the decoder.
  5. CNN-Based Decoding:
    • The decoder upsamples the feature representations back to their original resolution.
    • It then strategically merges these upsampled features with the encoder features (conveyed through the skip connections) to refine the voxel-level predictions.
Source: Avots et al., 2024

Training on BraTS Dataset

To effectively train Swin UNETR for accurate brain tumor segmentation, the Brain Tumor Segmentation (BraTS) dataset serves as a benchmark. This dataset focuses on segmenting brain tumors from MRI scans into three primary regions of interest:

  • Edema (ED): Surrounding swelling caused by the tumor.  
  • Enhancing Tumor (ET): The core tumor mass, typically showing active growth and increased vascularity.  
  • Non-Enhancing Tumor (NET): The tumor portion with less blood flow, often indicating areas of necrosis or less active growth.  

Each class is assigned a specific label for training purposes: 0 for background, 1 for ED, 2 for ET, and 3 for NET.  However, MRI intensities, which typically range from 0 to ~2000, need to be normalized before being fed into the model. Unlike modalities like Computed Tomography (CT) where voxel intensities correspond to standardized physical units (Hounsfield units), MRI intensities lack a consistent scale. This inconsistency arises because MRI signal strength is affected by:  

  • Scanner Differences: Different MRI machines may produce varying intensity ranges due to hardware discrepancies.
  • Acquisition Parameters: Settings like echo time (TE), repetition time (TR), and flip angle can alter signal intensities.
  • Patient Factors: Individual patient characteristics can influence the signal received.

Consequently, the same tissue type might exhibit different intensity values across scans, even within the same patient or using the same scanner. Normalization is crucial for several reasons:

  • Comparative Analysis: When comparing images from different sessions or patients, consistent intensity values are essential for accurate assessment.
  • Automated Processing: Many image analysis algorithms, including segmentation and registration tools, assume a standardized intensity range. Variations can lead to errors or reduced performance.
  • Quantitative Studies: Research involving quantitative measurements relies on consistent intensity values to draw valid conclusions.

Several methods have been developed to standardize MRI intensities:

  • Z-Score Normalization: Adjusts the image based on the mean and standard deviation of its intensities.
  • Histogram Matching: Aligns the histogram of an image to a reference histogram, ensuring a consistent intensity distribution.
  • WhiteStripe Method: Identifies and uses normal-appearing white matter regions to standardize intensities.

During the training process, an Adam optimizer with a learning rate of 0.0001 is commonly used. Visualization tools like 3D Slicer or ParaView aid in visualizing the volumes and ground-truth labels, helping researchers assess the model's performance and identify areas for improvement.

Training Logic

Input

The input tensor to the Swin UNETR model is a 4D tensor representing volumetric medical images. Here's the breakdown:

  • Shape: (B, C, H, W, D)
    • B: Batch size (determined by batch_size, which is 1 in this case).
    • C: Number of input channels (here 4, corresponding to 4 MRI modalities such as FLAIR, T1w, t1gd, T2w)
    • H, W, D: Spatial dimensions of the 3D image (in this case (96, 96, 96), as set by roi_size).

Example Input Shape: For a single batch:(1, 4, 96, 96, 96)

Context: The model processes volumetric data from 3D MRI scans, where each voxel has intensity values for 4 different modalities, providing rich spatial and anatomical information for segmentation.

Output: The output tensor from the SwinUNETR model is a 4D tensor representing the segmentation map of the input image. Here's the breakdown:

  • Shape: (B, C_out, H, W, D)
    • B: Batch size (same as the input batch size, here 1).
    • C_out: Number of output channels (here 3, corresponding to the 3 tumor classes: edema, non-enhancing tumor and enhancing tumor)
    • H, W, D: Spatial dimensions of the segmentation map (same as the input, (96, 96, 96)).

Example Output Shape: For a single batch:(1, 3, 96, 96, 96)

Context:

  • Each channel in the output corresponds to a specific tumor class:
    • Channel 0: Edema segmentation probability map.
    • Channel 1: Non-Enhancing Tumor segmentation probability map.
    • Channel 2: Enhancing Tumor (et) segmentation probability map.
  • The values in each voxel of the output tensor range between [0, 1] (after applying a sigmoid activation function), representing the probability that the voxel belongs to the corresponding class.

Post-Processing

  1. Thresholding: The output probabilities are converted into binary masks using a threshold of 0.5. This is done using:

hard_outputs = (torch.sigmoid(outputs) > 0.5).float()After this step, each voxel will have a value of 1 (belonging to the tumor class) or 0 (background).

  1. Per-Class Metrics: The hard_outputs are used to calculate metrics like Dice score for each tumor class.

Summary

  • Input Shape: (1, 4, 96, 96, 96) — a single 3D image with 4 modalities.
  • Output Shape: (1, 3, 96, 96, 96) — segmentation maps for 3 tumor classes.
  • Interpretation:
    • Input provides volumetric spatial and anatomical data across modalities.
    • Output provides voxel-wise segmentation probabilities for the tumor regions.
Training Results

The model demonstrates effective learning, as evidenced by the increasing trend in both training and validation mean Dice scores over the epochs. Within just 10 epochs, the model achieves a training mean Dice score of 0.82 and a validation mean Dice score of 0.78. The scores plateau in later epochs, suggesting that the model is approaching its optimal performance. The close alignment between the training and validation curves indicates good generalization to unseen data without significant overfitting. Notably, the model reaches this performance level within only 10 epochs.Performance and Computation AnalysisTraining 3D models like Swin UNETR is computationally expensive because they process volumetric data (height × width × depth), unlike 2D models that handle single slices. This impacts memory usage in several ways:

  • Increased Input Size: 3D volumes contain many slices, adding an extra dimension to the input data. Even moderate-sized volumes have significantly more voxels (3D pixels) than a 2D image.
  • More Features to Extract: The network needs to learn patterns in three dimensions, requiring more complex architectures with increased parameters and memory usage.
  • GPU Memory Limits: GPUs with limited VRAM (e.g., 12GB on an RTX 4070) can struggle to store the large 3D feature maps and intermediate activations generated during processing.

These factors can constrain batch size and volume size. Smaller batch sizes can impact training stability, while downsampling volumes can limit the model's ability to capture broader anatomical context.

Potential Solutions:

  • Patch-Based Training: Processing smaller sub-volumes (patches) reduces memory usage while still exposing the model to diverse image regions.
  • Gradient Checkpointing: This technique recomputes activations during training, reducing the need to store them all in memory.
  • Mixed Precision Training: Using lower precision for calculations (FP16/BFloat16) decreases memory usage without significant accuracy loss.
  • Distributed Training: Distributing the data and model across multiple GPUs increases the effective memory capacity.

Adapting Swin UNETR for 3D Image Classification

Swin UNETR can be adapted for classification tasks by modifying its architecture and training process:

Architectural Changes:

  1. Remove the Decoder: The U-Net-like decoder, designed for upsampling and generating segmentation maps, is not needed for classification.
  2. Global Average Pooling: Apply global average pooling to the feature maps generated by the Swin Transformer encoder. This aggregates the spatial information into a single feature vector.
  3. Fully Connected Layer: Add a fully connected layer after the global pooling. This layer maps the aggregated features to class probabilities, producing the final classification output.

Training Adaptations:

  • Dataset: Use a classification dataset instead of a segmentation dataset. For example, 3D MNIST, a dataset of volumetric digits (0-9), could be used.
  • Loss Function and Metrics: Use a cross-entropy loss function, which is suitable for classification tasks. Evaluate performance with metrics like accuracy, precision, recall, and F1-score.

By implementing these modifications, Swin UNETR can be effectively repurposed for 3D image classification tasks. In our training results, the training accuracy and validation accuracy steadily increased over the epochs and seem to be plateauing after around 40 epochs.  In 38 epochs, it achieved a training accuracy of 96% and a validation accuracy of 80% on 3D MNIST Data.

Image Classification Training Results

What’s Next?

Datature is constantly looking to incorporate state-of-the-art models that can boost accuracy and inference throughput. We are actively looking for new models that push the frontiers of 3D image segmentations, as well as explore how these models can be tweaked for other tasks. In fact, we are incorporating another model named SegResNet. Refer to our next blog for more information..

Our Developer’s Roadmap

If you have questions, feel free to join our Community Slack to post your questions or contact us if you wish to learn more about training a SwinUNETR model on Datature Nexus. For more detailed information about the SwinUNETR architecture, customization options, or answers to any common questions you might have, read more on our Developer Portal.

Build models with the best tools.

develop ml models in minutes with datature

START A PROJECT