
Writing Pallas Kernels for JAX: Stepping Outside the XLA Safety Net
When XLA's heuristics fail for custom attention mechanisms, you can't just hope for a compiler update. Here is how you write Triton-like kernels directly in Python using JAX Pallas.

When XLA's heuristics fail for custom attention mechanisms, you can't just hope for a compiler update. Here is how you write Triton-like kernels directly in Python using JAX Pallas.