Highlights
We are excited to announce the 0.4 release of torchao! This release adds support for KV cache quantization, quantization aware training (QAT), low bit optimizer support, composing quantization and sparsity, and more!
KV cache quantization (https://github.com/pytorch/ao/pull/532)
We've added support for KV cache quantization, showing a peak memory reduction from 19.7 -> 19.2 GB on Llama3-8B at an 8192 context length. We plan to investigate Llama3.1 next.
<img src="https://github.com/user-attachments/assets/31946f46-e8eb-45c2-ac1c-3a7d981c58a2" width="300" height="auto">
Quantization-Aware Training (QAT) ([383](https://github.com/pytorch/ao/pull/383), [#555](https://github.com/pytorch/ao/pull/555))
We now support two QAT schemes for linear layers: Int8 per token dynamic activations + int4 per group weights, and int4 per group weights (using the efficient [tinygemm int4 kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training). Users can access this feature by transforming their models before and after training using the appropriate quantizer, for example:
python
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Quantizer for int8 dynamic per token activations +
int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
Insert "fake quantize" operations into linear layers.
These operations simulate quantization numerics during
training without performing any dtype casting
model = qat_quantizer.prepare(model)
Convert fake quantize to actual quantize operations
model = qat_quantizer.convert(model)
Initial evaluation results indicate that QAT in torchao can recover up to 96% of quantized accuracy degradation on hellaswag and up to 68% of quantized perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [README](https://github.com/pytorch/ao/tree/main/torchao/quantization/prototype/qat) and [this blog post](https://pytorch.org/blog/quantization-aware-training/).
Composing quantization and sparsity (457, 473)
We've added support for composing int8 dynamic quantization with 2:4 sparsity, using the `quantize_` API. We also added SAM benchmarks that show a 7% speedup over standalone sparsity / int8 dynamic quantization [here](https://github.com/pytorch/ao/tree/main/torchao/sparsity#segment-anything-fast).
python
from torchao.quantization import quantize_, int8_dynamic_activation_int8_semi_sparse_weight
quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight())
Community Contributions
low-bit optimizer support (478, 463, 482, 484, 538)
gau-nernst added implementations for 4-bit, 8-bit, and FP8 Adam with FSDP2/FSDP support. Our API is a drop-in replacement for `torch.optim.Adam` and can be used as follows:
python
from torchao.prototype.low_bit_optim import Adam8bit, Adam4bit, AdamFp8
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
model = ...
optim = Adam8bit(model.parameters()) replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
For more information about low bit optimizer support please refer to our [README](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim).
Improvements to 4-bit quantization (https://github.com/pytorch/ao/pull/517, https://github.com/pytorch/ao/pull/552, https://github.com/pytorch/ao/pull/544, #479 )
bdhirsh jeromeku yanbing-j manuelcandales larryliu0820 added torch.compile support for NF4 Tensor, custom CUDA int4 tinygemm unpacking ops, and several bugfixes to torchao
BC breaking
* `quantize` has been renamed to `quantize_` https://github.com/pytorch/ao/pull/467
python