Inox

Latest version: v0.6.3

Safety actively analyzes 688027 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 2 of 6

0.10.11

New features

- Equinox now offers true runtime errors! This is available as `equinox.error_if`. This is something new under the JAX sun: these are raised eagerly during the execution, they work on TPU, and if you set the environment variable `EQX_ON_ERROR=breakpoint`, then they'll even drop you into a debugger as soon as you hit an error. (These are basically a strict improvement over `jax.experimental.checkify`, which doesn't offer many of these advantages.)

- Added a suite of debugging tools:
- `equinox.debug.announce_transform`: prints to stdout when it is transformed via jvp/vmap etc; very useful for keeping track of how many times a particular operation is getting transformed or compiled, when trying to minimise your compilation times.
- `equinox.debug.backward_nan`: for debugging NaNs that only arise on the backward pass.
- `equinox.debug.breakpoint_if`: opens a breakpoint if a condition is satisfied.
- `equinox.debug.{store_dce, inspect_dce}`: used for checking whether certain variables are removed via the dead-code-elimination pass of the XLA compiler.

- `equinox.filter_jvp` now supports keyword arguments (which are treated as not differentiated).

Bugfixes

- Nested `filter_jvp`s will now no longer materialise symbolic zero tangents. (422).

Documentation

- The marvellous [Levanter](https://github.com/stanford-crfm/levanter) library is now linked to in the documentation!

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.10...v0.10.11

0.10.10

**Performance improvements**

These are the real highlight of this release.

- `equinox.internal.{while_loop, scan}` now use new symbolic zero functionality, which may result in runtime speedups (and slight increases in compile times) as they can now skip calculating gradients for some quantities.
- `equinox.internal.{while_loop, scan}(..., buffers=...)` now do their best to work around an XLA bug (https://github.com/google/jax/issues/10197). This can reduce computational cost from quadratic scaling to linear scaling.
- `equinox.internal.{while_loop, scan}` now includes several optimisations for the common case is which every step is checkpointed. (415)

**Features**

- `equinox.filter_custom_{jvp,vjp}` now support symbolic zeros.

Previously, `None` was passed to represent symbolic zero tangent/cotangents for anything that wasn't a floating-point array -- but all floating-point-arrays always had materialised tangent/cotangents.

With this release, `None` may also sometimes be passed as the tangent of floating-point arrays. In this case it represents a zero tangent/cotangent, and moreover this zero is "symbolic" -- that is to say it is known to be zero at compile time, which may allow you to write more-efficient custom JVP/VJP rules. (The canonical example is the inverse function theorem -- this involves a linear solve, parts of which you can skip if you know parts of it are zero.)

In addition, `filter_custom_vjp` now takes another argument, `perturbed`, indicating whether a value actually needs cotangents calculated for it. You can skip calculating cotangents for anything that is not perturbed.

For more information see `jax.custom_jvp.defjvp(..., symbolic_zeros=True)` and `jax.custom_vjp.defvjp(..., symbolic_zeros=True)`, which provide the underlying behaviour that is being forwarded.

Note that this is provided through a new API: `filter_custom_jvp.def_jvp` instead of `filter_custom_jvp.defjvp`, and `filter_custom_vjp.{def_fwd, def_bwd}` instead of `filter_custom_vjp.defvjp`. The old API will continue to exhibit the previous behaviour, for backward compatibility.

**Misc**

- Apply functools.wraps to Module methods to preserve docstrings (Thanks bowlingmh! https://github.com/patrick-kidger/equinox/pull/409)
- Enumerations now perform their checks at compile time if possible. This sometimes makes it possible to get more efficent code, by special-casing on these values or eliding branches. (417)

**New Contributors**

- bowlingmh made their first contribution in https://github.com/patrick-kidger/equinox/pull/409

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.6...v0.10.10

(Why no v0.10.{7,8,9}? We had a bit of a rocky release this time around, and these got yanked for having bugs. Thanks to everyone who reported issues so quickly! Things look like they're stable now...)

0.10.6

Features
* Added `eqx.field`: this supports `converter=...` and `static=...`. The former is an extension to dataclasses that applies that conversion function when the field is assigned. The latter supersedes the old `eqx.static_field`. (390)
* Added `eqx.Enumeration`, which are JAX-compatible Enums. (Moved from `eqx.internal.Enumeration.) (392)
* Added `eqx.clear_caches` to clear internal caches and reduce memory usage. (380)
* Added `eqx.nn.BatchNorm(..., dtype=...)` (Thanks Benjamin-Walker! 384)
* Inside `eqx.internal.while_loop`: buffers now support `buffer.at[index].add(...)` etc. (Thanks packquickly! 395)

Changes
* Updated `typing->collections.abc` where appropriate; `Tuple->tuple` etc. (385)

Bugfixes
* `eqx.module_update_wrapper` no longer assigns `__wrapped__`. (381)

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.5...v0.10.6

0.10.5

Quite a small release.

**Bugfixes**

- Fixed modules initialising twice (369; this bug was introduced in the last couple of Equinox versions.)

**Documentation**

- Fix docstring typos in `MLP.__init__`. (Thanks schmrlng! 366)
- Added example ofor serialisation of hyperparameters (Thanks bytbox! 374)

**Misc**

* Add `equinox.internal.eval_full` (like `equinox.internal.eval_{zeros, empty}`) (Thanks RaderJason! 367)
* Added JAX-compatible enums: `equinox.internal.Enumeration` (375)
* The minimum supported Python version has been bumped to 3.9 (379)

New Contributors
* schmrlng made their first contribution in https://github.com/patrick-kidger/equinox/pull/366
* bytbox made their first contribution in https://github.com/patrick-kidger/equinox/pull/374

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.4...v0.10.5

0.10.4

Features

- `eqx.nn.{LayerNorm, GroupNorm}` can now accept a call-time `state` argument that they thread through unchanged. This means that they have the same API as `eqx.nn.BatchNorm`, so that they may be used interchangeably.
- `eqx.Module`s now work with the new `jax.tree_util.tree_flatten_with_path` API. (363)
- `eqx.nn.MLP` now supports `use_bias` and `use_final_bias` arguments. (Thanks jlperla! 358)
- Added `eqx.tree_check` to assert that a pytree does not contain duplicate elements, and does not contain any reference cycles. This may be useful to call on your models prior to training, to check that they are well-formed. (355)
- Added `eqx.tree_flatten_one_level` to flatten a pytree by one level only. (355)

Internal (semi-undocumented / unstable) Features

- `eqx.internal.{error_if, branched_error_if, debug_backward_nans}` now have TPU support! This means that they now support all backends, and are (to my knowledge) the single best option for adding runtime checks to JAX programs. In addition they now eagerly will raise errors at trace-time if the predicate is a raw Python `True`. (351)
- `eqx.internal.scan` now supports `buffers` and `checkpoints` arguments for finer-grained control over its autodiff. (349)
- Added `eqx.internal.scan_trick`, which can be used to minimise compilation time by wrapping nearby function invocations into a single scan. [See this PR against Diffrax](https://github.com/patrick-kidger/diffrax/pull/253) for an example.

Bugfixes

- Remove implicit rank promotion in `eqx.nn.ConvTranspose` (Thanks khdlr! 335)
- `eqx.static_field()`s were sometimes being put in leaves; this is now fixed. (This issue existed in v0.10.3 only.) (338)
- `eqx.filter_custom_jvp` will no longer raise the occasional spurious leaked tracer error. (When using traced non-floating arrays.) (349)
- Fixed crash when using zero-sized arrays inside `eqxi.while_loop(... kind='checkpointed')` (331)

Other

- Now using `pyproject.toml` to handle everything (no more `setup.py`, `.flake8` etc!)
- Added example docs for autoparallel APIs ([link](https://docs.kidger.site/equinox/examples/parallelism/))
- `eqx.internal.while_loop` should now have a slightly faster compile time. (353)


New Contributors
* khdlr made their first contribution in https://github.com/patrick-kidger/equinox/pull/335
* jlperla made their first contribution in https://github.com/patrick-kidger/equinox/pull/358

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.3...v0.10.4

0.10.3

Features

- Added `equinox.nn.{State, StateIndex}`. This has been one of the longest-requested features for Equinox: we now have proper stateful operations! (In a carefully-controlled way -- see the new [stateful docs](https://docs.kidger.site/equinox/api/nn/stateful/) and the new [stateful example](https://docs.kidger.site/equinox/examples/stateful/).)

- As an application of these new stateful operations: added `equinox.nn.{BatchNorm, SpectralNorm}`, which have graduated from experimental! Note that these have a slightly different API to their previous experimental versions.

- Added `equinox.Partial`, which is a tidied-up version of `jax.tree_util.Partial`.

- `equinox.filter_{jit, pmap}` are now compatibile with ahead-of-time compilation. (325)

- `equinox.nn.LayerNorm` now supports `use_weight` and `use_bias` arguments to disable each individually. This is reflecting the fact that many modern transformer architectures now use layer normalisation without bias. (310; thanks lockwo!)

- Added `equinox.internal.{AbstractVar, AbstractClassVar}` to denote abstract instance attributes and abstract class attributes respectively. (Analogous to `abc.abstractmethod` denoting abstract methods.) The downstream scientific ecosystem is making heavy use of abstract base classes (e.g. [all the ABCs in Diffrax](https://docs.kidger.site/diffrax/usage/extending/)) and these have turned out to be a really useful feature. See [this docstring](https://github.com/patrick-kidger/equinox/blob/83a1aacb4b25f5bbf6faa81c9bc7e2ef2b76f700/equinox/_better_abstract.py#L1) for more details. Right now these are an undocumented internal-only feature, but we could plausibly spin these out into their own library.

Tweaks

- `equinox.nn.Conv` should now be compatible with disabled rank promotion (308; thanks lockwo!)
- `equinox.internal.loop` should now be compatible with `jax.experimental.xmap` (https://github.com/patrick-kidger/diffrax/issues/246)
- Normalisation layers should now be tolerance to floating-point inaccuracies that occasionally produce negative variances. (314; thanks anh-tong!)
- The BERT example now has fixed dropout behaviour (316; thanks j5b!)
- Some doc fixes (303; thanks RaderJason!)

Removed

- Everything in `equinox.experimental.*` has been removed. See the new stateful functionality described above.

New Contributors
* RaderJason made their first contribution in https://github.com/patrick-kidger/equinox/pull/303
* lockwo made their first contribution in https://github.com/patrick-kidger/equinox/pull/311
* anh-tong made their first contribution in https://github.com/patrick-kidger/equinox/pull/314

**Full Changelog**: https://github.com/patrick-kidger/equinox/compare/v0.10.2...v0.10.3

Page 2 of 6

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.