Jax

Latest version: v0.4.35

Safety actively analyzes 682471 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 12 of 19

0.2.22

* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.21...jax-v0.2.22).
* Breaking Changes
* Static arguments to `jax.pmap` must now be hashable.

Unhashable static arguments have long been disallowed on `jax.jit`, but they
were still permitted on `jax.pmap`; `jax.pmap` compared unhashable static
arguments using object identity.

This behavior is a footgun, since comparing arguments using
object identity leads to recompilation each time the object identity
changes. Instead, we now ban unhashable arguments: if a user of `jax.pmap`
wants to compare static arguments by object identity, they can define
`__hash__` and `__eq__` methods on their objects that do that, or wrap their
objects in an object that has those operations with object identity
semantics. Another option is to use `functools.partial` to encapsulate the
unhashable static arguments into the function object.
* `jax.util.partial` was an accidental export that has now been removed. Use
`functools.partial` from the Python standard library instead.
* Deprecations
* The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are
deprecated and will be removed in a future JAX release. Please use
[the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html)
instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a
`DeprecationWarning`.
* New features:
* An optimized C++ code-path improving the dispatch time for `pmap` is now the
default when using jaxlib 0.1.72 or newer. The feature can be disabled using
the `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable).
* `jax.numpy.unique` now supports an optional `fill_value` argument ({jax-issue}`8121`)

jaxlib 0.1.72 (Oct 12, 2021)
* Breaking changes:
* Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 11.1+.
* Bug fixes:
* Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrong
outputs on all platforms due to incorrect buffer aliasing inside the XLA
compiler.

0.2.21

* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.20...jax-v0.2.21).
* Breaking Changes
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
`jax.*` instead.
* `jax.partial`, and `jax.lax.partial` were accidental exports that have now
been removed. Use `functools.partial` from the Python standard library
instead.
* Boolean scalar indices now raise a `TypeError`; previously this silently
returned wrong results ({jax-issue}`7925`).
* Many more `jax.numpy` functions now require array-like inputs, and will error
if passed a list ({jax-issue}`7747` {jax-issue}`7802` {jax-issue}`7907`).
See {jax-issue}`7737` for a discussion of the rationale behind this change.
* When inside a transformation such as `jax.jit`, `jax.numpy.array` always
stages the array it produces into the traced computation. Previously
`jax.numpy.array` would sometimes produce a on-device array, even under
a `jax.jit` decorator. This change may break code that used JAX arrays to
perform shape or index computations that must be known statically; the
workaround is to perform such computations using classic NumPy arrays
instead.
* `jnp.ndarray` is now a true base-class for JAX arrays. In particular, this
means that for a standard numpy array `x`, `isinstance(x, jnp.ndarray)` will
now return `False` ({jax-issue}`7927`).
* New features:
* Added {func}`jax.numpy.insert` implementation ({jax-issue}`7936`).

0.2.20

* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.19...jax-v0.2.20).
* Breaking Changes
* `jnp.poly*` functions now require array-like inputs ({jax-issue}`7732`)
* `jnp.unique` and other set-like operations now require array-like inputs
({jax-issue}`7662`)

jaxlib 0.1.71 (Sep 1, 2021)
* Breaking changes:
* Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 10.2 and CUDA 11.1+.

0.2.19

* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19).
* Breaking changes:
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* The `jit` decorator has been added around the implementation of a number of
operators on JAX arrays. This speeds up dispatch times for common
operators such as `+`.

This change should largely be transparent to most users. However, there is
one known behavioral change, which is that large integer constants may now
produce an error when passed directly to a JAX operator
(e.g., `x + 2**40`). The workaround is to cast the constant to an
explicit type (e.g., `np.float64(2**40)`).
* New features:
* Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g., `jnp.mean`.
({jax-issue}`7317`)
* Bug fixes:
* Some leaked trace errors from the previous release ({jax-issue}`7613`)

jaxlib 0.1.70 (Aug 9, 2021)
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.

* The host_callback mechanism now uses one thread per local device for
making the calls to the Python callbacks. Previously there was a single
thread for all devices. This means that the callbacks may now be called
interleaved. The callbacks corresponding to one device will still be
called in sequence.

0.2.18

* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.17...jax-v0.2.18).

* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* The minimum jaxlib version is now 0.1.69.
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
removed.

* New features:
* Added a polar decomposition ({py:func}`jax.scipy.linalg.polar`).

* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with an invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`7196`)


jaxlib 0.1.69 (July 9 2021)
* Fix bugs in TFRT CPU backend that results in incorrect results.

0.2.17

* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.16...jax-v0.2.17).
* Bug fixes:
* Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68
to work around 7229, which caused wrong outputs on CPU due to a concurrency
problem.
* New features:
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
* Reverse-mode autodiff functions ({func}`jax.grad`,
{func}`jax.value_and_grad`, {func}`jax.vjp`, and
{func}`jax.linear_transpose`) support a parameter that indicates which named
axes should be summed over in the backward pass if they were broadcasted
over in the forward pass. This enables use of these APIs in a
non-per-example way inside maps (initially only
{func}`jax.experimental.maps.xmap`) ({jax-issue}`6950`).

Page 12 of 19

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.