· Deep Tech · 9 min read
Writing Pallas Kernels for JAX: Stepping Outside the XLA Safety Net
When XLA's heuristics fail for custom attention mechanisms, you can't just hope for a compiler update. Here is how you write Triton-like kernels directly in Python using JAX Pallas.

The promise of JAX is beautiful. You write clean, functional NumPy code in Python. You wrap it in a @jax.jit decorator. You press enter, and the XLA compiler magically transforms your high-level mathematical expressions into aggressively fused, heavily optimized machine code that runs at the speed of light on Google Cloud TPUs or NVIDIA GPUs.
For 95 percent of deep learning workloads, this illusion holds up perfectly. XLA is a marvel of software engineering. It analyzes your computation graph, figures out that you are multiplying a matrix, adding a bias, and applying a ReLU activation, and it fuses all of those operations into a single, highly efficient kernel loop. It prevents the GPU from wasting precious cycles writing intermediate results back to main memory.
But then you hit the 5 percent.
You decide to implement a novel sparse attention mechanism you read about in a paper published at 3 AM on arXiv. You write the masking logic in standard jax.numpy. You jit the function. You deploy it to your GKE cluster of highly expensive accelerators. And then you open the Vertex AI TensorBoard profiler and you see what’s going on.
The throughput has cratered. The memory utilization has spiked to 99 percent. XLA did not understand the intricate block-sparsity of your custom attention map. Instead of fusing the operations gracefully, the compiler panicked. It fell back to materializing a massive, dense matrix filled mostly with zeros, saturating the HBM (High Bandwidth Memory) bus and turning your multi-million dollar supercomputer into a very expensive space heater.
When XLA’s heuristics fail, you cannot just submit a GitHub issue and hope Google engineers patch the compiler next month. You have a model to ship now. You have to step outside the safety net. You have to write the kernel yourself.
The Hardware Reality - Compute vs. Memory
Before we write custom kernels, we have to look the hardware & understand how it actually works.
Whether you are using an NVIDIA or a Google Cloud TPU, the fundamental architectural constraint is the same: the processors doing the math are starving.
The ALUs (Arithmetic Logic Units) that perform matrix multiplications are incredibly fast. The main memory where your large tensors live (the HBM) is relatively slow. The connection between them is the bottleneck. If your kernel tries to compute a value, but it has to wait 200 clock cycles to fetch a tensor slice from HBM, the ALUs sit idle.
To solve this, hardware engineers created a tiny, ultra-fast layer of memory physically located right next to the ALUs. On NVIDIA GPUs, this is called Shared Memory. On Google TPUs, it is called VMEM (Vector Memory). This fast memory can only hold a few megabytes of data.
Writing a high-performance custom kernel is universally an exercise in logistics, not mathematics. You are managing a shipping port. You have to carefully load small, precise blocks of data from the massive container ship (HBM) onto the docks (SRAM). You do the math on the docks. Then you ship the finished results back to the container ship.
If you do not explicitly manage this block-transfer process, you encounter the exact HBM bandwidth starvation that XLA’s failed heuristics caused.
Enter Pallas: Triton for the JAX Ecosystem
Historically, writing these block-level kernels meant writing raw CUDA C++ for GPUs or dropping down to low-level compiler dialects for TPUs. It was brutal, error-prone work that required weeks of debugging segmentation faults.
OpenAI’s Triton changed the game for the PyTorch ecosystem by allowing engineers to write block-level GPU kernels in Python. Pallas is JAX’s answer, and it is built natively into the JAX library (jax.experimental.pallas).
Pallas allows you to write grid-based, block-level kernels in Python that compile down to PTX for NVIDIA GPUs or LLO/Mosaic for Google Cloud TPUs. It forces you to think explicitly about loading, computing, and storing blocks of memory.
Let’s look at what happens when we implement a simplified, custom Block-Sparse Attention kernel.
Designing the Block Specification
The first conceptual leap in Pallas is that your kernel function does not operate on the entire tensor. It operates on exactly one small block.
Before we write the logic, we have to define how the massive HBM tensors are going to be diced up and fed to our kernel instances. We do this using BlockSpec.
Imagine our Attention Query tensor Q has the shape (batch, seq_len, head_dim). Our hardware SRAM is too tiny to hold the entire sequence. We decide to process it in chunks of 128 tokens at a time.
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
# We want to process Query blocks of size (batch_size, 128, head_dim)
# We use BlockSpec to map the global hardware grid mapping to the tensor slicing.
q_block_spec = pl.BlockSpec(
memory_space=pl.ANY,
# This lambda maps the loop indices (i, j) to the slice indices for Q
index_map=lambda i, j: (0, i, 0),
block_shape=(BATCH_SIZE, BLOCK_Q, HEAD_DIM)
)The index_map is the tricky part. Pallas will launch a 2D grid of kernel instances (conceptually iterating over i and j). The index_map tells Pallas: “When executing the kernel at grid coordinate (i, j), fetch the chunk of Q starting at batch 0, sequence index i * BLOCK_Q, and head dimension 0.”
We define similar BlockSpecs for our Keys (K), Values (V), and our sparse Mask matrix.
Writing the Kernel Logic: Load, Compute, Store
Now we write the actual Python function that executes on the accelerator core. This function runs asynchronously across thousands of cores simultaneously. If you write bad logic here, you don’t get a Python stack trace. You get a Silent Failure or a catastrophic NaN explosion.
Remember the shipping dock analogy. Pallas functions do not accept JAX arrays as inputs. They accept pallas.Ref objects. A Ref is essentially a raw memory pointer to a slice of HBM.
You cannot do math on a Ref. You must use pl.load to bring the data into SRAM, do your math using standard jax.numpy operations, and then use pl.store to push it back.
def sparse_attention_kernel(
q_ref, k_ref, v_ref, mask_ref, o_ref, # These are pointers to HBM
*, block_q, block_k, head_dim
):
# 1. LOAD: Fetch the blocks from slow HBM into fast SRAM
q_block = pl.load(q_ref, (pl.dslice(None), pl.dslice(None), pl.dslice(None)))
k_block = pl.load(k_ref, (pl.dslice(None), pl.dslice(None), pl.dslice(None)))
mask_block = pl.load(mask_ref, (pl.dslice(None), pl.dslice(None)))
# 2. COMPUTE: Now we are operating entirely in SRAM.
# Calculate Q * K^T (scaled dot product)
scale = 1.0 / jnp.sqrt(head_dim)
scores = jnp.einsum('bmd,bnd->bmn', q_block, k_block) * scale
# The magical sparsification step that XLA failed to fuse earlier
# If the mask is 0, we push the score to negative infinity before softmax
scores = jnp.where(mask_block, scores, -jnp.inf)
attention_weights = jax.nn.softmax(scores, axis=-1)
# Calculate Attention * V
v_block = pl.load(v_ref, (pl.dslice(None), pl.dslice(None), pl.dslice(None)))
output_block = jnp.einsum('bmn,bnd->bmd', attention_weights, v_block)
# accumulating to the output reference pointer in HBM
# Notice we don't return anything. We mutate state.
current_o = pl.load(o_ref, (pl.dslice(None), pl.dslice(None), pl.dslice(None)))
pl.store(o_ref, (pl.dslice(None), pl.dslice(None), pl.dslice(None)), current_o + output_block)Look closely at that code block. This is not declarative functional programming anymore. We are mutating state via pl.store. We are orchestrating memory movement. We have stepped out of the mathematical abstraction and into the physical reality of transistors and electrical pathways.
The jnp.where masking logic here runs instantaneously because we only paid the HBM memory bandwidth cost to load the mask_block once per chunk. We did not materialize a 32,000 x 32,000 matrix of zeros in HBM. We only materialized the small 128x128 block currently residing in our fast SRAM.
Deploying via JAX Pallas Call
To use this kernel in our standard JAX training loop running on GKE, we invoke pl.pallas_call. This takes our JAX arrays, dices them up according to our BlockSpec definitions, and schedules the grid of kernels across the TPU circuitry.
# The grid defines how many instances of our kernel we launch
# We launch a kernel for every Q block and every K block
grid = (SEQ_LEN // BLOCK_Q, SEQ_LEN // BLOCK_K)
# We wrap it in a pallas call to bridge standard JAX with our custom kernel
custom_sparse_attn_fn = pl.pallas_call(
sparse_attention_kernel,
grid=grid,
in_specs=[q_block_spec, k_block_spec, v_block_spec, mask_block_spec],
out_specs=o_block_spec,
out_shape=jax.ShapeDtypeStruct((BATCH_SIZE, SEQ_LEN, HEAD_DIM), jnp.bfloat16),
compiler_params=dict(
block_q=BLOCK_Q,
block_k=BLOCK_K,
head_dim=HEAD_DIM
)
)
# You can now use this inside your standard @jax.jit training step
o = custom_sparse_attn_fn(q, k, v, mask)You can inject this pallas_call directly into the middle of your standard ResNet or Transformer architecture. When XLA encounters it during compilation, it stops trying to optimize the inside of the function. It trusts that you know what you are doing. It compiles the surrounding jax.numpy operations as usual, and when it reaches the Pallas node, it simply drops in the highly optimized PTX or LLO binary block you just defined.
The Empathy of the Void
Writing Pallas kernels is an incredibly powerful tool to have in your infrastructure utility belt. When you are deploying custom models on and you notice your TPU instances are severely under-utilized during a specific forward pass, Pallas is your escape hatch.
But it comes with a profound trade-off.
When you use the high-level XLA compiler, the high level libraries act as your safety net. It ensure mathematical correctness, track gradient flow perfectly for backpropagation, and handle the nasty architectural nuances of memory banks and network topologies.
When you drop down to Pallas, you are working with the metal. If you define a BlockSpec incorrectly, your kernel will fetch the wrong memory address. You won’t get a helpful error message explaining that your indices are out of bounds. You will get a silent data corruption that causes your model to converge to a garbage result three hours later, or you will hard-crash the TPU firmware and bring down the Kubernetes pod.
You have to write your own gradient definitions using jax.custom_vjp. You have to exhaustively write unit tests against a naive standard jax.numpy implementation to ensure your block-wise math results in the exact same floating-point outcomes.
You are stepping into the void. And you should only do it when profiling data proves that you absolutely must.
But when you finally trace that 5ms latency spike to a poorly fused masking operation, write a clean 50-line Pallas kernel, deploy it to your cluster, and watch the idle time on your accelerators drop to zero while throughput triples - there is no better feeling in modern software engineering.



