## Thinking of Machine Learning, the first frameworks that come to mind are Tensorflow and PyTorch, which are currently the state-of-the-art frameworks if you want to work with Deep Neural Networks. Technology is changing rapidly and more flexibility is needed, so Google researchers are developing a new high performance framework for the open source community: Flax.

The base for the calculations is JAX instead of NumPy, which is also a Google research project. One of the biggest advantages of JAX is the use of XLA, a special compiler for linear algebra, that **enables execution on GPUs and TPUs as well**.

For those who do not know, TPU (tensor processing unit) is a specific chip optimized for Machine Learning. JAX reimplements parts of NumPy to run your functions on a GPU/TPU.

Flax focuses on key points like:

**easy**to read code**prefers duplication**, instead of bad abstraction or inflated functions**helpful error messages**, seems they learned from the Tensorflow error messages- easy expandability of basic implementations

Enough praises, now let’s start coding.

Because the MNIST-Example becomes boring I will build an Image Classification for the Simpsons Family, unfortunately, Maggie is missing in the dataset 🙁 .

**Sample Images of the Dataset**

First, we install the necessary libraries and unzip our dataset. Unfortunately you will still need Tensorflow at this point because Flax misses a good data input pipeline.

```
pip install -q --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*).([0-9]*),.*/cuda12/p"`/jaxlib-0.1.42-`python3 -V | sed -En "s/Python ([0-9]*).([0-9]*).*/cp12/p"`-none-linux_x86_64.whl jax
pip install -q git+https://github.com/google/flax.git@dev-setup
pip install tensorflow
pip install tensorflow_datasets
unzip simpsons_faces.zip
```

Now we import the libraries. You see we have two “versions” of numpy, the normal numpy lib and the one part of the API that JAX implements. The print statement prints CPU, GPU or TPU out according to the available hardware.

```
from jax.lib import xla_bridge
import jax
import flax
import numpy as onp
import jax.numpy as jnp
import csv
import tensorflow as tf
import tensorflow_datasets as tfds
print(xla_bridge.get_backend().platform)
```

For training and evaluation we first have to create two Tensorflow datasets and convert them to numpy/jax arrays, because FLAX doesn’t take TF data types. This is currently a bit hacky, because the evaluation method doesn’t take batches.

I had to create one large batch for the eval step and create a TF feature dictionary from it, which is now parsable and can be fed to our eval step after each epoch.

```
def train():
train_ds = create_dataset(tf.estimator.ModeKeys.TRAIN)
test_ds = create_dataset(tf.estimator.ModeKeys.EVAL)
test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)
#test_ds is one giant batch
test_ds = test_ds.batch(1000)
#test ds is a feature dictonary!
test_ds = tf.compat.v1.data.experimental.get_single_element(test_ds)
test_ds = tfds.as_numpy(test_ds)
test_ds = {'image': test_ds[0].astype(jnp.float32), 'label': test_ds[1].astype(jnp.int32)}
_, initial_params = CNN.init_by_shape(jax.random.PRNGKey(0), [((1, 160, 120, 3), jnp.float32)])
model = flax.nn.Model(CNN, initial_params)
optimizer = flax.optim.Momentum(learning_rate=0.01, beta=0.9, weight_decay=0.0005).create(model)
for epoch in range(50):
for batch in tfds.as_numpy(train_ds):
optimizer = train_step(optimizer, batch)
metrics = eval(optimizer.target, test_ds)
print('eval epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch+1,metrics['loss'], metrics['accuracy'] * 100))
```

## The Model

The CNN-class contains our convolutional neural network. When you are familiar with Tensorflow/Pytorch you see it’s pretty straight forward. Every call of our flax.nn.Conv defines a learnable kernel.

I used the MNIST-Example and extended it with some additional layers. In the end, we have our Dense-Layer with four output neurons, because we have a four-class problem.

```
class CNN(flax.nn.Module):
def apply(self, x):
x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=16, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = flax.nn.Dense(x, features=256)
x = flax.nn.relu(x)
x = flax.nn.Dense(x, features=64)
x = flax.nn.relu(x)
x = flax.nn.Dense(x, features=4)
x = flax.nn.softmax(x)
return x
```

Unlike in Tensorflow, the activation function is called explicitly, this makes it very easy to test new and own written activation functions. FLAX is based on the module abstraction and both initiating and calling the network is done with the apply function.

## Metrics in FLAX

Of course, we want to measure how good our network becomes. Therefore, we compute our metrics like loss and accuracy. Our accuracy is then computed with the JAX library, instead of NumPy because we can use JAX on TPU/GPU.

```
def compute_metrics(logits, labels):
loss = jnp.mean(cross_entropy_loss(logits, labels))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return {'loss': loss, 'accuracy': accuracy}
```

`@jax.vmap`

as a function decorator for our loss function. This vectorizes our code for running on batches efficiently.

```
@jax.vmap
def cross_entropy_loss(logits, label):
return -jnp.log(logits[label])
```

`cross_entropy_loss`

work?

`@jax.vmap`

takes both arrays, logits and label, and performs our

`cross_entropy_loss`

on each pair, thus allowing the parallel calculation of a batch. The cross entropy formula for a single example is:

Our ground truth y is 0 or 1 for one of the four output neurons, therefore we do not need the sum formula in our code, because we just calculate the log(y_hat) of the correct label. The mean in our loss calculation is used because we have batches.

## Training

`@jax.jit`

, for speeding up our function. This works very similar to Tensorflow. Please have in mind

`batch[0]`

is our image data and

`batch[1]`

our label.

```
@jax.jit
def train_step(optimizer, batch):
def loss_fn(model):
logits = model(batch[0])
loss = jnp.mean(cross_entropy_loss(
logits, batch[1]))
return loss
grad = jax.grad(loss_fn)(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return optimizer
```

`optimizer.target`

, and our

`jax.grad()`

calculates its gradient. After the calculation we apply the gradient like in Tensorflow.

The eval step is very simple and minimalistic in Flax. Please note that the complete evaluation dataset is passed to this function.

```
@jax.jit
def eval(model, eval_ds):
logits = model(eval_ds['image'])
return compute_metrics(logits, eval_ds['label'])
```

After 50 epochs we have a very high accuracy. Of course, we can continue to tweak the model and optimize hyperparameter.

## Conclusions

It is important to note that **FLAX is currently still in alpha** and is not an official Google product.

The work so far gives hope for a **fast, lightweight and highly customizable ML framework**. What is completely missing so far is a data-input pipeline, so Tensorflow still has to be used.

The current set of optimizers is unfortunately limited to ADAM and SGD with Momentum. I especially liked the very strict forward direction of how to use this framework and the high flexibility.

My next plans are to develop some activation features that are not yet available. Also a speed comparison between Tensorflow, PyTorch and FLAX would also be very interesting.