* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.13...jax-v0.3.14).
* Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support
`max_cache_size_ bytes` anymore and will not get that as an input.
* `JAX_PLATFORMS` now raises an exception when platform initialization fails.
* Changes
* Fixed compatibility problems with NumPy 1.23.
* {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument
that allows selection between an LU-decomposition based implementation and
an implementation based on QR decomposition.
* {func}`jax.numpy.linalg.qr` now supports `mode="raw"`.
* `pickle`, `copy.copy`, and `copy.deepcopy` now have more complete support when
used on jax arrays ({jax-issue}`10659`). In particular:
- `pickle` and `deepcopy` previously returned `np.ndarray` objects when used
on a `DeviceArray`; now `DeviceArray` objects are returned. For `deepcopy`,
the copied array is on the same device as the original. For `pickle` the
deserialized array will be on the default device.
- Within function transformations (i.e. traced code), `deepcopy` and `copy`
previously were no-ops. Now they use the same mechanism as `DeviceArray.copy()`.
- Calling `pickle` on a traced array now results in an explicit
`ConcretizationTypeError`.
* The implementation of singular value decomposition (SVD) and
symmetric/Hermitian eigendecomposition should be significantly faster on
TPU, especially for matrices above 1000x1000 or so. Both now use a spectral
divide-and-conquer algorithm for eigendecomposition (QDWH-eig).
* {func}`jax.numpy.ldexp` no longer silently promotes all inputs to float64,
instead it promotes to float32 for integer inputs of size int32 or smaller
({jax-issue}`10921`).
* Add a `create_perfetto_link` option to {func}`jax.profiler.start_trace` and
{func}`jax.profiler.start_trace`. When used, the profiler will generate a
link to the Perfetto UI to view the trace.
* Changed the semantics of {func}`jax.profiler.start_server(...)` to store the
keepalive globally, rather than requiring the user to keep a reference to
it.
* Added {func}`jax.random.generalized_normal`.
* Added {func}`jax.random.ball`.
* Added {func}`jax.default_device`.
* Added a `python -m jax.collect_profile` script to manually capture program
traces as an alternative to the TensorBoard UI.
* Added a `jax.named_scope` context manager that adds profiler metadata to
Python programs (similar to `jax.named_call`).
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
dtype casts are deprecated, and now result in a `FutureWarning`.
In a future release, this will become an error. An example of an unsafe implicit
cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was
silently truncated to `1`.
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
bucket path as input.
* Added {func}`jax.scipy.stats.gennorm`.
* {func}`jax.numpy.roots` is now better behaved when `strip_zeros=False` when
coefficients have leading zeros ({jax-issue}`11215`).
jaxlib 0.3.14 (June 27, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14).
* x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14
was released in 2018, so this should not be a very onerous requirement.
* The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
* The Python flatbuffers package is no longer a dependency of jaxlib.