
Writing Pallas Kernels for JAX: Stepping Outside the XLA Safety Net
When XLA's heuristics fail for custom attention mechanisms, you can't just hope for a compiler update. Here is how you write Triton-like kernels directly in Python using JAX Pallas.

When XLA's heuristics fail for custom attention mechanisms, you can't just hope for a compiler update. Here is how you write Triton-like kernels directly in Python using JAX Pallas.

Recompilation is the silent killer of training throughput. If you see 'Jit' in your profiler, you are losing money. We dive into XLA internals.