Jax

Latest version: v0.5.2

Safety actively analyzes 714792 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 4 of 19

0.4.22

* Deprecations
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
Explicit buffers have been replaced by the more flexible array sharding interface,
but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`

jaxlib 0.4.22 (Dec 13, 2023)

0.4.21

* New Features
* Added {obj}`jax.nn.squareplus`.

* Changes
* The minimum jaxlib version is now 0.4.19.
* Released wheels are built now with clang instead of gcc.
* Enforce that the device backend has not been initialized prior to calling `jax.distributed.initialize()`.
* Automate arguments to `jax.distributed.initialize()` in cloud TPU environments.

* Deprecations
* The previously-deprecated `sym_pos` argument has been removed from
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
keyword arguments has been deprecated, to match `numpy.where`.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array is deprecated and now raises a
{obj}`DeprecationWaning`. Currently the functions return False, in the future this
will raise an exception.
* The `device()` method of JAX arrays is deprecated. Depending on the context, it may
be replaced with one of the following:
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
- {attr}`jax.Array.sharding` gives the sharding configuration used by the array.

jaxlib 0.4.21 (Dec 4 2023)

* Changes
* In preparation for adding distributed CPU support, JAX now treats CPU
devices identically to GPU and TPU devices, that is:

* `jax.devices()` includes all devices present in a distributed job, even
those not local to the current process. `jax.local_devices()` still only
includes devices local to the current process, so if the change to
`jax.devices()` breaks you, you most likely want to use
`jax.local_devices()` instead.
* CPU devices now receive a globally unique ID number within a distributed
job; previously CPU devices would receive a process-local ID number.
* The `process_index` of each CPU device will now match any GPU or TPU
devices within the same process; previously the `process_index` of a CPU
device was always 0.

* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.

* Bug fixes
* Fixed error/hang when an array with non-finite values is passed to a
non-symmetric eigendecomposition (18226). Arrays with non-finite values now
produce arrays full of NaNs as outputs.

0.4.20

jaxlib 0.4.20 (Nov 2, 2023)

* Bug fixes
* Fixed some type confusion between E4M3 and E5M2 float8 types.

0.4.19

* New Features
* Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that
are convertible to JAX dtypes.
* Added `jax.numpy.fill_diagonal`.

* Changes
* JAX now requires SciPy 1.9 or newer.

* Bug fixes
* Only process 0 in a multicontroller distributed JAX program will write
persistent compilation cache entries. This fixes write contention if the
cache is placed on a network file system such as GCS.
* The version check for cusolver and cufft no longer considers the patch
versions when determining if the installed version of these libraries is at
least as new as the versions against which JAX was built.

jaxlib 0.4.19 (Oct 19, 2023)

* Changes
* jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
(nvidia-... packages) over any other CUDA installation if they are
installed, including installations named in `LD_LIBRARY_PATH`. If this
causes problems and the intent is to use a system-installed CUDA, the fix is
to remove the pip installed CUDA library packages.

0.4.18

jaxlib 0.4.18 (Oct 6, 2023)

* Changes
* CUDA jaxlibs now depend on the user to install a compatible NCCL version.
If using the recommended `cuda12_pip` installation, NCCL should be installed
automatically. Currently, NCCL 2.16 or newer is required.
* We now provide Linux aarch64 wheels, both with and without NVIDIA GPU
support.
* {meth}`jax.Array.item` now supports optional index arguments.

* Deprecations
* A number of internal utilities and inadvertent exports in {mod}`jax.lax` have
been deprecated, and will be removed in a future release.
* `jax.lax.dtypes`: use `jax.dtypes` instead.
* `jax.lax.itertools`: use `itertools` instead.
* `naryop`, `naryop_dtype_rule`, `standard_abstract_eval`, `standard_naryop`,
`standard_primitive`, `standard_unop`, `unop`, and `unop_dtype_rule` are
internal utilities, now deprecated without replacement.

* Bug fixes
* Fixed Cloud TPU regression where compilation would OOM due to smem.

0.4.17

* New features
* Added new {func}`jax.numpy.bitwise_count` function, matching the API of the similar
function recently added to NumPy.
* Deprecations
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
* Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument
to {func}`jax.random.PRNGKey` or {func}`jax.random.key` instead:
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
* Changes:
* CUDA: JAX now verifies that the CUDA libraries it finds are at least as new
as the CUDA libraries that JAX was built against. If older libraries are
found, JAX raises an exception since that is preferable to mysterious
failures and crashes.
* Removed the "No GPU/TPU" found warning. Instead warn if, on Linux, an
NVIDIA GPU or a Google TPU are found but not used and `--jax_platforms` was
not specified.
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
1.11.
* Most `jax.numpy` functions and attributes now have fully-defined type stubs.
Previously many of these were treated as `Any` by static type checkers like
`mypy` and `pytype`.

jaxlib 0.4.17 (Oct 3, 2023)

* Changes:
* Python 3.12 wheels were added in this release.
* The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.

* Bug fixes:
* Fixed log spam from ABSL when the JAX CPU backend was initialized.

Page 4 of 19

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.