BERT from scratch with Nx

Published June 22, 2025 by Toran Billups

The more time I spend in machine learning, the more my curiosity grows. That curiosity turned into action last fall when I encountered a truly personal need: I wanted better tools to find what I was looking for in Scripture.

So often I find myself searching for half-remembered phrases or themes—"number our days" when I meant "number my days," or hunting for passages about a topic without knowing the exact words. And while this might seem small, this deeply personal problem set the stage for a nearly year long journey.

the search problem

I typically flip between the YouVersion bible app and Dwell because each has its strengths and weaknesses. I was shocked to learn that the Dwell app in particular failed spectacularly at search, returning zero results in the vast majority of my experience. At this point I realized that the gap between what I was searching for and what the software could find represented a fascinating technical problem: how do you build search that understands both keywords and the meaning?

What started as a search problem quickly turned into the most technically challenging adventure of my software career. I didn't know it then but I would go on to build a BERT-like encoder from scratch —not just to solve my immediate need, but to learn how embeddings work at a more fundamental level. While the end result was facinating, getting there was a masterclass in making every possible mistake.

The Vision

My goal was straightforward: build a vector search with my home grown encoder. I wanted to understand embeddings from the ground up and build a truly useful search experience for combing through the New Testament.

The constraint that made this interesting was my decision to use only two dependencies: Polaris and Tokenizers. Everything else would be vanilla implementations, forcing me to understand each component intimately.

I started by creating a dataset from a few different translations of the bible chunked into 100 token examples. I used the pre-training task, Masked Language Modeling, to train embeddings from my biblical text corpus. This training process revolves around having the model guess masked tokens in a sequence and updating the weights after calculating cross-entropy loss.

With the original BERT paper in view I was equally inspired and determined to reproduce their success on a narrow and much smaller text corpus. I used the BERT tokenizer and much of that breakthrough research including 768 dimensions for the output vectors. One notable difference was my training sequence with a length of 104 tokens (think context length). This worked well for my search use case because the longest New Testament verse is roughly between 100-110 tokens, and the average is much lower.

Fail Fast

I documented the journey as I went and looking back I realized it was a series of mistakes / learnings. But like all of my adventures this was the most interesting part of the story so here are the more memorable notes and what I've learned since.

Start Small

The biggest trap I fell into early on was assuming the compiler would find bugs for me. I jumped straight into training runs with complex architectures, convinced that more layers meant better results. I was so optimistic that this would be like any other software hobby project I've done over the years. Unfortunately that familiar mindset would be shaken by the reality that comes in the details.

I should have started by trying to memorize 100 tokens over 10 epochs to validate simple base assumptions instead of sprinting ahead to add another layer, more attention heads, etc. There are no shortcuts—you need to verify your math step-by-step on a tiny dataset before scaling up.

I spent weeks debugging issues that would have been obvious with a simple setup. Start with one layer, confirm you can see simple mirroring and overfitting, then slowly add capacity. Longer term you'll get closer to generalization, but this burden of proof will provide a more stable foundation as you get started.

Variable Length

After I had simple memorization working, I quickly fell into a performance issue that tormented me for weeks. What I wish I knew then is that variable length sequences require you add pad tokens. The trouble wasn't the variability but that I didn't update my attention mechanism to skip these low-value positions.

The attention mechanism was literally attending to padding tokens, learning patterns from nonsense data. This single masking issue was hiding behind what appeared to be architecture problems, leading me down countless wrong paths. Once I properly implemented padding masks, performance jumped dramatically.

Gradient Clipping

For the longest time, I never computed or printed my gradients to evaluate them. I was hypnotized by my training and validation loss curves, completely ignoring the underlying mathematics.

I spent some time digging into the Polaris library so I could compute gradients, and with that additional information I was finally able to see that I had very large gradients during training, which gave me erratic and inconsistent results at best. From here, I read about gradient clipping and discovered how the original BERT team clipped at 1.0—which means they just scaled the updates down.

This single change transformed my erratic training into steady, predictable progress. The learning: your gradients are telling you a story. Listen to them.

Small Wins Compound

Several smaller changes had outsized impacts on performance:

Positional Encoding

Early on, I had a hyperparameter nightmare trying to optimize positional embeddings. I later discovered I could use the more simplistic sinusoidal approach, which improved performance yet again while eliminating dozens of hyperparameters.

