Jax

Latest version: v0.5.0

Safety actively analyzes 706267 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 2 of 19

0.4.33

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.

A TPU-only data corruption bug was found in the version of libtpu pinned by

0.4.32

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* New Functionality
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
to support the use of the new {ref}`ffi-tutorial` to interface with custom
C++ and CUDA code from JAX.

* Changes
* `jax_enable_memories` flag is set to `True` by default.
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.
* Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Added new {func}`jax.process_indices` function to replace the
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
* To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been
modified to no longer support `complex dtypes`.
* ``jax.tree_util.register_dataclass`` now checks that ``data_fields``
and ``meta_fields`` includes all dataclass fields with ``init=True``
and only them, if ``nodetype`` is a dataclass.
* Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc`
interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`,
{obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`,
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
In future releases we plan to expand these to other ufuncs.
* Added {func}`jax.lax.optimization_barrier`, which allows users to prevent
compiler optimizations such as common-subexpression elimination and to
control scheduling.

* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
`stablehlo` dialect instead.

* Deprecations
* Complex inputs to {func}`jax.numpy.clip` and {func}`jax.numpy.hypot` are
no longer allowed, after being deprecated since JAX v0.4.27.
* Deprecated the following APIs:
* `jax.lib.xla_bridge.xla_client`: use {mod}`jax.lib.xla_client` directly.
* `jax.lib.xla_bridge.get_backend`: use {func}`jax.extend.backend.get_backend`.
* `jax.lib.xla_bridge.default_backend`: use {func}`jax.extend.backend.default_backend`.
* The `jax.experimental.array_api` module is deprecated, and importing it is no
longer required to use the Array API. `jax.numpy` supports the array API
directly; see {ref}`python-array-api` for more information.
* The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and
`jax.core.check_valid_jaxtype` are now deprecated, and will be removed in
the future.
* `jax.numpy.round_` has been deprecated, following removal of the corresponding
API in NumPy 2.0. Use {func}`jax.numpy.round` instead.
* Passing a DLPack capsule to {func}`jax.dlpack.from_dlpack` is deprecated.
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
another framework that implements the ``__dlpack__`` protocol.

jaxlib 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* Breaking changes
* This release of jaxlib switched to a new version of the CPU backend, which
should compile faster and leverage parallelism better. If you experience
any problems due to this change, you can temporarily enable the old CPU
backend by setting the environment variable
`XLA_FLAGS=--xla_cpu_use_thunk_runtime=false`. If you need to do this,
please file a JAX bug with instructions to reproduce.
* Hermetic CUDA support is added.
Hermetic CUDA uses a specific downloadable version of CUDA instead of the
user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL
distributions, and then use CUDA libraries and tools as dependencies in
various Bazel targets. This enables more reproducible builds for JAX and its
supported CUDA versions.

* Changes
* SparseCore profiling is added.
* JAX now supports profiling [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore) on TPUv5p chips. These traces will be viewable in Tensorboard Profiler's [TraceViewer](https://www.tensorflow.org/guide/profiler#trace_viewer).

0.4.31

* Deletion
* xmap has been deleted. Please use {func}`shard_map` as the replacement.

* Changes
* The minimum CuDNN version is v9.1. This was true in previous releases also,
but we now declare this version constraint formally.
* The minimum Python version is now 3.10. 3.10 will remain the minimum
supported version until July 2025.
* The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum
supported version until December 2024.
* The minimum SciPy version is now 1.10. SciPy 1.10 will remain the minimum
supported version until January 2025.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
be passed *before* `index_map`. The old argument order is deprecated and
will be removed in a future release.
* Updated the repr of gpu devices to be more consistent
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
part of JAX's [Array API](https://data-apis.org/array-api) support.
* Deprecations
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
`dimension_as_value`, `definitely_equal`, and `symbolic_equal_dim`.
* HLO lowering rules should no longer wrap singleton ir.Values in tuples.
Instead, return singleton ir.Values unwrapped. Support for wrapped values
will be removed in a future version of JAX.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or `enable_xla=False` is now deprecated and this support will be removed in
a future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
* The previously-deprecated function `jax.random.shuffle` has been removed;
instead use `jax.random.permutation` with `independent=True`.

jaxlib 0.4.31 (July 29, 2024)

* Bug fixes
* Fixed a bug that meant that negative static_argnums to a jit were mishandled
by the jit dispatch fast path.
* Fixed a bug that meant triangular solves of batches of singular matrices
produce nonsensical finite values, instead of inf or nan (3589, 15429).

0.4.30

* Changes
* JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
* `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e.
* jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required.
* Added an API for exporting and serializing JAX functions. This used
to exist in `jax.experimental.export` (which is being deprecated),
and will now live in `jax.export`.
See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html).

