Jax-flash-attn2

Latest version: v0.0.1

Safety actively analyzes 687881 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

0.0.1

The first release of Jax-flash-attn2 provides a flexible and efficient implementation of Flash Attention 2.0 for JAX with multiple backend support.

🚀 Features

- Multiple backend support (GPU/TPU/CPU)
- Multiple platform implementations (Triton/Pallas/JAX)
- Efficient caching of attention instances
- Support for Grouped Query Attention (GQA)
- Head dimensions up to 256
- JAX sharding-friendly implementation
- Automatic platform selection based on the backend
- Compatible with existing JAX mesh patterns

💻 Installation

bash
pip install jax-flash-attn2


✅ Supported Configurations

Backend-Platform Compatibility Matrix

| Backend | Supported Platforms |
| ------- | ------------------- |
| GPU | Triton, Pallas, JAX |
| TPU | Pallas, JAX |
| CPU | JAX |

📦 Requirements

- Python >=3.10
- JAX >=0.4.33
- JAXlib >=0.4.33
- Triton ~=3.0.0
- scipy ==1.13.1
- einops
- chex

🔍 Known Limitations

- Triton platform is only available on NVIDIA GPUs
- Some platform-backend combinations are not supported
- Custom attention masks are not yet supported (use bias instead)

📝 Usage Example

python
from jax_flash_attn2 import get_cached_flash_attention

attention = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=64,
blocksize_k=128,
softmax_scale=headdim ** -0.5
)

outputs = attention(
query=query_states,
key=key_states,
value=value_states,
bias=attention_bias, Optional
)


🙏 Acknowledgments

- Based on [Flash Attention 2.0 paper](https://arxiv.org/abs/2205.14135)
- Uses [JAX-Triton](https://github.com/jax-ml/jax-triton/)
- Kernels adapted from [EasyDeL](https://github.com/erfanzar/Easydel)

📚 Documentation

Full documentation will soon be available at: https://erfanzar.github.io/jax-flash-attn2

🐛 Bug Reports

Please report any issues on our [GitHub Issues page](https://github.com/erfanzar/jax-flash-attn2/issues)

Links

Releases

Has known vulnerabilities

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.