Jax

Latest version: v0.4.35

Safety actively analyzes 682457 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 6 of 19

0.4.4

* Changes
* The implementation of `jit` and `pjit` has been merged. Merging jit and pjit
changes the internals of JAX without affecting the public API of JAX.
Before, `jit` was a final style primitive. Final style means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see
[this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
You can disable it only via the environment variable i.e.
`os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`.
The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported.
* `axis_resources` argument of `with_sharding_constraint` is deprecated.
Please use `shardings` instead. There is no change needed if you were using
`axis_resources` as an arg. If you were using it as a kwarg, then please
use `shardings` instead. `axis_resources` will be removed after 3 months
from Feb 13, 2023.
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
functions.
* The following names have been deprecated:
* `jax.xla.Device` and `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.experimental.maps.Mesh`. Use `jax.sharding.Mesh`
instead.
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`.
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* Breaking Changes
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
is now required to be a scalar, consistent with the corresponding NumPy API.
The previous behavior of broadcasting the output against non-scalar `initial`
values was an unintentional implementation detail ({jax-issue}`14446`).

jaxlib 0.4.4 (Feb 16, 2023)
* Breaking changes
* Support for NVIDIA Kepler series GPUs has been removed from the default
`jaxlib` builds. If Kepler support is needed, it is still possible to
build `jaxlib` from source with Kepler support (via the
`--cuda_compute_capabilities=sm_35` option to `build.py`), however note
that CUDA 12 has completely dropped support for Kepler GPUs.

0.4.3

* Breaking changes
* Deleted {func}`jax.scipy.linalg.polar_unitary`, which was a deprecated JAX
extension to the scipy API. Use {func}`jax.scipy.linalg.polar` instead.

* Changes
* Added {func}`jax.scipy.stats.rankdata`.

jaxlib 0.4.3 (Feb 8, 2023)
* `jax.Array` now has the non-blocking `is_ready()` method, which returns `True`
if the array is ready (see also {func}`jax.block_until_ready`).

0.4.2

* Breaking changes
* Deleted `jax.experimental.callback`
* Operations with dimensions in presence of jax2tf shape polymorphism have
been generalized to work in more scenarios, by converting the symbolic
dimension to JAX arrays. Operations involving symbolic dimensions and
`np.ndarray` now can raise errors when the result is used as a shape value
({jax-issue}`14106`).
* jaxpr objects now raise an error on attribute setting in order to avoid
problematic mutations ({jax-issue}`14102`)

* Changes
* {func}`jax2tf.call_tf` has a new parameter `has_side_effects` (default `True`)
that can be used to declare whether an instance can be removed or replicated
by JAX optimizations such as dead-code elimination ({jax-issue}`13980`).
* Added more support for floordiv and mod for jax2tf shape polymorphism. Previously,
certain division operations resulted in errors in presence of symbolic dimensions
({jax-issue}`14108`).

jaxlib 0.4.2 (Jan 24, 2023)

* Changes
* Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring
automatic device memory defragmentation.

0.4.1

* Changes
* Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}`version-support-policy`.
* We introduce `jax.Array` which is a unified array type that subsumes
`DeviceArray`, `ShardedDeviceArray`, and `GlobalDeviceArray` types in JAX.
The `jax.Array` type helps make parallelism a core feature of JAX,
simplifies and unifies JAX internals, and allows us to unify `jit` and
`pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some
breaking change to the `pjit` API. The [jax.Array migration
guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can
help you migrate your codebase to `jax.Array`. You can also look at the
[Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
tutorial to understand the new concepts.
* `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
deprecated and will be removed in 3 months.
* `with_sharding_constraint`s new public endpoint is
`jax.lax.with_sharding_constraint`.
* If using ABSL flags together with `jax.config`, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
`jax.config` options, which are used pervasively in JAX.
* The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Before, it was using the 0th device for the JAX-default backend.
* A number of `jax.numpy` functions now have their arguments marked as
positional-only, matching NumPy.
* `jnp.msort` is now deprecated, following the deprecation of `np.msort` in numpy 1.24.
It will be removed in a future release, in accordance with the {ref}`api-compatibility`
policy. It can be replaced with `jnp.sort(a, axis=0)`.

jaxlib 0.4.1 (Dec 13, 2022)

* Changes
* Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}`version-support-policy`.
* The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of
the total GPU memory instead of the previous behavior of using currently available GPU memory
to calculate preallocation. Please refer to
[GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for
more details.
* The deprecated method `.block_host_until_ready()` has been removed. Use
`.block_until_ready()` instead.

0.4.0

* The release was yanked.

jaxlib 0.4.0 (Dec 12, 2022)

* The release was yanked.

0.3.25

* Changes
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
* {func}`jax.scipy.linalg.hessenberg` is now supported on CPU only. Requires
jaxlib > 0.3.24.
* New functions {func}`jax.lax.linalg.hessenberg`,
{func}`jax.lax.linalg.tridiagonal`, and
{func}`jax.lax.linalg.householder_product` were added. Householder reduction
is currently CPU-only and tridiagonal reductions are supported on CPU and
GPU only.
* The gradients of `svd` and `jax.numpy.linalg.pinv` are now computed more
economically for non-square matrices.
* Breaking Changes
* Deleted the `jax_experimental_name_stack` config option.
* Convert a string `axis_names` arguments to the
{class}`jax.experimental.maps.Mesh` constructor into a singleton tuple
instead of unpacking the string into a sequence of character axis names.

jaxlib 0.3.25 (Nov 15, 2022)
* Changes
* Added support for tridiagonal reductions on CPU and GPU.
* Added support for upper Hessenberg reductions on CPU.
* Bugs
* Fixed a bug that meant that frames in tracebacks captured by JAX were
incorrectly mapped to source lines under Python 3.10+

Page 6 of 19

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.