Need help with pytorch-3dunet?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

wolny
808 Stars 268 Forks MIT License 436 Commits 3 Opened issues

Description

3D U-Net model for volumetric semantic segmentation written in pytorch

Services available

!
?

Need anything else?

Contributors list

# 49,311
Jupyter...
Python
unet
semanti...
384 commits
# 178,357
pytorch
unet
semanti...
TeX
25 commits
# 362,916
Python
Shell
unet
semanti...
3 commits
# 464,375
Python
Shell
unet
semanti...
1 commit

DOI Build Status

pytorch-3dunet

PyTorch implementation 3D U-Net and its variants:

The code allows for training the U-Net for both: semantic segmentation (binary and multi-class) and regression problems (e.g. de-noising, learning deconvolutions).

2D U-Net

Training the standard 2D U-Net is also possible, see 2DUnet_dsb2018 for example configuration. Just make sure to keep the singleton z-dimension in your H5 dataset (i.e.

(1, Y, X)
instead of
(Y, X)
) , cause data loading / data augmentation requires tensors of rank 3 always.

Prerequisites

  • Linux
  • NVIDIA GPU
  • CUDA CuDNN

Running on Windows

The package has not been tested on Windows, however some reported using it on Windows. One thing to keep in mind: when training with

CrossEntropyLoss
: the label type in the config file should be change from
long
to
int64
, otherwise there will be an error:
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'
.

Supported Loss Functions

Semantic Segmentation

  • BCEWithLogitsLoss (binary cross-entropy)
  • DiceLoss (standard
    DiceLoss
    defined as
    1 - DiceCoefficient
    used for binary semantic segmentation; when more than 2 classes are present in the ground truth, it computes the
    DiceLoss
    per channel and averages the values).
  • BCEDiceLoss (Linear combination of BCE and Dice losses, i.e.
    alpha * BCE + beta * Dice
    ,
    alpha, beta
    can be specified in the
    loss
    section of the config)
  • CrossEntropyLoss (one can specify class weights via
    weight: [w_1, ..., w_k]
    in the
    loss
    section of the config)
  • PixelWiseCrossEntropyLoss (one can specify not only class weights but also per pixel weights in order to give more gradient to important (or under-represented) regions in the ground truth)
  • WeightedCrossEntropyLoss (see 'Weighted cross-entropy (WCE)' in the below paper for a detailed explanation; one can specify class weights via
    weight: [w_1, ..., w_k]
    in the
    loss
    section of the config)
  • GeneralizedDiceLoss (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation; one can specify class weights via
    weight: [w_1, ..., w_k]
    in the
    loss
    section of the config). Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise use standard DiceLoss.

For a detailed explanation of some of the supported loss functions see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations Carole H. Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, M. Jorge Cardoso

IMPORTANT: if one wants to use their own loss function, bear in mind that the current model implementation always output logits and it's up to the implementation of the loss to normalize it correctly, e.g. by applying Sigmoid or Softmax.

Regression

  • MSELoss
  • L1Loss
  • SmoothL1Loss
  • WeightedSmoothL1Loss - extension of the SmoothL1Loss which allows to weight the voxel values above (below) a given threshold differently

Supported Evaluation Metrics

Semantic Segmentation

  • MeanIoU - Mean intersection over union
  • DiceCoefficient - Dice Coefficient (computes per channel Dice Coefficient and returns the average) If a 3D U-Net was trained to predict cell boundaries, one can use the following semantic instance segmentation metrics (the metrics below are computed by running connected components on thresholded boundary map and comparing the resulted instances to the ground truth instance segmentation):
  • BoundaryAveragePrecision - Average Precision applied to the boundary probability maps: thresholds the boundary maps given by the network, runs connected components to get the segmentation and computes AP between the resulting segmentation and the ground truth
  • AdaptedRandError - Adapted Rand Error (see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation)
  • AveragePrecision - see https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric

If not specified

MeanIoU
will be used by default.

Regression

  • PSNR - peak signal to noise ratio

Installation

  • The easiest way to install

    pytorch-3dunet
    package is via conda:
    conda create -n 3dunet -c conda-forge -c awolny pytorch-3dunet
    conda activate 3dunet
    
    After installation the following commands are accessible within the conda environment:
    train3dunet
    for training the network and
    predict3dunet
    for prediction (see below).
  • One can also install directly from source:

    python setup.py install
    

Installation tips

Make sure that the installed

pytorch
is compatible with your CUDA version, otherwise the training/prediction will fail to run on GPU. You can re-install
pytorch
compatible with your CUDA in the
3dunet
env by:
conda install -c pytorch torchvision cudatoolkit= pytorch

Train

Given that

pytorch-3dunet
package was installed via conda as described above, one can train the network by simply invoking:
train3dunet --config 
where
CONFIG
is the path to a YAML configuration file, which specifies all aspects of the training procedure.

In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the config.

The HDF5 files should contain the raw/label data sets in the following axis order:

DHW
(in case of 3D)
CDHW
(in case of 4D).

One can monitor the training progress with Tensorboard

tensorboard --logdir /logs/
(you need
tensorflow
installed in your conda env), where
checkpoint_dir
is the path to the checkpoint directory specified in the config.

