Brevitas: quantization-aware training in Pytorch
Brevitas is a Pytorch library for quantization-aware training.
Brevitas is currently under active development and on a rolling release. It should be considered in beta stage. Minor API changes are still planned. Documentation, tests, examples, and pretrained models will be progressively released.
You can install the latest master directly from GitHub:
bash pip install git+https://github.com/Xilinx/brevitas.git
Brevitas implements a set of building blocks at different levels of abstraction to model a reduced precision hardware data-path at training time.
Brevitas provides a platform both for researchers interested in implementing new quantization-aware training techinques, as well as for practitioners interested in applying current techniques to their models.
Here's how a simple 4 bit weights, 8 bit activations LeNet looks like, using default settings for scaling:
from torch.nn import Module import torch.nn.functional as F from brevitas.nn import QuantIdentity, QuantConv2d, QuantReLU from brevitas.core.quant import QuantType
class QuantLeNet(Module): def init(self): super(QuantLeNet, self).init() self.quant_inp = QuantIdentity(bit_width=8) self.conv1 = QuantConv2d(3, 6, 5, weight_bit_width=4) self.relu1 = QuantReLU(bit_width=8) self.conv2 = QuantConv2d(6, 16, 5, weight_bit_width=4) self.relu2 = QuantReLU(bit_width=8) self.fc1 = QuantLinear(1655, 120, bias=True, weight_bit_width=4) self.relu3 = QuantReLU(bit_width=8) self.fc2 = QuantLinear(120, 84, bias=True, weight_bit_width=4) self.relu4 = QuantReLU(bit_width=8) self.fc3 = QuantLinear(84, 10, bias=False, weight_bit_width=4)
def forward(self, x): out = self.inp(x) out = self.relu1(self.conv1(x)) out = F.max_pool2d(out, 2) out = self.relu2(self.conv2(out)) out = F.max_pool2d(out, 2) out = out.view(out.size(0), -1) out = self.relu3(self.fc1(out)) out = self.relu4(self.fc2(out)) out = self.fc3(out) return out
Alessandro Pappalardo (@volcacius) @ Xilinx Research Labs.