Writing the BERT Encoder with Nx

Published June 22, 2025 by Toran Billups

I typically flip between the YouVersion and Dwell Bible apps, each with its own strengths. Dwell's audio experience is beautiful and immersive—well worth the price—but its search feature leaves me stranded whenever I can't remember the exact wording of a verse. YouVersion, on the other hand, nails it almost every time, even though it's free.

One evening I was searching for the verse about God giving us a new heart. I typed "give you a new heart" into Dwell and got nothing useful. The verse I wanted—Ezekiel 36:26—didn't even appear. YouVersion found it instantly.

I assumed improving Dwell's search would be a quick fix. It wasn't. What started as a simple search tweak turned into the most technically demanding project of my career—one that eventually led me to build a BERT encoder from scratch just to understand why semantic search is so difficult.

the search problem

The Vision

My goal was straightforward: build a hybrid search with a BERT-like encoder. I wanted to understand embeddings from the ground up and build a truly useful search experience for combing through the New Testament.

I would use the pre-training task, Masked Language Modeling, to train embeddings from my home grown biblical text corpus. This training process revolves around having the model guess masked tokens in a sequence and updating the weights after calculating cross-entropy loss.

With the original BERT paper in view I was equally inspired and determined to reproduce their success on a narrow and much smaller text corpus. I used the BERT tokenizer and much of the original research including 768 dimensions for the output vectors. One notable difference was my training sequence with a length of 104 tokens (context length). This worked well for my search use case because the longest New Testament verse is roughly between 100-110 tokens, and the average is much lower.

What followed was months of debugging, dead ends, and hard-won lessons. Like all of my adventures, the mistakes were the most interesting part of the story.

The First Wall: Padding Masks

The biggest trap I fell into early on was assuming the compiler would find bugs for me. I jumped straight into training runs with complex architectures, convinced that more layers meant better results. I should have started by trying to memorize 100 tokens over 10 epochs to validate simple base assumptions.

For a while I did see training loss improve, but validation loss was inconsistent. I didn't train long enough to see the plateau, which gave me a false impression that my attention implementation was working. Eventually I found that learning a single verse was possible, but when I tried to learn 10 verses—with an overfitting, repetitive dataset just to prove it could work—everything fell apart.

That's when I pulled back to inspect the math behind each component individually.

I used IO.inspect to print out the vectors for a shorter verse. Because this was shorter, the last 20 tokens were visibly padding token IDs. From here I printed each component of attention—query, key, value projected—and found that I hadn't excluded padding positions from the attention calculation.

The attention mechanism was literally attending to padding tokens, learning patterns from nonsense data. This single masking issue was hiding behind what appeared to be architecture problems, leading me down countless wrong paths.

The fix was an additive mask that sets padding positions to negative infinity before softmax, effectively zeroing them out:

defn self_attention(input, w_query, w_key, w_value, w_out, attention_mask) do
  # ... projection code ...

  attention_scores = Nx.dot(q, [3], [0, 1], k, [3], [0, 1])
  scaling_divisor = Nx.sqrt(head_dim)
  scaled_attention_scores = Nx.divide(attention_scores, scaling_divisor)

  # The fix: -1.0e8 makes padding positions effectively zero after softmax
  additive_mask = Nx.select(attention_mask, 0.0, -1.0e8)
  masked_scores = Nx.add(scaled_attention_scores, additive_mask)

  attention_weights = softmax(masked_scores)
  # ... rest of attention ...
end

Once I properly implemented padding masks, performance jumped dramatically. The learning: there are no shortcuts—you need to verify your math step-by-step on a tiny dataset before scaling up.

The Second Wall: Gradient Chaos

For the longest time, I never computed or printed my gradients to evaluate them. I was hypnotized by my training and validation loss curves, completely ignoring the underlying mathematics.

The symptom was that validation loss didn't have a nice curve. It would improve, then jump around, never settling into the steady descent I expected. So much code had been written at this point, and much of it wasn't properly validated in terms of math.

I spent some time digging into the Polaris library so I could compute gradients, and with that additional information I was finally able to see the problem: gradient explosion.