* Deprecations
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
in a future release.
* Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX
release. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
* `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead.
See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export).
* Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays
`x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`.
* `jax.xla_computation` is deprecated and will be removed in a future release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
output information (like tree structure, shape and dtype).
* For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.


jaxlib 0.4.30 (June 18, 2024)

* Support for monolithic CUDA jaxlibs has been dropped. You must use the
plugin-based installation (`pip install jax[cuda12]` or
`pip install jax[cuda12_local]`).

0.4.29

* Changes
* We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g. `pip install jax[cuda12]`).
* JAX now requires ml_dtypes version 0.4.0 or newer.
* Removed backwards-compatibility support for old usage of the
`jax.experimental.export` API. It is not possible anymore to use
`from jax.experimental.export import export`, and instead you should use
`from jax.experimental import export`.
The removed functionality has been deprecated since 0.4.24.
* Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`.

* Deprecations
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use
`jax.sharding.Sharding`.
* `jax.experimental.Exported.in_shardings` has been renamed as
`jax.experimental.Exported.in_shardings_hlo`. Same for `out_shardings`.
The old names will be removed after 3 months.
* Removed a number of previously-deprecated APIs:
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
* from {mod}`jax.lax`: `tie_in`
* from {mod}`jax.nn`: `normalize`
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`,
`translations`, `register_translation`, `xla_destructure`,
`TranslationRule`, `TranslationContext`, `XlaOp`.
* The ``tol`` argument of {func}`jax.numpy.linalg.matrix_rank` is being
deprecated and will soon be removed. Use `rtol` instead.
* The ``rcond`` argument of {func}`jax.numpy.linalg.pinv` is being
deprecated and will soon be removed. Use `rtol` instead.
* The deprecated `jax.config` submodule has been removed. To configure JAX
use `import jax` and then reference the config object via `jax.config`.
* {mod}`jax.random` APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}`jax.vmap` in such cases.
* In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been
renamed to `a` and `b` for consistency with other `beta` APIs.

* New Functionality
* Added {func}`jax.experimental.Exported.in_shardings_jax` to construct
shardings that can be used with the JAX APIs from the HloShardings
that are stored in the `Exported` objects.

jaxlib 0.4.29 (June 10, 2024)

* Bug fixes
* Fixed a bug where XLA sharded some concatenation operations incorrectly,
which manifested as an incorrect output for cumulative reductions (21403).
* Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(https://github.com/openxla/xla/pull/13301).
* Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396).

* Deprecations
* `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will
raise an error in a future version of jax. `None` is only a tree-prefix of
itself. To preserve the current behavior, you can ask `jax.tree.map` to
treat `None` as a leaf value by writing:
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.

0.4.28

* Bug fixes
* Reverted a change to `make_jaxpr` that was breaking Equinox (21116).

* Deprecations & removals
* The ``kind`` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort`
is now removed. Use `stable=True` or `stable=False` instead.
* Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu``
module. Use the ``compute_capability`` attribute of a GPU device, returned
by {func}`jax.devices` or {func}`jax.local_devices`, instead.
* The ``newshape`` argument to {func}`jax.numpy.reshape`is being deprecated
and will soon be removed. Use `shape` instead.

* Changes
* The minimum jaxlib version of this release is 0.4.27.

jaxlib 0.4.28 (May 9, 2024)

* Bug fixes
* Fixes a memory corruption bug in the type name of Array and JIT Python
objects in Python 3.10 or earlier.
* Fixed a warning `'+ptx84' is not a recognized feature for this target`
under CUDA 12.4.
* Fixed a slow compilation problem on CPU.

* Changes
* The Windows build is now built with Clang instead of MSVC.

Page 2 of 19

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.