[Lightning AI](https://lightning.ai) :zap: is excited to announce the release of Lightning 2.5.
Lightning 2.5 comes with improvements on several fronts, with **zero** API changes. Our users love it stable, we keep it stable :smile:.
Talking about love :heart:, the `lightning`, `pytorch-lightning` and `lightning-fabric` packages are collectively getting more than **10M downloads per month** :open_mouth:, for a total of over **180M downloads** :exploding_head: since the early days . It's incredible to see PyTorch Lightning getting such a strong adoption across the industry and the sciences.
Release 2.5 embraces PyTorch 2.5, and it marks some of its more recent directions as officially supported, namely tensor subclass-based APIs like [Distributed Tensors](https://pytorch.org/docs/stable/distributed.tensor.html) and [TorchAO](https://pytorch.org/blog/pytorch-native-architecture-optimization/), in combination with `torch.compile`.
Here's a couple of examples:
<details><summary>Distributed FP8 transformer with PyTorch Lightning</summary>
Full example [here](https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/pytorch/fp8_distributed_transformer)
python
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.model = None
def configure_model(self):
if self.model is not None:
return
with torch.device("meta"):
model = Transformer(
vocab_size=self.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)
float8_config = Float8LinearConfig(
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa
pad_inner_dim=True,
)
def module_filter_fn(mod: torch.nn.Module, fqn: str):
we skip the decoder because it typically vocabulary size
is not divisible by 16 as required by float8
return fqn != "decoder"
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
for module in model.modules():
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
fully_shard(module, mesh=self.device_mesh)
fully_shard(model, mesh=self.device_mesh)
self.model = torch.compile(model)
def training_step(self, batch):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-4)
def train():
L.seed_everything(42)
dataset = WikiText2()
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
model = LanguageModel(vocab_size=dataset.vocab_size)
mp_strategy = ModelParallelStrategy(
data_parallel_size=4,
tensor_parallel_size=1,
)
trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8)
trainer.fit(model, train_dataloader)
trainer.print(torch.cuda.memory_summary())
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
train()
</details>
<details><summary>Distributed FP8 transformer with Fabric</summary>
Full example [here](https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/fp8_distributed_transformer)
python
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.pytorch.demos import Transformer, WikiText2
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
from tqdm import tqdm
def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
float8_config = Float8LinearConfig(
pip install -U --index-url <https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/> triton-nightly # noqa
pad_inner_dim=True,
)
def module_filter_fn(mod: torch.nn.Module, fqn: str):
we skip the decoder because it typically vocabulary size
is not divisible by 16 as required by float8
return fqn != "decoder"
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
for module in model.modules():
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
fully_shard(module, mesh=device_mesh)
fully_shard(model, mesh=device_mesh)
return torch.compile(model)
def train():
L.seed_everything(42)
batch_size = 8
micro_batch_size = 1
max_steps = 100
dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size)
with torch.device("meta"):
model = Transformer(
vocab_size=dataset.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)
strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model)
fabric = L.Fabric(precision="bf16-true", strategy=strategy)
fabric.launch()
model = fabric.setup(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = fabric.setup_optimizers(optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader)
steps = 0
for i, batch in iterable:
input, target = batch
is_accumulating = i % (batch_size // micro_batch_size) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
output = model(input, target)
loss = F.nll_loss(output, target.view(-1))
fabric.backward(loss)
if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
steps += 1
if fabric.is_global_zero:
iterable.set_postfix_str(f"train_loss={loss.item():.2f}")
if steps == max_steps:
break
fabric.print(torch.cuda.memory_summary())
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
train()
</details>
As these examples show, it's now easier than ever to take your PyTorch Lightning module and run it with **FSDP2 and/or tensor parallelism in FP8 precision**, using the `ModelParallelStrategy` we introduced in 2.4.
Also note the use of distributed tensor APIs, TorchAO APIs, and `torch.compile` directly in the `configure_model` hook (or in the parallelize function in Fabric's `ModelParallelStrategy`), as opposed to the `LightningModule` as a whole. The advantage with this approach is that you can just **copy-paste the parallelize functions** that come with native PyTorch models directly in `configure_model` and get the same effect, no head-scratching involved :nerd_face:.
Talking about head scratching, we also made a pass at the PyTorch Lightning internals and **hardened** the parts where we keep track of **progress counters** during training, validation, testing, as well as learning rate scheduling, in relation to **resuming from checkpoints**. We now made sure there are no (to the best of our knowledge) edge cases where stopping and resuming from checkpoints can change the sequence of loops or other internal states. **Fault tolerance for the win** :partying_face:!
Alright! Feel free to take a look at the **full changelog** below.
And of course: the best way to use PyTorch Lightning and Fabric is through [Lightning Studio](https://lightning.ai/) :zap:. Access GPUs, train models, deploy and more with **zero setup**. Focus on data and models - not infrastructure.
<a name="changelog"></a>
Changes
<a name="changelog-pytorch"></a>
PyTorch Lightning
<details open><summary>Added</summary>
- Added `step` parameter to `TensorBoardLogger.log_hyperparams` to visualize changes during training ([20176](https://github.com/Lightning-AI/pytorch-lightning/pull/20176))
- Added `str` method to datamodule ([20301](https://github.com/Lightning-AI/pytorch-lightning/pull/20301))
- Added timeout to DeepSpeedStrategy ([20474](https://github.com/Lightning-AI/pytorch-lightning/pull/20474))
- Added doc for Truncated Back-Propagation Through Time ([20422](https://github.com/Lightning-AI/pytorch-lightning/pull/20422))
- Added FP8 + FSDP2 + torch.compile examples for PyTorch Lightning ([20440](https://github.com/Lightning-AI/pytorch-lightning/pull/20440))
- Added profiling to `Trainer.save_checkpoint` ([20405](https://github.com/Lightning-AI/pytorch-lightning/pull/20405))
- Added after_instantiate_classes hook to CLI ([20401](https://github.com/Lightning-AI/pytorch-lightning/pull/20401))
</details>
<details open><summary>Changed</summary>
- Updated checkpointing documentation to mark `resume_from_checkpoint` as deprecated ([20477](https://github.com/Lightning-AI/pytorch-lightning/pull/20477))
- Made plugin type checks more flexible ([20186](https://github.com/Lightning-AI/pytorch-lightning/pull/20186))
- Changed seeding NumPy using `np.random.SeedSequence()` in `pl_worker_init_function()` to robustly seed NumPy-dependent dataloader workers ([20369](https://github.com/Lightning-AI/pytorch-lightning/pull/20369))
- Allowed callbacks to be restored not just during training ([20403](https://github.com/Lightning-AI/pytorch-lightning/pull/20403))
- Changed LightningCLI tests to account for future fix in jsonargparse ([20372](https://github.com/Lightning-AI/pytorch-lightning/pull/20372))
- Bumped PyTorch to version `2.5` ([20351](https://github.com/Lightning-AI/pytorch-lightning/pull/20351))
- Decoupled checkpoint artifact path from model artifact path ([20325](https://github.com/Lightning-AI/pytorch-lightning/pull/20325))
- Updated BitsAndBytes version ([20313](https://github.com/Lightning-AI/pytorch-lightning/pull/20313))
- Changed merging of hparams when logging to ignore parameter names that start with an underscore `_` ([20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221))
- Re-enabled passing `BytesIO` as path in `.to_onnx()` ([20172](https://github.com/Lightning-AI/pytorch-lightning/pull/20172))
</details>
<details open><summary>Removed</summary>
- Removed `List[int]` as input type for Trainer when `accelerator="cpu"` ([20399](https://github.com/Lightning-AI/pytorch-lightning/pull/20399))
</details>
<details open><summary>Fixed</summary>
- Fixed UnboundLocalError when using the predict method with return_predictions=False. ([20484](https://github.com/Lightning-AI/pytorch-lightning/pull/20484))
- Fixed use of `convert_module` in FSDP to avoid using more memory than necessary during initialization ([20323](https://github.com/Lightning-AI/pytorch-lightning/pull/20323))
- Fixed TypeError in `configure_optimizers` when running with `ReduceLROnPlateau` ([20471](https://github.com/Lightning-AI/pytorch-lightning/pull/20471))
- Fixed return type in `configure_optimizers` example ([20420](https://github.com/Lightning-AI/pytorch-lightning/pull/20420))
- Fixed in ncorrect URI prefix stripping in MLFlowLogger ([20365](https://github.com/Lightning-AI/pytorch-lightning/pull/20365))
- Fixed shuffling behavior when using a custom sampler in data module ([20327](https://github.com/Lightning-AI/pytorch-lightning/pull/20327))
- Ensured restarting from checkpoints leads to consistent internal counters compared to uninterrupted training ([20379](https://github.com/Lightning-AI/pytorch-lightning/pull/20379))
- Fixed LightningCLI failing when both module and data module save hyperparameters due to conflicting internal `_class_path` parameter ([20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221))
</details>
<a name="changelog-fabric"></a>
Lightning Fabric
<details open><summary>Added</summary>
- Added `step` parameter to `TensorBoardLogger.log_hyperparams` to visualize changes during training ([20176](https://github.com/Lightning-AI/pytorch-lightning/pull/20176))
- Added timeout to DeepSpeedStrategy ([20474](https://github.com/Lightning-AI/pytorch-lightning/pull/20474))
- Added FP8 + FSDP2 + torch.compile examples for Fabric ([20440](https://github.com/Lightning-AI/pytorch-lightning/pull/20440))
- Added RTX 4080 super to chips dictionary ([20285](https://github.com/Lightning-AI/pytorch-lightning/pull/20285))
- Added device property to lazy load functionality ([20183](https://github.com/Lightning-AI/pytorch-lightning/pull/20183))
- Added `ddp_find_unused_parameters_true` alias in Fabric's DDPStrategy ([20125](https://github.com/Lightning-AI/pytorch-lightning/pull/20125))
</details>
<details open><summary>Changed</summary>
- Changed seeding NumPy using `np.random.SeedSequence()` in `pl_worker_init_function()` to robustly seed NumPy-dependent dataloader workers ([20369](https://github.com/Lightning-AI/pytorch-lightning/pull/20369))
- Bumped PyTorch to version `2.5` ([20351](https://github.com/Lightning-AI/pytorch-lightning/pull/20351))
- Update BitsAndBytes version ([20313](https://github.com/Lightning-AI/pytorch-lightning/pull/20313))
</details>
<details open><summary>Removed</summary>
- Nothing to see here :smile:
</details>
<details open><summary>Fixed</summary>
- Fixed use of `convert_module` in FSDP to avoid using more memory than necessary during initialization ([20323](https://github.com/Lightning-AI/pytorch-lightning/pull/20323))
</details>
</br>
**Full commit list**: [2.4.0 -> 2.5.0](https://github.com/Lightning-AI/pytorch-lightning/compare/2.4.0...2.5.0)
<a name="contributors"></a>
Contributors
We thank **all folks** who submitted issues, features, fixes and doc changes. It's the only way we can **collectively** make Lightning :zap: better for everyone, nice job!
In particular, we would like to thank the authors of the pull-requests above, in no particular order:
ringohoffman MrWhatZitToYaa jedyang97 chualanagit lantiga AlessandroW kazuar t-vi 01AbhiSingh WangYue0000 amorehead EricCousineau-TRI mauvilsa Borda pete-mcelroy ali-alshaar7 GdoongMathew farhadrgh tshu-w LukasSalchow awindmann dadwadw233 qingquansong
Thank you :heart: and we hope you'll keep them coming!