pytorch-syncbn

by tamakoji

tamakoji / pytorch-syncbn

Synchronized Multi-GPU Batch Normalization

207 Stars 21 Forks Last release: Not found MIT License 3 Commits 0 Releases

Available items

No Items, yet!

The developer of this repository has not created any items for sale yet. Need a bug fixed? Help with integration? A different license? Create a request here:

pytorch-syncbn

Tamaki Kojima([email protected])

Announcement

Pytorch 1.0 support

Overview

This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.

The code was inspired by Pytorch-Encoding and Inplace-ABN

Remarks

  • Unlike Pytorch-Encoding, you don't need custom
    nn.DataParallel
  • Unlike Inplace-ABN, you can just replace your
    nn.BatchNorm2d
    to this module implementation, since it will not mark for inplace operation
  • You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
  • Backward computation is rewritten and tested against behavior of
    nn.BatchNorm2d

Requirements

For PyTorch, please refer to https://pytorch.org/

NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04

It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use.

sudo apt-get install ninja-build

Also install all dependencies for python. For pip, run:

pip install -U -r requirements.txt

Build

There is no need to build. just run and JIT will take care. JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes.

Usage

Please refer to

test.py
for testing the difference between

nn.BatchNorm2d
and
modules.nn.BatchNorm2d
import torch
from modules import nn as NN
num_gpu = torch.cuda.device_count()
model = nn.Sequential(
    nn.Conv2d(3, 3, 1, 1, bias=False),
    NN.BatchNorm2d(3),
    nn.ReLU(inplace=True),
    nn.Conv2d(3, 3, 1, 1, bias=False),
    NN.BatchNorm2d(3),
).cuda()
model = nn.DataParallel(model, device_ids=range(num_gpu))
x = torch.rand(num_gpu, 3, 2, 2).cuda()
z = model(x)

Math

Forward

  1. compute in each gpu
  2. gather all from workers to master and compute where

    and

    and then above global stats to be shared to all gpus, update runningmean and runningvar by moving average using global stats.

  3. forward batchnorm using global stats by

    and then

    where is weight parameter and is bias parameter.

  4. save for backward

Backward

  1. Restore saved

  2. Compute below sums on each gpu

    and

    where

    then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus.

  3. compute gradients using global stats

    where

    and

    and finally,

Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same.

You can go deeper on above explanation at Kevin Zakka's Blog

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.