def diagnose_learning_issues(model_params, examples, base_key) do
  preprocessed = Enum.take(examples, @batch_size)
  blocksize = get_blocksize(preprocessed)

  gradient_norms = Enum.map(preprocessed, fn {batch_id_list, tar_id, mask_pos} ->
    input_ids = pad_tokens([batch_id_list], blocksize, @pad_token_id)
    input_tensor = Nx.tensor(input_ids)
    # ... tensor setup ...

    gradient = get_gradients(model_params, input_tensor, target_tensor, mask_pos_tensor, base_key)
    total_norm = calculate_gradient_norm(gradient) |> Nx.to_number()

    Nx.backend_deallocate(gradient)
    total_norm
  end)

  %{
    mean_gradient_norm: Enum.sum(gradient_norms) / length(gradient_norms),
    gradient_norms: gradient_norms
  }
end

Before epoch 6 or 8, the gradient norm would be something like 12. But as training continued I would see 18, then 20, then 33—and it never came down. I consulted with Gemini about this pattern, and you might say I took the LLM's word as gospel: this was a bad signal indicating erratic learning.

From here, I read about gradient clipping and discovered how the original BERT team clipped at 1.0. The fix was straightforward with Polaris:

{init_fn, update_fn} =
  Polaris.Updates.clip_by_global_norm(max_norm: 1.0)
  |> Polaris.Updates.scale_by_adam()
  |> Polaris.Updates.add_decayed_weights(decay: 0.01)
  |> Polaris.Updates.scale_by_schedule(schedule_fn)

init_optimizer_state = init_fn.(initial_params)

This single change transformed my erratic training into steady, predictable progress. The learning: your gradients are telling you a story. Listen to them.

The Third Wall: Memory Leaks

As I started to scale up my encoder, I found my RTX 4090 would run out of vRAM by epoch 8 or 10. CUDA out of memory. So I put together a Docker container and took it to the cloud with RunPod using my runpod-cuda12 setup.

I trained for 16 hours, but even the 80GB H100 ran out of vRAM—a clear signal that I was doing something wrong, not just hitting hardware limits.

I worked with Claude and Gemini to inspect my training loop with fresh eyes. Gemini signaled that I was likely holding vRAM between epochs. You might say I took the LLM's word as gospel again.

The problem was that I had Nx.backend_deallocate calls, but not in the innermost loop. The first batch of inputs from each epoch would just accumulate without release. With each epoch, memory grew until everything fell over around epoch 8-10 locally.

After moving deallocation into the inner loop, I was able to cut my cloud spending and go back to training at home with just one more trick: I cracked open the Nx types.ex file in my deps directory and changed the default float from f32 to bf16. This was a direct hack because the configuration doesn't yet exist in Nx, but it allowed me to run without compromise in terms of layers, dimensionality, or attention heads.

The Architecture Shift: Pre-Layer Normalization

While training loss was improving, validation loss wasn't moving with my bigger dataset. After fixing the padding mask issue, I set my sights on this problem along with trying to dial in the number of layers and attention heads.

I originally had post-layer normalization—the pattern from the original BERT paper where you normalize after the residual connection. After some chatting with Gemini about my architecture, it recommended I try pre-layer normalization instead.

The difference is subtle but important. Post-LN (original BERT):

# Post-LN: normalize AFTER residual add
attn_out = self_attention(x, ...)
add = residual_connection(x, attn_out)
norm = layer_norm(add)

Pre-LN (what I switched to):

# Pre-LN: normalize BEFORE sublayer
norm_attn = layer_norm(x)
attn_out = self_attention(norm_attn, ...)
add = residual_connection(x, attn_out)

Why does this matter? In post-LN, gradients must flow through the layer normalization after the residual connection. As you stack layers—11 in my case—this creates a compounding effect where gradients can either explode or vanish by the time they reach the early layers.

Pre-LN creates what I think of as a "gradient highway." By normalizing before each sublayer, the residual connection becomes a clean, unobstructed path from output back to input. The gradient can flow directly through the x + sublayer_output addition without passing through normalization.

Think of it like this: in post-LN, every layer is a toll booth that modifies the gradient. In pre-LN, the main highway (residual path) is toll-free, and only the side roads (attention and FFN paths) go through normalization.

