This release mainly improves
1. Fix some fidelity issues.
2. Refactor schedule primitives, and add `.fork_rng()`, `.annotate()`, and `.replace_all()` primitives.
3. Other bug fixing.
If any of the following cases match your existing schedule based on v0.0.2, you have to change them to support v0.0.3.
1. Tagging parameters for DeepSpeed pipeline runtime to perform an additional all-reduce on TP group. For example, you may have the following code snippet that tags LayerNorm parameters:
python
def tag_layernorm(sch):
for m in sch.mod.modules():
if isinstance(m, nn.LayerNorm):
for p in m.parameters(recurse=False):
p.replicated_param = True
This can be changed to the following in v0.0.3:
python
def annotate_layernorm_and_bias(sch):
for sub_sch in sch.child.values():
if isinstance(sub_sch.mod, nn.LayerNorm):
for name, _ in sub_sch.mod.named_parameters(recurse=False):
sub_sch.annotate(name, "replicated_param", True)
if issubclass(sub_sch.mod.__class__, LinearWithSyncFunc):
sub_sch.annotate("bias", "replicated_param", True)
annotate_layernorm_and_bias(sub_sch)
Reference: https://github.com/awslabs/slapo/blob/main/slapo/model_schedule/gpt2.py#L529
2. RNG control can be done easily with a new introduced schedule primitive `.fork_rng()`. Accordingly, the old `slapo.op.AttentionOpWithRNG` is removed. If you have the following code snippet:
python
new_op = AttentionOpWithRNG(
sub_sch["module"]["attn_op"].mod.attn_op_name,
sub_sch["module"]["attn_op"].mod.apply_causal_mask,
sub_sch["module"]["attn_op"].mod.scale,
)
sub_sch["module"]["attn_op"].replace(new_op)
It has to be changed to
python
sub_sch["module"]["attn_op"].fork_rng()
3. The primitive `.trace_for_pipeline()` has been renamed to `.trace_until()`. Since the arguments remain the same, you could simply replace all occurrences.
4. If you use `slapo.op.FusedMLP` with sharding, you need to change your schedule to reflect the change of FusedMLP implementation. For example:
python
fc_names = ["fc_in", "act", "fc_out"]
sub_sch[fc_names[0]].shard("weight", axis=0)
sub_sch[fc_names[1]].shard("bias", axis=0)
sub_sch[fc_names[2]].shard("weight", axis=1)
sub_sch[fc_names[0]].sync(mode="bwd_post", sync_op_or_fn="all_reduce")
sub_sch[fc_names[2]].sync(mode="fwd_post", sync_op_or_fn="all_reduce")
changes to
python
fc_names = ["fc_in", "fc_out"]
sub_sch[fc_names[0]].shard("weight", axis=0)
sub_sch[fc_names[0]].shard("bias", axis=0)
sub_sch[fc_names[1]].shard("weight", axis=1)
sub_sch[fc_names[0]].sync(mode="bwd_post", sync_op_or_fn="all_reduce")
sub_sch[fc_names[1]].sync(mode="fwd_post", sync_op_or_fn="all_reduce")
What's Changed
* [Action] Fix release flow by comaniac in https://github.com/awslabs/slapo/pull/69
* [Refactor] Schedule primitives by comaniac in https://github.com/awslabs/slapo/pull/68
* [Primitive] .fork_rng() by comaniac in https://github.com/awslabs/slapo/pull/70
* [Primitive] .annotate() and .trace_until() by comaniac in https://github.com/awslabs/slapo/pull/71
* [CI] Update CI rules for docs by chhzh123 in https://github.com/awslabs/slapo/pull/72
* [Op] Fuse bias+dropout in FusedMLP by comaniac in https://github.com/awslabs/slapo/pull/73
* [Refactor] Modulize sharding methods by comaniac in https://github.com/awslabs/slapo/pull/74
* [CI] Quick fix by chhzh123 in https://github.com/awslabs/slapo/pull/75
* [Primitive][fork_rng] Do not replace module by comaniac in https://github.com/awslabs/slapo/pull/76
* [Bugfix] Include other custom LinearWithXX by comaniac in https://github.com/awslabs/slapo/pull/77
* [Primitive] Add fallback fusion by chhzh123 in https://github.com/awslabs/slapo/pull/78
* [examples] Refactor dataloader to support BERT by chhzh123 in https://github.com/awslabs/slapo/pull/79
* [Bugfix] Shard embedding hooks by comaniac in https://github.com/awslabs/slapo/pull/80
* [Version] Refactor version updating logic by comaniac in https://github.com/awslabs/slapo/pull/82
* [Op] Print by comaniac in https://github.com/awslabs/slapo/pull/81
* [Primitive] Add .replace_all() by chhzh123 in https://github.com/awslabs/slapo/pull/85
* [Version] Update version to v0.0.3 by chhzh123 in https://github.com/awslabs/slapo/pull/84
**Full Changelog**: https://github.com/awslabs/slapo/compare/v0.0.2...v0.0.3