VBE
TorchRec now natively supports VBE (variable batched embeddings) within the `EmbeddingBagCollection` module. This allows variable batch size per feature, unlocking sparse input data deduplication, which can greatly speed up embedding lookup and all-to-all time. To enable, simply initialize `KeyedJaggedTensor `with `stride_per_key_per_rank` and `inverse_indices` fields, which specify batch size per feature and inverse indices to reindex the embedding output respectively.
Embedding offloading
Embedding offloading is UVM caching (i.e. storing embedding tables on host memory with cache on HBM memory) plus prefetching and optimal sizing of cache. Embedding offloading would allow running a larger model with fewer GPUs, while maintaining competitive performance. To use, one needs to use the prefetching pipeline ([PrefetchTrainPipelineSparseDist](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/train_pipeline.py?#L1056)) and pass in [per table cache load factor](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L457) and the [prefetch_pipeline](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L460) flag through constraints in the planner.
Trec.shard/shard_modules
These APIs replace embedding submodules with its sharded variant. The shard API applies to an individual embedding module while the shard_modules API replaces all embedding modules and won’t touch other non-embedding submodules.
Embedding sharding follows similar behavior to the prior TorchRec DistributedModuleParallel behavior, except the ShardedModules have been made composable, meaning the modules are backed by[ TableBatchedEmbeddingSlices](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/composable/table_batched_embedding_slice.py#L15) which are views into the underlying TBE (including .grad). This means that fused parameters are now returned with named_parameters(), including in DistributedModuleParallel.