Jax

Latest version: v0.4.29

Safety actively analyzes 638681 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 3 of 18

0.4.17

* New features
* Added new {func}`jax.numpy.bitwise_count` function, matching the API of the similar
function recently added to NumPy.
* Deprecations
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
* Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument
to {func}`jax.random.PRNGKey` or {func}`jax.random.key` instead:
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
* Changes:
* CUDA: JAX now verifies that the CUDA libraries it finds are at least as new
as the CUDA libraries that JAX was built against. If older libraries are
found, JAX raises an exception since that is preferable to mysterious
failures and crashes.
* Removed the "No GPU/TPU" found warning. Instead warn if, on Linux, an
NVIDIA GPU or a Google TPU are found but not used and `--jax_platforms` was
not specified.
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
1.11.
* Most `jax.numpy` functions and attributes now have fully-defined type stubs.
Previously many of these were treated as `Any` by static type checkers like
`mypy` and `pytype`.

jaxlib 0.4.17 (Oct 3, 2023)

* Changes:
* Python 3.12 wheels were added in this release.
* The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.

* Bug fixes:
* Fixed log spam from ABSL when the JAX CPU backend was initialized.

0.4.16

* Changes
* Added {class}`jax.numpy.ufunc`, as well as {func}`jax.numpy.frompyfunc`, which can convert
any scalar-valued function into a {func}`numpy.ufunc`-like object, with methods such as
{meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`,
{meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and
{meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`17054`).
* Added {func}`jax.scipy.integrate.trapezoid`.
* When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See
[here](https://github.com/google/jax/pull/16949) for an example.
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
applies to mock devices or user created devices. `jax.devices()` are
already hashable.

* Breaking changes:
* jax2tf now uses native serialization by default. See
the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
* `jax.jaxpr_util` has been removed from the public JAX namespace.
* `JAX_USE_PJRT_C_API_ON_TPU` no longer has an effect (i.e. it always defaults to true).
* The backwards compatibility flag `--jax_host_callback_ad_transforms`
introduced in December 2021, has been removed.

* Deprecations:
* Several `jax.numpy` APIs have been deprecated following
[NumPy NEP-52](https://numpy.org/neps/nep-0052-python-api-cleanup.html):
* `jax.numpy.NINF` has been deprecated. Use `-jax.numpy.inf` instead.
* `jax.numpy.PZERO` has been deprecated. Use `0.0` instead.
* `jax.numpy.NZERO` has been deprecated. Use `-0.0` instead.
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
* `jax.numpy.trapz` has been deprecated. Use `jax.scipy.integrate.trapezoid` instead.
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
Use the built-in `math.prod` instead.
* A number of exports from `jax.interpreters.xla` related to defining
HLO lowering rules for custom JAX primitives have been deprecated. Custom
primitives should be defined using the StableHLO lowering utilities in
`jax.interpreters.mlir` instead.
* The following previously-deprecated functions have been removed after a
three-month deprecation period:
* `jax.abstract_arrays.ShapedArray`: use `jax.core.ShapedArray`.
* `jax.abstract_arrays.raise_to_shaped`: use `jax.core.raise_to_shaped`.
* `jax.numpy.alltrue`: use `jax.numpy.all`.
* `jax.numpy.sometrue`: use `jax.numpy.any`.
* `jax.numpy.product`: use `jax.numpy.prod`.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`.

* Deprecations/removals:
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
for type annotations, and `jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)` for
runtime detection of typed prng keys.
* The method `PRNGKeyArray.unsafe_raw_array` is deprecated. Use
{func}`jax.random.key_data` instead.
* `jax.experimental.pjit.with_sharding_constraint` is deprecated. Use
`jax.lax.with_sharding_constraint` instead.
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`
have been removed. Opaque dtypes have been renamed to Extended dtypes; use
`jnp.issubdtype(dtype, jax.dtypes.extended)` instead (available since jax v0.4.14).
* The utility `jax.interpreters.xla.register_collective_primitive` has been
removed. This utility did nothing useful in recent JAX releases and calls
to it can be safely removed.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)

jaxlib 0.4.16 (Sept 18, 2023)

* Changes:
* Sparse CSR matrix multiplications via the experimental jax sparse APIs
no longer uses a deterministic algorithm on NVIDIA GPUs. This change was
made to improve compatibility with CUDA 12.2.1.

* Bug fixes:
* Fixed a crash on Windows due to a fatal LLVM error related to out-of-order
sections and IMAGE_REL_AMD64_ADDR32NB relocations
(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).

0.4.14

* Changes
* `jax.jit` takes `donate_argnames` as an argument. It's semantics are similar
to `static_argnames`.
If neither donate_argnums nor donate_argnames is provided, no
arguments are donated. If donate_argnums is not provided but
donate_argnames is, or vice versa, JAX uses
`inspect.signature(fun)` to find any positional arguments that
correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual
parameters listed in either donate_argnums or donate_argnames will
be donated.
* {func}`jax.random.gamma` has been re-factored to a more efficient algorithm
with more robust endpoint behavior ({jax-issue}`16779`). This means that the
sequence of values returned for a given `key` will change between JAX v0.4.13
and v0.4.14 for `gamma` and related samplers (including {func}`jax.random.ball`,
{func}`jax.random.beta`, {func}`jax.random.chisquare`, {func}`jax.random.dirichlet`,
{func}`jax.random.generalized_normal`, {func}`jax.random.loggamma`, {func}`jax.random.t`).

* Deletions
* `in_axis_resources` and `out_axis_resources` have been deleted from pjit since
it has been more than 3 months since their deprecation. Please use
`in_shardings` and `out_shardings` as the replacement.
This is a safe and trivial name replacement. It does not change any of the
current pjit semantics and doesn't break any code.
You can still pass in `PartitionSpecs` to in_shardings and out_shardings.


* Deprecations
* Python 3.8 support has been dropped as per
https://jax.readthedocs.io/en/latest/deprecation.html
* JAX now requires NumPy 1.22 or newer as per
https://jax.readthedocs.io/en/latest/deprecation.html
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
no longer supported, after being deprecated in JAX version 0.4.7.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* The following `jax.Array` methods have been removed, after being deprecated
in JAX v0.4.5:
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
* The following APIs have been removed after previous deprecation:
* `jax.ad`: use {mod}`jax.interpreters.ad`.
* `jax.curry`: use ``curry = lambda f: partial(partial, f)``.
* `jax.partial_eval`: use {mod}`jax.interpreters.partial_eval`.
* `jax.pxla`: use {mod}`jax.interpreters.pxla`.
* `jax.xla`: use {mod}`jax.interpreters.xla`.
* `jax.ShapedArray`: use {class}`jax.core.ShapedArray`.
* `jax.interpreters.pxla.device_put`: use {func}`jax.device_put`.
* `jax.interpreters.pxla.make_sharded_device_array`: use {func}`jax.make_array_from_single_device_arrays`.
* `jax.interpreters.pxla.ShardedDeviceArray`: use {class}`jax.Array`.
* `jax.numpy.DeviceArray`: use {class}`jax.Array`.
* `jax.stages.Compiled.compiler_ir`: use {func}`jax.stages.Compiled.as_text`.

* Breaking changes
* JAX now requires ml_dtypes version 0.2.0 or newer.
* To fix a corner case, calls to {func}`jax.lax.cond` with five
arguments will always resolve to the "common operands" `cond`
behavior (as documented) if the second and third arguments are
callable, even if other operands are callable as well. See
[16413](https://github.com/google/jax/issues/16413).
* The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`,
which did nothing, have been removed. These options have been true by
default for many releases.

* New features
* JAX now supports a configuration flag --jax_serialization_version
and a JAX_SERIALIZATION_VERSION environment variable to control the
serialization version ({jax-issue}`16746`).
* jax2tf in presence of shape polymorphism now generates code that checks
certain shape constraints, if the serialization version is at least 7.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.

jaxlib 0.4.14 (July 27, 2023)

* Deprecations
* Python 3.8 support has been dropped as per
https://jax.readthedocs.io/en/latest/deprecation.html

0.4.13

* Changes
* `jax.jit` now allows `None` to be passed to `in_shardings` and
`out_shardings`. The semantics are as follows:
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* `jax.experimental.pjit.pjit` also allows `None` to be passed to
`in_shardings` and `out_shardings`. The semantics are as follows:
* If the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.
* Executable.cost_analysis() works on Cloud TPU
* Added a warning if a non-allowlisted `jaxlib` plugin is in use.
* Added `jax.tree_util.tree_leaves_with_path`.
* `None` is not a valid input to
`jax.experimental.multihost_utils.host_local_array_to_global_array` or
`jax.experimental.multihost_utils.global_array_to_host_local_array`.
Please use `jax.sharding.PartitionSpec()` if you wanted to replicate your
input.

* Bug fixes
* Fixed incorrect wheel name in CUDA 12 releases (16362); the correct wheel
is named `cudnn89` instead of `cudnn88`.

* Deprecations
* The `native_serialization_strict_checks` parameter to
{func}`jax.experimental.jax2tf.convert` is deprecated in favor of the
new `native_serializaation_disabled_checks` ({jax-issue}`16347`).

jaxlib 0.4.13 (June 22, 2023)

* Changes
* Added Windows CPU-only wheels to the `jaxlib` Pypi release.

* Bug fixes
* `__cuda_array_interface__` was broken in previous jaxlib versions and is now
fixed ({jax-issue}`16440`).
* Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.

0.4.12

* Changes
* Added {class}`scipy.spatial.transform.Rotation` and {class}`scipy.spatial.transform.Slerp`

* Deprecations
* `jax.abstract_arrays` and its contents are now deprecated. See related
functionality in :mod:`jax.core`.
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
* `jax.sharding.OpShardingSharding` has been removed since it has been 3
months since it was deprecated.

jaxlib 0.4.12 (June 8, 2023)

* Changes
* Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previous
versions of jaxlib should work on Hopper but would have a long
JIT-compilation delay the first time a JAX operation was executed.

* Bug fixes
* Fixes incorrect source line information in JAX-generated Python tracebacks
under Python 3.11.
* Fixes crash when printing local variables of frames in JAX-generated Python
tracebacks (16027).

0.4.11

* Deprecations
* The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
* `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
as input and remove the optional `in_shardings` argument to `pjit`.
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
* `jax.interpreters.xla.Buffer`: use `jax.Array`.
* `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.interpreters.xla.DeviceArray`: use `jax.Array`.
* `jax.interpreters.xla.device_put`: use `jax.device_put`.
* `jax.interpreters.xla.xla_call_p`: use `jax.experimental.pjit.pjit_p`.
* `axis_resources` argument of `with_sharding_constraint` is removed. Please
use `shardings` instead.


jaxlib 0.4.11 (May 31, 2023)

* Changes
* Added `memory_stats()` method to `Device`s. If supported, this returns a
dict of string stat names with int values, e.g. `"bytes_in_use"`, or None if
the platform doesn't support memory statistics. The exact stats returned may
vary across platforms. Currently only implemented on Cloud TPU.
* Readded support for the Python buffer protocol (`memoryview`) on CPU
devices.

Page 3 of 18

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.