Generalization release
The main focus of this release was on adding flexibility and generalization to support broad research cases.
Next release will be Dec 7th (every 30 days).
Internal Facebook support
lorenzoFabbri tullie myleott ashwinb shootingsoul vreis
These features were added to support FAIR, FAIAR and broader ML across other FB teams.
In general, we can expose any part that isn't exposed yet where someone might want to override the lightning implementation.
1. Added truncated back propagation through time support (thanks tullie).
python
Trainer(truncated_bptt_steps=2)
2. Added iterable datasets.
python
return iterabledataset
def train_dataloader(...):
ds = IterableDataset(...)
return Dataloader(ds)
set validation to a fix number of batches
(checks val every 100 train epochs)
Trainer(val_check_interval=100)
3. Add ability to customize backward and other training parts:
python
def backward(self, use_amp, loss, optimizer):
"""
Override backward with your own implementation if you need to
:param use_amp: Whether amp was requested or not
:param loss: Loss is already scaled by accumulated grads
:param optimizer: Current optimizer being used
:return:
"""
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
3. DDP custom implementation support (override these hooks):
python
def configure_ddp(self, model, device_ids):
"""
Override to init DDP in a different way or use your own wrapper.
Must return model.
:param model:
:param device_ids:
:return: DDP wrapped model
"""
model = LightningDistributedDataParallel(
model,
device_ids=device_ids,
find_unused_parameters=True
)
return model
def init_ddp_connection(self, proc_rank, world_size):
"""
Connect all procs in the world using the env:// init
Use the first node as the root address
"""
use slurm job id for the port number
guarantees unique ports across jobs from same grid search
try:
use the last 4 numbers in the job id as the id
default_port = os.environ['SLURM_JOB_ID']
default_port = default_port[-4:]
all ports should be in the 10k+ range
default_port = int(default_port) + 15000
except Exception as e:
default_port = 12910
if user gave a port number, use that one instead
try:
default_port = os.environ['MASTER_PORT']
except Exception:
os.environ['MASTER_PORT'] = str(default_port)
figure out the root node addr
try:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
except Exception:
root_node = '127.0.0.2'
root_node = self.trainer.resolve_root_node_address(root_node)
os.environ['MASTER_ADDR'] = root_node
dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
4. Support for your own apex init or implementation.
python
def configure_apex(self, amp, model, optimizers, amp_level):
"""
Override to init AMP your own way
Must return a model and list of optimizers
:param amp:
:param model:
:param optimizers:
:param amp_level:
:return: Apex wrapped model and optimizers
"""
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)
return model, optimizers
5. DDP2 implementation (inspired by parlai and stephenroller).
DDP2 acts as DP in the node and DDP across nodes.
As a result, an optional method is introduced training_end
where you can use the outputs of training_step (performed on each GPU with a portion of the batch),
to do something with the outputs of all batches on the node (ie: negative sampling).
python
Trainer(distributed_backend='ddp2')
def training_step(...):
x is 1/nb_gpus of the full batch
out = model(x)
return {'out': out}
def training_end(self, outputs):
all_outs has outs from ALL gpus
all_outs = outputs['out']
loss = softmax(all_outs)
return {'loss': loss}
Logging
- More logger diversity including Comet.ml.
- Versioned logs for all loggers.
- switched from print to logging
progress bar
- now the progress bar has a full bar for the full train + val epochs and a second bar visible only during val.
loading
- checkpoints now store hparams
- no need to pass tags.csv to restore state because it lives in the checkpoint.
Slurm resubmit with apex + ddp
- Fixes issue of ddp restore weights blowing out GPU memory (load on cpu first then GPU).
- Saves apex states automatically and restores it for a checkpoint.
Refactoring
- internal code made modular through Mixins for ease of readability and to minimize merge conflicts.
Docs
- Tons of doc improvements.
Thanks!
Thank you to the amazing contributor community! Especially neggert and Borda for reviewing PRs and taking care of a good number of Github issues. The community is thriving and has really embraced making Lightning better.
Great job everyone!