Jax

Latest version: v0.5.2

Safety actively analyzes 714815 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 13 of 19

0.2.21

* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.20...jax-v0.2.21).
* Breaking Changes
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
`jax.*` instead.
* `jax.partial`, and `jax.lax.partial` were accidental exports that have now
been removed. Use `functools.partial` from the Python standard library
instead.
* Boolean scalar indices now raise a `TypeError`; previously this silently
returned wrong results ({jax-issue}`7925`).
* Many more `jax.numpy` functions now require array-like inputs, and will error
if passed a list ({jax-issue}`7747` {jax-issue}`7802` {jax-issue}`7907`).
See {jax-issue}`7737` for a discussion of the rationale behind this change.
* When inside a transformation such as `jax.jit`, `jax.numpy.array` always
stages the array it produces into the traced computation. Previously
`jax.numpy.array` would sometimes produce a on-device array, even under
a `jax.jit` decorator. This change may break code that used JAX arrays to
perform shape or index computations that must be known statically; the
workaround is to perform such computations using classic NumPy arrays
instead.
* `jnp.ndarray` is now a true base-class for JAX arrays. In particular, this
means that for a standard numpy array `x`, `isinstance(x, jnp.ndarray)` will
now return `False` ({jax-issue}`7927`).
* New features:
* Added {func}`jax.numpy.insert` implementation ({jax-issue}`7936`).

0.2.20

* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.19...jax-v0.2.20).
* Breaking Changes
* `jnp.poly*` functions now require array-like inputs ({jax-issue}`7732`)
* `jnp.unique` and other set-like operations now require array-like inputs
({jax-issue}`7662`)

jaxlib 0.1.71 (Sep 1, 2021)
* Breaking changes:
* Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 10.2 and CUDA 11.1+.

0.2.19

* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19).
* Breaking changes:
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* The `jit` decorator has been added around the implementation of a number of
operators on JAX arrays. This speeds up dispatch times for common
operators such as `+`.

This change should largely be transparent to most users. However, there is
one known behavioral change, which is that large integer constants may now
produce an error when passed directly to a JAX operator
(e.g., `x + 2**40`). The workaround is to cast the constant to an
explicit type (e.g., `np.float64(2**40)`).
* New features:
* Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g., `jnp.mean`.
({jax-issue}`7317`)
* Bug fixes:
* Some leaked trace errors from the previous release ({jax-issue}`7613`)

jaxlib 0.1.70 (Aug 9, 2021)
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.

* The host_callback mechanism now uses one thread per local device for
making the calls to the Python callbacks. Previously there was a single
thread for all devices. This means that the callbacks may now be called
interleaved. The callbacks corresponding to one device will still be
called in sequence.

0.2.18

* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.17...jax-v0.2.18).

* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* The minimum jaxlib version is now 0.1.69.
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
removed.

* New features:
* Added a polar decomposition ({py:func}`jax.scipy.linalg.polar`).

* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with an invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`7196`)


jaxlib 0.1.69 (July 9 2021)
* Fix bugs in TFRT CPU backend that results in incorrect results.

0.2.17

* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.16...jax-v0.2.17).
* Bug fixes:
* Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68
to work around 7229, which caused wrong outputs on CPU due to a concurrency
problem.
* New features:
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
* Reverse-mode autodiff functions ({func}`jax.grad`,
{func}`jax.value_and_grad`, {func}`jax.vjp`, and
{func}`jax.linear_transpose`) support a parameter that indicates which named
axes should be summed over in the backward pass if they were broadcasted
over in the forward pass. This enables use of these APIs in a
non-per-example way inside maps (initially only
{func}`jax.experimental.maps.xmap`) ({jax-issue}`6950`).

0.2.16

* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.15...jax-v0.2.16).

Page 13 of 19

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.