Jax

Latest version: v0.4.35

Safety actively analyzes 682471 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 1 of 19

0.4.35

* Breaking Changes
* {func}`jax.numpy.isscalar` now returns True for any array-like object with
zero dimensions. Previously it only returned True for zero-dimensional
array-like objects with a weak dtype.
* `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we removed it.
See {jax-issue}`20385` for a discussion of alternatives.

* Changes:
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
operations. The semi-public API `jax.lib.xla_client.FftType` has been
deprecated.
* TPU: JAX now installs TPU support from the `libtpu` package rather than
`libtpu-nightly`. For the next few releases JAX will pin an empty version of
`libtpu-nightly` as well as `libtpu` to ease the transition; that dependency
will be removed in Q1 2025.

* Deprecations:
* The semi-public API `jax.lib.xla_client.PaddingType` has been deprecated.
No JAX APIs consume this type, so there is no replacement.
* The default behavior of {func}`jax.pure_callback` and
{func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
the `vectorized` parameter to those functions. The `vmap_method` parameter
should be used instead for better defined behavior. See the discussion in
{jax-issue}`23881` for more details.
* The semi-public API `jax.lib.xla_client.register_custom_call_target` has
been deprecated. Use the JAX FFI instead.
* The semi-public APIs `jax.lib.xla_client.dtype_to_etype`,
`jax.lib.xla_client.ops`,
`jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`,
`jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and
`jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO
instead.

0.4.34

* New Functionality
* This release includes wheels for Python 3.13. Free-threading mode is not yet
supported.
* `jax.errors.JaxRuntimeError` has been added as a public alias for the
formerly private `XlaRuntimeError` type.

* Breaking changes
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
* array[0] on a pmap result now introduces a reshape (use array[0:1]
instead).
* The per-shard shape (accessable via jax_array.addressable_shards or
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
* `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we set the default value of the
`--jax_host_callback_legacy` configuration value to `True`, which means that
if your code uses `jax.experimental.host_callback` APIs, those API calls
will be implemented in terms of the new `jax.experimental.io_callback` API.
If this breaks your code, for a very limited time, you can set the
`--jax_host_callback_legacy` to `True`. Soon we will remove that
configuration option, so you should instead transition to using the
new JAX callback APIs. See {jax-issue}`20385` for a discussion.

* Deprecations
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike
arguments with `ndim != 1` are now deprecated, and in the future will result
in an error.
* Internal pretty-printing tools `jax.core.pp_*` have been removed, after
being deprecated in JAX v0.4.30.
* `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead.
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
`jax.errors.JaxRuntimeError` instead.

* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
in 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
output information (like tree structure, shape and dtype).
* For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
The argument was only used by `xmap` which was removed in 0.4.31.
* `jax.tree.map(f, None, non-None)`, which previously emitted a
`DeprecationWarning`, now raises an error in a future version of jax. `None`
is only a tree-prefix of itself. To preserve the current behavior, you can
ask `jax.tree.map` to treat `None` as a leaf value by writing:
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
* `jax.sharding.XLACompatibleSharding` has been removed. Please use
`jax.sharding.Sharding`.

* Bug fixes
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
if a non-boolean input was provided and `dtype=bool` was specified.
* Edit implementation of {func}`jax.numpy.ldexp` to get correct gradient.

0.4.33

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.

A TPU-only data corruption bug was found in the version of libtpu pinned by

0.4.32

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* New Functionality
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
to support the use of the new {ref}`ffi-tutorial` to interface with custom
C++ and CUDA code from JAX.

* Changes
* `jax_enable_memories` flag is set to `True` by default.
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.
* Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Added new {func}`jax.process_indices` function to replace the
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
* To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been
modified to no longer support `complex dtypes`.
* ``jax.tree_util.register_dataclass`` now checks that ``data_fields``
and ``meta_fields`` includes all dataclass fields with ``init=True``
and only them, if ``nodetype`` is a dataclass.
* Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc`
interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`,
{obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`,
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
In future releases we plan to expand these to other ufuncs.
* Added {func}`jax.lax.optimization_barrier`, which allows users to prevent
compiler optimizations such as common-subexpression elimination and to
control scheduling.

* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
`stablehlo` dialect instead.

