Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.
Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.
For more information take a look at the Documentation.
Install Elegy using pip:
bash pip install elegy
For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.
Elegy's high-level API provides a very simple interface you can use by implementing following steps:
1. Define the architecture inside a
Module. We will use Flax Linen for this example: ```python import flax.linen as nn import jax
class MLP(nn.Module): @nn.compact def call(self, x): x = nn.Dense(300)(x) x = jax.nn.relu(x) x = nn.Dense(10)(x) return x ```
2. Create a
Modelfrom this module and specify additional things like losses, metrics, and optimizers: ```python import elegy, optax
model = elegy.Model( module=MLP(), loss=[ elegy.losses.SparseCategoricalCrossentropy(fromlogits=True), elegy.regularizers.GlobalL2(l=1e-5), ], metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=optax.rmsprop(1e-3), )
**3.** Train the model using the `fit` method:python model.fit( x=Xtrain, y=ytrain, epochs=100, stepsperepoch=200, batchsize=64, validationdata=(Xtest, y_test), shuffle=True, callbacks=[elegy.callbacks.TensorBoard("summaries")] ) ```
In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the
test_stepto implement a linear classifier in pure jax:
1. Calculate our loss, logs, and states:
python class LinearClassifier(elegy.Model): # request parameters by name via depending injection. # names: x, y_true, sample_weight, class_weight, states, initializing def test_step( self, x, # inputs y_true, # labels states: elegy.States, # model state initializing: bool, # if True we should initialize our parameters ): rng: elegy.RNGSeq = states.rng # flatten + scale x = jnp.reshape(x, (x.shape[0], -1)) / 255 # initialize or use existing parameters if initializing: w = jax.random.uniform( rng.next(), shape=[np.prod(x.shape[1:]), 10] ) b = jax.random.uniform(rng.next(), shape=[1]) else: w, b = states.net_params # model logits = jnp.dot(x, w) + b # categorical crossentropy loss labels = jax.nn.one_hot(y_true, 10) loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)) accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true) # metrics logs = dict( accuracy=accuracy, loss=loss, ) return loss, logs, states.update(net_params=(w, b))
2. Instantiate our
LinearClassifierwith an optimizer:
python model = LinearClassifier( optimizer=optax.rmsprop(1e-3), )3. Train the model using the
fitmethod:
python model.fit( x=X_train, y=y_train, epochs=100, steps_per_epoch=200, batch_size=64, validation_data=(X_test, y_test), shuffle=True, callbacks=[elegy.callbacks.TensorBoard("summaries")] )
It is straightforward to integrate other functional JAX libraries with this low-level API:
class LinearClassifier(elegy.Model): def test_step( self, x, y_true, states: elegy.States, initializing: bool ): rng: elegy.RNGSeq = states.rng x = jnp.reshape(x, (x.shape[0], -1)) / 255 if initializing: logits, variables = self.module.init_with_output( {"params": rng.next(), "dropout": rng.next()}, x ) else: variables = dict(params=states.net_params, **states.net_states) logits, variables = self.module.apply( variables, x, rngs={"dropout": rng.next()}, mutable=True ) net_states, net_params = variables.pop("params")labels = jax.nn.one_hot(y_true, 10) loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)) accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true) logs = dict(accuracy=accuracy, loss=loss) return loss, logs, states.update(net_params=net_params, net_states=net_states)
To run the examples first install some required packages:
pip install -r examples/requirements.txtNow run the example:
python examples/flax_mnist_vae.py
Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.
We are some friends passionate about ML.
Apache
To cite this project:
BibTeX
@software{elegy2020repository, author = {PoetsAI}, title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem}, url = {https://github.com/poets-ai/elegy}, version = {0.6.0}, year = {2020}, }
Where the current version may be retrieved either from the
Releasetag or the file elegy/__init__.py and the year corresponds to the project's release year.