* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.5...jax-v0.2.6).
* New Features:
* Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
See [README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
* Breaking change cleanup
* Raise an error on non-hashable static arguments for jax.jit and
xla_computation. See [cb48f42](https://github.com/jax-ml/jax/commit/cb48f42).
* Improve consistency of type promotion behavior ({jax-issue}`4744`):
* Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously
it returned `complex128`.
* Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
are now independent of the order of arguments. For example:
`jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)` and
`jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)` both return `float16`, where previously
the first returned `float64` and the second returned `float16`.
* The contents of the (undocumented) `jax.lax_linalg` linear algebra module
are now exposed publicly as `jax.lax.linalg`.
* `jax.random.PRNGKey` now produces the same results in and out of JIT compilation
({jax-issue}`4877`).
This required changing the result for a given seed in a few particular cases:
* With `jax_enable_x64=False`, negative seeds passed as Python integers now return a different result
outside JIT mode. For example, `jax.random.PRNGKey(-1)` previously returned
`[4294967295, 4294967295]`, and now returns `[0, 4294967295]`. This matches the behavior in JIT.
* Seeds outside the range representable by `int64` outside JIT now result in an `OverflowError`
rather than a `TypeError`. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with `jax_enable_x64=False`
outside JIT, you can use:
key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
* DeviceArray now raises `RuntimeError` instead of `ValueError` when trying
to access its value while it has been deleted.
jaxlib 0.1.58 (January 12ish 2021)
* Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
`np.cint`) instead of standard types (e.g., `np.int32`). (4903)
* Fixed a crash when constant-folding certain int16 operations. (4971)
* Added an `is_leaf` predicate to {func}`pytree.flatten`.
jaxlib 0.1.57 (November 12 2020)
* Fixed manylinux2010 compliance issues in GPU wheels.
* Switched the CPU FFT implementation from Eigen to PocketFFT.
* Fixed a bug where the hash of bfloat16 values was not correctly initialized
and could change (4651).
* Add support for retaining ownership when passing arrays to DLPack (4636).
* Fixed a bug for batched triangular solves with sizes greater than 128 but not
a multiple of 128.
* Fixed a bug when performing concurrent FFTs on multiple GPUs (3518).
* Fixed a bug in profiler where tools are missing (4427).
* Dropped support for CUDA 10.0.