* Breaking Changes
* This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels, `post_process_call`,
`new_base_main`, `custom_bind`, and so on. The change should only affect
users that use JAX internals.
If you do use JAX internals then you may need to
update your code (see
https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
`config.jax_data_dependent_tracing_fallback` flag as a workaround, and if
you need help updating your code then please file a bug.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
with native serialization will still be supported.
* In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed
after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`,
`xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`.
* The deprecated module `jax.experimental.export` has been removed. It was replaced
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
for information on migrating to the new API.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
has been removed, after being deprecated in v0.4.27.
* Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`)
now raises an error. Previously, this returned a scalar object array.
* The following deprecated methods and functions in {mod}`jax.export` have
been removed:
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
already.
* `jax.export.Exported.lowering_platforms`: use `platforms`.
* `jax.export.Exported.mlir_module_serialization_version`:
use `calling_convention_version`.
* `jax.export.Exported.uses_shape_polymorphism`:
use `uses_global_constants`.
* the `lowering_platforms` kwarg for {func}`jax.export.export`: use
`platforms` instead.
* The kwargs `symbolic_scope` and `symbolic_constraints` from
{func}`jax.export.symbolic_args_specs` have been removed. They were
deprecated in June 2024. Use `scope` and `constraints` instead.
* Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a `TypeError`.
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Run `python build/build.py --help` for
more details. Brief overview of the new subcommand options:
* `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt`
* `requirements_update`: Updates requirements_lock.txt files.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
return NaN for negative integer inputs, to match the behavior of SciPy from
https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
* We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use the `disabled_checks`
parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls).
* New Features
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
passing compilation options to XLA. For the moment it's undocumented and
may be in flux.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
* Added {func}`jax.numpy.put_along_axis`.
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
supported on GPU. See {jax-issue}`24663` for more details.
* Added two new configuration flags, `jax_exec_time_optimization_effort` and `jax_memory_fitting_effort`, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}`24843` for more details.
* Deprecations
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
use `jax.Array` instead.
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
instead.