Jax

Latest version: v0.4.30

Safety actively analyzes 640986 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 2 of 18

0.4.24

* Changes

* JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_partitioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set.
This is needed because custom_partitioning and JAX callbacks need physical
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
and `descending` arguments.
* Several changes to the handling of shape polymorphism (used in
{mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`):
* cleaner pretty-printing of symbolic expressions ({jax-issue}`19227`)
* added the ability to specify symbolic constraints on the dimension variables.
This makes shape polymorphism more expressive, and gives a way to workaround
limitations in the reasoning about inequalities.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* with the addition of symbolic constraints ({jax-issue}`19235`) we now
consider dimension variables from different scopes to be different, even
if they have the same name. Symbolic expressions from different scopes
cannot interact, e.g., in arithmetic operations.
Scopes are introduced by {func}`jax.experimental.jax2tf.convert`,
{func}`jax.experimental.export.symbolic_shape`, {func}`jax.experimental.export.symbolic_args_specs`.
The scope of a symbolic expression `e` can be read with `e.scope` and passed
into the above functions to direct them to construct symbolic expressions in
a given scope.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* simplified and faster equality comparisons, where we consider two symbolic dimensions
to be equal if the normalized form of their difference reduces to 0
({jax-issue}`19231`; note that this may result in user-visible behavior
changes)
* improved the error messages for inconclusive inequality comparisons
({jax-issue}`19235`).
* the `core.non_negative_dim` API (introduced recently)
was deprecated and `core.max_dim` and `core.min_dim` were introduced
({jax-issue}`18953`) to express `max` and `min` for symbolic dimensions.
You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`.
* the `shape_poly.is_poly_dim` is deprecated in favor of `export.is_symbolic_dim`
({jax-issue}`19282`).
* the `export.args_specs` is deprecated in favor of `export.symbolic_args_specs
({jax-issue}`19283`).
* the `shape_poly.PolyShape` and `jax2tf.PolyShape` are deprecated, use
strings for polymorphic shapes specifications ({jax-issue}`19284`).
* JAX default native serialization version is now 9. This is relevant
for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`.
See [description of version numbers](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will
continue to work for a deprecation period of 3 months.
* Added {func}`jax.scipy.stats.sem`.
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
reshaped to the dimension of the input, following a similar change to
{func}`numpy.unique` in NumPy 2.0.
* {func}`jax.numpy.sign` now returns `x / abs(x)` for nonzero complex inputs. This is
consistent with the behavior of {func}`numpy.sign` in NumPy version 2.0.
* {func}`jax.scipy.special.logsumexp` with `return_sign=True` now uses the NumPy 2.0
convention for the complex sign, `x / abs(x)`. This is consistent with the behavior
of {func}`scipy.special.logsumexp` in SciPy v1.13.
* JAX now supports the bool DLPack type for both import and export.
Previously bool values could not be imported and were exported as integers.

* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
This includes:
* From {mod}`jax.core`: `TracerArrayConversionError`,
`TracerIntegerConversionError`, `UnexpectedTracerError`,
`as_hashable_function`, `collections`, `dtypes`, `lu`, `map`,
`namedtuple`, `partial`, `pp`, `ref`, `safe_zip`, `safe_map`,
`source_info_util`, `total_ordering`, `traceback_util`, `tuple_delete`,
`tuple_insert`, and `zip`.
* From {mod}`jax.lax`: `dtypes`, `itertools`, `naryop`, `naryop_dtype_rule`,
`standard_abstract_eval`, `standard_naryop`, `standard_primitive`,
`standard_unop`, `unop`, and `unop_dtype_rule`.
* The `jax.linear_util` submodule and all its contents.
* The `jax.prng` submodule and all its contents.
* From {mod}`jax.random`: `PRNGKeyArray`, `KeyArray`, `default_prng_impl`,
`threefry_2x32`, `threefry2x32_key`, `threefry2x32_p`, `rbg_key`, and
`unsafe_rbg_key`.
* From {mod}`jax.tree_util`: `register_keypaths`, `AttributeKeyPathEntry`, and
`GetItemKeyPathEntry`.
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`, `translations`,
`register_translation`, `xla_destructure`, `TranslationRule`, `TranslationContext`,
`axis_groups`, `ShapedArray`, `ConcreteArray`, `AxisEnv`, `backend_compile`,
and `XLAOp`.
* from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`,
`trapz`, and `in1d`.
* from {mod}`jax.scipy.linalg`: `tril` and `triu`.
* The previously-deprecated method `PRNGKeyArray.unsafe_raw_array` has been
removed. Use {func}`jax.random.key_data` instead.
* `bool(empty_array)` now raises an error rather than returning `False`. This
previously raised a deprecation warning, and follows a similar change in NumPy.
* Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses
the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
removed in the future. Use the "stablehlo" dialect instead.
* {mod}`jax.random`: passing batched keys directly to random number generation functions,
such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated
and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching.
* {func}`jax.lax.tie_in` is deprecated: it has been a no-op since JAX v0.2.0.

jaxlib 0.4.24 (Feb 6, 2024)

* Changes

* JAX now supports CUDA 12.3 and CUDA 11.8. Support for CUDA 12.2 has been
dropped.
* `cost_analysis` now works with cross-compiled `Compiled` objects (i.e. when
using `.lower().compile()` with a topology object, e.g., to compile for
Cloud TPU from a non-TPU computer).
* Added [CUDA Array
Interface](https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html)
import support (requires jax 0.4.25).

0.4.23

jaxlib 0.4.23 (Dec 13, 2023)

* Fixed a bug that caused verbose logging from the GPU compiler during
compilation.

0.4.22

* Deprecations
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
Explicit buffers have been replaced by the more flexible array sharding interface,
but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`

jaxlib 0.4.22 (Dec 13, 2023)

0.4.21

* New Features
* Added {obj}`jax.nn.squareplus`.

* Changes
* The minimum jaxlib version is now 0.4.19.
* Released wheels are built now with clang instead of gcc.
* Enforce that the device backend has not been initialized prior to calling `jax.distributed.initialize()`.
* Automate arguments to `jax.distributed.initialize()` in cloud TPU environments.

* Deprecations
* The previously-deprecated `sym_pos` argument has been removed from
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
keyword arguments has been deprecated, to match `numpy.where`.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array is deprecated and now raises a
{obj}`DeprecationWaning`. Currently the functions return False, in the future this
will raise an exception.
* The `device()` method of JAX arrays is deprecated. Depending on the context, it may
be replaced with one of the following:
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
- {attr}`jax.Array.sharding` gives the sharding configuration used by the array.

jaxlib 0.4.21 (Dec 4 2023)

* Changes
* In preparation for adding distributed CPU support, JAX now treats CPU
devices identically to GPU and TPU devices, that is:

* `jax.devices()` includes all devices present in a distributed job, even
those not local to the current process. `jax.local_devices()` still only
includes devices local to the current process, so if the change to
`jax.devices()` breaks you, you most likely want to use
`jax.local_devices()` instead.
* CPU devices now receive a globally unique ID number within a distributed
job; previously CPU devices would receive a process-local ID number.
* The `process_index` of each CPU device will now match any GPU or TPU
devices within the same process; previously the `process_index` of a CPU
device was always 0.

* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.

* Bug fixes
* Fixed error/hang when an array with non-finite values is passed to a
non-symmetric eigendecomposition (18226). Arrays with non-finite values now
produce arrays full of NaNs as outputs.

0.4.20

jaxlib 0.4.20 (Nov 2, 2023)

* Bug fixes
* Fixed some type confusion between E4M3 and E5M2 float8 types.

0.4.19

* New Features
* Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that
are convertible to JAX dtypes.
* Added `jax.numpy.fill_diagonal`.

* Changes
* JAX now requires SciPy 1.9 or newer.

* Bug fixes
* Only process 0 in a multicontroller distributed JAX program will write
persistent compilation cache entries. This fixes write contention if the
cache is placed on a network file system such as GCS.
* The version check for cusolver and cufft no longer considers the patch
versions when determining if the installed version of these libraries is at
least as new as the versions against which JAX was built.

jaxlib 0.4.19 (Oct 19, 2023)

* Changes
* jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
(nvidia-... packages) over any other CUDA installation if they are
installed, including installations named in `LD_LIBRARY_PATH`. If this
causes problems and the intent is to use a system-installed CUDA, the fix is
to remove the pip installed CUDA library packages.

Page 2 of 18

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.