* Deprecations
* Complex inputs to {func}`jax.numpy.clip` and {func}`jax.numpy.hypot` are
no longer allowed, after being deprecated since JAX v0.4.27.
* Deprecated the following APIs:
* `jax.lib.xla_bridge.xla_client`: use {mod}`jax.lib.xla_client` directly.
* `jax.lib.xla_bridge.get_backend`: use {func}`jax.extend.backend.get_backend`.
* `jax.lib.xla_bridge.default_backend`: use {func}`jax.extend.backend.default_backend`.
* The `jax.experimental.array_api` module is deprecated, and importing it is no
longer required to use the Array API. `jax.numpy` supports the array API
directly; see {ref}`python-array-api` for more information.
* The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and
`jax.core.check_valid_jaxtype` are now deprecated, and will be removed in
the future.
* `jax.numpy.round_` has been deprecated, following removal of the corresponding
API in NumPy 2.0. Use {func}`jax.numpy.round` instead.
* Passing a DLPack capsule to {func}`jax.dlpack.from_dlpack` is deprecated.
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
another framework that implements the ``__dlpack__`` protocol.

jaxlib 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* Breaking changes
* This release of jaxlib switched to a new version of the CPU backend, which
should compile faster and leverage parallelism better. If you experience
any problems due to this change, you can temporarily enable the old CPU
backend by setting the environment variable
`XLA_FLAGS=--xla_cpu_use_thunk_runtime=false`. If you need to do this,
please file a JAX bug with instructions to reproduce.
* Hermetic CUDA support is added.
Hermetic CUDA uses a specific downloadable version of CUDA instead of the
user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL
distributions, and then use CUDA libraries and tools as dependencies in
various Bazel targets. This enables more reproducible builds for JAX and its
supported CUDA versions.

* Changes
* SparseCore profiling is added.
* JAX now supports profiling [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore) on TPUv5p chips. These traces will be viewable in Tensorboard Profiler's [TraceViewer](https://www.tensorflow.org/guide/profiler#trace_viewer).

0.4.31

* Deletion
* xmap has been deleted. Please use {func}`shard_map` as the replacement.

* Changes
* The minimum CuDNN version is v9.1. This was true in previous releases also,
but we now declare this version constraint formally.
* The minimum Python version is now 3.10. 3.10 will remain the minimum
supported version until July 2025.
* The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum
supported version until December 2024.
* The minimum SciPy version is now 1.10. SciPy 1.10 will remain the minimum
supported version until January 2025.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
be passed *before* `index_map`. The old argument order is deprecated and
will be removed in a future release.
* Updated the repr of gpu devices to be more consistent
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
part of JAX's [Array API](https://data-apis.org/array-api) support.
* Deprecations
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
`dimension_as_value`, `definitely_equal`, and `symbolic_equal_dim`.
* HLO lowering rules should no longer wrap singleton ir.Values in tuples.
Instead, return singleton ir.Values unwrapped. Support for wrapped values
will be removed in a future version of JAX.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or `enable_xla=False` is now deprecated and this support will be removed in
a future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
* The previously-deprecated function `jax.random.shuffle` has been removed;
instead use `jax.random.permutation` with `independent=True`.

jaxlib 0.4.31 (July 29, 2024)

* Bug fixes
* Fixed a bug that meant that negative static_argnums to a jit were mishandled
by the jit dispatch fast path.
* Fixed a bug that meant triangular solves of batches of singular matrices
produce nonsensical finite values, instead of inf or nan (3589, 15429).

0.4.30

* Changes
* JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
* `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e.
* jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required.
* Added an API for exporting and serializing JAX functions. This used
to exist in `jax.experimental.export` (which is being deprecated),
and will now live in `jax.export`.
See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html).

* Deprecations
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
in a future release.
* Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX
release. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
* `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead.
See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export).
* Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays
`x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`.
* `jax.xla_computation` is deprecated and will be removed in a future release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
output information (like tree structure, shape and dtype).
* For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.


jaxlib 0.4.30 (June 18, 2024)

* Support for monolithic CUDA jaxlibs has been dropped. You must use the
plugin-based installation (`pip install jax[cuda12]` or
`pip install jax[cuda12_local]`).

Page 1 of 19

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.