This is a big update. The highlight here is the new `equinox.internal` namespace, which contains a slew of advanced features.
These are only "semi public". These are deliberately not in the main documentation, and exist primarily for the benefit of downstream libraries like [Diffrax](http://github.com/patrick-kidger/diffrax). But you may still have fun playing with them.
Features
- `equinox.internal.`
- Autodiff:
- `nondifferentiable`: will raise an error at trace-time if you attempt to differentiate it.
- `nondifferentiable_backward`: will raise an error at trace-time if you attempt to reverse-mode differentiate it.
- Debug tooling:
- `announce_jaxpr`: will call a custom callback whenever it is traced/transformed in a jaxpr. `print(<transform stack>)` is the default callback.
- Runtime errors:
- `error_if`: can raise runtime errors. (Works on CPU; doesn't work on TPU. GPU support may be flaky.)
- `branched_error_if`: can raise one of multiple errors, depending on a traced value.
- Floating point manipulation:
- `nextafter`: returns the next floating point number. Unlike `jnp.nextafter`, it is differentiable.
- `prevbefore`: returns the previous floating point number. Is differentiable.
- MLIR sub-graphs:
- `noinline`: used to mark that a subcomputation should be placed in a separate computation graph, e.g. to avoid compiling the same thing multiple times if it is called repeatedly. Can also be used to iteratively recompile just parts of a computation graph, if the sub/super-graph is the only thing that changes.
- Omega:
- `ω`: nice syntax for tree arithmetic. For example `(x**ω + y**ω).ω == tree_map(operator.add, x, y)`. Like [tree-math](https://github.com/google/tree-math) but with nicer syntax.
- Custom primitives:
- `filter_primitive_{def, jvp, transpose, batching, bind}`: Define rules for custom primitive that accept arbitrary PyTrees; not just JAX arrays.
- `create_vprim`: Autodefines batching rules for higher-order primitives, according to `transform(vmap(prim)) == vmap(transform(prim))`.
- String handling:
- `str2jax`: turns a string into a JAX'able object.
- Unvmap'ing:
- `unvmap_{any, all, max}`: apply reductions whilst ignoring the batch dimension.
- New filtered transformations: `eqx.{filter_jvp,filter_custom_jvp}`
Bugfixes / backward incompatibilities
- `eqx.nn.GRUCell` will now use its bias term. (Previously it was never adding this.)
- `eqx.filter_eval_shape` will no longer promote array-likes to arrays, in either its input or its output.
- `eqx.tree_equal` now treats JAX arrays and NumPy arrays as equal.
Misc
- Improved compilation speed of `eqx.filter_vmap`.
New Contributors
* jondeaton made their first contribution in https://github.com/patrick-kidger/equinox/pull/204
* IsaacBreen made their first contribution in https://github.com/patrick-kidger/equinox/pull/215
**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.8.0...v0.9.0