Need help with local-loss?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

anokland
135 Stars 30 Forks 4 Commits 3 Opened issues

Description

PyTorch code for training neural networks without global back-propagation

Services available

!
?

Need anything else?

Contributors list

# 388,844
Python
Shell
4 commits

Training neural networks with local error signals

This repo contains PyTorch code for training neural networks without global backprop. Experiments are performed by Arild Nøkland and Lars Hiller Eidnes.

A more detailed description of the experiments is available on arXiv here: https://arxiv.org/abs/1901.06656

Supervised training of neural networks for classification is typically performed with a global loss function. The loss function provides a gradient for the output layer, and this gradient is back-propagated to hidden layers to dictate an update direction for the weights. An alternative approach is to train the network with layer-wise loss functions. In this paper we demonstrate, for the first time, that layer-wise training can approach the state-of-the-art on a variety of image datasets. We use single-layer sub-networks and two different supervised loss functions to generate local error signals for the hidden layers, and we show that the combination of these losses help with optimization in the context of local learning. Using local errors could be a step towards more biologically plausible deep learning because the global error does not have to be transported back to hidden layers.

In the tables below, 'pred' indicates a layer-wise cross-entropy loss, 'sim' indicates a layer-wise similarity matching loss, and 'predsim' indicates a combination of these losses. For the local losses, the computational graph is detached after each hidden layer.

Experiments

Results on MNIST with 2 pixel jittering:

| Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' | | :--- | :--- | :--- | :--- | :--- | :-- | | mlp | 2.9M | 0.75 | 0.68 | 0.80 | 0.62 | | vgg8b | 7.3M | 0.26 | 0.40 | 0.65 | 0.31 | | vgg8b + cutout | 7.3M | - | - | - | 0.26 |

Results on Fashion-MNIST with 2 pixel jittering and horizontal flipping:

| Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' | | :--- | :--- | :--- | :--- | :--- | :-- | | mlp | 2.9M | 8.37 | 8.60 | 9.70 | 8.54 | | vgg8b | 7.3M | 4.53 | 5.66 | 5.12 | 4.65 | | vgg8b (2x) | 28.2M | 4.55 | 5.11 | 4.92 | 4.33 | | vgg8b (2x) + cutout | 28.2M | - | - | - | 4.14 |

Results on Kuzusjiji-MNIST with no data augmentation:

| Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' | | :--- | :--- | :--- | :--- | :--- | :-- | | mlp | 2.9M | 5.99 | 7.26 | 9.80 | 7.33 | | vgg8b | 7.3M | 1.53 | 2.22 | 2.19 | 1.36 | | vgg8b + cutout | 7.3M | - | - | - | 0.99 |

Results on Cifar-10 with data augmentation:

| Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' | | :--- | :-- | :--- | :--- | :--- | :--- | | mlp | 27.3M | 33.56 | 32.33 | 33.48 | 30.93 | | vgg8b | 8.9M | 5.99 | 8.40 | 7.16 | 5.58 | | vgg11b | 11.6M | 5.56 | 8.39 | 6.70 | 5.30 | | vgg11b (2x) | 42.0M | 4.91 | 7.30 | 6.66 | 4.42 | | vgg11b (3x) | 91.3M | 5.02 | 7.37 | 9.34 | 3.97 | | vgg11b (3x) + cutout | 91.3M | - | - | - | 3.60 |

Results on Cifar-100 with data augmentation:

| Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' | | :--- | :-- | :--- | :--- | :--- | :--- | | mlp | 27.3M | 62.57 | 58.87 | 62.46 | 56.88 | | vgg8b | 9.0M | 26.24 | 29.32 | 32.64 | 24.07 | | vgg11b | 11.7M | 25.18 | 29.58 | 30.82 | 24.05 | | vgg11b (2x) | 42.1M | 23.44 | 26.91 | 28.03 | 21.20 | | vgg11b (3x) | 91.4M | 23.69 | 25.90 | 28.01 | 20.13 |

Results on SVHN with extra training data, but no augmentation:

| Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' | | :--- | :-- | :--- | :--- | :--- | :--- | | vgg8b | 8.9M | 2.29 | 2.12 | 1.89 | 1.74 | | vgg8b + cutout | 8.9M | - | - | - | 1.65 |

Results on STL-10 with no data augmentation:

| Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' | | :--- | :--- | :--- | :--- | :--- | :-- | | vgg8b | 11.5M | 33.08 | 26.83 | 23.15 | 20.51 | | vgg8b + cutout | 11.5M | - | - | - | 19.25 |

Training recipes

To replicate training of MLP on MNIST with local loss 'predsim':

python train.py --model mlp --dataset MNIST --dropout 0.1 --lr 5e-4 --num-layers 3 --epochs 100 --lr-decay-milestones 50 75 89 94 --nonlin leakyrelu

To replicate training of VGG8b on MNIST with local loss 'predsim':

python train.py --model vgg8b --dataset MNIST --dropout 0.2 --lr 5e-4 --epochs 100 --lr-decay-milestones 50 75 89 94 --nonlin leakyrelu --dim-in-decoder 1024

To replicate training of MLP on CIFAR10 with local loss 'predsim':

python train.py --model mlp --dataset CIFAR10 --dropout 0.1 --lr 5e-4 --num-layers 3 --num-hidden 3000 --nonlin leakyrelu

To replicate training of VGG8b on CIFAR10 with local loss 'predsim':

python train.py --model vgg8b --dataset CIFAR10 --dropout 0.2 --lr 5e-4 --nonlin leakyrelu --dim-in-decoder 2048

To replicate training of VGG11b (3x) on CIFAR10 with local loss 'predsim':

python train.py --model vgg11b --dataset CIFAR10 --dropout 0.3 --lr 3e-4 --feat-mult 3 --nonlin leakyrelu

For all the above recipes, to train with local cross-entropy loss, add argument

--loss-sup pred

For all the above recipes, to train with local similarity matching loss, add argument

--loss-sup sim

For all the above recipes, to train with global loss, add argument

--backprop

For all the above recipes, to train with a more biologically plausible version of local loss, add argument

--bio

To add cutout regularization with cutout hole size 14, add arguments

--cutout --length 14

To replicate all the above experiments, run

bash
./run_experiments.sh

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.