This is a big update.
Exciting new features!
- Added `filter_vmap`.
- This can be used to create ensembles of models.
- (Closes 65.)
- Added `filter_pmap`.
- (Closes 65.)
- Added pooling layers:
- `eqx.nn.Pool`
- `eqx.nn.AvgPool1d`
- `eqx.nn.AvgPool2d`
- `eqx.nn.AvgPool3d`
- `eqx.nn.MaxPool1d`
- `eqx.nn.MaxPool2d`
- `eqx.nn.MaxPool3d`
- (Closes 59.)
- (Thanks to Benjamin-Walker for implementing this.)
- Added `tree_serialise_leaves` and `tree_deserialise_leaves`.
- This can be used to save and load models to file.
- (Closes 46.)
- (Thanks to Jaschau for helpful discussions on this.)
- Added `tree_inference`, as a convenience for toggling all inference flags through a model.
Refactoring for nicer APIs
- `filter_{jit,grad,value_and_grad}` now have an easier-to-use API for specifying which arguments have what behaviour.
- Instead of having to specify `(args, kwargs)` as a single PyTree, then you can specify a `default`, `args`, `kwargs` separately. In particular this avoids doing messy stuff like `filter_spec=((...), {})` when you had no kwargs.
- You no longer have to match up the filter specification for `args` and `kwargs` against their runtime values. Both the runtime values, and the filter specification, are matched up against the function signature.
e.g. you can do `filter_jit(lambda x: x, kwargs=dict(x=True))(1)`, using a keyword argument at JIT-time and a positional argument at call time.
- Currying is available: both `filter_jit(fun)` and `filter_jit(default=...)(fun)` will work.
- The old API is still available for backward compatibility, of course.
- (Closes 48.)
- `tree_at` can now replace subtrees, and not just leaves.
- (Closes 47.)
- `filter`, `partition` now support an `is_leaf` argument.
- (Closes 68.)
Miscellaneous
- Calling `filter_jit(filter_grad(fun))` twice will no longer lead to unnecessary recompilation: the second `filter_grad(fun)` instance will be a PyTree that looks like the first `filter_grad(fun)` instance, and thus we won't get any recompilation.
- This is actually an improvement over standard JAX! See https://github.com/google/jax/discussions/10284.
**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.4.0...v0.5.0