Major Features
Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official [OpenXLA](https://github.com/openxla/xla) on Intel GPU. It is based on [PJRT](https://opensource.googleblog.com/2023/05/pjrt-simplifying-ml-hardware-and-framework-integration.html) plugin mechanism, which can seamlessly run [JAX](https://jax.readthedocs.io/en/latest/index.html) models on [Intel® Data Center GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html) and [Intel® Data Center GPU Flex Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/flex-series.html). This release contains following major features:
- **Jax Upgrade:** Upgrade version to **v0.4.30** and support the [Compatibility](https://docs.google.com/document/d/1TKB5NyGtdzrpgw5mpyFjVAhJjpSNdF31T6pjPl_UT2o/edit#heading=h.yukvacyl63d2) of `jax` and `jaxlib`, it allows the Extension to support multiple different versions of `jax`. Please refer to <[How are jax and jaxlib versioned?](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned)> for more version details between `jax` and `jaxlib`.
|**intel-extension-for-openxla**|**jaxlib**|**jax**|
|:-:|:-:|:-:|
| 0.5.0 | 0.4.30 | >= 0.4.30, <= 0.4.31|
- **Feature Support:**
- Support for Python 3.9,3.10,3.11,3.12 versions.
- Continue to improve `jax` native distributed scale-up collectives.
- Support for accuracy for GPT-J with different layer number.
- Continue to improve support of the FMHA backward fusion
- **Bug Fix:**
- Fix Forward MHA accuracy error.
- Fix **known caveat** fix-in-place error that occurred on Stable Diffusion model.
- Fix **known caveat** hang related to deadlock when working with **Toolkit 2025.0**.
- Fix **known caveat** some unit test failures with the latest graphics driver.
- Fix **known caveat** OOM related to deprecated API clear_backends.
- **Toolkit Support:** Support [Intel® oneAPI Base Toolkit 2025.0](https://www.intel.com/content/www/us/en/developer/articles/release-notes/intel-oneapi-toolkit-release-notes.html).
- **Driver Support:** Support upgraded Driver [LTS release 2350.125](https://dgpu-docs.intel.com/releases/LTS-release-notes.html#release-2024-12-04.html)
- **OneDNN support:** Support [oneDNN v3.6.1](https://github.com/oneapi-src/oneDNN/releases/tag/v3.6.1).
Known Caveats
- Flan T5 and Gemma models have a dependency on Tensorflow-Text, which doesn't support Python 3.12.
- [Multi-process API](https://jax.readthedocs.io/en/latest/multi_process.html) support is still experimental and may cause hang issues with collectives.
Breaking changes
- Previous JAX **v0.4.26** is no longer supported by this release. Please follow [JAX change log](https://jax.readthedocs.io/en/latest/changelog.html) to update the application if meets version errors. Please roll back the Extension version if want to use it with old JAX version.
Documents
- [Introduce of Intel® Extension for OpenXLA*](https://github.com/intel/intel-extension-for-openxla/blob/main/README.md#intel-extension-for-openxla)
- [Accelerate JAX models on Intel GPUs via PJRT](https://opensource.googleblog.com/2023/06/accelerate-jax-models-on-intel-gpus-via-pjrt.html)
- [How JAX and OpenXLA Enabled an Argonne Workload and Quality Assurance on the Aurora Supercompute](https://www.intel.com/content/www/us/en/developer/articles/technical/jax-openxla-argonne-workload-and-quality-assurance.html?wapkw=openxla)
- [JAX and OpenXLA* Part 1: Run Process and Underlying Logic](https://www.intel.com/content/www/us/en/developer/articles/technical/jax-and-openxla-run-process-and-underlying-logic-1.html?wapkw=openxla)
- [JAX and OpenXLA Part 2: Run Process and Underlying Logic](https://www.intel.com/content/www/us/en/developer/articles/technical/jax-and-openxla-run-process-and-underlying-logic-2.html)
- [How are jax and jaxlib versioned?](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned)