Need help with elegy?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

170 Stars 11 Forks Apache License 2.0 285 Commits 20 Opened issues


Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Services available


Need anything else?

Contributors list


PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status

Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
  • Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information take a look at the Documentation.


Install Elegy using pip:

pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a very simple interface you can use by implementing following steps:

1. Define the architecture inside a

. We will use Flax Linen for this example: ```python import flax.linen as nn import jax

class MLP(nn.Module): @nn.compact def call(self, x): x = nn.Dense(300)(x) x = jax.nn.relu(x) x = nn.Dense(10)(x) return x ```

2. Create a

from this module and specify additional things like losses, metrics, and optimizers: ```python import elegy, optax

model = elegy.Model( module=MLP(), loss=[ elegy.losses.SparseCategoricalCrossentropy(fromlogits=True), elegy.regularizers.GlobalL2(l=1e-5), ], metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=optax.rmsprop(1e-3), )

**3.** Train the model using the `fit` method:
python x=Xtrain, y=ytrain, epochs=100, stepsperepoch=200, batchsize=64, validationdata=(Xtest, y_test), shuffle=True, callbacks=[elegy.callbacks.TensorBoard("summaries")] ) ```

Quick Start: Low-level API

In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the

to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True we should initialize our parameters
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
      , shape=[[1:]), 10]
            b = jax.random.uniform(, shape=[1])
            w, b = states.net_params
        # model
        logits =, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our

with an optimizer:
model = LinearClassifier(
3. Train the model using the
    validation_data=(X_test, y_test),

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params":, "dropout":}, x
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout":}, mutable=True
        net_states, net_params = variables.pop("params")

    labels = jax.nn.one_hot(y_true, 10)
    loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

    logs = dict(accuracy=accuracy, loss=loss)
    return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info


To run the examples first install some required packages:

pip install -r examples/requirements.txt
Now run the example:
python examples/ 


Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.

About Us

We are some friends passionate about ML.



Citing Elegy

To cite this project:


author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {},
version = {0.6.0},
year = {2020},

Where the current version may be retrieved either from the

tag or the file elegy/ and the year corresponds to the project's release year.

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.