Inox

Latest version: v0.7.1

Safety actively analyzes 724004 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 3 of 8

0.11.0

Better errors

Equinox now includes several additional checks to guard against various bugs. If you have a new error, then this is probably an indication that your code always had a silent bug, and should be updated.

- `eqx.nn.LayerNorm` now correctly validates that the shape of its input. This was a common cause of silent bugs. (Thanks dlwh for pointing this one out!)
- Equinox now prints out a warning if you supply both `__init__` and `__post_init__` -- the former actually overwrites the latter. (This is normal Python dataclass behaviour, but probably unexpected.)
- Equinox now prevents you from assigning Module attributes with a bound method of your current instance, e.g.
python
class Model(eqx.Module):
foo: Callable

def __init__(self):
self.foo = self.bar

def bar(self):
...

Otherwise, you end up with two different copies of your model! One at `self`, the other at `self.foo.__self__`. (The latter being in the bound method.)
- `eqx.tree_at` now gives a better error message if you use it try to and update something that isn't a PyTree leaf. (Thanks LouisDesdoigts!)

API changes

These should all be very minor.

- **Breaking change:** `eqx.nn.StateIndex` now takes the initial value, rather than a function that returns the initial value.
- **Breaking change:** If using `eqx.field(converter=...)`, then conversion now happens before `__post_init__`, rather than after it.
- Prefer `eqx.nn.make_with_state` over `eqx.nn.State`. The latter will continue to work, but the former is more memory-efficient. (It deletes the original copy of the initial state.)
- Prefer `eqx.nn.inference_mode` over `eqx.tree_inference`. The latter will continue to exist for backward compatibility. These are the same function, this is really just a matter of moving it into the `eqx.nn` namespace where it always belonged.

Sharing layers

Equinox now supports sharing a layer between multiple parts of your model! This has probably been our longest-requested feature -- in large part because of how intractable it seemed. Equinox models are Py*Trees*, not Py*DAGs*, so how exactly are we supposed to have two different parts of our model point at the same layer?

The answer turned out to be the following -- in this example, we're reusing the embedding weight matrix between the initial embedding layer, and the final readout layer, of a language model.
python
class LanguageModel(eqx.Module):
shared: eqx.nn.Shared

def __init__(self):
embedding = eqx.nn.Embedding(...)
linear = eqx.nn.Linear(...)
These two weights will now be tied together.
where = lambda embed_and_lin: embed_and_lin[1].weight
get = lambda embed_and_lin: embed_and_lin[0].weight
self.shared = eqx.nn.Shared((embedding, linear), where, get)

def __call__(self, tokens):
Expand back out so we can evaluate these layers.
embedding, linear = self.shared()
assert embedding.weight is linear.weight same parameter!
Now go ahead and evaluate your language model.
...

here, `eqx.nn.Shared(...)` simply removes all of the nodes at `where`, so that we don't have two separate copies. Then when it is called at `self.shared()`, it puts them back again. Note that this isn't a copy and doesn't incur any additional memory overhead; this all happens at the Python level, not the XLA level.

(The curious may like to take a look at the implementation in `equinox/nn/_shared.py`, which turned out to be very simple.)

_On a meta level, I'd like to comment that I'm quite proud of having gotten this one in! It means that Equinox now supports both stateful layers and shared layers, which have always been the two pieces that seemed out of reach when using something as simple as PyTrees to represent models. But it turns out that PyTrees really are all you need. :D_

Other changes

Documentation

- Many documentation fixes courtesy of colehaus and Artur-Galstyan!
- Added two new examples to the documentation. Thank you to ahmed-alllam for both of them!
- Deep convolutional GAN
- Vision Transformer
- Added an FAQ entry on comparisons between Equinox and PyTorch/Keras/Julia/Flax. It's a common enough question that should probably have had an answer before now.
- Added an FAQ entry on debugging recompilation.

Features

- Added `eqx.filter_checkpoint`, which as you might expect is a filtered version of `jax.checkpoint`. (Thanks dlwh!)
- Added `eqx.Module.__check_init__`. This is run in a similar fashion to `__post_init__`; see the documentation. This can be used to check that invariants of your module hold after initialisation.
- Added support for vmap'ing stateful layers, by adding `eqx.nn.State.{substate, update}`. This offers a way to subset or update a `State` object, that so only the parts of it that need to be vmap'd are passed in. See the stateful documentation for an example of how to do this.
- Runtime error should now produce much more readable results, without any of the terrifying `INTERNAL: Generated function failed: CpuCallback error` stuff! This clean-up of the runtime error message is done by `eqx.filter_jit`, so that will need to be your top-level way of JIT'ing your computation.
- Added `eqx.nn.StatefulLayer` -- this is (only!) with `eqx.nn.Sequential`, to indicate that the layer should be called with `x, state`, and not just `x`. If you would like a custom stateful layer to be compatible with `Sequential` then go ahead and subclass this, and potentially implement the `is_stateful` method. (Thanks paganpasta!)
- The forward pass of each `eqx.nn.*` layer is now wrapped in a `jax.named_scope`, for better debugging experience. (Thanks ahmed-alllam!)
- `eqx.module_update_wrapper` no longer requires a second argument; it will look at the `__wrapped__` attribute of its first argument.
- Added `eqx.internal.closure_to_pytree`, for... you guessed it, turning function closures into PyTrees. The closed-over variables are treated as the subnodes in the PyTree. This will operate recursively so that closed-over closures will themselves become PyTrees, etc. Note that closed-over global variables are not included.

