JAX compatibility
Recent versions of JAX (0.4.28+) have made some changes to:
- Hashing of tracers;
- Tree-map'ing over Nones;
- Callbacks;
- Pretty-printing.
With this update, we should now be compatible with both old and new versions of JAX: this fixes both some new crashes, and some new warnings. (719, 724, 753, 758, thanks jakevdp, hawkinsp!)
Better errors
- The error messages from `eqx.error_if` are now substantially more informative: they include traceback information including the stack, and mention the availability of the `EQX_ON_ERROR` variable. We also do a much better job hiding the large unhelpful printouts that XLA gives by default. (785, 803)
- The default value of `EQX_ON_ERROR_BREAKPOINT_FRAMES` is now `1`. (777) The impact of this is that using `eqx.error_if` alongside `EQX_ON_ERROR=breakpoint` will now:
- reliably always open a debugger, rather than sometimes crashing at trace-time due to upstream JAX bug [16732](https://github.com/google/jax/issues/16732).
- however, by default the debugger will no longer include any additional stack frames above it (accessed via `u`).
- much of the above is now explained in a printed-out informative message prior to the debugger opening.
Bugfixes
- `eqx.filter_{jacfwd, jacrev}` now only apply filtering to their inputs but not their outputs. Previously this was problematic as there was no way to represent static-input-by-static-output in the returned Jacobian, so pieces were silently dropped. (734, thanks lockwo!)
- `eqx.tree_at` can now be used to replace empty tuples. (715, 717, 722, thanks lockwo!)
- `eqx.filter_custom_jvp` no longer raises a trace-time crash in some scenarios in which its `**kwargs` were erroneously counted as having tangents. (https://github.com/patrick-kidger/equinox/issues/745#issuecomment-2148560546, 749)
- No longer getting a trace-time crash when doing a particular combination of vmap + autodiff + checkpointed while loops. This occurred when using `optimistix.BFGS` around `diffrax.diffeqsolve`. (777)
- Fixed a trace-time crash when:
- using a checkpointed while loop...
- ...with a body function that has a closed-over tracer...
- ...and that closed-over tracer is differentiated...
- ...and there are no other closed-over tracers that are differentiated...
- ...and the dependency on that tracer is only linear.
- (https://github.com/patrick-kidger/diffrax/pull/387#issuecomment-2132472392, 752, thanks dkweiss31!)
- Fixed a trace-time crash when composing the grad of vmap of `lineax.linear_solve`. (https://github.com/patrick-kidger/lineax/issues/101, #795, thanks rhacking!)
- `eqx.nn.RMSNorm` now uses at least 32-bit precision for numerical stability (723, thanks AakashKumarNain!)
New features
- `eqx.nn.{Linear,Conv,GRUCell,LSTMCell}` now support complex dtypes (765, thanks ChenAo-Phys!)
- Added `eqx.nn.RotaryEmbedding(..., theta=...)`. (735, thanks Artur-Galstyan!)
Other changes
- Several doc fixes. (708, 731, 733, 747, 750, 757 + several other PRs, thanks Artur-Galstyan, matteoguarrera, lockwo, nasyxx!)
- Several internal test fixes as downstream libraries have changed slightly. (740, 742 + several other PRs, big thanks to GaetanLepage for reporting many of these!)
- There is now a Mistral 7B implementation using JAX+Equinox available over in [AakashKumarNain/mistral_jax](https://github.com/AakashKumarNain/mistral_jax)!
New Contributors
* nasyxx made their first contribution in https://github.com/patrick-kidger/equinox/pull/708
* jakevdp made their first contribution in https://github.com/patrick-kidger/equinox/pull/724
* matteoguarrera made their first contribution in https://github.com/patrick-kidger/equinox/pull/739
**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.11.4...v0.11.5