* New Functionality
* Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`,
following their addition in the array API 2023 standard, soon to be
adopted by NumPy.
* Added a new config option `jax_cpu_collectives_implementation` to select the
implementation of cross-process collective operations used by the CPU backend.
Choices available are `'none'`(default), `'gloo'` and `'mpi'` (requires jaxlib 0.4.26).
If set to `'none'`, cross-process collective operations are disabled.
* Changes
* {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
and {func}`jax.debug.callback` now use {class}`jax.Array` instead
of {class}`np.ndarray`. You can recover the old behavior by transforming
the arguments via `jax.tree.map(np.asarray, args)` before passing them
to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
* `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could
be created and threaded in and out of computations to build up dependency.
The singleton object `core.token` has been removed, users now should create
and use fresh `core.Token` objects instead.
* On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new
default causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.
* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positional only), `min`, and
`max` ({jax-issue}`20550`).
* The `device()` method of JAX arrays has been removed, after being deprecated
since JAX v0.4.21. Use `arr.devices()` instead.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
is deprecated; empty inputs to softmax are now supported without setting this.
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
now leads to an error rather than a warning.
* The minimum jaxlib version is now 0.4.23.
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
related functions now raise an error, following a similar change in NumPy.
* The config option `jax_cpu_enable_gloo_collectives` is deprecated.
Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead.
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
been removed after being deprecated in JAX v0.4.22. Instead use
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
* The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now
positional-only, following deprecation of the keywords in JAX v0.4.21.
* Non-array arguments to functions in {mod}`jax.lax.linalg` now must be
specified by keyword. Previously, this raised a DeprecationWarning.
* Array-like arguments are now required in several :func:`jax.numpy` APIs,
including {func}`~jax.numpy.apply_along_axis`,
{func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`,
{func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`,
{func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`.
* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
The default value is set to `copy=False` to preserve backwards compatibility.
jaxlib 0.4.27 (May 7, 2024)