Added
- fMHA: `PagedBlockDiagonalGappyKeysMask`
- fMHA: heterogeneous queries in `triton_splitk`
- fMHA: support for paged attention in flash
- fMHA: Added backwards pass for `merge_attentions`
- fMHA: Added `torch.compile` support for 3 biases (`LowerTriangularMask`, `LowerTriangularMaskWithTensorBias` and `BlockDiagonalMask`) - some might require PyTorch 2.4
- fMHA: Added `torch.compile` support in `memory_efficient_attention` when passing the flash operator explicitely (eg `memory_efficient_attention(..., op=(flash.FwOp, flash.BwOp))`)
- fMHA: `memory_efficient_attention` now expects its `attn_bias` argument to be on the same device as the other input tensor. Previously, it would convert the bias to the right device.
- fMHA: `AttentionBias` subclasses are now constructed by default on the `cuda` device if available - they used to be created on the CPU device
- 2:4 sparsity: Added `xformers.ops.sp24.sparsify24_ste` for Straight Through Estimator (STE) with options to rescale the gradient differently for masked out/kept values
Improved
- fMHA: Fixed out-of-bounds reading for Split-K triton implementation
- Profiler: fix bug with modules that take a single tuple as argument
- Profiler: Added manual trigger for a profiling step, by creating a `trigger` file in the profiling directory
Removed
- Removed support for PyTorch version older than 2.2