Need help with elegy?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

poets-ai
272 Stars 16 Forks Apache License 2.0 307 Commits 28 Opened issues

Description

Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Services available

!
?

Need anything else?

Contributors list

Elegy

PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status


Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that maximizes flexibility when needed.
  • Agnostic: Elegy supports various frameworks, including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume many familiar data sources, including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information, take a look at the Documentation.

Installation

Install Elegy using pip:

pip install elegy

For Windows users, we recommend the Windows subsystem for Linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a straightforward interface you can use by implementing the following steps:

1. Define the architecture inside a

Module
. We will use Flax Linen for this example:
import flax.linen as nn
import jax

class MLP(nn.Module): @nn.compact def call(self, x): x = nn.Dense(300)(x) x = jax.nn.relu(x) x = nn.Dense(10)(x) return x

2. Create a

Model
from this module and specify additional things like losses, metrics, and optimizers:
import elegy, optax

model = elegy.Model( module=MLP(), loss=[ elegy.losses.SparseCategoricalCrossentropy(from_logits=True), elegy.regularizers.GlobalL2(l=1e-5), ], metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=optax.rmsprop(1e-3), )

3. Train the model using the

fit
method:
model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Quick Start: Low-level API

Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define the

test_step
to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        self,
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True, we should initialize our parameters
    ):  
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10]
            )
            b = jax.random.uniform(rng.next(), shape=[1])
        else:
            w, b = states.net_params
        # model
        logits = jnp.dot(x, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our

LinearClassifier
with an optimizer:
model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the

fit
method:
model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
    ):
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
        else:
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout": rng.next()}, mutable=True
            )
        net_states, net_params = variables.pop("params")

    labels = jax.nn.one_hot(y_true, 10)
    loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

    logs = dict(accuracy=accuracy, loss=loss)
    return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info

Examples

To run the examples, first install some required packages:

pip install -r examples/requirements.txt

Now run the example:

python examples/flax_mnist_vae.py

Contributing

Deep Learning is evolving at an incredible pace, and there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy, open an issue or send a PR! For more information, check out our Contributing Guide.

About Us

We are some friends passionate about ML.

License

This project is licensed under the Apache v2.0 License.

Citing Elegy

To cite this project:

BibTeX

@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.7.4},
year = {2020},
}

The current version may be retrieved either from the

Release
tag or the file elegy/__init__.py and the year corresponds to the project's release year.

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.