Jax

Latest version: v0.4.29

Safety actively analyzes 638681 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 12 of 18

0.2.16

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16).

0.2.15

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...jax-v0.2.15).
* New features:
* [7042](https://github.com/google/jax/pull/7042) Turned on TFRT CPU backend
with significant dispatch performance improvements on CPU.
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
({jax-issue}`6956`).
* New SciPy function {py:func}`jax.scipy.special.lpmn_values`.

* Breaking changes:
* Support for NumPy 1.16 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).

* Bug fixes:
* Fixed bug that prevented round-tripping from JAX to TF and back:
`jax2tf.call_tf(jax2tf.convert)` ({jax-issue}`6947`).

jaxlib 0.1.68 (June 23 2021)
* Bug fixes:
* Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to
CPU.

0.2.14

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
* New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
tracebacks.
* A new traceback filtering mode using `__tracebackhide__` is now enabled by
default in sufficiently recent versions of IPython.
* The {func}`jax2tf.convert` supports shape polymorphism even when the
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
({jax-issue}`6827`).
* The {func}`jax2tf.convert` generates custom attributes with location information
in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA.
* New SciPy function {py:func}`jax.scipy.special.lpmn`.

* Bug fixes:
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
for Python scalars and for choosing 32-bit vs. 64-bit computations
as JAX ({jax-issue}`6883`).
* The {func}`jax2tf.convert` now scopes the `enable_xla` conversion parameter
properly to apply only during the just-in-time conversion
({jax-issue}`6720`).
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
`XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
({jax-issue}`6717`).
* The {func}`jax2tf.convert` now has support for inequality comparisons and
min/max for complex numbers ({jax-issue}`6892`).

jaxlib 0.1.67 (May 17 2021)

jaxlib 0.1.66 (May 11 2021)

* New features:
* CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.

NVidia now promises compatibility between CUDA minor releases starting with
CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that
is compatible with CUDA 11.2 and 11.3.

There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use
the CUDA 11.1 wheel for those versions (cuda111).
* Jaxlib now bundles `libdevice.10.bc` in CUDA wheels. There should be no need
to point JAX to a CUDA installation to find this file.
* Added automatic support for static keyword arguments to the {func}`jit`
implementation.
* Added support for pretransformation exception traces.
* Initial support for pruning unused arguments from {func}`jit` -transformed
computations.
Pruning is still a work in progress.
* Improved the string representation of {class}`PyTreeDef` objects.
* Added support for XLA's variadic ReduceWindow.
* Bug fixes:
* Fixed a bug in the remote cloud TPU support when large numbers of arguments
are passed to a computation.
* Fix a bug that meant that JAX garbage collection was not triggered by
{func}`jit` transformed functions.

0.2.13

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13).
* New features:
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
keyword arguments. A new `static_argnames` option has been added to specify
keyword arguments as static.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`6501`)
* {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`6532`).
* {func}`jax.experimental.host_callback.call` now supports `pjit.pjit` ({jax-issue}`6569`).
* Added {func}`jax.scipy.linalg.eigh_tridiagonal` that computes the
eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at
present.
* The order of the filtered and unfiltered stack traces in exceptions has been
changed. The traceback attached to an exception thrown from JAX-transformed
code is now filtered, with an `UnfilteredStackTrace` exception
containing the original trace as the `__cause__` of the filtered exception.
Filtered stack traces now also work with Python 3.6.
* If an exception is thrown by code that has been transformed by reverse-mode
automatic differentiation, JAX now attempts to attach as a `__cause__` of
the exception a `JaxStackTraceBeforeTransformation` object that contains the
stack trace that created the original operation in the forward pass.
Requires jaxlib 0.1.66.

* Breaking changes:
* The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `host_id` --> {func}`~jax.process_index`
* `host_count` --> {func}`~jax.process_count`
* `host_ids` --> `range(jax.process_count())`
* Similarly, the argument to {func}`~jax.local_devices` has been renamed from
`host_id` to `process_index`.
* Arguments to {func}`jax.jit` other than the function are now marked as
keyword-only. This change is to prevent accidental breakage when arguments
are added to `jit`.
* Bug fixes:
* The {func}`jax2tf.convert` now works in presence of gradients for functions
with integer inputs ({jax-issue}`6360`).
* Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured
`tf.Variable` ({jax-issue}`6572`).

jaxlib 0.1.65 (April 7 2021)

0.2.12

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12).
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
* {func}`jax.lax.reduce` is now differentiable.
* Breaking changes:
* The minimum jaxlib version is now 0.1.64.
* Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
* `trace_function` --> {func}`~jax.profiler.annotate_function`
* Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md)
for more information.
* Python integers larger than the maximum `int64` value will now lead to an overflow
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`6047`).
* Outside X64 mode, Python integers outside the range representable by `int32` will now lead to an
`OverflowError` rather than having their value silently truncated.
* Bug fixes:
* `host_callback` now supports empty arrays in arguments and results ({jax-issue}`6262`).
* {func}`jax.random.randint` clips rather than wraps of out-of-bounds limits, and can now generate
integers in the full range of the specified dtype ({jax-issue}`5868`)

0.2.11

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11).
* New features:
* [6112](https://github.com/google/jax/pull/6112) added context managers:
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
`jax.debug_infs`, `jax.log_compiles`.
* [6085](https://github.com/google/jax/pull/6085) added `jnp.delete`

* Bug fixes:
* [6136](https://github.com/google/jax/pull/6136) generalized
`jax.flatten_util.ravel_pytree` to handle integer dtypes.
* [6129](https://github.com/google/jax/issues/6129) fixed a bug with handling
some constants like `enum.IntEnums`
* [6145](https://github.com/google/jax/pull/6145) fixed batching issues with
incomplete beta functions
* [6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during
tracing
* [6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when
converting some large Python integers to floats
* Breaking changes:
* The minimum jaxlib version is now 0.1.62.


jaxlib 0.1.64 (March 18 2021)

jaxlib 0.1.63 (March 17 2021)

Page 12 of 18

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.