| | awq-uint4 | 43.59 | 194.93 | 7.31 | 4.47 |
| | int4wo-hqq | 209.19 | 804.32 | 4.89 | 3.84 |
| | int4wo-64 | 201.14 | 751.42 | 4.87 | 3.74 |
Usage:
Python
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear
quant_dtype = torch.uint4
group_size = 64
calibration_limit = 10
calibration_seq_length = 1024
model=model.to(device)
insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size)
with torch.no_grad():
for batch in calibration_data:
model(batch.to(device))
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear)
New Features
- [Prototype] Added Float8 support for AQT tensor parallel (1003)
- Added composable QAT quantizer (938)
- Introduced torchchat quantizer (897)
- Added INT8 mixed-precision training (748)
- Implemented sparse marlin AQT layout (621)
- Added a PerTensor static quant api (787)
- Introduced uintx quant to generate and eval (811)
- Added Float8 Weight Only and FP8 weight + dynamic activation (740)
- Implemented Auto-Round support (581)
- Added 2, 3, 4, 5 bit custom ops (828)
- Introduced symmetric quantization with no clipping error in the tensor subclass based API (845)
- Added int4 weight-only embedding QAT (947)
- Added support for 1-bit and 6-bit quantization for Llama in torchchat (910, 1007)
- Added a linear_observer class for doing static activation calibration (807)
- Exposed hqq through uintx_weight_only API (786)
- Added RowWise scaling option for Float8 dynamic activation quantization (819)
- Added Float8 weight only to autoquant api (866)
Improvements
- Enhanced Auto-Round functionality (870)
- Improved FSDP support for low-bit optimizers (538)
- Added support for using AffineQuantizedTensor with `weights_only=True` for torch.load (630)
- Optimized 3-bit packing (1029)
- Added more evaluation metrics to llama/eval.sh (934)
- Improved eager numerics for dynamic scales in float8 (904)
Bug fixes
- Fixed inference_mode issues (885)
- Fixed failing FP6 benchmark (931)
- Resolved various issues with float8 support (918, 923)
- Fixed load state dict when device is different for low-bit optim (1021)
Performance
- Added SM75 (Turing) support for FP6 kernel (942)
- Implemented int8 dynamic quant + bsr support (821)
- Added workaround to recover the perf for quantized vit in torch.compile (926)
-
INT8 Mixed-Precision Training
On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision.
Python
from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig
from torchao.quantization import quantize_
model = ...
apply INT8 matmul to all 3 matmuls
quantize_(model, int8_mixed_precision_training())
customize which matmul is left in original precision.
config = Int8MixedPrecisionTrainingConfig(
output=True,
grad_input=True,
grad_weight=False,
)
quantize_(model, int8_mixed_precision_training(config))
**End2end speed benchmark** using `benchmarks/quantized_training/pretrain_llama2.py`
Model & GPU | bs x seq_len| Config | Tok/s | Peak mem (GB)
-----|-----|-----|-----|-----
Llama2-7B, A100 | 8 x 2048 | BF16 (baseline) | ~4400 | 59.69
Llama2-7B, A100 | 8 x 2048 | INT8 mixed-precision | ~6100 (**+39%**) | 58.28
Llama2-1B, 4090 | 16 x 2048 | BF16 (baseline) | ~17,900 | 18.23
Llama2-1B, 4090 | 16 x 2048 | INT8 mixed-precision | ~30,700 (**+72%**) | 18.34
Docs
- Updated README with more current float8 speedup information (816)
- Added tutorial for trainable tensor subclass (908)
- Improved documentation for float8 unification and inference (895, 896)
Devs
- Added compile tests to test suite (906)
- Improved CI setup and build processes (887)
- Added M1 wheel support (822)
- Added more benchmarking and profiling tools (1017)
- Renamed `fpx` to `floatx` (877)
- Removed torchao_nightly package (661)
- Added more lint fixes (827)
- Added better subclass testing support (839)
- Added CI to catch syntax errors (861)
- Added tutorial on composing quantized subclass w/ Dtensor based TP (785)
Security
No significant security updates in this release.
Untopiced
- Added basic SAM2 AutomaticMaskGeneration example server (1039)
New Contributors
New Contributors
* iseeyuan made their first contribution in https://github.com/pytorch/ao/pull/805
* YihengBrianWu made their first contribution in https://github.com/pytorch/ao/pull/860
* kshitij12345 made their first contribution in https://github.com/pytorch/ao/pull/863
* ZainRizvi made their first contribution in https://github.com/pytorch/ao/pull/887
* alexsamardzic made their first contribution in https://github.com/pytorch/ao/pull/899
* vaishnavi17 made their first contribution in https://github.com/pytorch/ao/pull/911
* tobiasvanderwerff made their first contribution in https://github.com/pytorch/ao/pull/931
* kwen2501 made their first contribution in https://github.com/pytorch/ao/pull/937
* y-sq made their first contribution in https://github.com/pytorch/ao/pull/912
* jimexist made their first contribution in https://github.com/pytorch/ao/pull/969
* danielpatrickhug made their first contribution in https://github.com/pytorch/ao/pull/914
* ramreddymounica made their first contribution in https://github.com/pytorch/ao/pull/1007
* yushangdi made their first contribution in https://github.com/pytorch/ao/pull/1006
* ringohoffman made their first contribution in https://github.com/pytorch/ao/pull/1023
**Full Changelog**: https://github.com/pytorch/ao/compare/v0.5.0...v0.6.1