The first release of Jax-flash-attn2 provides a flexible and efficient implementation of Flash Attention 2.0 for JAX with multiple backend support.
🚀 Features
- Multiple backend support (GPU/TPU/CPU)
- Multiple platform implementations (Triton/Pallas/JAX)
- Efficient caching of attention instances
- Support for Grouped Query Attention (GQA)
- Head dimensions up to 256
- JAX sharding-friendly implementation
- Automatic platform selection based on the backend
- Compatible with existing JAX mesh patterns
💻 Installation
bash
pip install jax-flash-attn2
✅ Supported Configurations
Backend-Platform Compatibility Matrix
| Backend | Supported Platforms |
| ------- | ------------------- |
| GPU | Triton, Pallas, JAX |
| TPU | Pallas, JAX |
| CPU | JAX |
📦 Requirements
- Python >=3.10
- JAX >=0.4.33
- JAXlib >=0.4.33
- Triton ~=3.0.0
- scipy ==1.13.1
- einops
- chex
🔍 Known Limitations
- Triton platform is only available on NVIDIA GPUs
- Some platform-backend combinations are not supported
- Custom attention masks are not yet supported (use bias instead)
📝 Usage Example
python
from jax_flash_attn2 import get_cached_flash_attention
attention = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=64,
blocksize_k=128,
softmax_scale=headdim ** -0.5
)
outputs = attention(
query=query_states,
key=key_states,
value=value_states,
bias=attention_bias, Optional
)
🙏 Acknowledgments
- Based on [Flash Attention 2.0 paper](https://arxiv.org/abs/2205.14135)
- Uses [JAX-Triton](https://github.com/jax-ml/jax-triton/)
- Kernels adapted from [EasyDeL](https://github.com/erfanzar/Easydel)
📚 Documentation
Full documentation will soon be available at: https://erfanzar.github.io/jax-flash-attn2
🐛 Bug Reports
Please report any issues on our [GitHub Issues page](https://github.com/erfanzar/jax-flash-attn2/issues)