The original BERT team used post-LN with 12 layers, but they had massive compute for hyperparameter tuning and careful initialization. For someone training on a 4090 in nights and weekends, pre-LN is more forgiving—you don't need to nail the learning rate and initialization as precisely. While using pre-LN seemed to work, I'll admit I still don't fully understand all the dynamics at play here.

The Turning Point: Adam

Early in my journey, I was completely stuck with a validation loss hovering around 2.0. I worked only a few hours a week in my nights and weekends, so "weeks of frustration" translated to maybe 20-30 actual hours of debugging spread over a month.

An LLM suggested I try switching from SGD to the Adam optimizer. This made the biggest single performance improvement I experienced. After making the switch, I was eventually able to reach 1.29 validation loss.

This change was so dramatic it helped encourage me to keep going when progress felt impossible. The adaptive learning rates that Adam provides were perfect for the complex parameter space I was navigating.

Shout out to Sean Moriarity for the Polaris library. It provided all the plumbing code so I could switch optimizers with ease.

Several smaller changes also had outsized impacts: switching from ReLU to GeLU activation, using sinusoidal positional encoding instead of learned embeddings (eliminating dozens of hyperparameters), and implementing proper warmup with linear decay. The original BERT team used a 10,000 step warmup, and understanding why this matters was crucial—it allows the model to ease into the optimization landscape rather than taking massive, destabilizing steps from the beginning.

The Data Reckoning

More than anything, I spent months building, curating, and tuning my dataset. I knew I couldn't achieve the same size and diversity the original BERT team had, but I was pleasantly surprised with what I could accomplish with just under 400,000 unique training examples.

The trouble was I made several mistakes that cost me significant time:

False diversity. I wanted a diverse set of Bible text and found myself using different translations—NLT, NIV, ESV, NET. This seemed like a good idea, but I didn't actually look at the dataset closely. While subtle differences exist between translations, they share similar tokens and themes. They aren't radically different. I later learned I should have included genuinely different text like Mere Christianity—rich theological writing that isn't the original Bible text.

Quote noise. Biblical text has a lot of quotation marks, and I found this was distracting to the masked token prediction with my simplistic encoder. The model was learning patterns around quote boundaries rather than semantic meaning.

Lazy first pass. While I expanded to books over time, I didn't do my best work on the first pass. A lot of that text data was less than ideal until I circled back to do real cleaning—which only happened by actually reading the text in those contextualized chunks.

If I had to synthesize what I would do differently, it comes down to this: look at the data more. I underestimated the time involved in data work, which has become a recurring theme throughout my machine learning journey.

The Payoff

Here's the twist I didn't expect: BM25 was surprisingly powerful on its own.

That original search—"give you a new heart"—landed great Ezekiel results with just BM25. No embeddings required. The benchmark results were striking:

================================================================================
BENCHMARK RESULTS COMPARISON
================================================================================
Metric                    | BM25            | ILIKE
--------------------------------------------------------------------------------
NDCG@10 (relevance)       | 0.2940          | 0.1105
Precision@10              | 0.2806          | 0.1747
Recall % (found relevant) | 66.7            | 28.1
================================================================================

BM25 provided 166% better relevance ranking than simple text matching. For most of my search use case, this was enough.

But the encoder still adds value. For thematic queries—"salvation by grace theme" or "fruit of the spirit"—the hybrid approach with embeddings helps surface verses that don't contain the exact keywords. I built a comprehensive evaluation covering keyword retrieval, thematic clustering, and distractor tests to measure this. The encoder wasn't fully optimized, but it moved the needle on complex queries.

I set out to build semantic search and discovered that sometimes the "simpler" approach wins for most cases. The hybrid still helps at the margins.

Build It Yourself

While I owe much of my understanding to the original BERT paper and various online resources, the deep learning that came from implementing everything myself proved invaluable. Each mistake along the way unlocked deeper understanding, and I'm grateful in hindsight.

The journey taught me that transformers aren't magic—they're sophisticated but understandable mathematical constructs. And sometimes, the best way to truly understand something is to build it yourself, one mistake at a time.

You can find the source code from my adventure on github.


Buy Me a Coffee

Twitter / Github / Email