Bugfixes

- `eqx.tree_{serialise,deserialise}_leaves` now correctly handle unusual NumPy scalars, like `bfloat16`. (Thanks colehaus!)
- `eqx.field(metadata=...)` arguments no longer results in the `static`/`converter` arguments being ignored. (Thanks mjo22!)
- `eqx.filter_custom_vjp` now supports residuals that are not arrays. (The residuals are the pytree that is passed between the forward and backward pass.)
- `eqx.{AbstractVar,AbstractClassVar}` should now support overriden generics in subclasses. That is, something like this:
python
class Foo(eqx.Module):
x: eqx.AbstractVar[list[str]]

class Bar(Foo):
x: list[str]

should no longer raise spurious errors under certain conditions.
- `eqx.internal.while_loop` now supports using custom (non-Equinox) pytrees in the state.
- `eqx.tree_check` no longer raises some false positives.
- Equinox modules now support `__init_subclass__` with additional class creation kwargs. (Thanks ASEM000, Roger-luo!)

New Contributors
* homerjed made their first contribution in https://github.com/patrick-kidger/equinox/pull/445
* LouisDesdoigts made their first contribution in https://github.com/patrick-kidger/equinox/pull/460
* knyazer made their first contribution in https://github.com/patrick-kidger/equinox/pull/474

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.11...v0.11.0

0.10.11

New features

- Equinox now offers true runtime errors! This is available as `equinox.error_if`. This is something new under the JAX sun: these are raised eagerly during the execution, they work on TPU, and if you set the environment variable `EQX_ON_ERROR=breakpoint`, then they'll even drop you into a debugger as soon as you hit an error. (These are basically a strict improvement over `jax.experimental.checkify`, which doesn't offer many of these advantages.)

- Added a suite of debugging tools:
- `equinox.debug.announce_transform`: prints to stdout when it is transformed via jvp/vmap etc; very useful for keeping track of how many times a particular operation is getting transformed or compiled, when trying to minimise your compilation times.
- `equinox.debug.backward_nan`: for debugging NaNs that only arise on the backward pass.
- `equinox.debug.breakpoint_if`: opens a breakpoint if a condition is satisfied.
- `equinox.debug.{store_dce, inspect_dce}`: used for checking whether certain variables are removed via the dead-code-elimination pass of the XLA compiler.

- `equinox.filter_jvp` now supports keyword arguments (which are treated as not differentiated).

Bugfixes

- Nested `filter_jvp`s will now no longer materialise symbolic zero tangents. (422).

Documentation

