Jax

Latest version: v0.5.2

Safety actively analyzes 714792 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 1 of 19

0.5.1

* New Features
* Added an experimental {func}`jax.experimental.custom_dce.custom_dce`
decorator to support customizing the behavior of opaque functions under
JAX-level dead code elimination (DCE). See {jax-issue}`25956` for more
details.
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
column-pivoting on CPU and GPU. See {jax-issue}`20282` and
{jax-issue}`25955` for more details.

* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
env vars. Before they could only be specified via jax.config or flags.
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` now defaults to `'gloo'`, meaning
multi-process CPU communication works out-of-the-box.
* The `jax[tpu]` TPU extra no longer depends on the `libtpu-nightly` package.
This package may safely be removed if it is present on your machine; JAX now
uses `libtpu` instead.

* Deprecations
* The internal function `linear_util.wrap_init` and the constructor
`core.Jaxpr` now must take a non-empty `core.DebugInfo` kwarg. For
a limited time, a `DeprecationWarning` is printed if
`jax.extend.linear_util.wrap_init` is used without debugging info.
A downstream effect of this several other internal functions need debug
info. This change does not affect public APIs.
See https://github.com/jax-ml/jax/issues/26480 for more detail.
* In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`,
non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.

* Bug fixes
* TPU runtime startup and shutdown time should be significantly improved on
TPU v5e and newer (from around 17s to around 8s). If not already set, you may
need to enable transparent hugepages in your VM image
(`sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'`).
We hope to improve this further in future releases.
* Persistent compilation cache no longer writes access time file if
JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU
eviction policy isn't enabled. This should improve performance when using
the cache with large-scale network storage.

0.5.0

As of this release, JAX now uses
[effort-based versioning](https://jax.readthedocs.io/en/latest/jep/25516-effver.html).
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.

* Breaking changes
* Enable `jax_threefry_partitionable` by default (see
[the update note](https://github.com/jax-ml/jax/discussions/18480)).

* This release drops support for Mac x86 wheels. Mac ARM of course remains
supported. For a recent discussion, see
https://github.com/jax-ml/jax/discussions/22936.

Two key factors motivated this decision:
* The Mac x86 build (only) has a number of test failures and crashes. We
would prefer to ship no release than a broken release.
* Mac x86 hardware is end-of-life and cannot be easily obtained for
developers at this point. So it is difficult for us to fix this kind of
problem even if we wanted to.

We are open to readding support for Mac x86 if the community is willing
to help support that platform: in particular, we would need the JAX test
suite to pass cleanly on Mac x86 before we could ship releases again.

* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
supported version until June 2025.
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`25214`).
* {func}`jax.numpy.linalg.solve` no longer supports batched 1D arguments
on the right hand side. To recover the previous behavior in these cases,
use `solve(a, b[..., None]).squeeze(-1)`.

* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`25606` for more details.
* Support added for user defined state in the FFI via the new
{func}`jax.ffi.register_ffi_type_id` function.
* The AOT lowering `.as_text()` method now supports the `debug_info` option
to include debugging information, e.g., source location, in the output.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.
* The {mod}`jax.extend.ffi` submodule was moved to {mod}`jax.ffi`, and the
previous import path is deprecated.

* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
is on by default.
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
and `jax.errors.JaxRuntimeError` respectively.
* The `jax.experimental.array_api` module has been removed after being
deprecated in JAX v0.4.32. Since that release, {mod}`jax.numpy` supports
the array API directly.

0.4.38

* Breaking Changes
* `XlaExecutable.cost_analysis` now returns a `dict[str, float]` (instead of a
single-element `list[dict[str, float]]`).

* Changes:
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
as shortcuts of the corresponding `tree_util` functions.

* Deprecations
* a number of APIs in the internal `jax.core` namespace have been deprecated.
Most were no-ops, were little-used, or can be replaced by APIs of the same
name in {mod}`jax.extend.core`; see the documentation for {mod}`jax.extend`
for information on the compatibility guarantees of these semi-public extensions.
* Several previously-deprecated APIs have been removed, including:
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
`non_negative_dim`.
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
* from {mod}`jax.numpy`: `round_`.

* New Features
* {func}`jax.export.export` can be used for device-polymorphic export with
shardings constructed with {func}`jax.sharding.AbstractMesh`.
See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).
* Added {func}`jax.lax.split`. This is a primitive version of
{func}`jax.numpy.split`, added because it yields a more compact
transpose during automatic differentiation.

0.4.37

* Bug fixes
* Fixed a bug where `jit` would error if an argument was named `f` (25329).
* Fix a bug that will throw `index out of range` error in
{func}`jax.lax.while_loop` if the user register pytree node class with
different aux data for the flatten and flatten_with_path.
* Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.

0.4.36

* Breaking Changes
* This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels, `post_process_call`,
`new_base_main`, `custom_bind`, and so on. The change should only affect
users that use JAX internals.

If you do use JAX internals then you may need to
update your code (see
https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
`config.jax_data_dependent_tracing_fallback` flag as a workaround, and if
you need help updating your code then please file a bug.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
with native serialization will still be supported.
* In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed
after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`,
`xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`.
* The deprecated module `jax.experimental.export` has been removed. It was replaced
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
for information on migrating to the new API.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
has been removed, after being deprecated in v0.4.27.
* Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`)
now raises an error. Previously, this returned a scalar object array.
* The following deprecated methods and functions in {mod}`jax.export` have
been removed:
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
already.
* `jax.export.Exported.lowering_platforms`: use `platforms`.
* `jax.export.Exported.mlir_module_serialization_version`:
use `calling_convention_version`.
* `jax.export.Exported.uses_shape_polymorphism`:
use `uses_global_constants`.
* the `lowering_platforms` kwarg for {func}`jax.export.export`: use
`platforms` instead.
* The kwargs `symbolic_scope` and `symbolic_constraints` from
{func}`jax.export.symbolic_args_specs` have been removed. They were
deprecated in June 2024. Use `scope` and `constraints` instead.
* Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a `TypeError`.
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Run `python build/build.py --help` for
more details. Brief overview of the new subcommand options:
* `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt`
* `requirements_update`: Updates requirements_lock.txt files.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
return NaN for negative integer inputs, to match the behavior of SciPy from
https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
* We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use the `disabled_checks`
parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls).

