* New Features
* Added {obj}`jax.nn.squareplus`.
* Changes
* The minimum jaxlib version is now 0.4.19.
* Released wheels are built now with clang instead of gcc.
* Enforce that the device backend has not been initialized prior to calling `jax.distributed.initialize()`.
* Automate arguments to `jax.distributed.initialize()` in cloud TPU environments.
* Deprecations
* The previously-deprecated `sym_pos` argument has been removed from
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
keyword arguments has been deprecated, to match `numpy.where`.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array is deprecated and now raises a
{obj}`DeprecationWaning`. Currently the functions return False, in the future this
will raise an exception.
* The `device()` method of JAX arrays is deprecated. Depending on the context, it may
be replaced with one of the following:
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
- {attr}`jax.Array.sharding` gives the sharding configuration used by the array.
jaxlib 0.4.21 (Dec 4 2023)
* Changes
* In preparation for adding distributed CPU support, JAX now treats CPU
devices identically to GPU and TPU devices, that is:
* `jax.devices()` includes all devices present in a distributed job, even
those not local to the current process. `jax.local_devices()` still only
includes devices local to the current process, so if the change to
`jax.devices()` breaks you, you most likely want to use
`jax.local_devices()` instead.
* CPU devices now receive a globally unique ID number within a distributed
job; previously CPU devices would receive a process-local ID number.
* The `process_index` of each CPU device will now match any GPU or TPU
devices within the same process; previously the `process_index` of a CPU
device was always 0.
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
* Bug fixes
* Fixed error/hang when an array with non-finite values is passed to a
non-symmetric eigendecomposition (18226). Arrays with non-finite values now
produce arrays full of NaNs as outputs.