Changelog
Bugs
1. The flash attention module was missing in the original codebase. This resulted in a module not found error during the execution.
2. Incorrect implementation of flash attention integration with the main attention module. The `forward` method in the `Attend` class wasn't correctly handling flash attention.
3. The `flash_attn` function within the `Attend` class had incorrect assumptions about the dimensions of the `k` and `v` tensors. This led to dimension mismatch errors during the tensor operations.
4. The original `flash_attn` method was not handling the scale correctly when `qk_norm` was set to `True`.
Improvements
1. Integrated the flash attention module into the main codebase and ensured the dimensions and operations are correct.
2. Modified the `forward` method in the `Attend` class to handle flash attention correctly. It checks whether flash attention is enabled and accordingly calls the correct attention method.
3. Adjusted the `flash_attn` method to account for possible missing dimensions in `q`, `k`, and `v` tensors, and to correct for possible dimension mismatches.
4. Included a check to determine if the tensor is on a CUDA device and if so, to leverage the appropriate CUDA configuration for efficient attention.
5. Correctly handled the scale in the `flash_attn` method when `qk_norm` was `True`.
6. Added assertions and informative error messages for incompatible options such as 'talking heads' and 'flash attention'.
7. Ensured compatibility with PyTorch version 2.0 and above for using flash attention.