Disentanglement library for PyTorch
Pytorch Implementation of Disentanglement algorithms for Variational Autoencoders. This library was developed as a contribution to the Disentanglement Challenge of NeurIPS 2019.
If the library helped your research, consider citing the corresponding submission of the NeurIPS 2019 Disentanglement Challenge:
@article{abdiDisentanglementPytorch, Author = {Amir H. Abdi and Purang Abolmaesumi and Sidney Fels}, Title = {Variational Learning with Disentanglement-PyTorch}, Year = {2019}, journal={arXiv preprint arXiv:1912.05184}, }
The following algorithms are implemented: - VAE - β-VAE (Understanding disentangling in β-VAE) - Info-VAE (InfoVAE: Information Maximizing Variational Autoencoders) - Beta-TCVAE (Isolating Sources of Disentanglement in Variational Autoencoders) - DIP-VAE I & II (Variational Inference of Disentangled Latent Concepts from Unlabeled Observations ) - Factor-VAE (Disentangling by Factorising) - CVAE (Learning Structured Output Representation using Deep Conditional Generative Models) - IFCVAE (Adversarial Information Factorization)
Note: Everything is modular, you can mix and match neural architectures and algorithms. Also, multiple loss terms can be included in the
--loss_termsargument, each with their respective weights. This enables us to combine a set of disentanglement algorithms for representation learning.
Install the requirements:
pip install -r requirements.txt\ Or build conda environment:
conda env create -f environment.yml
The library visualizes the reconstructed images and the traversed latent spaces and saves them as static frames as well as animated GIFs. It also extensively uses the web-based Weights & Biases toolkit for logging and visualization purposes.
python main.py [[--ARG ARG_VALUE] ...]
or
bash scripts/SCRIPT_NAME
--alg: The main formulation for training. \ **Values: AE (AutoEncoder), VAE (Variational AutoEncoder), BetaVAE, CVAE (Conditional VAE), IFCVAE (Information Factorization CVAE)
--loss_terms: Extensions to the VAE algorithm are implemented as plug-ins to the original forumation. As a result, if the loss terms of two learning algorithms (e.g., A and B) were found to be compatible, they can simultaneously be included in the objective function with the flag set as
--loss_terms A B. The
loss_termsflag can be used with VAE, BetaVAE, CVAE, and IFCVAE algorithms. \ **Values: FACTORVAE, DIPVAEI, DIPVAEII, BetaTCVAE, INFOVAE
--evaluation_metric: Metric(s) to use for disentanglement evaluation (see
scripts/aicrowd_challenge). \ **Values: mig, sapscore, irs, factorvaemetric, dci, betavae_sklearn
For the complete list of arguments, please check the source.
To run the scripts:
1- Set the
-dset_dirflag or the
$DISENTANGLEMENT_LIB_DATAenvironment variable to the directory holding all the datasets (the former is given priority).
2- Set the
dset_nameflag or the
$DATASET_NAMEenvironment variable to the name of the dataset (the former is given priority). The supported datasets are: celebA, dsprites (and the Deppmind's variants: color, noisy, scream, introduced here), smallnorb, cars3d, mpi3d_toy, and mpi3d_realistic, and mpi3d_real.
Please check the repository for the mpi3d datasets for license agreements and consider citing their work.
Currently, there are two dataloaders in place: - One handles labels for semi-supervised and conditional (class-aware) training (e.g. CVAE, IFCVAE) , but only supports the celebA and dsprites_full datasets for now. - The other leverages Google's implementations of disentanglement_lib, and is based on the starter kit of the Disentanglement Challenge of NeurIPS 2019, hosted by AIcrowd.
To use this code in the NeurIPS 2019 Disentanglement Challenge
source train_environ.sh NAME_OF_DATASET_TO_TEST
--aicrowd_challenge=truein your bash file
--evaluate_metric mig sap_score irs factor_vae_metric dcito assess the progression of disentanglement metrics during training.
run.shto your highest performing configuration.
| Method | Latent traversal visualization |
| ----- | -----|
| VAE | |
| FactorVAE |
|
| CVAE (conditioned on shape)|
Right-most item is traversing the condition |
| IFCVAE (factorized on shape)|
Right-most factor is enforced to encode the shape |
| BetaTCVAE | |
| VAE |
|
Any contributions, especially around implementing more disentanglement algorithms, are welcome. Feel free to submit bugs, feature requests, or questions as issues, or contact me directly via email at: [email protected]