Brevitas: quantization-aware training in PyTorch
Brevitas is a PyTorch research library for quantization-aware training (QAT).
Note from the author: I'll be presenting a poster about Brevitas at the 2021 PyTorch Ecosystem Day. Find me during the morning session at virtual poster B4.
Brevitas is currently under active development. Documentation, examples, and pretrained models will be progressively released.
Please note that Brevitas is a research project and not an official Xilinx product.
If you like this project please consider ⭐ this repo, as it is the simplest and best way to support it.
If you have issues, comments, or are just looking for advices on training quantized neural networks, you can open an issue, a discussion, or chat over in our gitter channel.
You can install the latest release from PyPI:
bash pip install brevitas
To get the very latest version, you can install directly from GitHub:
bash pip install git+https://github.com/Xilinx/brevitas.git
Brevitas implements a set of building blocks at different levels of abstraction to model a reduced precision hardware data-path at training time. It provides a platform both for researchers interested in implementing new quantization-aware training techinques, as well as for practitioners interested in applying current techniques to their models, with the aim of bridging the gap between research and the industry around quantization.
Brevitas supports a super-set of quantization schemes implemented across various frameworks and compilers under a single unified API.
For certain combinations of layers and types of of quantization inference acceleration is supported by exporting to FINN, onnxruntime, Pytorch's own quantized inference operators, TVM (through the Pytorch export flow), and PyXIR.
Brevitas has been successfully adopted both in various research projects as well as in large-scale commercial deployments targeting custom accelerators running on Xilinx FPGAs. The general quantization style implemented is affine quantization, with a focus on uniform quantization. Non-uniform quantization is currently not supported out-of-the-box.
Brevitas serves various types of users and end goals. To showcase some of Brevitas features, we consider then different scenarios for the quantization of a classic neural network, LeNet-5.
Let's say we are interested in assessing how well the model does at 3 bit weights for CIFAR10 classification. For the purpose of this tutorial we will skip any detail around how to perform training, as training a neural network with Brevitas is no different than training any other neural network in PyTorch.
brevitas.nnprovides quantized layers that can be used in place of and/or mixed with traditional
torch.nnlayers. In this case then we import
brevitas.nn.QuantConv2dand
brevitas.nn.QuantLinearin place of their PyTorch variants, and we specify
weight_bit_width=3. For relu and max-pool, we leverage the usual
torch.nn.ReLUand
torch.nn.functional.max_pool2d.
The result is the following:
from torch import nn from torch.nn import Module import torch.nn.functional as Fimport brevitas.nn as qnn
class QuantWeightLeNet(Module): def init(self): super(QuantWeightLeNet, self).init() self.conv1 = qnn.QuantConv2d(3, 6, 5, weight_bit_width=3) self.relu1 = nn.ReLU() self.conv2 = qnn.QuantConv2d(6, 16, 5, weight_bit_width=3) self.relu2 = nn.ReLU() self.fc1 = qnn.QuantLinear(1655, 120, bias=True, weight_bit_width=3) self.relu3 = nn.ReLU() self.fc2 = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=3) self.relu4 = nn.ReLU() self.fc3 = qnn.QuantLinear(84, 10, bias=False, weight_bit_width=3)
def forward(self, x): out = self.relu1(self.conv1(out)) out = F.max_pool2d(out, 2) out = self.relu2(self.conv2(out)) out = F.max_pool2d(out, 2) out = out.reshape(out.reshape[0], -1) out = self.relu3(self.fc1(out)) out = self.relu4(self.fc2(out)) out = self.fc3(out) return out
quant_weight_lenet = QuantWeightLeNet()
... training ...
At the end of training the model is going to have a certain train and test accuracy. For users interested in simply evaluating how well their models do with quantization in the loop, without actually deploying them, that might be the end of it.
For those users that instead are interested in deploying their quantized models, the idea obviously would be to actually gain some kind of advantage from quantization. In the case of weight quantization, the advantage would be to save space in terms of model size. However, if we saved the model state with
torch.save(quant_weight_lenet.state_dict(), 'qw_lenet.pt')we would notice that it consumes the same amount of memory as its floating-point variant. That is because Brevitas is not concerned with deploying quantized models efficiently on its own. In order to deploy the model efficiently, we have to export it to an inference framework/toolchain first.
Being a research training library that informs the development of inference toolchains, Brevitas supports more quantization schems than what can be currently accelerated efficiently by supported inference frameworks. A neural network with 3 bits weights and floating-point activations is one of those scenarios that in practice is currently hard to take advantage of. In order to make it practical, we want to quantize activations and biases too.
We decide to quantize activations to 4 bits and biases to 8 bits. In order to do so, we replace
torch.nn.ReLUwith
brevitas.nn.QuantReLU, specifying
bit_width=4. For bias quantization, we import the 8-bit bias quantizer
Int8Biasfrom
brevitas.quantand set it appropriately. Additionally, in order to quantize the very first input, we introduce a
brevitas.nn.QuantIdentityat the beginning of the network. The end result is the following:
from torch.nn import Module import torch.nn.functional as Fimport brevitas.nn as qnn from brevitas.quant import Int8Bias as BiasQuant
class LowPrecisionLeNet(Module): def init(self): super(LowPrecisionLeNet, self).init() self.quant_inp = qnn.QuantIdentity( bit_width=4, return_quant_tensor=True) self.conv1 = qnn.QuantConv2d( 3, 6, 5, weight_bit_width=3, bias_quant=BiasQuant, return_quant_tensor=True) self.relu1 = qnn.QuantReLU( bit_width=4, return_quant_tensor=True) self.conv2 = qnn.QuantConv2d( 6, 16, 5, weight_bit_width=3, bias_quant=BiasQuant, return_quant_tensor=True) self.relu2 = qnn.QuantReLU( bit_width=4, return_quant_tensor=True) self.fc1 = qnn.QuantLinear( 1655, 120, bias=True, weight_bit_width=3, bias_quant=BiasQuant, return_quant_tensor=True) self.relu3 = qnn.QuantReLU( bit_width=4, return_quant_tensor=True) self.fc2 = qnn.QuantLinear( 120, 84, bias=True, weight_bit_width=3, bias_quant=BiasQuant, return_quant_tensor=True) self.relu4 = qnn.QuantReLU( bit_width=4, return_quant_tensor=True) self.fc3 = qnn.QuantLinear( 84, 10, bias=False, weight_bit_width=3)
def forward(self, x): out = self.quant_inp(x) out = self.relu1(self.conv1(out)) out = F.max_pool2d(out, 2) out = self.relu2(self.conv2(out)) out = F.max_pool2d(out, 2) out = out.reshape(out.shape[0], -1) out = self.relu3(self.fc1(out)) out = self.relu4(self.fc2(out)) out = self.fc3(out) return out
Note a couple of things: - Compared to the previous scenario, we now set
return_quant_tensor=Truein every quantized layer except the last one to propagate a
QuantTensoracross them. This informs each receiving layer of how activations have been quantized at the output of its predecessor, which in turns enables more functionalities, such as the kind of bias quantization here implemented. -
torchoperations that are algorithmically invariant to quantization, such as max-pool, can propagate QuantTensor through them without extra changes. This is supported in PyTorch 1.5.0 and later versions. - By default
QuantReLUis stateful, so there is a difference between instantiating one
QuantReLUthat is called multiple times, and instantiating multiple
QuantReLUthat are each called once.
The network defined above can be mapped to a low-precision integer-only dataflow accelerator implemented on a Xilinx FPGA by exporting it to FINN through a custom ONNX-based representation. We can invoke the FINN export manager to do so:
from brevitas.export import FINNManagerlow_precision_lenet = LowPrecisionLeNet()
... training ...
FINNManager.export(low_precision_lenet, input_shape=(1, 3, 32, 32), export_path='finn_lenet.onnx')
Brevitas also supports targeting other inference frameworks that support a mixture of floating-point and quantized layers, such as onnxruntime and PyTorch itself. In this case then,
return_quant_tensorclarifies to the export manager whether the output of a layer should be dequantized to floating-point or not. Additionally, since for those target platforms low precision acceleration is not yet supported, we target 7-bit and 8-bit quantization:
from torch import nn import torch.nn.functional as Fimport brevitas.nn as qnn from brevitas.quant import Int8WeightPerTensorFloat as SignedWeightQuant from brevitas.quant import ShiftedUint8WeightPerTensorFloat as UnsignedWeightQuant from brevitas.quant import ShiftedUint8ActPerTensorFloat as ActQuant from brevitas.quant import Int8Bias as BiasQuant
class ReducedRangeActQuant(ActQuant): bit_width = 7
class MixedFloatQuantLeNet(nn.Module): def init(self, bias_quant=True, reduced_act_quant=False, weight_signed=False): super(MixedFloatQuantLeNet, self).init()
bias_quant = BiasQuant if bias_quant else None act_quant = ReducedRangeActQuant if reduced_act_quant else ActQuant weight_quant = SignedWeightQuant if weight_signed else UnsignedWeightQuant self.conv1 = qnn.QuantConv2d( 3, 6, 5, input_quant=act_quant, weight_quant=weight_quant, output_quant=act_quant, bias_quant=bias_quant, return_quant_tensor=True) self.relu1 = nn.ReLU() self.conv2 = qnn.QuantConv2d( 6, 16, 5, weight_quant=weight_quant, output_quant=act_quant, bias_quant=bias_quant, return_quant_tensor=True) self.relu2 = nn.ReLU() self.fc1 = qnn.QuantLinear( 256, 120, bias=True, weight_quant=weight_quant, bias_quant=bias_quant, output_quant=act_quant) self.relu3 = nn.ReLU() self.fc2 = nn.Linear(120, 84, bias=True) self.relu4 = nn.ReLU() self.fc3 = nn.Linear(84, 10, bias=False) def forward(self, x): out = self.relu1(self.conv1(x)) out = F.max_pool2d(out, 2) out = self.relu2(self.conv2(out)) out = F.max_pool2d(out, 2) out = out.reshape(out.shape[0], -1) out = self.relu3(self.fc1(out)) out = self.relu4(self.fc2(out)) out = self.fc3(out) return out
Compared to the previous case, there are a few differences: - While in the previous example the default scaled integer quantized was being adopted for weights and activations, and only a bit-width was specified, here the network is compeltely parametrized by standalone quantizers taken from
brevitas.quant. This is to match the quantization schemes supported by the inference frameworks we are targeting. - We are defining a 7-bit activation quantizer by inheriting from an existing one and setting
bit_width=7. This is an alternative but equivalent syntax to setting the attribute as a keyword argument. - In this scenario activations are quantized before relu by setting output quantizers on
QuantConv2dand
QuantLinear. Again this matches how the frameworks we are targeting work. Because of this, we revert to using standard
torch.nn.ReLU.
After training, the above network can then be exported to an ONNX representation that complies with the standard opset:
from brevitas.export import StdONNXManageronnx_lenet = MixedFloatQuantLeNet()
... training ...
StdONNXManager.export(onnx_lenet, input_shape=(1, 3, 32, 32), export_path='onnx_lenet.onnx')
The generated output model can then be accelerated through any ONNX-compliant inference framework, such as onnxruntime:
import onnxruntime as rt import numpy as npsess = rt.InferenceSession('onnx_lenet.onnx') input_name = sess.get_inputs()[0].name pred_onx = sess.run(None, {input_name: np.random.randn(1, 3, 32, 32)})[0]
With the same network definition it's also possible to target PyTorch's own quantized inference operators:
from brevitas.export import PytorchQuantManagerpt_lenet = MixedFloatQuantLeNet(bias_quant=False, reduced_act_quant=True, weight_signed=True)
... training ...
traced_pt_lenet = PytorchQuantManager.export(pt_lenet, input_shape=(1, 3, 32, 32))
Note how the network was parametrized to reflect a few of the differences between PyTorch quantized inference operators and the standard ONNX opset: - Pytorch doesn't support explicit bias quantization, standard ONNX does. - We pick an 8-bit signed symmetric weights quantizer for PyTorch (the one used by default for weight quantization in Brevitas), while for ONNX we go for an unsigned asymmetric one, since support for it in onnxruntime is more mature. - With the FBGEMM x86 backend (which is enabled by default), PyTorch recommends to use 7-bit activations to avoid overflow.
The PyTorch export flow generates a TorchScript model, which means that the network can also easily be passed to any external toolchain that supports TorchScript, such as TVM:
from tvm import relayinput_name = "input"
input_shapes = [(input_name, (1, 3, 224, 224))] mod, params = relay.frontend.from_pytorch(traced_pt_lenet, input_shapes)
Thanks to their flexibility, Xilinx FPGAs support a variety of neural network hardware implementations. DPUs are a family of fixed-point neural network accelerators officially supported as part of the Vitis-AI toolchain. Currently Brevitas supports training for DPUv1 and DPUv2 by leveraging 8-bit fixed-point quantizers and a custom ONNX based export flow that targets PyXIR:
from torch import nn import torch.nn.functional as Fimport brevitas.nn as qnn from brevitas.quant import Int8WeightPerTensorFixedPoint as WeightQuant from brevitas.quant import Int8ActPerTensorFixedPoint as ActQuant from brevitas.quant import Int8BiasPerTensorFixedPointInternalScaling as BiasQuant from brevitas.export import DPUv1Manager, DPUv2Manager
class DPULeNet(nn.Module): def init(self): super(DPULeNet, self).init() self.conv1 = qnn.QuantConv2d( 3, 6, 5, input_quant=ActQuant, weight_quant=WeightQuant, output_quant=ActQuant, bias_quant=BiasQuant, return_quant_tensor=True) self.relu1 = nn.ReLU() self.conv2 = qnn.QuantConv2d( 6, 16, 5, weight_quant=WeightQuant, output_quant=ActQuant, bias_quant=BiasQuant, return_quant_tensor=True) self.relu2 = nn.ReLU() self.fc1 = qnn.QuantLinear( 256, 120, bias=True, weight_quant=WeightQuant, bias_quant=BiasQuant, output_quant=ActQuant, return_quant_tensor=True) self.relu3 = nn.ReLU() self.fc2 = qnn.QuantLinear( 120, 84, bias=True, weight_quant=WeightQuant, bias_quant=BiasQuant, output_quant=ActQuant, return_quant_tensor=True) self.relu4 = nn.ReLU() self.fc3 = qnn.QuantLinear( 84, 10, bias=False, weight_quant=WeightQuant, output_quant=ActQuant)
def forward(self, x): out = self.relu1(self.conv1(x)) out = F.max_pool2d(out, 2) out = self.relu2(self.conv2(out)) out = F.max_pool2d(out, 2) out = out.reshape(out.shape[0], -1) out = self.relu3(self.fc1(out)) out = self.relu4(self.fc2(out)) out = self.fc3(out) return out
dpu_lenet = DPULeNet()
... training ...
DPUv1Manager.export(dpu_lenet, input_shape=(1, 3, 32, 32), export_path='dpuv1_lenet.onnx') DPUv2Manager.export(dpu_lenet, input_shape=(1, 3, 32, 32), export_path='dpuv2_lenet.onnx')
Documentation is currently a work-in-progress.
A series of tutorials is being added to the notebooks folder. They are designed to walk users through some of the fundamentals of Brevitas, and as such they are meant to be followed in order.
A general description of how Brevitas works can be found under the ARCHITECTURE.md file.
Brevitas exposes a few settings that can be toggled through env variables.
BREVITAS_JIT=1 (Default: = 0): Enables compilation of the available built-in quantizers through TorchScript just-in-time compiler, together with a small native .cpp extension for the straight-through estimator functions. This can provide a speed-up and/or memory savings at training time. Please note that under certain circumstances this has been shown to produce diverging results compared to BREVITAS_JIT=0. Use at your own risk.
BREVITAS_VERBOSE=1 (Default: = 0): Enables verbose compilation of the straight-through estimator functions native extension.
BREVITASIGNOREMISSING_KEYS=1 (Default: =0): Ignore errors related to missing state_dict values when loading a pre-trained model on top of a Brevitas model. This is typically enabled when re-training from a floating-point checkpoint.
Q: Pytorch supports quantization-aware training. Why should I use Brevitas?
A: Quantization in Pytorch is designed to target two specific CPU backends (FBGEMM and qnnpack). Export to standard ONNX for quantized operators is not supported (only to a custom ONNX based format supported by the Caffe2).
Brevitas is designed as a platform to implement novel quantization algorithms to target a variety of hardware backends adhering to a loose set of assumptions (i.e. uniform affine quantization).
Q: How can I train X/Y and run it on hardware W/Z? I can't find any documentation.
A: Brevitas is still sparsely documented. Until the situation improves, feel free to open an issue or ask on our gitter channel.
Q: Training with Brevitas is slow and/or I can't fit the same batch size as with floating-point training. Why? What can I do?
A: Quantization-aware training involves a lot of element-wise operations, which carry low arithmetic intensity and contribute to a more involved computational graph during backpropragation. As such, it typically ends up being slower and more resource-intensive than standard floating-point training.
Brevitas in particular is biased towards greater flexibility, at the cost of some training-time effieciency. The general principle is that it's trading off more complexity at training time for more efficiency at inference time.
To mitigate somewhat the slow-down, try enabling BREVITAS_JIT as reported in the Settings section.
Q: Inference with Brevitas is slow. I thought the point of QAT was to make my model faster at inference time. What I am doing wrong?
A: Brevitas is concerned with modelling a reduced precision data-path, it does not provide inference-time acceleration on its own. To achieve acceleration, you should export your Brevitas model to a downstream toolchain / backend.
Brevitas can currently export to: - FINN - for dataflow acceleration on Xilinx FPGAs. - PyXIR (experimental) - for DPU acceleration on Xilinx FPGAs. - Standard ONNX (experimental) - for acceleration with e.g. onnxruntime, or any other ONNX-compliant toolchain. - Pytorch's quantized.functional operators (experimental) - for acceleration through Pytorch itself, or any additional downstream toolchains supported by Pytorch (e.g. TVM).
Because Brevitas implements a super-set of layers and datatypes supported by various downstream toolchains and hardware platforms, the result is that each export flow supports only a certain subset of features, in ways that are not necessarely obvious. More examples and documentation will be released to illustrate the various restrictions imposed by each target platform. As a general note though, currently FINN is the only toolchain that supports acceleration of low bit-width datatypes.
Q: My (C/G/T)PU supports float16 / bfloat16 / bfloat19 training. Can I use it to train with Brevitas?
A: Datatypes outside of float32 at training time have not been tested. That includes training on TPU / Pytorch-XLA. Do the math in terms of which reduced-precision integers can reasonably fit in a reduced-precision floating-point format at training time, and use at your own risk.
Alessandro Pappalardo (@volcacius) @ Xilinx Research Labs. For private communications, you can reach me at alessand at nameofmy_employer dot com.
@software{brevitas, author = {Alessandro Pappalardo}, title = {Xilinx/brevitas}, publisher = {Zenodo}, doi = {10.5281/zenodo.3333552}, url = {https://doi.org/10.5281/zenodo.3333552} }