Inox

Latest version: v0.6.3

Safety actively analyzes 688027 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 3 of 6

0.10.2

This release has lots of examples and bugfixes from several new contributors!

**Features**

- `eqx.nn.{Linear, MLP}` now support the string `"scalar"` for their input and output sizes, to produce an array of shape `()` rather than an array of shape `(1,)`.
- Added `equinox.internal.scan` for a checkpointed scan implementation. (It'd be interesting to see this used for an optimally-checkpointed [scan-over-layers](https://docs.kidger.site/equinox/tricks/#improve-compilation-speed-with-scan-over-layers) in an LLM?)

**Documentation**

- Much nicer examples! Big thanks to:
- Artur-Galstyan for contributing a CNN-on-MNIST example;
- to j5b for contributing a BERT example;
- to Benjamin-Walker for contributing a U-Net example.

**Bugfixes**

- `eqx.filter_closure_convert` and `eqx.internal.while_loop` now work with tree-math.
- Improved numerical stability of `MultiheadAttention`, and fixed it producing NaNs in fully-masked case. (Thanks j5b!)
- Fixed (the deprecated, but still) `deterministic=True` being ignored in `MultiheadAttention`. (Thanks mk-0!)
- `__new__` can now be overridden in subclasses of `eqx.Module`. (Thanks ASEM000!)

**Misc**

- Now using ruff and pyright. (No longer using flake8 or isort.)
- Modules are now private-by-default, e.g. `equinox._jit` instead of `equinox.jit`. If you're broken by this change then you should make sure to import from the public interface: e.g. `equinox.filter_jit` instead of `equinox._jit.filter_jit`.
- `equinox.internal.while_loop(..., kind="checkpointed")` now supports readable buffers.
- `eqx.filter_vmap` now supports all-`None`s in `in_axes`. (Thanks RaderJason!)

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.1...v0.10.2

0.10.1

The usual post-release hotfix.

See the [v0.10.0 release notes](https://github.com/patrick-kidger/equinox/releases/tag/v0.10.0) for the interesting recent changes.

**Changes in this release**

- Fixed a number of English typos in strings and error messages.
- Fixed a couple of type annotations. (Thanks dhirschfeld in 262)
- Removed spurious use of `typing_extensions`.
- Fixed `eqx.filter_{vmap, pmap}(in_axes=dict(...), ...)` crashing when used alongside default arguments.

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.0...v0.10.1

0.10.0

Highlights

1. A dramatically simplified API for `equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap}` . This is a backward-incompatible change.

2. `equinox.internal.while_loop`, which is a reverse-mode autodifferentiable while loop, using recursive checkpointing.

Full change list

New features

Some new relatively minor new features available in this release.

- Added support for donating buffers when using `eqx.{filter_jit, filter_pmap}`. (Thanks uuirs in 235!)
- Added `eqx.nn.PRelu`. (Thanks enver1323 in 249!)
- Added `eqx.tree_pprint`.
- Added `eqx.module_update_wrapper`.
- `eqx.filter_custom_jvp` now supports keyword arguments (which are always treated as nondifferentiable).

New `internal` features

Introducing a slew of new features for the advanced JAX user.

These are all available in the `equinox.internal` namespace. Note that these comes without stability guarantees, as they often depend on functionality that JAX doesn't make fully public.

- `eqxi.abstractattribute`, for marking abstract instance attributes of abstract Equinox modules.
- `eqxi.tree_pp`, for producing a pretty-print doc of an object. (This is what is then formatted to a particular width in e.g. `eqx.tree_pformat`.) In addition classes can now have custom pretty behaviour when used with `eqx.{tree_pp, tree_pformat, tree_pprint}`, by setting a `__tree_pp__` method.
- `eqxi.if_mapped`, as an alternative to the usual `eqx.if_array` passed to `eqx.{filter_vmap, filter_pmap}(out_axes=...)`.
- `eqxi.{finalise_jaxpr, finalise_fn}` for tracing through custom primitives `impl` rules (so that the custom primitive no longer appears in the jaxpr). This is useful for replacing such custom primitives prior to offloading a jaxpr to some other IR, e.g. via `jax2tf`.
- `eqxi.{nonbatchable, nondifferentiable, nondifferentiable_backward, nontraceable}` for asserting that an operation is never batched, differentiated, or subject to any transform at all.
- `eqxi.to_onnx` for exporting to ONNX.
- `eqxi.while_loop` for reverse-mode autodifferentiable while loops; in particular making use of recursive checkpointing. (A la treeverse.)

Backward-incompatible changes

- The API for `equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap}` has been dramatically simplified. If you were using the extra arguments to these functions (i.e. not just calling `eqx.filter_jit` etc. directly) then this is a backward-incompatible change; see the discussion below for more details.
- Removed `equinox.nn.{AvgPool1D, AvgPool2D, AvgPool3D, MaxPool1D, MaxPool2D, MaxPool3D}`. Use `AvgPool1d` etc. (lower-case "d") instead. (These were backward-compatiblity stubs that have now been removed.)
- Removed `equinox.Module.{tree_flatten, tree_unflatten}`. These were never technically public API; use `jax.tree_util.{tree_flatten, tree_unflatten}` instead.
- `equinox.filter_closure_convert` now asserts that you call it with argments compatible with those it was closure-converted with.
- Dropped support for Python 3.7.

Other

- The Python overhead when crossing a `filter_jit` or `filter_pmap` boundary should now be much reduced.
- `eqx.tree_inference` now runs faster. (Thanks uuirs in 233!)
- Lots of documentation improvements; in particular a new "Tricks" section forsome advanced notes. (Thanks carlosgmartin in 239!)

Filtered transformation API changes (AKA: "my code isn't working any more?")

These APIs have been simplified and made much easier to understand. No functionality has been lost, things might just need tweaking.

`filter_jit`

This previously took `default`, `args`, `kwargs`, `out`, `fn` arguments, for controlling what should be traced and what should be held static.

In practice all JAX arrays and NumPy arrays always had to be traced, and everything that wasn't a JAXable type (JAX array, NumPy array, `bool`, `int`, `float`, `complex`) had to be held static. So these arguments just weren't that useful: pretty much the only thing you could do with them was to specify that you'd like to trace a `bool`/`int`/`float`/`complex`.

This minor use-case wasn't worth complicating such an important API for, which is why these arguments have been removed.

If after this change you still want to trace with respect to `bool`/`int`/`float`/`complex`, then do so simply by wrapping them into JAX arrays or NumPy arrays first: `np.asarray(x)`.

`filter_grad` and `filter_value_and_grad`

These previously took an `arg` argument, for controlling what parts of the first argument should be differentiated.

This was useful occasionally -- e.g. when freezing parts of a layer -- but in practice it still wasn't used that often. As such it this argument has been removed for the sake of simplicity.

If after this change you want to replicate the previous behaviour, then it is simple to do so using `partition` and `combine`:
python
Before
eqx.filter_grad(arg=foo)
def loss(first_arg, ...):
...

loss(bar, ...)

After
eqx.filter_grad
def loss(diff_first_arg, static_first_arg, ...):
first_arg = eqx.combine(diff_first_arg, static_first_arg)
...

diff_bar, static_bar = eqx.partition(bar, foo)
loss(diff_bar, static_bar, ...)

See also the updated [frozen layer](https://docs.kidger.site/equinox/examples/frozen_layer/) example for a demonstration.

`filter_vmap`

This previously took `default`, `args`, `kwargs`, `out`, `fn` arguments, for controlling what axes should be vectorised over.

In practice this API was just a bit more complicated than it really needed to be. The only useful feature relative to `jax.vmap` was `kwargs`, for easily specifying just a few named arguments that should behave differently.

The new API instead accepts `in_axes` and `out_axes` arguments, just like `jax.vmap`. To replace `kwargs`, one extra feature is supported: `in_axes` may be a dictionary of named argments, e.g.
python
eqx.filter_vmap(in_axes=dict(bar=None))
def fn(foo, bar):
...

All arguments not named in `kwargs` will have the default value of `eqx.if_array(0) -> 0 if is_array(x) else None` applied to them.

On which note, a new `eqx.if_array(i)` now exists, to make it easier to specify values for `in_axes` and `out_axes`.

If you were using the old `fn` argument, then this can be replicated by instead decorating a function that accepts the callable:
python
Before
eqx.filter_vmap(foo, fn=bar)(x, y)

After
eqx.filter_vmap(in_axes=dict(fn=bar))
def accepts_foo(fn, x, y):
return fn(x, y)

accepts_foo(foo, x, y)


`filter_pmap`.

This previously took `default`, `args`, `kwargs`, `out`, `fn` arguments, for controlling what axes should be parallelised over, and which arguments should be traced vs static.

This was a fiendishly complicated API merging together both the `filter_jit` and `filter_vmap` APIs.

The JIT part of it is now handled automatically, as with `filter_jit`: all arrays are traced, everything else is static.

The vmap part of it is now handled in the same way as `filter_vmap`, using `in_axes` and `out_axes` arguments.

New Contributors
* carlosgmartin made their first contribution in https://github.com/patrick-kidger/equinox/pull/239
* enver1323 made their first contribution in https://github.com/patrick-kidger/equinox/pull/249

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.9.2...v0.10.0

0.9.2

Autogenerated release notes as follows:

What's Changed
* Minor doc fixes by patrick-kidger in https://github.com/patrick-kidger/equinox/pull/228
* Allow passing file-like objects to eqx.serialise/deserialise by jatentaki in https://github.com/patrick-kidger/equinox/pull/229
* Fixed broken `filter_closure_convert` (and new JAX breaking Equinox's experimental stateful operations) by patrick-kidger in https://github.com/patrick-kidger/equinox/pull/232


**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.9.1...v0.9.2

0.9.1

New features

These are all pretty self-explanatory!
- `equinox.filter_make_jaxpr`
- `equinox.filter_vjp`
- `equinox.filter_closure_convert`
- `equinox.filter_pure_callback`

Also:
- `equinox.internal.debug_backward_nan(x)` will print out the primal and cotangent for `x`, and if the cotangent has a NaN then the computation is halted.

Bugfixes

- `equinox.{is_array, is_array_like, is_inexact_array, is_inexact_array_like}` all now recognise NumPy scalars as being array types.
- `equinox.internal.{error_if, branched_error_if}` are now compatible with `jax.ensure_compile_time_eval`.
- `equinox.internal.noinline` will now no longer throw an assert error during tracing under certain edge-case conditions. (In particular, when part of the branched of a `vmap`'d `lax.cond` with batched predicate.)
- `equinox.tree_pformat` now prints out `jax.tree_util.Partial`s, and dataclass types (not instances) correctly.

Tweaks:

- `equinox.internal.noinline` is now compatible with `jax.jit`, i.e. a `noinline`-wrapped function can be passed across a jit API boundary. (Previously `equinox.filter_jit` was required.)
- `equinox.internal.announce_jaxpr` has been renamed to `equinox.internal.announce_transform`.
- `equinox.internal.{nondifferentiable, nondifferentiable_backward}` now take a `msg` argument for overriding their error messages.

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.9.0...v0.9.1

0.9.0

This is a big update. The highlight here is the new `equinox.internal` namespace, which contains a slew of advanced features.

These are only "semi public". These are deliberately not in the main documentation, and exist primarily for the benefit of downstream libraries like [Diffrax](http://github.com/patrick-kidger/diffrax). But you may still have fun playing with them.

Features
- `equinox.internal.`
- Autodiff:
- `nondifferentiable`: will raise an error at trace-time if you attempt to differentiate it.
- `nondifferentiable_backward`: will raise an error at trace-time if you attempt to reverse-mode differentiate it.
- Debug tooling:
- `announce_jaxpr`: will call a custom callback whenever it is traced/transformed in a jaxpr. `print(<transform stack>)` is the default callback.
- Runtime errors:
- `error_if`: can raise runtime errors. (Works on CPU; doesn't work on TPU. GPU support may be flaky.)
- `branched_error_if`: can raise one of multiple errors, depending on a traced value.
- Floating point manipulation:
- `nextafter`: returns the next floating point number. Unlike `jnp.nextafter`, it is differentiable.
- `prevbefore`: returns the previous floating point number. Is differentiable.
- MLIR sub-graphs:
- `noinline`: used to mark that a subcomputation should be placed in a separate computation graph, e.g. to avoid compiling the same thing multiple times if it is called repeatedly. Can also be used to iteratively recompile just parts of a computation graph, if the sub/super-graph is the only thing that changes.
- Omega:
- `ω`: nice syntax for tree arithmetic. For example `(x**ω + y**ω).ω == tree_map(operator.add, x, y)`. Like [tree-math](https://github.com/google/tree-math) but with nicer syntax.
- Custom primitives:
- `filter_primitive_{def, jvp, transpose, batching, bind}`: Define rules for custom primitive that accept arbitrary PyTrees; not just JAX arrays.
- `create_vprim`: Autodefines batching rules for higher-order primitives, according to `transform(vmap(prim)) == vmap(transform(prim))`.
- String handling:
- `str2jax`: turns a string into a JAX'able object.
- Unvmap'ing:
- `unvmap_{any, all, max}`: apply reductions whilst ignoring the batch dimension.
- New filtered transformations: `eqx.{filter_jvp,filter_custom_jvp}`

Bugfixes / backward incompatibilities
- `eqx.nn.GRUCell` will now use its bias term. (Previously it was never adding this.)
- `eqx.filter_eval_shape` will no longer promote array-likes to arrays, in either its input or its output.
- `eqx.tree_equal` now treats JAX arrays and NumPy arrays as equal.

Misc
- Improved compilation speed of `eqx.filter_vmap`.

New Contributors
* jondeaton made their first contribution in https://github.com/patrick-kidger/equinox/pull/204
* IsaacBreen made their first contribution in https://github.com/patrick-kidger/equinox/pull/215

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.8.0...v0.9.0

Page 3 of 6

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.