Inox

Latest version: v0.7.1

Safety actively analyzes 724004 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 2 of 8

0.11.6

This is primarily a bug fix release.

- Runtime error messages (those from `eqx.error_if`, in particular when wrapped with `eqx.filter_jit`) should now be compatible with PyCharm's debugger, and with certain multithreaded contexts. (Thanks adam-hartshorne, dlwh! 828, 844, 849)

- Marking a `jax.Array` or `np.ndarray` as an `eqx.field(static=True)` will now raise a warning. This was *technically* okay as long as you use it in certain very narrow contexts (e.g. to smuggle it into a JIT'd region without being traced), but in practice it was nearly always just a common new-user footgun. (Thanks lockwo! 800)

- Using `eqx.tree_at` for replacing empty tuples is improved. (Thanks danielward27! 818, 819)

- `eqx.nn.RotaryEmbedding` no longer promote input dtypes to at least float32. (Thanks knyazer! 836)

- Mypy now understands that `eqx.Module`s are dataclasses. (Pyright always did, but mypy needed a slightly different approach to appreciate this fact.) (Thanks NeilGirdhar! 822)

- Multiple `eqx.Module`s participating in co-operative multiple inheritance (at least 5 inheriting from each other seem to be necessary?), with some of them overriding the `__post_init__`s of others, should now follow their expected resolution order. (Thanks NeilGirdhar! 832, 834)

- We now have a `.editorconfig` file, (thanks NeilGirdhar! 821)

- Doc improvements. (Thanks garymm, ColCarroll! 804, 805)

New Contributors
* garymm made their first contribution in https://github.com/patrick-kidger/equinox/pull/804
* ColCarroll made their first contribution in https://github.com/patrick-kidger/equinox/pull/805
* NeilGirdhar made their first contribution in https://github.com/patrick-kidger/equinox/pull/823

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.11.5...v0.11.6

0.11.5

JAX compatibility

Recent versions of JAX (0.4.28+) have made some changes to:

- Hashing of tracers;
- Tree-map'ing over Nones;
- Callbacks;
- Pretty-printing.

With this update, we should now be compatible with both old and new versions of JAX: this fixes both some new crashes, and some new warnings. (719, 724, 753, 758, thanks jakevdp, hawkinsp!)

Better errors

- The error messages from `eqx.error_if` are now substantially more informative: they include traceback information including the stack, and mention the availability of the `EQX_ON_ERROR` variable. We also do a much better job hiding the large unhelpful printouts that XLA gives by default. (785, 803)

- The default value of `EQX_ON_ERROR_BREAKPOINT_FRAMES` is now `1`. (777) The impact of this is that using `eqx.error_if` alongside `EQX_ON_ERROR=breakpoint` will now:
- reliably always open a debugger, rather than sometimes crashing at trace-time due to upstream JAX bug [16732](https://github.com/google/jax/issues/16732).
- however, by default the debugger will no longer include any additional stack frames above it (accessed via `u`).
- much of the above is now explained in a printed-out informative message prior to the debugger opening.

Bugfixes

- `eqx.filter_{jacfwd, jacrev}` now only apply filtering to their inputs but not their outputs. Previously this was problematic as there was no way to represent static-input-by-static-output in the returned Jacobian, so pieces were silently dropped. (734, thanks lockwo!)

- `eqx.tree_at` can now be used to replace empty tuples. (715, 717, 722, thanks lockwo!)

- `eqx.filter_custom_jvp` no longer raises a trace-time crash in some scenarios in which its `**kwargs` were erroneously counted as having tangents. (https://github.com/patrick-kidger/equinox/issues/745#issuecomment-2148560546, 749)

- No longer getting a trace-time crash when doing a particular combination of vmap + autodiff + checkpointed while loops. This occurred when using `optimistix.BFGS` around `diffrax.diffeqsolve`. (777)

- Fixed a trace-time crash when:
- using a checkpointed while loop...
- ...with a body function that has a closed-over tracer...
- ...and that closed-over tracer is differentiated...
- ...and there are no other closed-over tracers that are differentiated...
- ...and the dependency on that tracer is only linear.
- (https://github.com/patrick-kidger/diffrax/pull/387#issuecomment-2132472392, 752, thanks dkweiss31!)

- Fixed a trace-time crash when composing the grad of vmap of `lineax.linear_solve`. (https://github.com/patrick-kidger/lineax/issues/101, #795, thanks rhacking!)

- `eqx.nn.RMSNorm` now uses at least 32-bit precision for numerical stability (723, thanks AakashKumarNain!)

New features

- `eqx.nn.{Linear,Conv,GRUCell,LSTMCell}` now support complex dtypes (765, thanks ChenAo-Phys!)

- Added `eqx.nn.RotaryEmbedding(..., theta=...)`. (735, thanks Artur-Galstyan!)

Other changes

- Several doc fixes. (708, 731, 733, 747, 750, 757 + several other PRs, thanks Artur-Galstyan, matteoguarrera, lockwo, nasyxx!)

- Several internal test fixes as downstream libraries have changed slightly. (740, 742 + several other PRs, big thanks to GaetanLepage for reporting many of these!)

- There is now a Mistral 7B implementation using JAX+Equinox available over in [AakashKumarNain/mistral_jax](https://github.com/AakashKumarNain/mistral_jax)!


New Contributors
* nasyxx made their first contribution in https://github.com/patrick-kidger/equinox/pull/708
* jakevdp made their first contribution in https://github.com/patrick-kidger/equinox/pull/724
* matteoguarrera made their first contribution in https://github.com/patrick-kidger/equinox/pull/739

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.11.4...v0.11.5

0.11.4

Features

- Added `eqx.filter_shard`. This lowers to `jax.lax.with_sharding_constraint` as a single way to transfer data, or reshard data, both inside and outside of JIT! (No more `jax.device_put`.) In addition, the parallelism example has been updated to use this simpler new functionality. (Thanks homerjed and dlwh! 688, 691)

- Added `eqx.filter_{jacfwd,jacrev,hessian}`. These do what you expect! (Thanks lockwo! 677)

- Added `eqx.nn.RotaryPostionalEmbedding`. This is designed to be used in conjunction with the existing `eqx.nn.MultiheadAttention`. (Thanks Artur-Galstyan! 568)

- Added support for `padding='VALID'`, `padding='SAME'`, `padding='SAME_LOWER'` to the convolutional layers: `eqx.nn.{Conv, ...}`. (Thanks ChenAo-Phys! 658)

- Added support for `padding_mode='ZEROS'`, `padding_mode='REFLECT'`, `padding_mode='REPLICATE'`, `padding_mode='CIRCULAR'` to the convolutional layers: `eqx.nn.{Conv, ...}`. (Thanks ChenAo-Phys! 658)

- Added a `dtype` argument to `eqx.nn.{MultiheadAttention, Linear, Conv, ...}` for specifying the dtype of their parameters. In addition `eqx.nn.BatchNorm` will now also uses its `dtype` argument to determine the dtype of its weights and bias, not just the dtype of its moving statistics. (Thanks Artur-Galstyan and AakashKumarNain! 680, 689)

Compatibility

- `eqx.error_if` is now compatible with JAX 0.4.26, which changed JAX's own reporting of error messages slightly. (Thanks hawkinsp! 670)

- Added a warning that checks for doing something like:
python
class MyModule(eqx.Module):
fn: Callable

def __init__(self, ...):
self.fn = jax.vmap(some_fn)

As this is an easy source of bugs. (The vmap'd function is not a PyTree so will not propagate anything in the PyTree stucture of `some_fn`.)

Technical internal stuff

- `eqx.internal.while_loop(..., kind="checkpointed")` will now only propagate forward JVP tracers for those outputs which are perturbed due to the input to the loop being perturbed. (Rather than all of them.) This change just means that later calls to a nondifferentiable operation, like `jax.pure_callback` or `eqx.internal.nondifferentiable`, will no longer crash at trace time. (See https://github.com/patrick-kidger/diffrax/issues/396.)
- `eqx.internal.while_loop(..., kind="bounded")` will now handle certain vmap+grad combinations without crashing. (It seems like JAX is adding some spurious batch tracers.) (See https://github.com/patrick-kidger/optimistix/issues/48#issuecomment-2009221739)

- the transpose rule for `eqx.internal.create_vprim` now understands symbolic zeros, fixing a crash for `grad-of-vmap-of-<lineax.linear_solve that we only use some outputs from>`. (See https://github.com/patrick-kidger/optimistix/issues/48.)

- The type annotation for the input of any converter function used in `eqx.field(converter=...)` will now be used as the type annotation in any `dataclass`-autogenerated `__init__` functions. In particular this should mean such functions are now compatible with runtime type checkers like beartype. (jaxtyping users, you were already covered: this checks the assigned annotations instead.)

New Contributors
* ChenAo-Phys made their first contribution in https://github.com/patrick-kidger/equinox/pull/658
* hawkinsp made their first contribution in https://github.com/patrick-kidger/equinox/pull/670
* AakashKumarNain made their first contribution in https://github.com/patrick-kidger/equinox/pull/680
* imilas made their first contribution in https://github.com/patrick-kidger/equinox/pull/699

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.11.3...v0.11.4

0.11.3

Features

- Added `equinox.nn.RMSNorm`.
- Added `equinox.nn.WeightNorm`.
- `equinox.tree_deserialise_leaves` now treats `jax.ShapeDtypeStruct`s in the same way as arrays. This makes it possible to avoid instantiating the initial model parameters only to throw them away again, by using `equinox.filter_eval_shape`:
python
model = eqx.filter_eval_shape(Model, ...hyperparameters...)
model = eqx.tree_deserialise_leaves(load_path, model)

(259)

Bugfixes

- `equinox.internal.noinline` no longer initialises the JAX backend on use.
- `equinox.filter_jit(...).lower(..., some_kwarg=...)` no longer crashes (625, 627)
- The state of `equionx.nn.BatchNorm` now uses the default floating point dtype, rather than always using `float32`.
- `equinox.nn.MultiheadAttention` should now perform the softmax in `float32` even when the input is of lower dtype. (This is important for numerical stability.)

Refactor

- All the layers in `equinox.nn.{Linear, MLP, ...}` now standardise on accepting extra `**kwargs` and not calling `super().__init__`. The intention is that these layers be treated as final, i.e. not subclassable. (Previously things were inconsistent: some did this and some did not.)
- Should now be compatible with `JAX_NUMPY_DTYPE_PROMOTION=strict` and `JAX_NUMPY_RANK_PROMOTION=raise`, and this is checked in tests.
- Better error message when no kwargs passed to `filter_grad` (Thanks knyazer! 589)

Internal features
_These are undocumented internal features, that may be changed at any time._

- Added `EQX_GETKEY_SEED` for use with `equinox.internal.GetKey`.
- `equinox.internal.while_loop` now has its runtime errors removed. This should help with compatibility with TPUs. (628)


New Contributors
* haydn-jones made their first contribution in https://github.com/patrick-kidger/equinox/pull/608

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.11.2...v0.11.3

0.11.2

**Features**

- Added `eqx.filter_jit(..., donate="all-except-first")` and `eqx.filter_jit(..., donate="warn-except-first")`. This offers a way to donate all arguments *except* the first one. (If you have multiple such arguments then just pack them together into a tuple in the first argument.) This aims to be a low-overhead easy way to handle buffer donation.
- Added `eqx.debug.{assert_max_traces, get_num_traces}`, which aim to provide a friendly way of asserting that a JIT'd function is not recompiled -- and if it is, which argument changed to cause the recompilation.
- `eqx.tree_pprint` and `eqx.tree_pformat` now handle PyTorch tensors and `jax.ShapeDtypeStruct`s.
- `eqx.tree_equal` now has new arguments:
- `typematch=True`: this will require that every leaf have precisely the same type as each other, i.e. right now the requirement is essentially `leaf == leaf2`; with this flag it becomes `type(leaf) == type(leaf2) and leaf == leaf2`.
- `rtol` and `atol`: setting these to nonzero values allows for checking that inexact (floating or complex) arrays are allclose, rather than exactly equal.
- The expectation is that these will be useful in unit tests, e.g. to write checks of the form `assert eqx.tree_equal(output, expected_output, typematch=True, rtol=1e-5, atol=1e-5)`.

**Bugfixes**

- Previously, a learnt activation function for `eqx.nn.MLP` would use the exact same learnt weights for every neuron in every layer. Now, a separate copy of the activation function is used in each location.
- Subclasses of `eqx.Module` should now have their `__init__` signatures correctly reported by downstream tooling, e.g. automated doc generators, some IDEs. (Thanks danielward27! 573)

**Typing**

- `eqx.filter_value_and_grad` now declares that it preserves the return type of its function (Thanks ConnorBaker! 557)

**Documentation**

- Fix missing index argument in docstring example for `StateIndex` (Thanks edwardwli! 556)
- Fixed broken link in `eqx.Enumueration` docstrings (Thanks LouisDesdoigts! 579)
- Fixed missing shape specification by in one of the tricks. (Thanks homerjed! 582)

**Other**

- Improved a few IPython tracebacks with appropriate `__tracebackhide__ = True` assignments.
- Subclassed`eqx.Enumeration`s can now override the message associated with their parent Enumeration: this now produces a warning rather than an error.
- Documented the `EQX_ON_ERROR_BREAKPOINT_FRAMES` config variable, which is used to work around a JAX bug when setting `EQX_ON_ERROR=breakpoint`.
- Can now monkey-patch the methods of an `eqx.Module`, e.g.
python
class Foo(eqx.Module):
def f(self): ...

Foo.f = some_transform(Foo.f)

the anticipated use-case for this is to make it easier for typecheckers; see 584.
- `eqx.debug.store_dce` now supports non-arrays in its argument.
- `eqx.Enumeration.where(traced_pred, x, x)` will now statically return `x` without tracing. This is occasionally useful to better propagate information at compile time.

**Internal features (not officially supported, advanced use only)**

- Added `eqx.internal.GetKey`. This generates a random JAX PRNG key when called, and crucially has a nice `__repr__` reporting what the seed value is. This should not be used in normal JAX code! This is intended as a convenience for tests, so that the random seed appears in the debug printout of a failed test.
- Added `eqx.internal.MaybeBuffer` to indicate that an argument of an `eqx.internal.{while_loop,scan}` might be wrapped in a buffer.
- Added `eqx.internal.buffer_at_set` to support `buffer.at[...].set(..., pred=...)` whilst being agnostic to whether `buffer` is a JAX array or one of our while loop buffers.

New Contributors
* edwardwli made their first contribution in https://github.com/patrick-kidger/equinox/pull/556
* ConnorBaker made their first contribution in https://github.com/patrick-kidger/equinox/pull/557
* danielward27 made their first contribution in https://github.com/patrick-kidger/equinox/pull/573

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.11.1...v0.11.2

0.11.1

This is a minor bugfix release.

**Bugfixes**
* Checkpointed while loops (`eqx.internal.while_loop(..., kind="checkpointed")`) now perform a more careful analysis of which arguments need to be differentiated. (548) This fix is the primary reason for this release -- it unlocks some efficiency improvements when solving SDEs in Diffrax: https://github.com/patrick-kidger/diffrax/pull/320
* Fixed `Abstract{Class,}Var` misbehaving around multiple inheritance. (544)
* Better compatibility with the beartype library. In a few cases this was throwing some spurious errors to do with forward references. (543)

**Documentation**
* Fix scan-over-layers example in docs. (Thanks mcbal! 542)

**Other**
* Static type checkers should now use Equinox's type hints correctly. (Specfically, we now have the `py.typed` marker file. Thanks vidhanio! 547)
* Added the `EQX_ON_ERROR_BREAKPOINT_FRAMES` environment variable, to work around JAX bug https://github.com/google/jax/issues/16732 when using `EQX_ON_ERROR=breakpoint`. This new variable sets the number of stack frames you can access via the `u` debugger command, when the on-error debugger is triggered. Set this to a small enough number, e.g. `EQX_ON_ERROR_BREAKPOINT_FRAMES=1`, and it should fix unusual trace-time errors when using `EQX_ON_ERROR=breakpoint`.

New Contributors
* mcbal made their first contribution in https://github.com/patrick-kidger/equinox/pull/542
* vidhanio made their first contribution in https://github.com/patrick-kidger/equinox/pull/547

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.11.0...v0.11.1

Page 2 of 8

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.