* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.7...jax-v0.3.8).
* Changes
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.pinv` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input.
* {func}`jax.scipy.cluster.vq.vq` has been added.
* `jax.experimental.maps.mesh` has been deleted.
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.
* {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when
`mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`10452`)
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
that specifies the behavior of out-of-bounds indexing. By default,
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
previous versions of JAX, invalid indices were clamped into range. The
previous behavior can be restored by passing `mode="clip"`.
* {func}`jax.numpy.take` now defaults to `mode="fill"`, which returns
invalid values (e.g., NaN) for out-of-bounds indices.
* Scatter operations, such as `x.at[...].set(...)`, now have `"drop"` semantics.
This has no effect on the scatter operation itself, but it means that when
differentiated the gradient of a scatter will yield zero cotangents for
out-of-bounds indices. Previously out-of-bounds indices were clamped into
range for the gradient, which was not mathematically correct.
* {func}`jax.numpy.take_along_axis` now raises a `TypeError` if its indices
are not of an integer type, matching the behavior of
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
cast to integers.
* {func}`jax.numpy.ravel_multi_index` now raises a `TypeError` if its `dims` argument
is not of an integer type, matching the behavior of
{func}`numpy.ravel_multi_index`. Previously non-integer `dims` was silently
cast to integers.
* {func}`jax.numpy.split` now raises a `TypeError` if its `axis` argument
is not of an integer type, matching the behavior of
{func}`numpy.split`. Previously non-integer `axis` was silently
cast to integers.
* {func}`jax.numpy.indices` now raises a `TypeError` if its dimensions
are not of an integer type, matching the behavior of
{func}`numpy.indices`. Previously non-integer dimensions were silently
cast to integers.
* {func}`jax.numpy.diag` now raises a `TypeError` if its `k` argument
is not of an integer type, matching the behavior of
{func}`numpy.diag`. Previously non-integer `k` was silently
cast to integers.
* Added {func}`jax.random.orthogonal`.
* Deprecations
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
`format_shape_dtype_string`, `rand_uniform`, `skip_on_devices`, `with_config`, `xla_bridge`, and
`_default_tolerance` ({jax-issue}`10389`). These, along with previously-deprecated `JaxTestCase`,
`JaxTestLoader`, and `BufferDonationTestCase`, will be removed in a future JAX release.
Most of these utilities can be replaced by calls to standard python & numpy testing utilities found
in e.g. {mod}`unittest`, {mod}`absl.testing`, {mod}`numpy.testing`, etc. JAX-specific functionality
such as device checking can be replaced through the use of public APIs such as {func}`jax.devices`.
Many of the deprecated utilities will still exist in {mod}`jax._src.test_util`, but these are not
public APIs and as such may be changed or removed without notice in future releases.