Thinking of Machine Learning, the first frameworks that come to mind are Tensorflow and PyTorch, which are currently the state-of-the-art frameworks if you want to work with Deep Neural Networks. Technology is changing rapidly and more flexibility is needed, so Google researchers are developing a new high performance framework for the open source community: Flax.

The base for the calculations is JAX instead of NumPy, which is also a Google research project. One of the biggest advantages of JAX is the use of XLA, a special compiler for linear algebra, that enables execution on GPUs and TPUs as well.

For those who do not know, TPU (tensor processing unit) is a specific chip optimized for Machine Learning. JAX reimplements parts of NumPy to run your functions on a GPU/TPU.

Flax focuses on key points like:

  • easy to read code
  • prefers duplication, instead of bad abstraction or inflated functions
  • helpful error messages, seems they learned from the Tensorflow error messages
  • easy expandability of basic implementations

Enough praises, now let’s start coding.

Because the MNIST-Example becomes boring I will build an Image Classification for the Simpsons Family, unfortunately, Maggie is missing in the dataset 🙁 .

Sample Images of the Dataset

First, we install the necessary libraries and unzip our dataset. Unfortunately you will still need Tensorflow at this point because Flax misses a good data input pipeline.

pip install -q --upgrade`nvcc -V | sed -En "s/.* release ([0-9]*).([0-9]*),.*/cuda12/p"`/jaxlib-0.1.42-`python3 -V | sed -En "s/Python ([0-9]*).([0-9]*).*/cp12/p"`-none-linux_x86_64.whl jax
pip install -q git+
pip install tensorflow
pip install tensorflow_datasets

Now we import the libraries. You see we have two “versions” of numpy, the normal numpy lib and the one part of the API that JAX implements. The print statement prints CPU, GPU or TPU out according to the available hardware.

from jax.lib import xla_bridge
import jax
import flax

import numpy as onp
import jax.numpy as jnp
import csv
import tensorflow as tf
import tensorflow_datasets as tfds


For training and evaluation we first have to create two Tensorflow datasets and convert them to numpy/jax arrays, because FLAX doesn’t take TF data types. This is currently a bit hacky, because the evaluation method doesn’t take batches.

I had to create one large batch for the eval step and create a TF feature dictionary from it, which is now parsable and can be fed to our eval step after each epoch.

def train():

  train_ds = create_dataset(tf.estimator.ModeKeys.TRAIN)
  test_ds = create_dataset(tf.estimator.ModeKeys.EVAL)
  test_ds = test_ds.prefetch(
  #test_ds is one giant batch
  test_ds = test_ds.batch(1000)
  #test ds is a feature dictonary!
  test_ds =
  test_ds = tfds.as_numpy(test_ds)
  test_ds = {'image': test_ds[0].astype(jnp.float32), 'label': test_ds[1].astype(jnp.int32)}

  _, initial_params = CNN.init_by_shape(jax.random.PRNGKey(0), [((1, 160, 120, 3), jnp.float32)])

  model = flax.nn.Model(CNN, initial_params)

  optimizer = flax.optim.Momentum(learning_rate=0.01, beta=0.9, weight_decay=0.0005).create(model)

  for epoch in range(50):
    for batch in tfds.as_numpy(train_ds):
      optimizer = train_step(optimizer, batch)

    metrics = eval(, test_ds)

    print('eval epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch+1,metrics['loss'], metrics['accuracy'] * 100))

The Model

The CNN-class contains our convolutional neural network. When you are familiar with Tensorflow/Pytorch you see it’s pretty straight forward. Every call of our flax.nn.Conv defines a learnable kernel.

I used the MNIST-Example and extended it with some additional layers. In the end, we have our Dense-Layer with four output neurons, because we have a four-class problem.

class CNN(flax.nn.Module):
  def apply(self, x):
    x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=16, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = flax.nn.Dense(x, features=256)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=64)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=4)
    x = flax.nn.softmax(x)
    return x

Unlike in Tensorflow, the activation function is called explicitly, this makes it very easy to test new and own written activation functions. FLAX is based on the module abstraction and both initiating and calling the network is done with the apply function.

Metrics in FLAX

Of course, we want to measure how good our network becomes. Therefore, we compute our metrics like loss and accuracy. Our accuracy is then computed with the JAX library, instead of NumPy because we can use JAX on TPU/GPU.

def compute_metrics(logits, labels):
  loss = jnp.mean(cross_entropy_loss(logits, labels))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return {'loss': loss, 'accuracy': accuracy}
To measure our loss we use the Cross Entropy Loss, unlike in Tensorflow it is calculated by yourself, we do not have the possibility to use ready-made loss objects yet. As you can see we use


as a function decorator for our loss function. This vectorizes our code for running on batches efficiently.

def cross_entropy_loss(logits, label):
  return -jnp.log(logits[label])
How does the




takes both arrays, logits and label, and performs our


on each pair, thus allowing the parallel calculation of a batch. The cross entropy formula for a single example is:

Our ground truth y is 0 or 1 for one of the four output neurons, therefore we do not need the sum formula in our code, because we just calculate the log(y_hat) of the correct label. The mean in our loss calculation is used because we have batches.


In our train step, we use again a function decorator,


, for speeding up our function. This works very similar to Tensorflow. Please have in mind


is our image data and


our label.

def train_step(optimizer, batch):
  def loss_fn(model):
    logits = model(batch[0])
    loss = jnp.mean(cross_entropy_loss(
        logits, batch[1]))
    return loss
  grad = jax.grad(loss_fn)(
  optimizer = optimizer.apply_gradient(grad)
  return optimizer
The loss function loss_fn returns the loss for our current model,

, and our


calculates its gradient. After the calculation we apply the gradient like in Tensorflow.

The eval step is very simple and minimalistic in Flax. Please note that the complete evaluation dataset is passed to this function.

def eval(model, eval_ds):
  logits = model(eval_ds['image'])
  return compute_metrics(logits, eval_ds['label'])

After 50 epochs we have a very high accuracy. Of course, we can continue to tweak the model and optimize hyperparameter.

For this experiment, I used Google Colab, so if you want to test it yourself create a new environment with a GPU/TPU and import my notebook from Github. Please note that FLAX is not working under Windows at the moment.


It is important to note that FLAX is currently still in alpha and is not an official Google product.

The work so far gives hope for a fast, lightweight and highly customizable ML framework. What is completely missing so far is a data-input pipeline, so Tensorflow still has to be used.

The current set of optimizers is unfortunately limited to ADAM and SGD with Momentum. I especially liked the very strict forward direction of how to use this framework and the high flexibility.

My next plans are to develop some activation features that are not yet available. Also a speed comparison between Tensorflow, PyTorch and FLAX would also be very interesting.

Source link