
Writing Pallas Kernels for JAX: Stepping Outside the XLA Safety Net
When XLA's heuristics fail for custom attention mechanisms, you can't just hope for a compiler update. Here is how you write Triton-like kernels directly in Python using JAX Pallas.

When XLA's heuristics fail for custom attention mechanisms, you can't just hope for a compiler update. Here is how you write Triton-like kernels directly in Python using JAX Pallas.

Using a 'Draft' model costs 10% more VRAM but saves 50% Latency. Here is the mechanics of the gamble.

A model is only as smart as its router. We explore the physics of expert zones, the tax of token dropping, and how to keep your load balancer honest.

Recompilation is the silent killer of training throughput. If you see 'Jit' in your profiler, you are losing money. We dive into XLA internals.

The AI industry is shifting from celebrating large compute budgets to hunting for efficiency. Your competitive advantage is no longer your GPU count, but your cost-per-inference.
FP4 isn't just 'lower precision' - it requires a fundamental rethink of activation outliers. We dive into the bit-level implementation of NVFP4, Micro-Tensor Scaling, and the new Tensor Memory hierarchy.