If your work is well-served by existing libraries, great! There's no need to compete against something that's already working well. But that's frequently not the case for modeling, simulation, differential equations, and SciML.
But I think a reasonably compentent Python/JAX programmer can roll out whatever they need relatively easily (especially if you want to use the GPU). I do miss Tullio, though.
Another example: It's frustrating that Flax had to implement it's own "lifted" transformations instead of being able to just use jax transformations -- which makes it impossible to just slot a Flax model into a jax library that integrates ODEs. Equinox might be better on this front, but that means that all the models now need to be re-implemented in Equinox. The fragmentation and churn in the Python ecosystem is outrageous -- the only reason it doesn't collapse under its own weight is how much funding and manpower ML stakeholders are able to pour into the ecosystem.
Given how much the ecosystem depends on that sponsored effort, the popular frameworks will likely prioritize ML applications, and corollary use cases will be second class citizens in case of design tradeoffs. Eg: framework overheads matter less when one is trying to use large NN models -vs- when one is trying to use small models, or other parametric approaches.
Also, IIRC, it’s not terribly difficult to use flax with equinox. It’s just a matter of storing the weight dict and model function in an equinox module. Filter_jit will correctly recognize the weights as a dynamic variable and the flax model as a static variable.
You mean in terms of the ODE stuff, Julia provides?
For simulations, JAX will choke on very “branchy” computations. But, honestly I’ve had very little success differentiating through those computations in the first place and they don’t run well on the GPU. Thus, I’m generally inclined to use wrappers around C++ (or ideally Rust) for those purposes (my use-case is usually some rigid-body dynamics style simulation).