GeLU

At some point along the way, I made the switch from ReLU to GeLU, which had a noticeable performance increase. I didn't find examples of this implementation in the wild with Nx, so I had to hammer this out with help from Gemini. The smooth gradients of GeLU proved much more effective for my use case.

Learning Rate

I spent countless hours doing training runs toward the end to find the sweet spot for learning rate. This depends so much on the size and diversity of your dataset, but it also played a crucial role as I scaled up the number of layers and attention heads.

Pro tip: stay patient. You will need to see a large number of epochs to properly validate your best learning rate. The interaction between learning rate and model complexity is non-obvious, and rushing this optimization cost me more GPU hours than any other single factor.

Memory Leaks

As I started to scale up my encoder, I found my RTX 4090 would run out of vRAM by epoch 8 or 10. So I put together a Docker container and took it to the cloud with RunPod using my runpod-cuda12 setup.

I trained for 16 hours but even the 80GB H100 ran out of vRAM —a clear signal that I was doing something wrong. That's when I realized I had failed to free up memory consistently. With each epoch, the first batch was held forever, which explains why epoch 8-10 fell over during my local training runs.

After this quick change, I was able to cut my cloud spending and go back to training at home with just one more trick. I cracked open the Nx types.ex file, in the deps directory, and changed the default float from f32 to bf16. This allowed me to run without further compromise in terms of layers, dimensionality, or attention heads.

Adam

Early in my journey, switching from SGD to the Adam optimizer made the biggest single performance improvement I experienced. I was completely stuck with a validation loss hovering around 2.0, but after making this shift I was eventually able to reach 1.29. This change was so dramatic it helped encourage me to keep going when progress felt impossible. The adaptive learning rates that Adam provides were perfect for the complex parameter space I was navigating.

Shout out to Sean Moriarity for the Polaris library! This provided all the plumbing code so I could switch with ease.

Warmup

Without proper warmup and linear decay, I hit another significant performance wall. The original BERT team used a 10,000 step warmup, and understanding why this matters was crucial to unlocking my final performance tier.

Warmup allows your model to ease into the optimization landscape rather than taking massive, destabilizing steps from the beginning. Combined with linear decay, this created the stable training dynamics I needed for consistent convergence.

Dropout

I spent extensive time optimizing dropout for both my dataset and the complexity of my encoder. I did find the sweet spot in my final training run where validation loss and training loss were almost identical as the training loss plateaued.

Finding that 0.143 dropout rate required systematic experimentation, but the satisfaction of seeing perfect convergence made every hour worthwhile. Dropout needs retuning as your architecture complexity changes or you add/subtract from your dataset—there's no universal value.

Bigger Picture

Dataset

More than anything, I spent months building, curating, and tuning my dataset. I knew I couldn't achieve the same size and diversity the original BERT team had, but I was pleasantly surprised with what I could accomplish with just under 400,000 unique training examples.

The iterative process of dataset refinement taught me that data quality trumps quantity at smaller scales. Every hour spent cleaning and curating paid dividends in final model performance. If I had to synthesize what I would do differently, it comes down to this: look at the data more. I underestimated the time involved in data work, which has become a recurring theme throughout my machine learning journey.

Architecture

Like all other hyperparameters, tuning the number of layers and attention heads takes considerable patience. You need to complete extensive training runs with fewer layers than you need and more layers than you need to find the sweet spot for both your dataset and the linguistic complexity you're trying to capture.

I systematically tested different configurations, learning that my 11-layer, 12-head architecture hit the sweet spot between underfitting simple patterns and overfitting my biblical text domain. The 280 epochs would normally take 48-50 hours in total training time depending on my dataset, which meant each architectural experiment required significant commitment.

The Payoff

At the end of my journey, I have an encoder that produces embeddings for biblical text that works exceptionally well for my search use case. I combined this with a BM25 implementation to produce a search experience that is truly useful for spelunking around the New Testament.

While I owe much of my understanding to the original BERT paper and various online resources, the deep learning that came from implementing everything myself proved invaluable. Each mistake along the way unlocked deeper learning and I'm grateful in heinsight.

The journey taught me that transformers aren't magic—they're sophisticated but understandable mathematical constructs. And sometimes, the best way to truly understand something is to build it yourself, one mistake at a time.


Buy Me a Coffee

Twitter / Github / Email