Major Features
Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official [OpenXLA](https://github.com/openxla/xla) on Intel GPU. It is based on [PJRT](https://opensource.googleblog.com/2023/05/pjrt-simplifying-ml-hardware-and-framework-integration.html) plugin mechanism, which can seamlessly run [JAX](https://jax.readthedocs.io/en/latest/index.html) models on Intel® Data Center GPU Max Series. This release contains following major features:
- Upgrade JAX version to **v0.4.20**.
- Experimental support JAX native distributed scale-up collectives based on [JAX pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html).
- Continuous optimize common kernels, and optimize GEMM kernels by [Intel® Xe Templates for Linear Algebra](https://github.com/intel/xetla). 3 inference models (Stable Diffusion, GPT-J, FLAN-T5) are verified on Intel® Data Center GPU Max Series single device, and added to [examples](https://github.com/intel/intel-extension-for-openxla/tree/r0.2.0/example).
Known Caveats
- Device number is restricted as **2/4/6/8/10/12** in the experimental supported collectives in single node.
- `XLA_ENABLE_MULTIPLE_STREAM=1` should be set when use [JAX parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#) on multiply devices without collectives. It will add synchronization between different devices to avoid possible accuracy issue.
- `MHA=0` should be set to disable MHA fusion in training. MHA fusion is not supported in training yet and will cause runtime error as below:
ir_emission_[utils.cc:109](http://utils.cc:109/)] Check failed: lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)) == rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))
Breaking changes
- Previous JAX **v0.4.13** is no longer supported. Please follow [JAX change log](https://jax.readthedocs.io/en/latest/changelog.html) to update application if meet version errors.
- GCC **10.0.0** or newer is required if build from source. Please refer [installation guide](https://github.com/intel/intel-extension-for-openxla/blob/r0.2.0/README.md#3-install) for more details.
Documents
- [Introduce of Intel® Extension for OpenXLA*](https://github.com/intel/intel-extension-for-openxla/blob/r0.2.0/README.md#intel-extension-for-openxla)
- [Accelerate JAX models on Intel GPUs via PJRT](https://opensource.googleblog.com/2023/06/accelerate-jax-models-on-intel-gpus-via-pjrt.html)
- [Accelerate Stable Diffusion on Intel GPUs with Intel® Extension for OpenXLA*](https://www.intel.com/content/www/us/en/developer/articles/technical/accelerate-stable-diffusion-on-intel-gpus-openxla.html)