Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
  • Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information take a look at the Documentation.


Install Elegy using pip:

pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a very simple interface you can use by implementing following steps:

1. Define the architecture inside a

. We will use Flax Linen for this example: ```python import flax.linen as nn import jax

class MLP(nn.Module): @nn.compact def call(self, x): x = nn.Dense(300)(x) x = jax.nn.relu(x) x = nn.Dense(10)(x) return x ```

2. Create a

from this module and specify additional things like losses, metrics, and optimizers: ```python import elegy, optax

model = elegy.Model( module=MLP(), loss=[ elegy.losses.SparseCategoricalCrossentropy(fromlogits=True), elegy.regularizers.GlobalL2(l=1e-5), ], metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=optax.rmsprop(1e-3), )

**3.** Train the model using the `fit` method:
python x=Xtrain, y=ytrain, epochs=100, stepsperepoch=200, batchsize=64, validationdata=(Xtest, y_test), shuffle=True, callbacks=[elegy.callbacks.TensorBoard("summaries")] ) ```

Quick Start: Low-level API

In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the

to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True we should initialize our parameters
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
      , shape=[[1:]), 10]
            b = jax.random.uniform(, shape=[1])
            w, b = states.net_params
        # model
        logits =, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our

with an optimizer:
model = LinearClassifier(
3. Train the model using the
    validation_data=(X_test, y_test),

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params":, "dropout":}, x
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout":}, mutable=True
        net_states, net_params = variables.pop("params")

    labels = jax.nn.one_hot(y_true, 10)
    loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

    logs = dict(accuracy=accuracy, loss=loss)
    return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info


To run the examples first install some required packages:

pip install -r examples/requirements.txt
Now run the example:
python examples/ 


Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.

About Us

We are some friends passionate about ML.



Citing Elegy

To cite this project:


author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {},
version = {0.6.0},
year = {2020},

Where the current version may be retrieved either from the

tag or the file elegy/ and the year corresponds to the project's release year.

