Changed
- Improve training speed by using`torch.nn.functional.multi_head_attention_forward` for self- and encoder-attention
during training. Requires reorganization of the parameter layout of the key-value input projections,
as the current Sockeye attention interleaves for faster inference.
Attention masks (both for source masking and autoregressive masks need some shape adjustments as requirements
for the fused MHA op differ slightly).
- Non-interleaved format for joint key-value input projection parameters:
`in_features=hidden, out_features=2*hidden -> Shape: (2*hidden, hidden)`
- Interleaved format for joint-key-value input projection stores key and value parameters, grouped by heads:
`Shape: ((num_heads * 2 * hidden_per_head), hidden)`
- Models save and load key-value projection parameters in interleaved format.
- When `model.training == True` key-value projection parameters are put into
non-interleaved format for `torch.nn.functional.multi_head_attention_forward`
- When `model.training == False`, i.e. model.eval() is called, key-value projection
parameters are again converted into interleaved format in place.