Intel-extension-for-openxla

Latest version: v0.5.0

Safety actively analyzes 710644 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

0.5.0

Major Features

Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official [OpenXLA](https://github.com/openxla/xla) on Intel GPU. It is based on [PJRT](https://opensource.googleblog.com/2023/05/pjrt-simplifying-ml-hardware-and-framework-integration.html) plugin mechanism, which can seamlessly run [JAX](https://jax.readthedocs.io/en/latest/index.html) models on [Intel® Data Center GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html) and [Intel® Data Center GPU Flex Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/flex-series.html). This release contains following major features:

- **Jax Upgrade:** Upgrade version to **v0.4.30** and support the [Compatibility](https://docs.google.com/document/d/1TKB5NyGtdzrpgw5mpyFjVAhJjpSNdF31T6pjPl_UT2o/edit#heading=h.yukvacyl63d2) of `jax` and `jaxlib`, it allows the Extension to support multiple different versions of `jax`. Please refer to <[How are jax and jaxlib versioned?](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned)> for more version details between `jax` and `jaxlib`.
|**intel-extension-for-openxla**|**jaxlib**|**jax**|
|:-:|:-:|:-:|
| 0.5.0 | 0.4.30 | >= 0.4.30, <= 0.4.31|
- **Feature Support:**
- Support for Python 3.9,3.10,3.11,3.12 versions.
- Continue to improve `jax` native distributed scale-up collectives.
- Support for accuracy for GPT-J with different layer number.
- Continue to improve support of the FMHA backward fusion
- **Bug Fix:**
- Fix Forward MHA accuracy error.
- Fix **known caveat** fix-in-place error that occurred on Stable Diffusion model.
- Fix **known caveat** hang related to deadlock when working with **Toolkit 2025.0**.
- Fix **known caveat** some unit test failures with the latest graphics driver.
- Fix **known caveat** OOM related to deprecated API clear_backends.
- **Toolkit Support:** Support [Intel® oneAPI Base Toolkit 2025.0](https://www.intel.com/content/www/us/en/developer/articles/release-notes/intel-oneapi-toolkit-release-notes.html).
- **Driver Support:** Support upgraded Driver [LTS release 2350.125](https://dgpu-docs.intel.com/releases/LTS-release-notes.html#release-2024-12-04.html)
- **OneDNN support:** Support [oneDNN v3.6.1](https://github.com/oneapi-src/oneDNN/releases/tag/v3.6.1).


Known Caveats
- Flan T5 and Gemma models have a dependency on Tensorflow-Text, which doesn't support Python 3.12.
- [Multi-process API](https://jax.readthedocs.io/en/latest/multi_process.html) support is still experimental and may cause hang issues with collectives.


Breaking changes

- Previous JAX **v0.4.26** is no longer supported by this release. Please follow [JAX change log](https://jax.readthedocs.io/en/latest/changelog.html) to update the application if meets version errors. Please roll back the Extension version if want to use it with old JAX version.


Documents

- [Introduce of Intel® Extension for OpenXLA*](https://github.com/intel/intel-extension-for-openxla/blob/main/README.md#intel-extension-for-openxla)
- [Accelerate JAX models on Intel GPUs via PJRT](https://opensource.googleblog.com/2023/06/accelerate-jax-models-on-intel-gpus-via-pjrt.html)
- [How JAX and OpenXLA Enabled an Argonne Workload and Quality Assurance on the Aurora Supercompute](https://www.intel.com/content/www/us/en/developer/articles/technical/jax-openxla-argonne-workload-and-quality-assurance.html?wapkw=openxla)
- [JAX and OpenXLA* Part 1: Run Process and Underlying Logic](https://www.intel.com/content/www/us/en/developer/articles/technical/jax-and-openxla-run-process-and-underlying-logic-1.html?wapkw=openxla)
- [JAX and OpenXLA Part 2: Run Process and Underlying Logic](https://www.intel.com/content/www/us/en/developer/articles/technical/jax-and-openxla-run-process-and-underlying-logic-2.html)
- [How are jax and jaxlib versioned?](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned)

0.4.0

Major Features

Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official [OpenXLA](https://github.com/openxla/xla) on Intel GPU. It is based on [PJRT](https://opensource.googleblog.com/2023/05/pjrt-simplifying-ml-hardware-and-framework-integration.html) plugin mechanism, which can seamlessly run [JAX](https://jax.readthedocs.io/en/latest/index.html) models on [Intel® Data Center GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html) and [Intel® Data Center GPU Flex Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/flex-series.html). This release contains following major features:

- **Jax Upgrade:** Upgrade version to **v0.4.26** and support the [Compatibility](https://docs.google.com/document/d/1TKB5NyGtdzrpgw5mpyFjVAhJjpSNdF31T6pjPl_UT2o/edit#heading=h.yukvacyl63d2) of `jax` and `jaxlib`, it allows the Extension to support multiple different versions of `jax`. Please refer to <[How are jax and jaxlib versioned?](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned)> for more version details between `jax` and `jaxlib`.
|**intel-extension-for-openxla**|**jaxlib**|**jax**|
|:-:|:-:|:-:|
| 0.4.0 | 0.4.26 | >= 0.4.26, <= 0.4.27|
- **Feature Support:**
- Support [Float8 training and inference](https://keras.io/examples/keras_recipes/float8_training_and_inference_with_transformer/) based on [Keras 3.0](https://keras.io/keras_3/). A new [FP8 case](https://github.com/intel/intel-extension-for-openxla/tree/r0.4/example/fp8) is added to `example`.
- Continue to improve `jax` native distributed scale-up collectives. A new distributed scale-up inference case [Grok](https://github.com/intel/intel-extension-for-openxla/tree/r0.4/example/grok) is added to `example`.
- Support the FMHA backward fusion on Intel GPU.
- **Bug Fix:**
- Fix crash in `jax` native [multi-process API](https://jax.readthedocs.io/en/latest/multi_process.html).
- Fix an accuracy error in dynamic slice fusion.
- Fix **known caveat** crash related to Binary operations and SPMD multi-device parallelism API [`psum_scatter`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#psum-scatter) under same `partial` annotation.
- Fix **known caveat** hang related to deadlock when working with **Toolkit 2024.1**.
- Fix **known caveat** OOM related to deprecated API `clear_backends`.
- **Toolkit Support:** Support [Intel® oneAPI Base Toolkit 2024.2](https://www.intel.com/content/www/us/en/developer/articles/release-notes/intel-oneapi-toolkit-release-notes.html).
- **Driver Support:** Support upgraded Driver [LTS release 2350.63](https://dgpu-docs.intel.com/releases/LTS_803.63_20240617.html)
- **OneDNN support:** Support [oneDNN v3.5.1](https://github.com/oneapi-src/oneDNN/releases/tag/v3.5.1).


Known Caveats
- Some models show performance regression when working with **Toolkit 2024.2.** Recommend to use **Toolkit 2024.1** if meet performance issues.
- [Multi-process API](https://jax.readthedocs.io/en/latest/multi_process.html) support is still experimental and may cause hang issues with collectives.


Breaking changes

- Previous JAX **v0.4.24** is no longer supported by this release. Please follow [JAX change log](https://jax.readthedocs.io/en/latest/changelog.html) to update the application if meets version errors. Please roll back the Extension version if want to use it with old JAX version.


Documents

- [Introduce of Intel® Extension for OpenXLA*](https://github.com/intel/intel-extension-for-openxla/blob/main/README.md#intel-extension-for-openxla)
- [Accelerate JAX models on Intel GPUs via PJRT](https://opensource.googleblog.com/2023/06/accelerate-jax-models-on-intel-gpus-via-pjrt.html)
- [Accelerate Stable Diffusion on Intel GPUs with Intel® Extension for OpenXLA*](https://www.intel.com/content/www/us/en/developer/articles/technical/accelerate-stable-diffusion-on-intel-gpus-openxla.html)
- [Float8 training and inference with a simple Transformer model](https://keras.io/examples/keras_recipes/float8_training_and_inference_with_transformer/)
- [How are jax and jaxlib versioned?](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned)

0.3.0

Major Features

Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official [OpenXLA](https://github.com/openxla/xla) on Intel GPU. It is based on [PJRT](https://opensource.googleblog.com/2023/05/pjrt-simplifying-ml-hardware-and-framework-integration.html) plugin mechanism, which can seamlessly run [JAX](https://jax.readthedocs.io/en/latest/index.html) models on [Intel® Data Center GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html) and [Intel® Data Center GPU Flex Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/flex-series.html). This release contains following major features:

- **JAX Upgrade:** Upgrade version to **v0.4.24**.
- **Feature Support:**
- Supports custom call registration mechanism by new OpenXLA C API. This feature provides the ability to interact with third-party software, such as [mpi4jax](https://github.com/mpi4jax/mpi4jax).
- Continue to improve JAX native distributed scale-up collectives. Now it supports any number of devices **less than 16** in a single node.
- Experimental support for Intel® Data Center GPU Flex Series.
- **Bug Fix:**
- Fix accuracy issues in GEMM kernel when it's optimized by Intel® Xe Templates for Linear Algebra (XeTLA).
- Fix crash when input batch size is greater than **65535**.
- **Toolkit Support:** Support [Intel® oneAPI Base Toolkit 2024.1](https://www.intel.com/content/www/us/en/developer/articles/release-notes/intel-oneapi-toolkit-release-notes.html).


Known Caveats

- Extension will crash when using Binary operations (e.g. `Mul`, `MatMul`) and SPMD multi-device parallelism API [`psum_scatter`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#psum-scatter) under same `partial` annotation. Please refer JAX UT [test_matmul_reduce_scatter](https://github.com/google/jax/blob/jaxlib-v0.4.24/tests/shard_map_test.py#L153-L159) to understand the error scenario better.
- JAX collectives fall into deadlock and hang Extension when working with **Toolkit 2024.1**. Recommend to use **Toolkit 2024.0** if need collectives.
- `clear_backends` API doesn't work and may cause an OOM exception as below when working with **Toolkit 2024.0**.

terminate called after throwing an instance of 'sycl::_V1::runtime_error'
what(): Native API failed. Native API returns: -5 (PI_ERROR_OUT_OF_RESOURCES) -5 (PI_ERROR_OUT_OF_RESOURCES)
Fatal Python error: Aborted


**Note**: `clear_backends` API will be deprecated by JAX soon.


Breaking changes

- Previous JAX **v0.4.20** is no longer supported. Please follow [JAX change log](https://jax.readthedocs.io/en/latest/changelog.html) to update application if meets version errors.


Documents

- [Introduce of Intel® Extension for OpenXLA*](https://github.com/intel/intel-extension-for-openxla/blob/r0.2.0/README.md#intel-extension-for-openxla)
- [Accelerate JAX models on Intel GPUs via PJRT](https://opensource.googleblog.com/2023/06/accelerate-jax-models-on-intel-gpus-via-pjrt.html)
- [Accelerate Stable Diffusion on Intel GPUs with Intel® Extension for OpenXLA*](https://www.intel.com/content/www/us/en/developer/articles/technical/accelerate-stable-diffusion-on-intel-gpus-openxla.html)

0.2.1

Bug Fixes and Other Changes
* Fix **known caveat** related to `XLA_ENABLE_MULTIPLE_STREAM=1`. The accuracy issue is fixed and no need to set this environment variable anymore.
* Fix **known caveat** related to `MHA=0`. The crash error is fixed and no need to set this environment variable anymore.
* Fix compatibility issue with upgraded Driver [LTS release 2350.29](https://dgpu-docs.intel.com/releases/LTS_803.29_20240131.html)
* Fix random accuracy issue caused by `AllToAll` collective.
* Upgrade [transformers](https://github.com/huggingface/transformers) used by [examples](https://github.com/intel-innersource/frameworks.ai.intel-extension-for-openxla.intel-extension-for-openxla/tree/r0.2.1/example) to 4.36 to fix open CVE.

Known Caveats
* Device number is restricted as **2/4/6/8/10/12** in the experimental supported collectives in single node.
* Do not use collectives (e.g. `AllReduce`) in nested `pjit`, it may cause random accuracy issue. Please refer JAX UT [`testAutodiff`](https://github.com/google/jax/blob/jaxlib-v0.4.20/tests/pjit_test.py#L646-L661) to understand the error scenario better.

---
**Full Changelog**: https://github.com/intel/intel-extension-for-openxla/compare/0.2.0...0.2.1

0.2.0

Major Features

Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official [OpenXLA](https://github.com/openxla/xla) on Intel GPU. It is based on [PJRT](https://opensource.googleblog.com/2023/05/pjrt-simplifying-ml-hardware-and-framework-integration.html) plugin mechanism, which can seamlessly run [JAX](https://jax.readthedocs.io/en/latest/index.html) models on Intel® Data Center GPU Max Series. This release contains following major features:

- Upgrade JAX version to **v0.4.20**.

- Experimental support JAX native distributed scale-up collectives based on [JAX pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html).

- Continuous optimize common kernels, and optimize GEMM kernels by [Intel® Xe Templates for Linear Algebra](https://github.com/intel/xetla). 3 inference models (Stable Diffusion, GPT-J, FLAN-T5) are verified on Intel® Data Center GPU Max Series single device, and added to [examples](https://github.com/intel/intel-extension-for-openxla/tree/r0.2.0/example).


Known Caveats

- Device number is restricted as **2/4/6/8/10/12** in the experimental supported collectives in single node.

- `XLA_ENABLE_MULTIPLE_STREAM=1` should be set when use [JAX parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#) on multiply devices without collectives. It will add synchronization between different devices to avoid possible accuracy issue.

- `MHA=0` should be set to disable MHA fusion in training. MHA fusion is not supported in training yet and will cause runtime error as below:

ir_emission_[utils.cc:109](http://utils.cc:109/)] Check failed: lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)) == rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))



Breaking changes

- Previous JAX **v0.4.13** is no longer supported. Please follow [JAX change log](https://jax.readthedocs.io/en/latest/changelog.html) to update application if meet version errors.

- GCC **10.0.0** or newer is required if build from source. Please refer [installation guide](https://github.com/intel/intel-extension-for-openxla/blob/r0.2.0/README.md#3-install) for more details.


Documents

- [Introduce of Intel® Extension for OpenXLA*](https://github.com/intel/intel-extension-for-openxla/blob/r0.2.0/README.md#intel-extension-for-openxla)
- [Accelerate JAX models on Intel GPUs via PJRT](https://opensource.googleblog.com/2023/06/accelerate-jax-models-on-intel-gpus-via-pjrt.html)
- [Accelerate Stable Diffusion on Intel GPUs with Intel® Extension for OpenXLA*](https://www.intel.com/content/www/us/en/developer/articles/technical/accelerate-stable-diffusion-on-intel-gpus-openxla.html)

0.1.0

Major Features



Intel® Extension for OpenXLA* is Intel optimized Python package to extend official [OpenXLA](https://github.com/openxla/xla) on Intel GPU. It is based on [PJRT](https://opensource.googleblog.com/2023/05/pjrt-simplifying-ml-hardware-and-framework-integration.html) plugin mechanism, which can seamlessly run [JAX](https://jax.readthedocs.io/en/latest/index.html) models on Intel® Data Center GPU Max Series. The PJRT API simplified the integration, which allowed the Intel XPU plugin to be developed separately and quickly integrated into JAX. This release contains following major features:



- **Kernel enabling and optimization**



Common kernels are enabled with LLVM/SPIRV software stack. Convolution and GEMM are enabled with OneDNN. And [Stable Diffusion]( https://www.intel.com/content/www/us/en/developer/articles/technical/accelerate-stable-diffusion-on-intel-gpus-openxla.html) is verified.



Known Issues



* Limited support for collective ops due to the limitation of oneCCL.


Related Blog
* [Accelerate JAX models on Intel GPUs via PJRT](https://opensource.googleblog.com/2023/06/accelerate-jax-models-on-intel-gpus-via-pjrt.html)

Links

Releases

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.