Training Axon Models With Nvidia GPUs

Published April 29, 2023 by Toran Billups

A few months back I started a deep dive into machine learning. With all the excitement about Nx I spent the first few weeks building a toy example that solves fizzbuzz, first with Axon and later with Nx. After getting more familiar with Axon I started tinkering with the BERT fine tuning example and found the feedback loop was 45+ minutes.

I didn't have a huge budget but I knew a previous generation nvidia card like the RTX 3060 would improve the turnaround time allowing me to train models more quickly. After looking at some benchmarks and considering a few alternatives I decided to order the 12GB model and take it for a spin.

I started by looking at Nvidia support for linux and decided to install Pop!OS. Next I installed elixir with asdf to get a working dev enviornment before attempting to optimize it further. From a vanilla install of Pop!OS I found nothing was installed for me by default so I had to list the nvidia drivers and install the latest stable driver.

    sudo ubuntu-drivers list
    sudo ubuntu-drivers install nvidia-driver-525
    sudo apt install system76-cuda-11.2 system76-cudnn-11.2

Finally, I exported 2 environment variables that inform the runtime about the installed cuda version and path.

    export XLA_TARGET=cuda111
    export XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda-11.2

With this Elixir workstation you can generate training data, train your models and even serve them with Nx!

Despite all the promise and obvious speed improvements this GPU has to offer I found the fine tuning example I started my journey with throws out of memory errors during the 2nd epoch because the RAM usage jumps to 32GB with this specific BERT model.

I was however able to complete the fine tuning example with a slightly smaller BERT variant. Here is the full source for that elixir module for those interested.

Buy Me a Coffee

Twitter / Github / Email