- The marvellous [Levanter](https://github.com/stanford-crfm/levanter) library is now linked to in the documentation!

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.10...v0.10.11

0.10.10

**Performance improvements**

These are the real highlight of this release.

- `equinox.internal.{while_loop, scan}` now use new symbolic zero functionality, which may result in runtime speedups (and slight increases in compile times) as they can now skip calculating gradients for some quantities.
- `equinox.internal.{while_loop, scan}(..., buffers=...)` now do their best to work around an XLA bug (https://github.com/google/jax/issues/10197). This can reduce computational cost from quadratic scaling to linear scaling.
- `equinox.internal.{while_loop, scan}` now includes several optimisations for the common case is which every step is checkpointed. (415)

**Features**

- `equinox.filter_custom_{jvp,vjp}` now support symbolic zeros.

Previously, `None` was passed to represent symbolic zero tangent/cotangents for anything that wasn't a floating-point array -- but all floating-point-arrays always had materialised tangent/cotangents.

With this release, `None` may also sometimes be passed as the tangent of floating-point arrays. In this case it represents a zero tangent/cotangent, and moreover this zero is "symbolic" -- that is to say it is known to be zero at compile time, which may allow you to write more-efficient custom JVP/VJP rules. (The canonical example is the inverse function theorem -- this involves a linear solve, parts of which you can skip if you know parts of it are zero.)

In addition, `filter_custom_vjp` now takes another argument, `perturbed`, indicating whether a value actually needs cotangents calculated for it. You can skip calculating cotangents for anything that is not perturbed.

For more information see `jax.custom_jvp.defjvp(..., symbolic_zeros=True)` and `jax.custom_vjp.defvjp(..., symbolic_zeros=True)`, which provide the underlying behaviour that is being forwarded.

Note that this is provided through a new API: `filter_custom_jvp.def_jvp` instead of `filter_custom_jvp.defjvp`, and `filter_custom_vjp.{def_fwd, def_bwd}` instead of `filter_custom_vjp.defvjp`. The old API will continue to exhibit the previous behaviour, for backward compatibility.

**Misc**

- Apply functools.wraps to Module methods to preserve docstrings (Thanks bowlingmh! https://github.com/patrick-kidger/equinox/pull/409)
- Enumerations now perform their checks at compile time if possible. This sometimes makes it possible to get more efficent code, by special-casing on these values or eliding branches. (417)

**New Contributors**

- bowlingmh made their first contribution in https://github.com/patrick-kidger/equinox/pull/409

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.6...v0.10.10

(Why no v0.10.{7,8,9}? We had a bit of a rocky release this time around, and these got yanked for having bugs. Thanks to everyone who reported issues so quickly! Things look like they're stable now...)

0.10.6

Features
* Added `eqx.field`: this supports `converter=...` and `static=...`. The former is an extension to dataclasses that applies that conversion function when the field is assigned. The latter supersedes the old `eqx.static_field`. (390)
* Added `eqx.Enumeration`, which are JAX-compatible Enums. (Moved from `eqx.internal.Enumeration.) (392)
* Added `eqx.clear_caches` to clear internal caches and reduce memory usage. (380)
* Added `eqx.nn.BatchNorm(..., dtype=...)` (Thanks Benjamin-Walker! 384)
* Inside `eqx.internal.while_loop`: buffers now support `buffer.at[index].add(...)` etc. (Thanks packquickly! 395)

Changes
* Updated `typing->collections.abc` where appropriate; `Tuple->tuple` etc. (385)

Bugfixes
* `eqx.module_update_wrapper` no longer assigns `__wrapped__`. (381)

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.5...v0.10.6

0.10.5

Quite a small release.

**Bugfixes**

- Fixed modules initialising twice (369; this bug was introduced in the last couple of Equinox versions.)

**Documentation**

- Fix docstring typos in `MLP.__init__`. (Thanks schmrlng! 366)
- Added example ofor serialisation of hyperparameters (Thanks bytbox! 374)

**Misc**

* Add `equinox.internal.eval_full` (like `equinox.internal.eval_{zeros, empty}`) (Thanks RaderJason! 367)
* Added JAX-compatible enums: `equinox.internal.Enumeration` (375)
* The minimum supported Python version has been bumped to 3.9 (379)

New Contributors
* schmrlng made their first contribution in https://github.com/patrick-kidger/equinox/pull/366
* bytbox made their first contribution in https://github.com/patrick-kidger/equinox/pull/374

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.4...v0.10.5

0.10.4

Features

- `eqx.nn.{LayerNorm, GroupNorm}` can now accept a call-time `state` argument that they thread through unchanged. This means that they have the same API as `eqx.nn.BatchNorm`, so that they may be used interchangeably.
- `eqx.Module`s now work with the new `jax.tree_util.tree_flatten_with_path` API. (363)
- `eqx.nn.MLP` now supports `use_bias` and `use_final_bias` arguments. (Thanks jlperla! 358)
- Added `eqx.tree_check` to assert that a pytree does not contain duplicate elements, and does not contain any reference cycles. This may be useful to call on your models prior to training, to check that they are well-formed. (355)
- Added `eqx.tree_flatten_one_level` to flatten a pytree by one level only. (355)

Internal (semi-undocumented / unstable) Features

- `eqx.internal.{error_if, branched_error_if, debug_backward_nans}` now have TPU support! This means that they now support all backends, and are (to my knowledge) the single best option for adding runtime checks to JAX programs. In addition they now eagerly will raise errors at trace-time if the predicate is a raw Python `True`. (351)
- `eqx.internal.scan` now supports `buffers` and `checkpoints` arguments for finer-grained control over its autodiff. (349)
- Added `eqx.internal.scan_trick`, which can be used to minimise compilation time by wrapping nearby function invocations into a single scan. [See this PR against Diffrax](https://github.com/patrick-kidger/diffrax/pull/253) for an example.

Bugfixes

- Remove implicit rank promotion in `eqx.nn.ConvTranspose` (Thanks khdlr! 335)
- `eqx.static_field()`s were sometimes being put in leaves; this is now fixed. (This issue existed in v0.10.3 only.) (338)
- `eqx.filter_custom_jvp` will no longer raise the occasional spurious leaked tracer error. (When using traced non-floating arrays.) (349)
- Fixed crash when using zero-sized arrays inside `eqxi.while_loop(... kind='checkpointed')` (331)

Other

- Now using `pyproject.toml` to handle everything (no more `setup.py`, `.flake8` etc!)
- Added example docs for autoparallel APIs ([link](https://docs.kidger.site/equinox/examples/parallelism/))
- `eqx.internal.while_loop` should now have a slightly faster compile time. (353)


New Contributors
* khdlr made their first contribution in https://github.com/patrick-kidger/equinox/pull/335
* jlperla made their first contribution in https://github.com/patrick-kidger/equinox/pull/358

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.3...v0.10.4

Page 3 of 8

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.