* Changes
* Added {class}`jax.numpy.ufunc`, as well as {func}`jax.numpy.frompyfunc`, which can convert
any scalar-valued function into a {func}`numpy.ufunc`-like object, with methods such as
{meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`,
{meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and
{meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`17054`).
* Added {func}`jax.scipy.integrate.trapezoid`.
* When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See
[here](https://github.com/jax-ml/jax/pull/16949) for an example.
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
applies to mock devices or user created devices. `jax.devices()` are
already hashable.
* Breaking changes:
* jax2tf now uses native serialization by default. See
the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
* `jax.jaxpr_util` has been removed from the public JAX namespace.
* `JAX_USE_PJRT_C_API_ON_TPU` no longer has an effect (i.e. it always defaults to true).
* The backwards compatibility flag `--jax_host_callback_ad_transforms`
introduced in December 2021, has been removed.
* Deprecations:
* Several `jax.numpy` APIs have been deprecated following
[NumPy NEP-52](https://numpy.org/neps/nep-0052-python-api-cleanup.html):
* `jax.numpy.NINF` has been deprecated. Use `-jax.numpy.inf` instead.
* `jax.numpy.PZERO` has been deprecated. Use `0.0` instead.
* `jax.numpy.NZERO` has been deprecated. Use `-0.0` instead.
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
* `jax.numpy.trapz` has been deprecated. Use `jax.scipy.integrate.trapezoid` instead.
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
Use the built-in `math.prod` instead.
* A number of exports from `jax.interpreters.xla` related to defining
HLO lowering rules for custom JAX primitives have been deprecated. Custom
primitives should be defined using the StableHLO lowering utilities in
`jax.interpreters.mlir` instead.
* The following previously-deprecated functions have been removed after a
three-month deprecation period:
* `jax.abstract_arrays.ShapedArray`: use `jax.core.ShapedArray`.
* `jax.abstract_arrays.raise_to_shaped`: use `jax.core.raise_to_shaped`.
* `jax.numpy.alltrue`: use `jax.numpy.all`.
* `jax.numpy.sometrue`: use `jax.numpy.any`.
* `jax.numpy.product`: use `jax.numpy.prod`.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`.
* Deprecations/removals:
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
for type annotations, and `jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)` for
runtime detection of typed prng keys.
* The method `PRNGKeyArray.unsafe_raw_array` is deprecated. Use
{func}`jax.random.key_data` instead.
* `jax.experimental.pjit.with_sharding_constraint` is deprecated. Use
`jax.lax.with_sharding_constraint` instead.
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`
have been removed. Opaque dtypes have been renamed to Extended dtypes; use
`jnp.issubdtype(dtype, jax.dtypes.extended)` instead (available since jax v0.4.14).
* The utility `jax.interpreters.xla.register_collective_primitive` has been
removed. This utility did nothing useful in recent JAX releases and calls
to it can be safely removed.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
jaxlib 0.4.16 (Sept 18, 2023)
* Changes:
* Sparse CSR matrix multiplications via the experimental jax sparse APIs
no longer uses a deterministic algorithm on NVIDIA GPUs. This change was
made to improve compatibility with CUDA 12.2.1.
* Bug fixes:
* Fixed a crash on Windows due to a fatal LLVM error related to out-of-order
sections and IMAGE_REL_AMD64_ADDR32NB relocations
(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).