Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.
* New Functionality
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
to support the use of the new {ref}`ffi-tutorial` to interface with custom
C++ and CUDA code from JAX.
* Changes
* `jax_enable_memories` flag is set to `True` by default.
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.
* Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Added new {func}`jax.process_indices` function to replace the
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
* To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been
modified to no longer support `complex dtypes`.
* ``jax.tree_util.register_dataclass`` now checks that ``data_fields``
and ``meta_fields`` includes all dataclass fields with ``init=True``
and only them, if ``nodetype`` is a dataclass.
* Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc`
interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`,
{obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`,
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
In future releases we plan to expand these to other ufuncs.
* Added {func}`jax.lax.optimization_barrier`, which allows users to prevent
compiler optimizations such as common-subexpression elimination and to
control scheduling.
* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
`stablehlo` dialect instead.
* Deprecations
* Complex inputs to {func}`jax.numpy.clip` and {func}`jax.numpy.hypot` are
no longer allowed, after being deprecated since JAX v0.4.27.
* Deprecated the following APIs:
* `jax.lib.xla_bridge.xla_client`: use {mod}`jax.lib.xla_client` directly.
* `jax.lib.xla_bridge.get_backend`: use {func}`jax.extend.backend.get_backend`.
* `jax.lib.xla_bridge.default_backend`: use {func}`jax.extend.backend.default_backend`.
* The `jax.experimental.array_api` module is deprecated, and importing it is no
longer required to use the Array API. `jax.numpy` supports the array API
directly; see {ref}`python-array-api` for more information.
* The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and
`jax.core.check_valid_jaxtype` are now deprecated, and will be removed in
the future.
* `jax.numpy.round_` has been deprecated, following removal of the corresponding
API in NumPy 2.0. Use {func}`jax.numpy.round` instead.
* Passing a DLPack capsule to {func}`jax.dlpack.from_dlpack` is deprecated.
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
another framework that implements the ``__dlpack__`` protocol.
jaxlib 0.4.32 (September 11, 2024)
Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.
* Breaking changes
* This release of jaxlib switched to a new version of the CPU backend, which
should compile faster and leverage parallelism better. If you experience
any problems due to this change, you can temporarily enable the old CPU
backend by setting the environment variable
`XLA_FLAGS=--xla_cpu_use_thunk_runtime=false`. If you need to do this,
please file a JAX bug with instructions to reproduce.
* Hermetic CUDA support is added.
Hermetic CUDA uses a specific downloadable version of CUDA instead of the
user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL
distributions, and then use CUDA libraries and tools as dependencies in
various Bazel targets. This enables more reproducible builds for JAX and its
supported CUDA versions.
* Changes
* SparseCore profiling is added.
* JAX now supports profiling [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore) on TPUv5p chips. These traces will be viewable in Tensorboard Profiler's [TraceViewer](https://www.tensorflow.org/guide/profiler#trace_viewer).