Learning Jax
Tutorial link
Aim of this task: train Flax model on MNIST to understand basic Flax structure.
Task 1: MNIST
- flax.nnx: similar to torch.nn
- nnx.Module acts in same way as torch.nn.Module
- Utilise call for forward pass rather than forward
- In the tutorial, they use non-flax functions for functools. These have to be jax compatible and not contain parameters which require gradients.
There are however, jax versions provided, which I will likely use in future to avoid confusion.
- Define loss function, training etc. in function and then jax compile them.
- Very fucntional… example:
def train_step(model: CNN, optimiser: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label'])
optimiser.update(grads)