Need help with Lottery-Ticket-Hypothesis-in-Pytorch?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

rahulvigneswaran
187 Stars 67 Forks 8 Commits 9 Opened issues

Description

This repository contains a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks" by Jonathan Frankle and Michael Carbin that can be easily adapted to any model/dataset.

Services available

!
?

Need anything else?

Contributors list

# 95,985
Python
pytorch
network...
lottery
6 commits

Lottery Ticket Hypothesis in Pytorch

Made With python 3.7 Maintenance Open Source Love svg1

This repository contains a Pytorch implementation of the paper The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks by Jonathan Frankle and Michael Carbin that can be easily adapted to any model/dataset.

Requirements

pip3 install -r requirements.txt

How to run the code ?

Using datasets/architectures included with this repository :

python3 main.py --prune_type=lt --arch_type=fc1 --dataset=mnist --prune_percent=10 --prune_iterations=35
  • --prune_type
    : Type of pruning
    • Options :
      lt
      - Lottery Ticket Hypothesis,
      reinit
      - Random reinitialization
    • Default :
      lt
  • --arch_type
    : Type of architecture
    • Options :
      fc1
      - Simple fully connected network,
      lenet5
      - LeNet5,
      AlexNet
      - AlexNet,
      resnet18
      - Resnet18,
      vgg16
      - VGG16
    • Default :
      fc1
  • --dataset
    : Choice of dataset
    • Options :
      mnist
      ,
      fashionmnist
      ,
      cifar10
      ,
      cifar100
    • Default :
      mnist
  • --prune_percent
    : Percentage of weight to be pruned after each cycle.
    • Default :
      10
  • --prune_iterations
    : Number of cycle of pruning that should be done.
    • Default :
      35
  • --lr
    : Learning rate
    • Default :
      1.2e-3
  • --batch_size
    : Batch size
    • Default :
      60
  • --end_iter
    : Number of Epochs
    • Default :
      100
  • --print_freq
    : Frequency for printing accuracy and loss
    • Default :
      1
  • --valid_freq
    : Frequency for Validation
    • Default :
      1
  • --gpu
    : Decide Which GPU the program should use
    • Default :
      0
      ### Using datasets/architectures that are not included with this repository :
  • Adding a new architecture :
    • For example, if you want to add an architecture named
      new_model
      with
      mnist
      dataset compatibility.
      • Go to
        /archs/mnist/
        directory and create a file
        new_model.py
        .
      • Now paste your Pytorch compatible model inside
        new_model.py
        .
      • IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your
        new_model.py
        matches with the corresponding dataset that you are adding (in this case, it is
        mnist
        ).
      • Now open
        main.py
        and go to
        line 36
        and look for the comment
        # Data Loader
        . Now find your corresponding dataset (in this case,
        mnist
        ) and add
        new_model
        at the end of the line
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet
        .
      • Now go to
        line 82
        and add the following to it :
        elif args.arch_type == "new_model":
            model = new_model.new_model_name().to(device)
        
        Here,
        new_model_name()
        is the name of the model that you have given inside
        new_model.py
        .
  • Adding a new dataset :
    • For example, if you want to add a dataset named
      new_dataset
      with
      fc1
      architecture compatibility.
      • Go to
        /archs
        and create a directory named
        new_dataset
        .
      • Now go to /archs/new_dataset/
        and add a file named
        fc1.py` or copy paste it from existing dataset folder.
      • IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your
        new_model.py
        matches with the corresponding dataset that you are adding (in this case, it is
        new_dataset
        ).
      • Now open
        main.py
        and goto
        line 58
        and add the following to it :
        elif args.dataset == "cifar100":
            traindataset = datasets.new_dataset('../data', train=True, download=True, transform=transform)
            testdataset = datasets.new_dataset('../data', train=False, transform=transform)from archs.new_dataset import fc1
        
        Note that as of now, you can only add dataset that are natively available in Pytorch.

How to combine the plots of various
prune_type
?

  • Go to
    combine_plots.py
    and add/remove the datasets/archs who's combined plot you want to generate (Assuming that you have already executed the
    main.py
    code for those dataset/archs and produced the weights
    ).
  • Run
    python3 combine_plots.py
    .
  • Go to
    /plots/lt/combined_plots/
    to see the graphs.

Kindly raise an issue if you have any problem with the instructions.

Datasets and Architectures that were already tested

| | fc1 | LeNet5 | AlexNet | VGG16 | Resnet18 | |--------------|:------------------:|:---------------------:|:----------------------:|:--------------------:|:------------------------:| | MNIST | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | | CIFAR10 | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | | FashionMNIST | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | | CIFAR100 | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: | :heavycheckmark: |

Repository Structure

Lottery-Ticket-Hypothesis-in-Pytorch
├── archs
│   ├── cifar10
│   │   ├── AlexNet.py
│   │   ├── densenet.py
│   │   ├── fc1.py
│   │   ├── LeNet5.py
│   │   ├── resnet.py
│   │   └── vgg.py
│   ├── cifar100
│   │   ├── AlexNet.py
│   │   ├── fc1.py
│   │   ├── LeNet5.py
│   │   ├── resnet.py
│   │   └── vgg.py
│   └── mnist
│       ├── AlexNet.py
│       ├── fc1.py
│       ├── LeNet5.py
│       ├── resnet.py
│       └── vgg.py
├── combine_plots.py
├── dumps
├── main.py
├── plots
├── README.md
├── requirements.txt
├── saves
└── utils.py

Interesting papers that are related to Lottery Ticket Hypothesis which I enjoyed

Acknowledgement

Parts of code were borrowed from ktkth5.

Issue / Want to Contribute ? :

Open a new issue or do a pull request incase you are facing any difficulty with the code base or if you want to contribute to it.

forthebadge

Buy Me A Coffee

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.