Cloud TPUs now support the Pytorch 2.5 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.5 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
Highlights
We are excited to announce the release of PyTorch XLA 2.5! PyTorch 2.5 supports `torch_xla.compile` function which improves the debugging experience for developers during the development process, and aligns distributed APIs with upstream PyTorch with the traceable collective support for both Dynamo and non-Dynamo cases. Start from PyTorch/XLA 2.5, proposed a [clarified vision](https://github.com/pytorch/xla/issues/8000) for deprecation of the older torch_xla API in favor of moving towards the existing PyTorch API, providing for a simplified developer experience.
If you’ve used [vLLM](https://docs.vllm.ai/en/latest/index.html) for serving models on GPUs, you’ll now be able to seamlessly switch to its TPU backend. vLLM is a widely adopted inference framework that also serves as an excellent way to drive accelerator interoperability. With vLLM on TPU, users will retain the same vLLM interface we’ve grown to love, with direct integration with [Hugging Face Models](https://huggingface.co/models) to make model experimentation easy.
STABLE FEATURES
Eager
- Increase max in flight operation to accommodate eager mode [[7263](https://github.com/pytorch/xla/pull/7263)]
- Unify the logics to check eager mode [[7709](https://github.com/pytorch/xla/pull/7709)]
- Update `eager.md` [[7710](https://github.com/pytorch/xla/pull/7710)]
- Optimize execution for ops that have multiple output in eager mode [[7680](https://github.com/pytorch/xla/pull/7680)]
Quantization / Low Precision
- Asymmetric quantized `matmul` support [[7626](https://github.com/pytorch/xla/pull/7626)]
- Add blockwise quantized dot support [[7605](https://github.com/pytorch/xla/pull/7605)]
- Support `int4` weight in quantized matmul / linear [[7235](https://github.com/pytorch/xla/pull/7235)]
- Support `fp8e5m2 dtype` [[7740](https://github.com/pytorch/xla/pull/7740)]
- Add `fp8e4m3fn` support [[7842](https://github.com/pytorch/xla/pull/7842)]
- Support dynamic activation quant for per-channel quantized matmul [[7867](https://github.com/pytorch/xla/pull/7867)]
- Enable cross entropy loss for xla autocast with FP32 precision [[8094]](https://github.com/pytorch/xla/pull/8094)
Pallas Kernels
- Support ab for `flash_attention` [[7840](https://github.com/pytorch/xla/pull/7840)], actual kernel is implemented in [JAX](https://github.com/jax-ml/jax/blob/3e634d95304afae56e01de0145d9cb068351df3c/jax/experimental/pallas/ops/tpu/flash_attention.py#L144)
- Support `logits_soft_cap` parameter in `paged_attention` [[7704](https://github.com/pytorch/xla/pull/7704)], actual kernel is implemented in [JAX](https://github.com/jax-ml/jax/blob/3e634d95304afae56e01de0145d9cb068351df3c/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L136)
- Support `gmm` and `tgmm trace_pallas` caching [[7921](https://github.com/pytorch/xla/pull/7921)]
- Cache flash attention tracing [[8026](https://github.com/pytorch/xla/pull/8026)]
- Improve the user guide [[7625](https://github.com/pytorch/xla/pull/7625)]
- Update pallas doc with `paged_attention` [[7591](https://github.com/pytorch/xla/pull/7591)]
StableHLO
- Add user guide for stablehlo composite op [[7826](https://github.com/pytorch/xla/pull/7826)]
gSPMD
- Handle the parameter wrapping for SPMD [[7604](https://github.com/pytorch/xla/pull/7604)]
- Add helper function to get 1d mesh [[7577](https://github.com/pytorch/xla/pull/7577)]
- Support manual `all-reduce` [[7576](https://github.com/pytorch/xla/pull/7576)]
- Expose `apply_backward_optimization_barrier` [[7477](https://github.com/pytorch/xla/pull/7477)]
- Support reduce-scatter in manual sharding [[7231](https://github.com/pytorch/xla/pull/7231)]
- Allow `MpDeviceLoader` to shard dictionaries of tensor [[8202](https://github.com/pytorch/xla/pull/8202)]
Dynamo
- Optimize dynamo dynamic shape caching [[7726](https://github.com/pytorch/xla/pull/7726)]
- Add support for dynamic shape in dynamo [[7676](https://github.com/pytorch/xla/pull/7676)]
- In dynamo optim_mode avoid unnecessary set_attr [[7915](https://github.com/pytorch/xla/pull/7915)]
- Fix the crash with copy op in dynamo [[7902](https://github.com/pytorch/xla/pull/7902)]
- Optimize `_split_xla_args_tensor_sym_constant` [[7900](https://github.com/pytorch/xla/pull/7900)]
- DYNAMO RNG seed update optimization [[7884](https://github.com/pytorch/xla/pull/7884)]
- Support `mark_dynamic` [[7812](https://github.com/pytorch/xla/pull/7812)]
- Support gmm as a custom op for dynamo [[7672](https://github.com/pytorch/xla/pull/7672)]
- Fix dynamo inplace copy [[7933](https://github.com/pytorch/xla/pull/7933)]
- CPU time optimization for `GraphInputMatcher` [[7895](https://github.com/pytorch/xla/pull/7895)]
PJRT
- Improve device auto-detection [[7787](https://github.com/pytorch/xla/pull/7787)]
- Move _xla_register_custom_call_target implementation into PjRtComputationClient [[7801](https://github.com/pytorch/xla/pull/7801)]
- Handle SPMD case inside of ComputationClient::WaitDeviceOps [[7796](https://github.com/pytorch/xla/pull/7796)]
GKE
- Add tpu example for torchrun on GKE [[7620](https://github.com/pytorch/xla/pull/7620)]
- Add an example of using GKE with torchrun [[7589](https://github.com/pytorch/xla/pull/7589)]
Functionalization
- Add 1-layer gradient accumulation test to check aliasing [[7692](https://github.com/pytorch/xla/pull/7692)]
AMP
- Fix norm data-type when using AMP [[7878](https://github.com/pytorch/xla/pull/7878)]
BETA FEATURES
Op Lowering
- Lower `aten::_linalg_eigh` [[7674](https://github.com/pytorch/xla/pull/7674)]
- Fallback `_embedding_bag_backward` and force `sparse=false` [[7584](https://github.com/pytorch/xla/pull/7584)]
- Support trilinear by using upstream decomp [[7586](https://github.com/pytorch/xla/pull/7586)]
Higher order ops
- [Fori_loop] Update randint max range to Support bool dtype [[7632](https://github.com/pytorch/xla/pull/7632)]
TorchBench Integration
- [benchmarks] API alignment with PyTorch profiler events [[7930](https://github.com/pytorch/xla/pull/7930)]
- [benchmarks] Add IR dump option when run torchbench [[7927](https://github.com/pytorch/xla/pull/7927)]
- [benchmarks] Use same `matmul` precision between PyTorch and PyTorch/XLA[[7748](https://github.com/pytorch/xla/pull/7748)]
- [benchmarks] Introduce verifier to verify the model output correctness against native pytorch [[7724](https://github.com/pytorch/xla/pull/7724), [#7777](https://github.com/pytorch/xla/pull/7777)]
- [benchmarks] Fix moco model issue on XLA [[7257](https://github.com/pytorch/xla/pull/7257), [#7598](https://github.com/pytorch/xla/pull/7598)]
- Type annotation for `benchmarks/` [[7289](https://github.com/pytorch/xla/pull/7289)]
- Default with `CUDAGraphs` on for inductor [[7749](https://github.com/pytorch/xla/pull/7749)]
GPU
- Deprecate `XRT` for `XLA:CUDA` [[8006](https://github.com/pytorch/xla/pull/8006)]
EXPERIMENTAL FEATURES
[Backward Compatibility](https://github.com/pytorch/xla/issues/8000) & APIs that will be removed in 2.7 release:
- Deprecate APIs (deprecated → new):
| Deprecated | New | PRs |
| -------- | ------- | ------- |
| `xla_model.xrt_world_size()` | `runtime.world_size()` | [[7679](https://github.com/pytorch/xla/pull/7679)][[#7743](https://github.com/pytorch/xla/pull/7743)] |
| `xla_model.get_ordinal()` | `runtime.global_ordinal()` | [[7679](https://github.com/pytorch/xla/pull/7679)] |
| `xla_model.get_local_ordinal()` | `runtime.global_ordinal()` | [[7679](https://github.com/pytorch/xla/pull/7679)] |
- Internalize APIs
- `xla_model.parse_xla_device()` [[7675](https://github.com/pytorch/xla/pull/7675)]
- Improvement
- Automatic PJRT device detection when importing `torch_xla` [[7787](https://github.com/pytorch/xla/pull/7787)]
- Add deprecated decorator [[7703](https://github.com/pytorch/xla/pull/7703)]
Distributed
- Enable bucketized all-reduce for gradients [[7216](https://github.com/pytorch/xla/pull/7216)]
- Use reduce-scatter coalescing for FSDP [[6024](https://github.com/pytorch/xla/pull/6024)]
Distributed API
We have aligned our distributed APIs with upstream PyTorch. Previously, we implemented custom distributed APIs, such as torch_xla.xla_model.all_reduce. With the traceable collective support, we now enable `torch.distributed.all_reduce` and similar functions for both Dynamo and non-Dynamo cases in `torch_xla`.
- Support of upstream distributed APIs (torch.distributed.*) like `all_reduce`, `all_gather`, `reduce_scatter_tensor`, `all_to_all`. Previously we used xla specific distributed APIs in xla_model [[7860](https://github.com/pytorch/xla/pull/7860), [#7950](https://github.com/pytorch/xla/pull/7950/), [#8064](https://github.com/pytorch/xla/pull/8064)].
- Introduce `torch_xla.launch()` to launch the multiprocess in order to unify torchrun and `torch_xla.distributed.xla_multiprocessing.spawn()` [[7764](https://github.com/pytorch/xla/pull/7764), [#7648](https://github.com/pytorch/xla/pull/7648), [#7695](https://github.com/pytorch/xla/pull/7695)].
- `torch.distributed.reduce_scatter_tensor()`: [[7950]](https://github.com/pytorch/xla/pull/7950/)
- Register sdp lower precision autocast [[7299](https://github.com/pytorch/xla/pull/7299)]
- Add Python binding for xla::DotGeneral [[7863](https://github.com/pytorch/xla/pull/7863)]
- Fix input output alias for custom inplace ops [[7822](https://github.com/pytorch/xla/pull/7822)]
`torch_xla.compile`
- Support `full_graph` which will error out if there will be more than one graph being executed in the compiled region. [[7776](https://github.com/pytorch/xla/pull/7776)][[#7789](https://github.com/pytorch/xla/pull/7789)]
- Support the dynamic shape detection which will print a useful error message when the number of different graphs being executed across different executions exceeds the predefined limits. [[7918](https://github.com/pytorch/xla/pull/7918)]
- Support naming each compiled program which will make debug messages more informative. [[7802](https://github.com/pytorch/xla/pull/7802)]
Usability & Debuggability
- Wheel name change to support pip>=24.1: [[issue7697](https://github.com/pytorch/xla/issues/7697)]
- Add `tpu-info` as a dependency of `torch_xla[tpu]` and test: [[7938](https://github.com/pytorch/xla/pull/7938)][[#7337](https://github.com/pytorch/xla/pull/7337)]
- Support `torch_xla.manual_seed`: [[7340](https://github.com/pytorch/xla/pull/7340)]
- Support callback on tensor when async execution is finished [[7984](https://github.com/pytorch/xla/pull/7984)]
- Implement `torch.ops._c10d_functional.broadcast`: [[7770](https://github.com/pytorch/xla/pull/7770)]
- Flags `XLA_USE_BF16`, `XLA_DOWNCAST_BF16` will be removed in 2.6 release [[7582](https://github.com/pytorch/xla/pull/7582)][[#7945](https://github.com/pytorch/xla/pull/7945)]
AWS Neuron:
- Update Neuron initializations [[7952](https://github.com/pytorch/xla/pull/7952)]
- Pass local_world_size into neuron.initialize_env [[7852](https://github.com/pytorch/xla/pull/7852)]
- Update and short circuit the Neuron initialization [[8041](https://github.com/pytorch/xla/pull/8041)]
- Introduce multi-node SPMD support for Neuron [[8224](https://github.com/pytorch/xla/pull/8224)]