Jax

Latest version: v0.4.28

Safety actively analyzes 629811 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 1 of 17

0.4.28

* Bug fixes
* Fixes a memory corruption bug in the type name of Array and JIT Python
objects in Python 3.10 or earlier.
* Fixed a warning `'+ptx84' is not a recognized feature for this target`
under CUDA 12.4.
* Fixed a slow compilation problem on CPU.

* Changes
* The Windows build is now built with Clang instead of MSVC.

0.4.27

* New Functionality
* Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`,
following their addition in the array API 2023 standard, soon to be
adopted by NumPy.
* Added a new config option `jax_cpu_collectives_implementation` to select the
implementation of cross-process collective operations used by the CPU backend.
Choices available are `'none'`(default), `'gloo'` and `'mpi'` (requires jaxlib 0.4.26).
If set to `'none'`, cross-process collective operations are disabled.

* Changes
* {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
and {func}`jax.debug.callback` now use {class}`jax.Array` instead
of {class}`np.ndarray`. You can recover the old behavior by transforming
the arguments via `jax.tree.map(np.asarray, args)` before passing them
to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
* `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could
be created and threaded in and out of computations to build up dependency.
The singleton object `core.token` has been removed, users now should create
and use fresh `core.Token` objects instead.
* On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new
default causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.

* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positional only), `min`, and
`max` ({jax-issue}`20550`).
* The `device()` method of JAX arrays has been removed, after being deprecated
since JAX v0.4.21. Use `arr.devices()` instead.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
is deprecated; empty inputs to softmax are now supported without setting this.
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
now leads to an error rather than a warning.
* The minimum jaxlib version is now 0.4.23.
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
related functions now raise an error, following a similar change in NumPy.
* The config option `jax_cpu_enable_gloo_collectives` is deprecated.
Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead.
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
been removed after being deprecated in JAX v0.4.22. Instead use
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
* The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now
positional-only, following deprecation of the keywords in JAX v0.4.21.
* Non-array arguments to functions in {mod}`jax.lax.linalg` now must be
specified by keyword. Previously, this raised a DeprecationWarning.
* Array-like arguments are now required in several :func:`jax.numpy` APIs,
including {func}`~jax.numpy.apply_along_axis`,
{func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`,
{func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`,
{func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`.

* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
The default value is set to `copy=False` to preserve backwards compatability.

jaxlib 0.4.27 (May 7, 2024)

0.4.26

* New Functionality
* Added {func}`jax.numpy.trapezoid`, following the addition of this function in
NumPy 2.0.

* Changes
* Complex-valued {func}`jax.numpy.geomspace` now chooses the logarithmic spiral
branch consistent with that of NumPy 2.0.
* The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'`
and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has
changed](https://github.com/google/jax/issues/19085) so that
mapping over keys results in random generation only from the first
key in the batch.
* Docs now use `jax.random.key` for construction of PRNG key arrays
rather than `jax.random.PRNGKey`.

* Deprecations & Removals
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
* {func}`jax.clear_backends` is deprecated as it does not necessarily do what
its name suggests and can lead to unexpected consequences, e.g., it will not
destroy existing backends and release corresponding owned resources. Use
{func}`jax.clear_caches` if you only want to clean up compilation caches.
For backward compatibility or you really need to switch/reinitialize the
default backend, use {func}`jax.extend.backend.clear_backends`.
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
* The `jax.experimental.host_callback` module is deprecated.
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the
new callbacks. See {jax-issue}`20385` for a discussion.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array now results in an exception.
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
This flag was long deprecated and did nothing; its use was a no-op.
* The previously-deprecated imports `jax.interpreters.ad.config` and
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
and `jax.extend.source_info_util` instead.
* JAX export does not support older serialization versions anymore. Version 9
has been supported since October 27th, 2023 and has become the default
since February 1, 2024.
See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
This change could break clients that set a specific
JAX serialization version lower than 9.

jaxlib 0.4.26 (April 3, 2024)

* Changes
* JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been
dropped.
* JAX now supports NumPy 2.0.

0.4.25

* New Features
* Added [CUDA Array
Interface](https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html)
import support (requires jaxlib 0.4.24).
* JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`.
* Added {mod}`jax.tree` module, with a more convenient interface for referencing functions
in {mod}`jax.tree_util`.
* {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts
`inner_treedef=None`, in which case the inner treedef will be automatically inferred.

* Changes
* Pallas now uses XLA instead of the Triton Python APIs to compile Triton
kernels. You can revert to the old behavior by setting the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable to `"0"`.
* Several deprecated APIs in {mod}`jax.interpreters.xla` that were removed in v0.4.24
have been re-added in v0.4.25, including `backend_specific_translations`,
`translations`, `register_translation`, `xla_destructure`, `TranslationRule`,
`TranslationContext`, and `XLAOp`. These are still considered deprecated, and
will be removed again in the future when better replacements are available.
Refer to {jax-issue}`19816` for discussion.

* Deprecations & Removals
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
solves with `b.ndim > 1`. In the future these will be treated as batched 2D
solves.
* Conversion of a non-scalar array to a Python scalar now raises an error, regardless
of the size of the array. Previously a deprecation warning was raised in the case of
non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
* The previously deprecated configuration APIs have been removed
following a standard 3 months deprecation cycle (see {ref}`api-compatibility`).
These include
* the `jax.config.config` object and
* the `define_*_state` and `DEFINE_*` methods of {data}`jax.config`.
* Importing the `jax.config` submodule via `import jax.config` is deprecated.
To configure JAX use `import jax` and then reference the config object
via `jax.config`.
* The minimum jaxlib version is now 0.4.20.

jaxlib 0.4.25 (Feb 26, 2024)

0.4.24

* Changes

* JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set.
This is needed because custom_partitioning and JAX callbacks need physical
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
and `descending` arguments.
* Several changes to the handling of shape polymorphism (used in
{mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`):
* cleaner pretty-printing of symbolic expressions ({jax-issue}`19227`)
* added the ability to specify symbolic constraints on the dimension variables.
This makes shape polymorphism more expressive, and gives a way to workaround
limitations in the reasoning about inequalities.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* with the addition of symbolic constraints ({jax-issue}`19235`) we now
consider dimension variables from different scopes to be different, even
if they have the same name. Symbolic expressions from different scopes
cannot interact, e.g., in arithmetic operations.
Scopes are introduced by {func}`jax.experimental.jax2tf.convert`,
{func}`jax.experimental.export.symbolic_shape`, {func}`jax.experimental.export.symbolic_args_specs`.
The scope of a symbolic expression `e` can be read with `e.scope` and passed
into the above functions to direct them to construct symbolic expressions in
a given scope.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* simplified and faster equality comparisons, where we consider two symbolic dimensions
to be equal if the normalized form of their difference reduces to 0
({jax-issue}`19231`; note that this may result in user-visible behavior
changes)
* improved the error messages for inconclusive inequality comparisons
({jax-issue}`19235`).
* the `core.non_negative_dim` API (introduced recently)
was deprecated and `core.max_dim` and `core.min_dim` were introduced
({jax-issue}`18953`) to express `max` and `min` for symbolic dimensions.
You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`.
* the `shape_poly.is_poly_dim` is deprecated in favor of `export.is_symbolic_dim`
({jax-issue}`19282`).
* the `export.args_specs` is deprecated in favor of `export.symbolic_args_specs
({jax-issue}`19283`).
* the `shape_poly.PolyShape` and `jax2tf.PolyShape` are deprecated, use
strings for polymorphic shapes specifications ({jax-issue}`19284`).
* JAX default native serialization version is now 9. This is relevant
for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`.
See [description of version numbers](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will
continue to work for a deprecation period of 3 months.
* Added {func}`jax.scipy.stats.sem`.
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
reshaped to the dimension of the input, following a similar change to
{func}`numpy.unique` in NumPy 2.0.
* {func}`jax.numpy.sign` now returns `x / abs(x)` for nonzero complex inputs. This is
consistent with the behavior of {func}`numpy.sign` in NumPy version 2.0.
* {func}`jax.scipy.special.logsumexp` with `return_sign=True` now uses the NumPy 2.0
convention for the complex sign, `x / abs(x)`. This is consistent with the behavior
of {func}`scipy.special.logsumexp` in SciPy v1.13.
* JAX now supports the bool DLPack type for both import and export.
Previously bool values could not be imported and were exported as integers.

* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
This includes:
* From {mod}`jax.core`: `TracerArrayConversionError`,
`TracerIntegerConversionError`, `UnexpectedTracerError`,
`as_hashable_function`, `collections`, `dtypes`, `lu`, `map`,
`namedtuple`, `partial`, `pp`, `ref`, `safe_zip`, `safe_map`,
`source_info_util`, `total_ordering`, `traceback_util`, `tuple_delete`,
`tuple_insert`, and `zip`.
* From {mod}`jax.lax`: `dtypes`, `itertools`, `naryop`, `naryop_dtype_rule`,
`standard_abstract_eval`, `standard_naryop`, `standard_primitive`,
`standard_unop`, `unop`, and `unop_dtype_rule`.
* The `jax.linear_util` submodule and all its contents.
* The `jax.prng` submodule and all its contents.
* From {mod}`jax.random`: `PRNGKeyArray`, `KeyArray`, `default_prng_impl`,
`threefry_2x32`, `threefry2x32_key`, `threefry2x32_p`, `rbg_key`, and
`unsafe_rbg_key`.
* From {mod}`jax.tree_util`: `register_keypaths`, `AttributeKeyPathEntry`, and
`GetItemKeyPathEntry`.
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`, `translations`,
`register_translation`, `xla_destructure`, `TranslationRule`, `TranslationContext`,
`axis_groups`, `ShapedArray`, `ConcreteArray`, `AxisEnv`, `backend_compile`,
and `XLAOp`.
* from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`,
`trapz`, and `in1d`.
* from {mod}`jax.scipy.linalg`: `tril` and `triu`.
* The previously-deprecated method `PRNGKeyArray.unsafe_raw_array` has been
removed. Use {func}`jax.random.key_data` instead.
* `bool(empty_array)` now raises an error rather than returning `False`. This
previously raised a deprecation warning, and follows a similar change in NumPy.
* Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses
the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
removed in the future. Use the "stablehlo" dialect instead.
* {mod}`jax.random`: passing batched keys directly to random number generation functions,
such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated
and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching.
* {func}`jax.lax.tie_in` is deprecated: it has been a no-op since JAX v0.2.0.

jaxlib 0.4.24 (Feb 6, 2024)

* Changes

* JAX now supports CUDA 12.3 and CUDA 11.8. Support for CUDA 12.2 has been
dropped.
* `cost_analysis` now works with cross-compiled `Compiled` objects (i.e. when
using `.lower().compile()` with a topology object, e.g., to compile for
Cloud TPU from a non-TPU computer).
* Added [CUDA Array
Interface](https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html)
import support (requires jax 0.4.25).

0.4.23

jaxlib 0.4.23 (Dec 13, 2023)

* Fixed a bug that caused verbose logging from the GPU compiler during
compilation.

Page 1 of 17

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.