· Deep Tech  · 3 min read

JAX XLA: Why Your GPU is Idle 40% of the Time

Recompilation is the silent killer of training throughput. If you see 'Jit' in your profiler, you are losing money. We dive into XLA internals.

Recompilation is the silent killer of training throughput. If you see 'Jit' in your profiler, you are losing money. We dive into XLA internals.

The “Jit” Trap

You switched to JAX. You decorated your training step with @jax.jit. You feel like a performance wizard. Then you open your profiler, and you see it: Gaps. Massive, 500ms gaps between kernel executions where the GPU utilization drops to 0%.

Your expensive H100 cluster is sitting idle for 40% of the training run. The culprit isn’t the network. It’s Recompilation.

How XLA Actually Works (The Internals)

JAX doesn’t execute code; it traces it. When you call a jit function, JAX lowers your Python operations into an HLO (High-Level Optimizer) graph. You can actually see this graph.

Set this envar before running your script:

export XLA_FLAGS="--xla_dump_to=/tmp/xla_dump"

You will get a folder full of module_0001.before_optimizations.txt files. If you see hundreds of these files, you have a recompilation bug.

The “Python Control Flow” Trap

This is the most common mistake senior engineers make:

# ❌ BAD: Evaluated at TRACE time
def train_step(x):
    if x.shape[0] < 32:  # Python Branch
        return pad(x)    # Triggers NEW compilation for every shape!
    return model(x)

Because if is a Python keyword, JAX runs it during tracing. If x.shape[0] changes, JAX assumes the graph structure has changed and recompiles everything.

The Fix: jax.lax.cond

To branch without recompiling, you must keep the logic inside the XLA graph using primitives:

# ✅ GOOD: Evaluated at RUN time
def train_step(x):
    return jax.lax.cond(
        x.shape[0] < 32,          # Predicate
        lambda s: pad(s),         # True Branch
        lambda s: model(s),       # False Branch
        x                         # Operand
    )

Now, XLA compiles one graph that contains both branches. The shape is dynamic (or padded), but the graph topology is static. No recompilation.

The Shape Polymorphism Killer

XLA kernels are specialized for specific input shapes. If your input batch size changes—even by one digit—XLA throws away the old kernel and compiles a new one.

# The Silent Killer
def train_step(batch):
    # If the last batch is smaller (remainder),
    # JAX triggers a full recompilation!
    ...

In a dataset of 10,005 examples with batch size 100, the final batch has size 5. JAX sees a new shape. It triggers a compile. If you are doing dynamic padding or variable-length sequences without careful bucketing, you might be recompiling every single step.

Diagnosing the Idle Time

Don’t guess. Use jax.profiler.

jax.profiler.start_trace("/tmp/tensorboard")
# Run your step
jax.profiler.stop_trace()

Open the trace in TensorBoard. Look for “XLA Compile” or “Jit” blocks on the CPU thread. If they appear after the first few steps, you have a bug.

The Fix: Padding and Bucketing

  1. Pad to Fixed Shapes: Always pad your last batch to the full batch size and use a mask to zero out the loss.
  2. Bucket aggressively: If handling variable sequence lengths, bucket them into powers of 2 (128, 256, 512). Compile 3 kernels instead of 500.
  3. AOT Compilation: For production inference, use jax.aot (Ahead-of-Time) compilation to ensure you freeze the exact shapes and never compile at runtime.

A GPU is a thoroughbred racehorse. Don’t make it wait for the jockey to read the map.

Back to Blog

Related Posts

View All Posts »
The Compute-to-Cashflow Gap

The Compute-to-Cashflow Gap

The AI industry is shifting from celebrating large compute budgets to hunting for efficiency. Your competitive advantage is no longer your GPU count, but your cost-per-inference.