Jax

Latest version: v0.4.29

Safety actively analyzes 638466 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 10 of 18

0.3.0

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.28...jax-v0.3.0).

* Changes
* jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html)
for the explanation.

jaxlib 0.3.0 (Feb 10, 2022)
* Changes
* Bazel 5.0.0 is now required to build jaxlib.
* jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html)
for the explanation.

0.2.28

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.27...jax-v0.2.28).
* `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no
`dialect=` is passed.
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
`ir.Module` object instead of its string representation.

jaxlib 0.1.76 (Jan 27, 2022)

* New features
* Includes precompiled SASS for NVidia compute capability 8.0 GPUS
(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
to increase the number of compute capabilities: GPUs with compute capability
6.1 can use the 6.0 SASS.
* With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR
by default.
* Breaking changes
* Support for NumPy 1.18 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* Bug fixes
* Fixed a bug where apparently identical pytreedef objects constructed by different routes
do not compare as equal (9066).
* The JAX jit cache requires two static arguments to have identical types for a cache hit (9311).

0.2.27

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.26...jax-v0.2.27).

* Breaking changes:
* Support for NumPy 1.18 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the --jax_host_callback_ad_transforms flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`8678`).
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the
bit representation. In particular, ``0.0`` and ``-0.0`` are now treated as equivalent,
where previously ``-0.0`` was treated as less than ``0.0``. Additionally all ``NaN``
representations are now treated as equivalent and sorted to the end of the array.
Previously negative ``NaN`` values were sorted to the front of the array, and ``NaN``
values with different internal bit representations were not treated as equivalent, and
were sorted according to those bit patterns ({jax-issue}`9178`).
* {func}`jax.numpy.unique` now treats ``NaN`` values in the same way as `np.unique` in
NumPy versions 1.21 and newer: at most one ``NaN`` value will appear in the uniquified
output ({jax-issue}`9184`).

* Bug fixes:
* host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`8907`).

* New features:
* add `jax.block_until_ready` ({jax-issue}`8941)
* Added a new debugging flag/environment variable `JAX_DUMP_IR_TO=/path`.
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
file under the given path.
* Added `jax.ensure_compile_time_eval` to the public api ({jax-issue}`7987`).
* jax2tf now supports a flag jax2tf_associative_scan_reductions to change
the lowering for associative reductions, e.g., jnp.cumsum, to behave
like JAX on CPU and GPU (to use an associative scan). See the jax2tf README
for more details ({jax-issue}`9189`).


jaxlib 0.1.75 (Dec 8, 2021)
* New features:
* Support for python 3.10.

0.2.26

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.25...jax-v0.2.26).

* Bug fixes:
* Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with
`FILL_OR_DROP` semantics, as documented. This primarily affects the
reverse-mode derivative, where gradients corresponding to out-of-bounds
indices will now be returned as 0. (8634).
* jax2tf will force the converted code to use XLA for the code fragments
under jax.jit, e.g., most jax.numpy functions ({jax-issue}`7839`).

jaxlib 0.1.74 (Nov 17, 2021)
* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via
the host, which is usually slower.
* Added experimental MLIR Python bindings for use by JAX.

0.2.25

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.24...jax-v0.2.25).

* New features:
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
* `jax.random.permutation` supports new `independent` keyword argument
({jax-issue}`8430`)
* Breaking changes
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
* New features:
* Added `jax.lax.linalg.qdwh`.

0.2.24

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24).

* New features:
* `jax.random.choice` and `jax.random.permutation` now support
multidimensional arrays and an optional `axis` argument ({jax-issue}`8158`)
* Breaking changes:
* `jax.numpy.take` and `jax.numpy.take_along_axis` now require array-like inputs
(see {jax-issue}`7737`)

jaxlib 0.1.73 (Oct 18, 2021)

* Multiple cuDNN versions are now supported for jaxlib GPU `cuda11` wheels.
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
installation is new enough, since it supports additional functionality.
* cuDNN 8.0.5 or newer.

* Breaking changes:
* The install commands for GPU jaxlib are as follows:

bash
pip install --upgrade pip

Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html

Page 10 of 18

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.