JAX feels close to achieving a sort of high-level GPGPU ecosystem. Super fledgling — but I keep finding more little libraries that build on JAX (and can be compositionally used with other JAX libraries because of it).
Only problem is that lots of compositional usage leads to big code and therefore big compile times for XLA.
https://github.com/patrick-kidger/equinox?tab=readme-ov-file...
I've enjoyed using Equinox and Diffrax for performing ODE simulations. To my knowledge the only other peer library with similar capabilities is the Julia DifferentialEquations.jl package.
Use any. I used to work with Flax, now I work with Equinox more. Choose any between Flax, Equinox, and Haiku.
Will it try and bind me to other technologies?
Does it work out of the box on ${GPU}?
Is it well supported?
Will it continue to be supported?
Sadly, I cannot get JAX to work with the built-in GPU on my M1 MacBook Air. In theory it's supposed to work:
https://developer.apple.com/metal/jax/
But it crashes Python when I try to run a compiled function. And that's only after discovering I need an older specific version of jax-metal, because newer versions apparently don't work with M1 anymore (only M2/M3) -- they don't even report the GPU as existing. And even if you get it running, it's missing support for complex numbers.
I'm not clear whether it's Google or Apple who is building/maintaining support for Apple M chips though.
JAX works perfectly in CPU mode though on my MBA, so at least I can use it for development and debugging.