Jax

Latest version: v0.4.29

Safety actively analyzes 638681 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 13 of 18

0.2.10

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10).
* New features:
* {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods.
* {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods.
* Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions
from JAX ({jax-issue}`5627`)
and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
* Extended the batching rule for `lax.pad` to support batching of the padding values.
* Bug fixes:
* {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`5768`)
* Breaking changes:
* JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression `jnp.bfloat16(1) + 0.1 * jnp.arange(10)`
previously returned a `float64` array, and now returns a `bfloat16` array.
JAX's type promotion behavior is described at {ref}`type-promotion`.
* {func}`jax.numpy.linspace` now computes the floor of integer values, i.e.,
rounding towards -inf rather than 0. This change was made to match NumPy
1.20.0.
* {func}`jax.numpy.i0` no longer accepts complex numbers. Previously the
function computed the absolute value of complex arguments. This change was
made to match the semantics of NumPy 1.20.0.
* Several {mod}`jax.numpy` functions no longer accept tuples or lists in place
of array arguments: {func}`jax.numpy.pad`, :func`jax.numpy.ravel`,
{func}`jax.numpy.repeat`, {func}`jax.numpy.reshape`.
In general, {mod}`jax.numpy` functions should be used with scalars or array arguments.

jaxlib 0.1.62 (March 9 2021)

* New features:
* jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the `--target_cpu_features` flag
to `build.py`. `--target_cpu_features` also replaces
`--enable_march_native`.

jaxlib 0.1.61 (February 12 2021)

jaxlib 0.1.60 (February 3 2021)

* Bug fixes:
* Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The
memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
* `bool`, `int8`, and `uint8` are now considered safe to cast to
`bfloat16` NumPy extension type.

0.2.9

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9).
* New features:
* Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved
error checking and error messages.
* Add {func}`jax.experimental.enable_x64` and {func}`jax.experimental.disable_x64`.
These are context managers which allow X64 mode to be temporarily enabled/disabled
within a session.
* Breaking changes:
* {func}`jax.ops.segment_sum` now drops segment IDs that are out of range rather
than wrapping them into the segment ID space. This was done for performance
reasons.

jaxlib 0.1.59 (January 15 2021)

0.2.8

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8).
* New features:
* Add {func}`jax.closure_convert` for use with higher-order custom
derivative functions. ({jax-issue}`5244`)
* Add {func}`jax.experimental.host_callback.call` to call a custom Python
function on the host and return a result to the device computation.
({jax-issue}`5243`)
* Bug fixes:
* `jax.numpy.arccosh` now returns the same branch as `numpy.arccosh` for
complex inputs ({jax-issue}`5156`)
* `host_callback.id_tap` now works for `jax.pmap` also. There is an
optional parameter for `id_tap` and `id_print` to request that the
device from which the value is tapped be passed as a keyword argument
to the tap function ({jax-issue}`5182`).
* Breaking changes:
* `jax.numpy.pad` now takes keyword arguments. Positional argument `constant_values`
has been removed. In addition, passing unsupported keyword arguments raises an error.
* Changes for {func}`jax.experimental.host_callback.id_tap` ({jax-issue}`5243`):
* Removed support for `kwargs` for {func}`jax.experimental.host_callback.id_tap`.
(This support has been deprecated for a few months.)
* Changed the printing of tuples for {func}`jax.experimental.host_callback.id_print`
to use '(' instead of '['.
* Changed the {func}`jax.experimental.host_callback.id_print` in presence of JVP
to print a pair of primal and tangent. Previously, there were two separate
print operations for the primals and the tangent.
* `host_callback.outfeed_receiver` has been removed (it is not necessary,
and was deprecated a few months ago).
* New features:
* New flag for debugging `inf`, analogous to that for `NaN` ({jax-issue}`5224`).

0.2.7

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7).
* New features:
* Add `jax.device_put_replicated`
* Add multi-host support to `jax.experimental.sharded_jit`
* Add support for differentiating eigenvalues computed by `jax.numpy.linalg.eig`
* Add support for building on Windows platforms
* Add support for general in_axes and out_axes in `jax.pmap`
* Add complex support for `jax.numpy.linalg.slogdet`
* Bug fixes:
* Fix higher-than-second order derivatives of `jax.numpy.sinc` at zero
* Fix some hard-to-hit bugs around symbolic zeros in transpose rules
* Breaking changes:
* `jax.experimental.optix` has been deleted, in favor of the standalone
`optax` Python package.
* indexing of JAX arrays with non-tuple sequences now raises a `TypeError`. This type of indexing
has been deprecated in Numpy since v1.16, and in JAX since v0.2.4.
See {jax-issue}`4564`.

0.2.6

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6).
* New Features:
* Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
* Breaking change cleanup

* Raise an error on non-hashable static arguments for jax.jit and
xla_computation. See [cb48f42](https://github.com/google/jax/commit/cb48f42).
* Improve consistency of type promotion behavior ({jax-issue}`4744`):
* Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously
it returned `complex128`.
* Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
are now independent of the order of arguments. For example:
`jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)` and
`jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)` both return `float16`, where previously
the first returned `float64` and the second returned `float16`.
* The contents of the (undocumented) `jax.lax_linalg` linear algebra module
are now exposed publicly as `jax.lax.linalg`.
* `jax.random.PRNGKey` now produces the same results in and out of JIT compilation
({jax-issue}`4877`).
This required changing the result for a given seed in a few particular cases:
* With `jax_enable_x64=False`, negative seeds passed as Python integers now return a different result
outside JIT mode. For example, `jax.random.PRNGKey(-1)` previously returned
`[4294967295, 4294967295]`, and now returns `[0, 4294967295]`. This matches the behavior in JIT.
* Seeds outside the range representable by `int64` outside JIT now result in an `OverflowError`
rather than a `TypeError`. This matches the behavior in JIT.

To recover the keys returned previously for negative integers with `jax_enable_x64=False`
outside JIT, you can use:


key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)

* DeviceArray now raises `RuntimeError` instead of `ValueError` when trying
to access its value while it has been deleted.

jaxlib 0.1.58 (January 12ish 2021)

* Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
`np.cint`) instead of standard types (e.g., `np.int32`). (4903)
* Fixed a crash when constant-folding certain int16 operations. (4971)
* Added an `is_leaf` predicate to {func}`pytree.flatten`.

jaxlib 0.1.57 (November 12 2020)

* Fixed manylinux2010 compliance issues in GPU wheels.
* Switched the CPU FFT implementation from Eigen to PocketFFT.
* Fixed a bug where the hash of bfloat16 values was not correctly initialized
and could change (4651).
* Add support for retaining ownership when passing arrays to DLPack (4636).
* Fixed a bug for batched triangular solves with sizes greater than 128 but not
a multiple of 128.
* Fixed a bug when performing concurrent FFTs on multiple GPUs (3518).
* Fixed a bug in profiler where tools are missing (4427).
* Dropped support for CUDA 10.0.

0.2.5

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5).
* Improvements:
* Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`4650`.
* Expanded the set of JAX primitives converted by jax2tf.
See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).

Page 13 of 18

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.