* Changes
* Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}`version-support-policy`.
* We introduce `jax.Array` which is a unified array type that subsumes
`DeviceArray`, `ShardedDeviceArray`, and `GlobalDeviceArray` types in JAX.
The `jax.Array` type helps make parallelism a core feature of JAX,
simplifies and unifies JAX internals, and allows us to unify `jit` and
`pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some
breaking change to the `pjit` API. The [jax.Array migration
guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can
help you migrate your codebase to `jax.Array`. You can also look at the
[Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
tutorial to understand the new concepts.
* `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
deprecated and will be removed in 3 months.
* `with_sharding_constraint`s new public endpoint is
`jax.lax.with_sharding_constraint`.
* If using ABSL flags together with `jax.config`, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
`jax.config` options, which are used pervasively in JAX.
* The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Before, it was using the 0th device for the JAX-default backend.
* A number of `jax.numpy` functions now have their arguments marked as
positional-only, matching NumPy.
* `jnp.msort` is now deprecated, following the deprecation of `np.msort` in numpy 1.24.
It will be removed in a future release, in accordance with the {ref}`api-compatibility`
policy. It can be replaced with `jnp.sort(a, axis=0)`.
jaxlib 0.4.1 (Dec 13, 2022)
* Changes
* Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}`version-support-policy`.
* The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of
the total GPU memory instead of the previous behavior of using currently available GPU memory
to calculate preallocation. Please refer to
[GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for
more details.
* The deprecated method `.block_host_until_ready()` has been removed. Use
`.block_until_ready()` instead.