mxnet-gan

by tqchen

tqchen / mxnet-gan

MultiGPU enabled image generative models (GAN and DCGAN)

212 Stars 51 Forks Last release: Not found Apache License 2.0 22 Commits 0 Releases

Available items

No Items, yet!

The developer of this repository has not created any items for sale yet. Need a bug fixed? Help with integration? A different license? Create a request here:

MXNet GAN

MXNet module implementation of multi GPU compatible generative models.

List of Methods

  • Unsupervised Training
  • Semisupervised Training
  • Minibatch discrimation

Usage

import logging
import numpy as np
import mxnet as mx

from mxgan import module, generator, encoder, viz

def ferr(label, pred): pred = pred.ravel() label = label.ravel() return np.abs(label - (pred > 0.5)).sum() / label.shape[0]

lr = 0.0005 beta1 = 0.5 batch_size = 100 rand_shape = (batch_size, 100) num_epoch = 100 data_shape = (batch_size, 1, 28, 28) context = mx.gpu()

logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') sym_gen = generator.dcgan28x28(oshape=data_shape, ngf=32, final_act="sigmoid")

gmod = module.GANModule( sym_gen, symbol_encoder=encoder.lenet(), context=context, data_shape=data_shape, code_shape=rand_shape)

gmod.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34))

gmod.init_optimizer( optimizer="adam", optimizer_params={ "learning_rate": lr, "wd": 0., "beta1": beta1, })

data_dir = './../../mxnet/example/image-classification/mnist/' train = mx.io.MNISTIter( image = data_dir + "train-images-idx3-ubyte", label = data_dir + "train-labels-idx1-ubyte", input_shape = data_shape[1:], batch_size = batch_size, shuffle = True)

metric_acc = mx.metric.CustomMetric(ferr)

for epoch in range(num_epoch): train.reset() metric_acc.reset() for t, batch in enumerate(train): gmod.update(batch) gmod.temp_label[:] = 0.0 metric_acc.update([gmod.temp_label], gmod.outputs_fake) gmod.temp_label[:] = 1.0 metric_acc.update([gmod.temp_label], gmod.outputs_real)

    if t % 100 == 0:
        logging.info("epoch: %d, iter %d, metric=%s", epoch, t, metric_acc.get())
        viz.imshow("gout", gmod.temp_outG[0].asnumpy(), 2)
        diff = gmod.temp_diffD[0].asnumpy()
        diff = (diff - diff.mean()) / diff.std() + 0.5
        viz.imshow("diff", diff)
        viz.imshow("data", batch.data[0].asnumpy(), 2)

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.