* New Features
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
passing compilation options to XLA. For the moment it's undocumented and
may be in flux.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
* Added {func}`jax.numpy.put_along_axis`.
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
supported on GPU. See {jax-issue}`24663` for more details.
* Added two new configuration flags, `jax_exec_time_optimization_effort` and `jax_memory_fitting_effort`, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.

* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}`24843` for more details.

* Deprecations
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
use `jax.Array` instead.
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
instead.

0.4.35

* Breaking Changes
* {func}`jax.numpy.isscalar` now returns True for any array-like object with
zero dimensions. Previously it only returned True for zero-dimensional
array-like objects with a weak dtype.
* `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we removed it.
See {jax-issue}`20385` for a discussion of alternatives.

* Changes:
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
operations. The semi-public API `jax.lib.xla_client.FftType` has been
deprecated.
* TPU: JAX now installs TPU support from the `libtpu` package rather than
`libtpu-nightly`. For the next few releases JAX will pin an empty version of
`libtpu-nightly` as well as `libtpu` to ease the transition; that dependency
will be removed in Q1 2025.

* Deprecations:
* The semi-public API `jax.lib.xla_client.PaddingType` has been deprecated.
No JAX APIs consume this type, so there is no replacement.
* The default behavior of {func}`jax.pure_callback` and
{func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
the `vectorized` parameter to those functions. The `vmap_method` parameter
should be used instead for better defined behavior. See the discussion in
{jax-issue}`23881` for more details.
* The semi-public API `jax.lib.xla_client.register_custom_call_target` has
been deprecated. Use the JAX FFI instead.
* The semi-public APIs `jax.lib.xla_client.dtype_to_etype`,
`jax.lib.xla_client.ops`,
`jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`,
`jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and
`jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO
instead.

Page 1 of 19

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.