Need help with elegy?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

poets-ai
170 Stars 11 Forks Apache License 2.0 285 Commits 20 Opened issues

Description

Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Services available

!
?

Need anything else?

Contributors list

Elegy

PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status


Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
  • Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information take a look at the Documentation.

Installation

Install Elegy using pip:

bash
pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a very simple interface you can use by implementing following steps:

1. Define the architecture inside a

Module
. We will use Flax Linen for this example: ```python import flax.linen as nn import jax

class MLP(nn.Module): @nn.compact def call(self, x): x = nn.Dense(300)(x) x = jax.nn.relu(x) x = nn.Dense(10)(x) return x ```

2. Create a

Model
from this module and specify additional things like losses, metrics, and optimizers: ```python import elegy, optax

model = elegy.Model( module=MLP(), loss=[ elegy.losses.SparseCategoricalCrossentropy(fromlogits=True), elegy.regularizers.GlobalL2(l=1e-5), ], metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=optax.rmsprop(1e-3), )

**3.** Train the model using the `fit` method:
python model.fit( x=Xtrain, y=ytrain, epochs=100, stepsperepoch=200, batchsize=64, validationdata=(Xtest, y_test), shuffle=True, callbacks=[elegy.callbacks.TensorBoard("summaries")] ) ```

Quick Start: Low-level API

In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the

test_step
to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

python
class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        self,
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True we should initialize our parameters
    ):  
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10]
            )
            b = jax.random.uniform(rng.next(), shape=[1])
        else:
            w, b = states.net_params
        # model
        logits = jnp.dot(x, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our

LinearClassifier
with an optimizer:
python
model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)
3. Train the model using the
fit
method:
python
model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
    ):
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
        else:
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout": rng.next()}, mutable=True
            )
        net_states, net_params = variables.pop("params")

    labels = jax.nn.one_hot(y_true, 10)
    loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

    logs = dict(accuracy=accuracy, loss=loss)
    return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info

Examples

To run the examples first install some required packages:

pip install -r examples/requirements.txt
Now run the example:
python examples/flax_mnist_vae.py 

Contributing

Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.

About Us

We are some friends passionate about ML.

License

Apache

Citing Elegy

To cite this project:

BibTeX

@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.6.0},
year = {2020},
}

Where the current version may be retrieved either from the

Release
tag or the file elegy/__init__.py and the year corresponds to the project's release year.

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.