Jax

Latest version: v0.5.2

Safety actively analyzes 714792 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 5 of 19

0.4.16

* Changes
* Added {class}`jax.numpy.ufunc`, as well as {func}`jax.numpy.frompyfunc`, which can convert
any scalar-valued function into a {func}`numpy.ufunc`-like object, with methods such as
{meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`,
{meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and
{meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`17054`).
* Added {func}`jax.scipy.integrate.trapezoid`.
* When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See
[here](https://github.com/jax-ml/jax/pull/16949) for an example.
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
applies to mock devices or user created devices. `jax.devices()` are
already hashable.

* Breaking changes:
* jax2tf now uses native serialization by default. See
the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
* `jax.jaxpr_util` has been removed from the public JAX namespace.
* `JAX_USE_PJRT_C_API_ON_TPU` no longer has an effect (i.e. it always defaults to true).
* The backwards compatibility flag `--jax_host_callback_ad_transforms`
introduced in December 2021, has been removed.

* Deprecations:
* Several `jax.numpy` APIs have been deprecated following
[NumPy NEP-52](https://numpy.org/neps/nep-0052-python-api-cleanup.html):
* `jax.numpy.NINF` has been deprecated. Use `-jax.numpy.inf` instead.
* `jax.numpy.PZERO` has been deprecated. Use `0.0` instead.
* `jax.numpy.NZERO` has been deprecated. Use `-0.0` instead.
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
* `jax.numpy.trapz` has been deprecated. Use `jax.scipy.integrate.trapezoid` instead.
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
Use the built-in `math.prod` instead.
* A number of exports from `jax.interpreters.xla` related to defining
HLO lowering rules for custom JAX primitives have been deprecated. Custom
primitives should be defined using the StableHLO lowering utilities in
`jax.interpreters.mlir` instead.
* The following previously-deprecated functions have been removed after a
three-month deprecation period:
* `jax.abstract_arrays.ShapedArray`: use `jax.core.ShapedArray`.
* `jax.abstract_arrays.raise_to_shaped`: use `jax.core.raise_to_shaped`.
* `jax.numpy.alltrue`: use `jax.numpy.all`.
* `jax.numpy.sometrue`: use `jax.numpy.any`.
* `jax.numpy.product`: use `jax.numpy.prod`.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`.

* Deprecations/removals:
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
for type annotations, and `jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)` for
runtime detection of typed prng keys.
* The method `PRNGKeyArray.unsafe_raw_array` is deprecated. Use
{func}`jax.random.key_data` instead.
* `jax.experimental.pjit.with_sharding_constraint` is deprecated. Use
`jax.lax.with_sharding_constraint` instead.
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`
have been removed. Opaque dtypes have been renamed to Extended dtypes; use
`jnp.issubdtype(dtype, jax.dtypes.extended)` instead (available since jax v0.4.14).
* The utility `jax.interpreters.xla.register_collective_primitive` has been
removed. This utility did nothing useful in recent JAX releases and calls
to it can be safely removed.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)

jaxlib 0.4.16 (Sept 18, 2023)

* Changes:
* Sparse CSR matrix multiplications via the experimental jax sparse APIs
no longer uses a deterministic algorithm on NVIDIA GPUs. This change was
made to improve compatibility with CUDA 12.2.1.

* Bug fixes:
* Fixed a crash on Windows due to a fatal LLVM error related to out-of-order
sections and IMAGE_REL_AMD64_ADDR32NB relocations
(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).

0.4.14

* Changes
* `jax.jit` takes `donate_argnames` as an argument. It's semantics are similar
to `static_argnames`.
If neither donate_argnums nor donate_argnames is provided, no
arguments are donated. If donate_argnums is not provided but
donate_argnames is, or vice versa, JAX uses
`inspect.signature(fun)` to find any positional arguments that
correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual
parameters listed in either donate_argnums or donate_argnames will
be donated.
* {func}`jax.random.gamma` has been re-factored to a more efficient algorithm
with more robust endpoint behavior ({jax-issue}`16779`). This means that the
sequence of values returned for a given `key` will change between JAX v0.4.13
and v0.4.14 for `gamma` and related samplers (including {func}`jax.random.ball`,
{func}`jax.random.beta`, {func}`jax.random.chisquare`, {func}`jax.random.dirichlet`,
{func}`jax.random.generalized_normal`, {func}`jax.random.loggamma`, {func}`jax.random.t`).

* Deletions
* `in_axis_resources` and `out_axis_resources` have been deleted from pjit since
it has been more than 3 months since their deprecation. Please use
`in_shardings` and `out_shardings` as the replacement.
This is a safe and trivial name replacement. It does not change any of the
current pjit semantics and doesn't break any code.
You can still pass in `PartitionSpecs` to in_shardings and out_shardings.


* Deprecations
* Python 3.8 support has been dropped as per
https://jax.readthedocs.io/en/latest/deprecation.html
* JAX now requires NumPy 1.22 or newer as per
https://jax.readthedocs.io/en/latest/deprecation.html
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
no longer supported, after being deprecated in JAX version 0.4.7.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* The following `jax.Array` methods have been removed, after being deprecated
in JAX v0.4.5:
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
* The following APIs have been removed after previous deprecation:
* `jax.ad`: use {mod}`jax.interpreters.ad`.
* `jax.curry`: use ``curry = lambda f: partial(partial, f)``.
* `jax.partial_eval`: use {mod}`jax.interpreters.partial_eval`.
* `jax.pxla`: use {mod}`jax.interpreters.pxla`.
* `jax.xla`: use {mod}`jax.interpreters.xla`.
* `jax.ShapedArray`: use {class}`jax.core.ShapedArray`.
* `jax.interpreters.pxla.device_put`: use {func}`jax.device_put`.
* `jax.interpreters.pxla.make_sharded_device_array`: use {func}`jax.make_array_from_single_device_arrays`.
* `jax.interpreters.pxla.ShardedDeviceArray`: use {class}`jax.Array`.
* `jax.numpy.DeviceArray`: use {class}`jax.Array`.
* `jax.stages.Compiled.compiler_ir`: use {func}`jax.stages.Compiled.as_text`.

* Breaking changes
* JAX now requires ml_dtypes version 0.2.0 or newer.
* To fix a corner case, calls to {func}`jax.lax.cond` with five
arguments will always resolve to the "common operands" `cond`
behavior (as documented) if the second and third arguments are
callable, even if other operands are callable as well. See
[16413](https://github.com/jax-ml/jax/issues/16413).
* The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`,
which did nothing, have been removed. These options have been true by
default for many releases.

* New features
* JAX now supports a configuration flag --jax_serialization_version
and a JAX_SERIALIZATION_VERSION environment variable to control the
serialization version ({jax-issue}`16746`).
* jax2tf in presence of shape polymorphism now generates code that checks
certain shape constraints, if the serialization version is at least 7.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.

jaxlib 0.4.14 (July 27, 2023)

* Deprecations
* Python 3.8 support has been dropped as per
https://jax.readthedocs.io/en/latest/deprecation.html

0.4.13

* Changes
* `jax.jit` now allows `None` to be passed to `in_shardings` and
`out_shardings`. The semantics are as follows:
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* `jax.experimental.pjit.pjit` also allows `None` to be passed to
`in_shardings` and `out_shardings`. The semantics are as follows:
* If the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.
* Executable.cost_analysis() works on Cloud TPU
* Added a warning if a non-allowlisted `jaxlib` plugin is in use.
* Added `jax.tree_util.tree_leaves_with_path`.
* `None` is not a valid input to
`jax.experimental.multihost_utils.host_local_array_to_global_array` or
`jax.experimental.multihost_utils.global_array_to_host_local_array`.
Please use `jax.sharding.PartitionSpec()` if you wanted to replicate your
input.

* Bug fixes
* Fixed incorrect wheel name in CUDA 12 releases (16362); the correct wheel
is named `cudnn89` instead of `cudnn88`.

* Deprecations
* The `native_serialization_strict_checks` parameter to
{func}`jax.experimental.jax2tf.convert` is deprecated in favor of the
new `native_serializaation_disabled_checks` ({jax-issue}`16347`).

jaxlib 0.4.13 (June 22, 2023)

* Changes
* Added Windows CPU-only wheels to the `jaxlib` Pypi release.

* Bug fixes
* `__cuda_array_interface__` was broken in previous jaxlib versions and is now
fixed ({jax-issue}`16440`).
* Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.

0.4.12

* Changes
* Added {class}`scipy.spatial.transform.Rotation` and {class}`scipy.spatial.transform.Slerp`

* Deprecations
* `jax.abstract_arrays` and its contents are now deprecated. See related
functionality in :mod:`jax.core`.
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
* `jax.sharding.OpShardingSharding` has been removed since it has been 3
months since it was deprecated.

jaxlib 0.4.12 (June 8, 2023)

* Changes
* Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previous
versions of jaxlib should work on Hopper but would have a long
JIT-compilation delay the first time a JAX operation was executed.

* Bug fixes
* Fixes incorrect source line information in JAX-generated Python tracebacks
under Python 3.11.
* Fixes crash when printing local variables of frames in JAX-generated Python
tracebacks (16027).

0.4.11

* Deprecations
* The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
* `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
as input and remove the optional `in_shardings` argument to `pjit`.
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
* `jax.interpreters.xla.Buffer`: use `jax.Array`.
* `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.interpreters.xla.DeviceArray`: use `jax.Array`.
* `jax.interpreters.xla.device_put`: use `jax.device_put`.
* `jax.interpreters.xla.xla_call_p`: use `jax.experimental.pjit.pjit_p`.
* `axis_resources` argument of `with_sharding_constraint` is removed. Please
use `shardings` instead.


jaxlib 0.4.11 (May 31, 2023)

* Changes
* Added `memory_stats()` method to `Device`s. If supported, this returns a
dict of string stat names with int values, e.g. `"bytes_in_use"`, or None if
the platform doesn't support memory statistics. The exact stats returned may
vary across platforms. Currently only implemented on Cloud TPU.
* Readded support for the Python buffer protocol (`memoryview`) on CPU
devices.

0.4.10

jaxlib 0.4.10 (May 11, 2023)

* Changes
* Fixed `'apple-m1' is not a recognized processor for this target (ignoring
processor)` issue that prevented previous release from running on Mac M1.

Page 5 of 19

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.