With `accelerate` 1.0, we are officially stating that the core parts of the API are now "stable" and ready for the future of what the world of distributed training and PyTorch has to handle. With these release notes, we will focus first on the major breaking changes to get your code fixed, followed by what is new specifically between 0.34.0 and 1.0.
To read more, check out our official blog [here](https://huggingface.co/blog/accelerate-v1)
Migration assistance
* Passing in `dispatch_batches`, `split_batches`, `even_batches`, and `use_seedable_sampler` to the `Accelerator()` should now be handled by creating an `accelerate.utils.DataLoaderConfiguration()` and passing this to the `Accelerator()` instead (`Accelerator(dataloader_config=DataLoaderConfiguration(...))`)
* `Accelerator().use_fp16` and `AcceleratorState().use_fp16` have been removed; this should be replaced by checking `accelerator.mixed_precision == "fp16"`
* `Accelerator().autocast()` no longer accepts a `cache_enabled` argument. Instead, an `AutocastKwargs()` instance should be used which handles this flag (among others) passing it to the `Accelerator` (`Accelerator(kwargs_handlers=[AutocastKwargs(cache_enabled=True)])`)
* `accelerate.utils.is_tpu_available` should be replaced with `accelerate.utils.is_torch_xla_available`
* `accelerate.utils.modeling.shard_checkpoint` should be replaced with `split_torch_state_dict_into_shards` from the `huggingface_hub` library
* `accelerate.tqdm.tqdm()` no longer accepts `True`/`False` as the first argument, and instead, `main_process_only` should be passed in as a named argument
Multiple Model DeepSpeed Support
After long request, we finally have multiple model DeepSpeed support in Accelerate! (though it is quite early still). Read the full tutorial [here](https://huggingface.co/docs/accelerate/v1.0.0/en/usage_guides/deepspeed_multiple_model#using-multiple-models-with-deepspeed), however essentially:
When using multiple models, a DeepSpeed plugin should be created for each model (and as a result, a separate config). a few examples are below:
Knowledge distillation
(Where we train only one model, zero3, and another is used for inference, zero2)
python
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
zero2_plugin = DeepSpeedPlugin(hf_ds_config="zero2_config.json")
zero3_plugin = DeepSpeedPlugin(hf_ds_config="zero3_config.json")
deepspeed_plugins = {"student": zero2_plugin, "teacher": zero3_plugin}
accelerator = Accelerator(deepspeed_plugins=deepspeed_plugins)
To then select which plugin to be used at a certain time (aka when calling `prepare`), we call `accelerator.state.select_deepspeed_plugin("name"), where the first plugin is active by default:
python
accelerator.state.select_deepspeed_plugin("student")
student_model, optimizer, scheduler = ...
student_model, optimizer, scheduler, train_dataloader = accelerator.prepare(student_model, optimizer, scheduler, train_dataloader)
accelerator.state.select_deepspeed_plugin("teacher") This will automatically enable zero init
teacher_model = AutoModel.from_pretrained(...)
teacher_model = accelerator.prepare(teacher_model)
Multiple disjoint models
For disjoint models, separate accelerators should be used for each model, and their own `.backward()` should be called later:
python
for batch in dl:
outputs1 = first_model(**batch)
first_accelerator.backward(outputs1.loss)
first_optimizer.step()
first_scheduler.step()
first_optimizer.zero_grad()
outputs2 = model2(**batch)
second_accelerator.backward(outputs2.loss)
second_optimizer.step()
second_scheduler.step()
second_optimizer.zero_grad()
FP8
We've enabled MS-AMP support up to FSDP. At this time we are not going forward with implementing FSDP support with MS-AMP, due to design issues between both libraries that don't make them inter-op easily.
FSDP
* Fixed FSDP auto_wrap using characters instead of full str for layers
* Re-enable setting state dict type manually
Big Modeling
* Removed cpu restriction for bnb training
What's Changed
* Fix FSDP auto_wrap using characters instead of full str for layers by muellerzr in https://github.com/huggingface/accelerate/pull/3075
* Allow DataLoaderAdapter subclasses to be pickled by implementing `__reduce__` by byi8220 in https://github.com/huggingface/accelerate/pull/3074
* Fix three typos in src/accelerate/data_loader.py by xiabingquan in https://github.com/huggingface/accelerate/pull/3082
* Re-enable setting state dict type by muellerzr in https://github.com/huggingface/accelerate/pull/3084
* Support sequential cpu offloading with torchao quantized tensors by a-r-r-o-w in https://github.com/huggingface/accelerate/pull/3085
* fix bug in `_get_named_modules` by faaany in https://github.com/huggingface/accelerate/pull/3052
* use the correct available memory API for XPU by faaany in https://github.com/huggingface/accelerate/pull/3076
* fix `skip_keys` usage in forward hooks by 152334H in https://github.com/huggingface/accelerate/pull/3088
* Update README.md to include distributed image generation gist by sayakpaul in https://github.com/huggingface/accelerate/pull/3077
* MAINT: Upgrade ruff to v0.6.4 by BenjaminBossan in https://github.com/huggingface/accelerate/pull/3095
* Revert "Enable Unwrapping for Model State Dicts (FSDP)" by SunMarc in https://github.com/huggingface/accelerate/pull/3096
* MS-AMP support (w/o FSDP) by muellerzr in https://github.com/huggingface/accelerate/pull/3093
* [docs] DataLoaderConfiguration docstring by stevhliu in https://github.com/huggingface/accelerate/pull/3103
* MAINT: Permission for GH token in stale.yml by BenjaminBossan in https://github.com/huggingface/accelerate/pull/3102
* [docs] Doc sprint by stevhliu in https://github.com/huggingface/accelerate/pull/3099
* Update image ref for docs by muellerzr in https://github.com/huggingface/accelerate/pull/3105
* No more t5 by muellerzr in https://github.com/huggingface/accelerate/pull/3107
* [docs] More docstrings by stevhliu in https://github.com/huggingface/accelerate/pull/3108
* 🚨🚨🚨 The Great Deprecation 🚨🚨🚨 by muellerzr in https://github.com/huggingface/accelerate/pull/3098
* POC: multiple model/configuration DeepSpeed support by muellerzr in https://github.com/huggingface/accelerate/pull/3097
* Fixup test_sync w/ deprecated stuff by muellerzr in https://github.com/huggingface/accelerate/pull/3109
* Switch to XLA instead of TPU by SunMarc in https://github.com/huggingface/accelerate/pull/3118
* [tests] skip pippy tests for XPU by faaany in https://github.com/huggingface/accelerate/pull/3119
* Fixup multiple model DS tests by muellerzr in https://github.com/huggingface/accelerate/pull/3131
* remove cpu restriction for bnb training by jiqing-feng in https://github.com/huggingface/accelerate/pull/3062
* fix deprecated `torch.cuda.amp.GradScaler` FutureWarning for pytorch 2.4+ by Mon-ius in https://github.com/huggingface/accelerate/pull/3132
* 🐛 [HotFix] Handle Profiler Activities Based on PyTorch Version by yhna940 in https://github.com/huggingface/accelerate/pull/3136
* only move model to device when model is in cpu and target device is xpu by faaany in https://github.com/huggingface/accelerate/pull/3133
* fix tip brackets typo by davanstrien in https://github.com/huggingface/accelerate/pull/3129
* typo of "scalar" instead of "scaler" by tonyzhaozh in https://github.com/huggingface/accelerate/pull/3116
* MNT Permission for PRs for GH token in stale.yml by BenjaminBossan in https://github.com/huggingface/accelerate/pull/3112
New Contributors
* xiabingquan made their first contribution in https://github.com/huggingface/accelerate/pull/3082
* a-r-r-o-w made their first contribution in https://github.com/huggingface/accelerate/pull/3085
* 152334H made their first contribution in https://github.com/huggingface/accelerate/pull/3088
* sayakpaul made their first contribution in https://github.com/huggingface/accelerate/pull/3077
* Mon-ius made their first contribution in https://github.com/huggingface/accelerate/pull/3132
* davanstrien made their first contribution in https://github.com/huggingface/accelerate/pull/3129
* tonyzhaozh made their first contribution in https://github.com/huggingface/accelerate/pull/3116
**Full Changelog**: https://github.com/huggingface/accelerate/compare/v0.34.2...v1.0.0