vegans

by unit8co

unit8co /vegans

A library providing various existing GANs in PyTorch.

413 Stars 30 Forks Last release: Not found MIT License 103 Commits 0 Releases

Available items

No Items, yet!

The developer of this repository has not created any items for sale yet. Need a bug fixed? Help with integration? A different license? Create a request here:

VeGANs

A library to easily train various existing GANs (Generative Adversarial Networks) in PyTorch.

This library targets mainly GAN users, who want to use existing GAN training techniques with their own generators/discriminators. However researchers may also find the GAN base class useful for quicker implementation of new GAN training techniques.

The focus is on simplicity and providing reasonable defaults.

How to install

You need python 3.5 or above. Then:

pip install vegans

How to use

The basic idea is that the user provides discriminator and generator networks, and the library takes care of training them in a selected GAN setting: ``` from vegans import WGAN from vegans.utils import plotlosses, plotimage_samples

netD = ### Your discriminator/critic (torch.nn.Module) netG = ### Your generator (torch.nn.Module) dataloader = ### Your dataloader (torch.utils.data.DataLoader)

Build a Wasserstein GAN

gan = WGAN(netG, netD, dataloader, nr_epochs=20)

train it

gan.train()

vizualise results

imglist, Dlosses, Glosses = gan.gettrainingresults() plotlosses(Glosses, Dlosses) plotimagesamples(img_list, 50) ```

You can currently use the following GANs: *

MMGAN
: Classic minimax GAN, in its non-saturated version *
WGAN
: Wasserstein GAN *
WGANGP
: Wasserstein GAN with gradient penalty *
BEGAN
: Boundary Equilibrium enforcing GAN

Slightly More Details:

All of these GAN objects inherit from a

GAN
base class. When building any such GAN, you must give in argument a generator and discriminator networks (some
torch.nn.Module
), as well as a
torch.utils.data.DataLoader
. In addition, you can specify some parameters supported by all GAN implementations: *
optimizer_D
and
optimizer_G
: some PyTorch optimizers (from
torch.optim
) for the discriminator and generator networks. By defaults those are set with default optimization parameters suggested in the original papers. *
nr_epochs
: the number of epochs (default: 5) *
nz
: size of the noise vector (input of the generator) - by default
nz=100
. *
save_every
: VeGANs will store some samples produced by the generator every
save_every
iteration. Default: 500 *
fixed_noise_size
: The number of samples to save (from fixed noise vectors) *
print_every
: The number of iterations between printing training progress. Default: 50

Finally, when calling

train()
you can specify some parameters specific to each GAN. For example, for the Wasserstein GAN we can do:
gan = WGAN(netG, netD, dataloader)
gan.train(clip_value=0.1)
This will train a Wasserstein GAN with clipping values of
0.1
(instead of the default
0.01
).

If you are researching new GAN training algorithms, you may find it useful to inherit from the

GAN
base class.

Learn more:

Currently the best way to learn more about how to use VeGANs is to have a look at the example notebooks. You can start with this simple example showing how to sample from a univariate Gaussian using a GAN. Alternatively, can run example scripts.

Contribute

PRs and suggestions are welcome. Look here for more details on the setup.

Credits

Some of the code has been inspired by some existing GAN implementations: * https://github.com/eriklindernoren/PyTorch-GAN * https://github.com/martinarjovsky/WassersteinGAN * https://pytorch.org/tutorials/beginner/dcganfacestutorial.html

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.