Training tips

  1. When training with binary-based losses, i.e.:
    BCEWithLogitsLoss
    ,
    DiceLoss
    ,
    BCEDiceLoss
    ,
    GeneralizedDiceLoss
    : The target data has to be 4D (one target binary mask per channel). If you have a 3D binary data (foreground/background), you can just change
    ToTensor
    transform for the label to contain
    expand_dims: true
    , see e.g. trainconfigdice.yaml. When training with
    WeightedCrossEntropyLoss
    ,
    CrossEntropyLoss
    ,
    PixelWiseCrossEntropyLoss
    the target dataset has to be 3D, see also pytorch documentation for CE loss: https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html
  2. final_sigmoid
    in the
    model
    config section applies only to the inference time: When training with cross entropy based losses (
    WeightedCrossEntropyLoss
    ,
    CrossEntropyLoss
    ,
    PixelWiseCrossEntropyLoss
    ) set
    final_sigmoid=False
    so that
    Softmax
    normalization is applied to the output. When training with
    BCEWithLogitsLoss
    ,
    DiceLoss
    ,
    BCEDiceLoss
    ,
    GeneralizedDiceLoss
    set
    final_sigmoid=True

Prediction

Given that

pytorch-3dunet
package was installed via conda as described above, one can run the prediction via:
predict3dunet --config 

In order to predict on your own data, just provide the path to your model as well as paths to HDF5 test files (see testconfigdice.yaml).

Prediction tips

In order to avoid checkerboard artifacts in the output prediction masks the patch predictions are averaged, so make sure that

patch/stride
params lead to overlapping blocks, e.g.
patch: [64 128 128] stride: [32 96 96]
will give you a 'halo' of 32 voxels in each direction.

Data Parallelism

By default, if multiple GPUs are available training/prediction will be run on all the GPUs using DataParallel. If training/prediction on all available GPUs is not desirable, restrict the number of GPUs using

CUDA_VISIBLE_DEVICES
, e.g.
bash
CUDA_VISIBLE_DEVICES=0,1 train3dunet --config 
or
bash
CUDA_VISIBLE_DEVICES=0,1 predict3dunet --config 

Examples

Cell boundary predictions for lightsheet images of Arabidopsis thaliana lateral root

The data can be downloaded from the following OSF project: * training set: https://osf.io/9x3g2/ * validation set: https://osf.io/vs6gb/ * test set: https://osf.io/tn4xj/

Training and inference configs can be found in 3DUnetlightsheetboundary.

Sample z-slice predictions on the test set (top: raw input , bottom: boundary predictions):

Cell boundary predictions for confocal images of Arabidopsis thaliana ovules

The data can be downloaded from the following OSF project: * training set: https://osf.io/x9yns/ * validation set: https://osf.io/xp5uf/ * test set: https://osf.io/8jz7e/

Training and inference configs can be found in 3DUnetconfocalboundary.

Sample z-slice predictions on the test set (top: raw input , bottom: boundary predictions):

Nuclei predictions for lightsheet images of Arabidopsis thaliana lateral root

The training and validation sets can be downloaded from the following OSF project: https://osf.io/thxzn/

Training and inference configs can be found in 3DUnetlightsheetnuclei.

Sample z-slice predictions on the test set (top: raw input, bottom: nuclei predictions):

2D nuclei predictions for Kaggle DSB2018

The data can be downloaded from: https://www.kaggle.com/c/data-science-bowl-2018/data

Training and inference configs can be found in 2DUnet_dsb2018.

Sample predictions on the test image (top: raw input, bottom: nuclei predictions):

Contribute

If you want to contribute back, please make a pull request.

Cite

If you use this code for your research, please cite as:

@article {10.7554/eLife.57613,
article_type = {journal},
title = {Accurate and versatile 3D segmentation of plant tissues at cellular resolution},
author = {Wolny, Adrian and Cerrone, Lorenzo and Vijayan, Athul and Tofanelli, Rachele and Barro, Amaya Vilches and Louveaux, Marion and Wenzl, Christian and Strauss, Sören and Wilson-Sánchez, David and Lymbouridou, Rena and Steigleder, Susanne S and Pape, Constantin and Bailoni, Alberto and Duran-Nebreda, Salva and Bassel, George W and Lohmann, Jan U and Tsiantis, Miltos and Hamprecht, Fred A and Schneitz, Kay and Maizel, Alexis and Kreshuk, Anna},
editor = {Hardtke, Christian S and Bergmann, Dominique C and Bergmann, Dominique C and Graeff, Moritz},
volume = 9,
year = 2020,
month = {jul},
pub_date = {2020-07-29},
pages = {e57613},
citation = {eLife 2020;9:e57613},
doi = {10.7554/eLife.57613},
url = {https://doi.org/10.7554/eLife.57613},
abstract = {Quantitative analysis of plant and animal morphogenesis requires accurate segmentation of individual cells in volumetric images of growing organs. In the last years, deep learning has provided robust automated algorithms that approach human performance, with applications to bio-image analysis now starting to emerge. Here, we present PlantSeg, a pipeline for volumetric segmentation of plant tissues into cells. PlantSeg employs a convolutional neural network to predict cell boundaries and graph partitioning to segment cells based on the neural network predictions. PlantSeg was trained on fixed and live plant organs imaged with confocal and light sheet microscopes. PlantSeg delivers accurate results and generalizes well across different tissues, scales, acquisition settings even on non plant samples. We present results of PlantSeg applications in diverse developmental contexts. PlantSeg is free and open-source, with both a command line and a user-friendly graphical interface.},
keywords = {instance segmentation, cell segmentation, deep learning, image analysis},
journal = {eLife},
issn = {2050-084X},
publisher = {eLife Sciences Publications, Ltd},
}

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.