| 13B | 54.4% | 77.1% |
| 20B | 50.9% | 53.8% |
| 65B | 44.6% | 55.5% |
The smaller models underutilize the hardware, but the larger models are better able to saturate the TPU v3-256.
To help contextualize these numbers, on the next-generation **TPU v4-128**s and with a slightly different 22B parameter model,
the performance-focused [MaxText](https://github.com/google/maxtext) library
[gets MFU](https://github.com/google/maxtext#runtime-performance-results) between 53.2% and 56.7%. Our nearest neighbor at 20B is somewhat lower
but roughly in the same ballpark;
we hope to improve this in the future, partially by using their tricks...
Though the hardware is different, we can also compare to the [very large table of results](https://github.com/mosaicml/examples/tree/release/v0.0.4/examples/llm/throughput#a100-80gb)
from [MosaicML](https://www.mosaicml.com/), whose numbers are generally in the 45-55% range for MFU and 55-65% range for HFU. Our results are in the same ballpark, though
our highest numbers are not as high as theirs. In part, this is because they use [Flash Attention](https://arxiv.org/abs/2205.14135) and they
can avoid gradient checkpointing at lower scales (which is easier to do on the higher-memory A100s); these changes improve MFU.
For other comparisons (to much larger models trained on much larger clusters), we can compare to the table from the [PALM paper](https://arxiv.org/pdf/2204.02311.pdf), to give
a rough sense of how our results compare to other work:
![table showing MFU and HFU for various models; Table 3 in https://arxiv.org/pdf/2204.02311.pdf](figures/palm_mfu_table.png)
FSDP is likely to perform less well on clusters of the sizes in this table (i.e., a few thousand TPUs or GPUs), since it requires more communication than other approaches.
However, at our scale, we find that FSDP is better than either tensor parallelism or a combination of FSDP and tensor parallelism.
We leave pipeline parallelism and more thorough comparisons as future work.
Our results here demonstrate that you can get good scalability in a highly legible codebase, with the logic of the model decoupled
from the logic of parallelism.
We of course cannot claim full credit for these results: they build on the excellent work of the JAX, XLA, and TPU teams,
as well as all the algorithmic and hardware improvements that they themselves build on. Nevertheless, we hope that our work makes it easier
for others to experiment with models at larger scales than they otherwise would have.
Reproducibility: Bitwise Determinism with Levanter and JAX
After legibility and scalability, we have reproducibility, which JAX helps with enormously. In particular, JAX's fine-grained
control over PRNG states makes it easy to ensure bitwise determinism.
Levanter takes advantage of this to offer bitwise reproducibility for training runs, even after preemption. That is,
the same run with the same code on the same hardware configuration (e.g. a v3-32 or a v3-256) will produce the exact same loss curve, even if it is
preempted and resumed multiple times. As an example, here is a screenshot of a training run being resumed multiple times, even on different TPU pod slices:
![plot showing bitwise reproducibility with four training runs superimposed with the exact same loss curve](figures/bitwise_repro_curve.png)
The fact that you can't make out the different lines is the point: the training runs are bitwise identical,
a huge advantage for debugging and reproducibility. For instance, loss spikes are not uncommon when training large models,
and it can be difficult to tell whether a spike is due to a bug, data, optimizer state, or just bad luck with the random
number generator. Without bitwise reproducibility, investigating these issues is challenging because you can't rewind and replay
your training run's state to the time of the spike. If you make an adjustment without bitwise reproducibility, you can't tell whether it
fixed the problem, or whether the problem went away randomly.
Experimental Setup Logging and Checkpointing
Levanter also logs everything necessary to exactly reproduce a run: the git SHA, code, configuration,
and a pip-freeze of the environment. Checkpoints serialize the entire model state, including the optimizer state,
as well as the "main" PRNG state, which is used to generate the other PRNG states. This means that you can
exactly reproduce a run by simply checking out the git SHA, installing the dependencies, and running the code (on the same
hardware configuration).
Other Features in Levanter
Beyond our three pillars of legibility, scalability, and reproducibility, Levanter also has a number of other
features that make it easier to train large models. We describe some of them here.
Data Preparation and Visualization
During our collaborations with teams to build domain-specific models, we have found that data preparation can be a significant challenge.
Indeed, it is often the biggest challenge.
In particular, we have found that users want to iterate quickly on different data formats (and more
generally the entire [ETL pipeline](https://en.wikipedia.org/wiki/Extract,_transform,_load)).
Moreover, it can be difficult to visualize the effects of different preprocessing options on the data. To address this,
we have built two features into Levanter: cached on-demand data preprocessing and live visualization during training.
Cached On-Demand Data Preprocessing
Training a language model involves taking a large corpus of text and converting it into a sequence of integers called tokens.
When training large autoregressive models, it is typical to concatenate (or "pack") short sequences and break apart longer sequences
so that the resulting sequences are all the same length.
Data preprocessing is done in one of two ways: either it is performed offline as a separate preprocessing step, or it is
performed streaming, so that the data is processed on-the-fly as it is being used for training. The former is typically
faster, but the latter is more flexible, since it allows you to iterate on the data format without having to reprocess
the entire data set as a preprocessing step. However, streaming, especially when coupled with sequence packing, is difficult to pair with
resuming from preemption, since the data stream must be restarted from the beginning (or one must take care to track byte offsets).
In Levanter, we take a hybrid approach. We preprocess the data online, but we cache the results of preprocessing so
that resumes are much faster and so that subsequent runs are even faster. As soon
as the first part of the cache is complete, Levanter will start training, and will continue to preprocess the rest of
the data in the background. This allows us to start training as soon as possible, while still allowing us to iterate
on the data format. Moreover, we can resume from preemption without having to reprocess the entire data set.
Our cache format also allows for iterating on sequence length without retokenizing, which in our experience is a commonly requested feature.
Levanter's preprocessing works by spinning up a [Ray cluster](https://www.ray.io/) using the hosts being used for training,
exploiting the typically impressive CPUs of those machines to preprocess data.
This is especially useful for large datasets like [The Pile](https://pile.eleuther.ai/) or the [Red Pajama](https://github.com/togethercomputer/RedPajama-Data) data set.
Preprocessing can also be performed offline using a Ray cluster, or on a single machine. In all cases, the caches
produced by preprocessing are fully reproducible, so that we can assure bitwise reproducibility even when preprocessing
is performed on different machines.
Levanter works out of the box with either [Hugging Face Datasets](https://huggingface.co/datasets) (including streaming) or urls of (compressed)
jsonl files. Caches can be stored in any fsspec-compatible file system, including GCS and local file systems. We use
[Hugging Face Tokenizers](https://huggingface.co/docs/tokenizers/) for tokenization.
Live Visualization during Training
Levanter also provides a feature for visualizing the probability of each token in a sample of the validation set during training.
When training large models, it can be difficult to get a sense of how the model is learning. This is especially true
when training on novel datasets. As an example, we have seen issues with early versions of new datasets where the
model had suspiciously low loss.
The visualization produces a heatmap of the log probability of each token in a sample of the validation set
that is updated periodically during training. Here is an example of the token probability visualization in action on a
small, quick training run:
![video showing heat map of token probabilities for a sample of the validation set evolving as training progresses](figures/token_probabilities.mov)
The darker, more purple the color, the lower the probability of the token. The lighter, more yellow the color, the higher the probability.
This visualization is logged to WandB as training progresses and can be viewed interactively. We have found this to be a
nice alternative to just staring obsessively at the loss curve (not that we ever do that).
In the past, we have used our visualization to identify a pattern of highly but not perfectly redundant data in a new data set
(what we call "madlib duplicates"), suggesting that the model is "wasting" time and context length on low-value data.
We've also used it to qualitatively assess how alternative architectures (like [Backpacks](http://backpackmodels.science/))
learn differently from Transformers.
A few other features
* **Training**: Levanter uses [Optax](https://github.com/deepmind/optax) for optimization,
though our new optimizer, [Sophia](https://arxiv.org/abs/2305.14342), is coming to Levanter soon!
* **Logging**: Logging is done with [WandB](https://wandb.ai/), complete with a fancy online visualization of the validation set during training.
* **Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now.
* **Export**: We also support exporting models to the Hugging Face Hub, with export compatible with Pytorch and Transformers via [SafeTensors](https://github.com/huggingface/safetensors).
* **Stability**: The GPT-2 implementation uses the [Mistral stability trick](https://crfm.stanford.edu/2021/08/26/mistral.html) to improve stability during training.
Getting Started with Levanter
To get started, first install the appropriate version of JAX for your system. See [JAX's installation instructions](https://github.com/google/jax/blob/main/README.md#installation) as it varies from platform to platform.
If you're using a TPU, more complete documentation for setting that up is available [here](Getting-Started-TPU-VM.md). GPU support is still in-progress; documentation is available [here](Getting-Started-GPU.md).
Next, clone the repository and install it with pip:
bash
git clone https://github.com/stanford-crfm/levanter.git
cd levanter
pip install -e .
wandb login optional, we use wandb for logging
Training a GPT2-nano
As a kind of hello world, here's how you can train a GPT-2 "nano-sized" model on the small [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) dataset:
bash
python -m levanter.main.train_lm --config_path config/gpt2_nano.yaml
alternatively, if you didn't use -e and are in a different directory
python -m levanter.main.train_lm --config_path gpt2_nano
Training a GPT2-small on your own data
If your dataset is a [Hugging Face dataset](https://huggingface.co/docs/datasets/loading_datasets.html), you can use the `data.id` field to specify it:
bash
python -m levanter.main.train_lm --config_path config/gpt2_small.yaml --data.id openwebtext
optionally, you may specify a tokenizer and/or a cache directory, which may be local or on gcs
python -m levanter.main.train_lm --config_path config/gpt2_small.yaml --data.id openwebtext --data.tokenizer "EleutherAI/gpt-neox-20b" --data.cache_dir "gs://path/to/cache/dir"
If instead your data is a list of URLs, you can use the `data.train_urls` and `data.validation_urls` fields to specify them.
Data URLS can be local files, gcs files, or http(s) URLs, or anything that [fsspec](https://filesystem-spec.readthedocs.io/en/latest/) supports.
Levanter (really, fsspec) will automatically uncompress `.gz` and `.zstd` files, and probably other formats too.
bash
python -m levanter.main.train_lm --config_path config/gpt2_small.yaml --data.train_urls ["https://path/to/train/data_*.jsonl.gz"] --data.validation_urls ["https://path/to/val/data_*.jsonl.gz"]
You can also change the dataset by changing the `dataset` field in the config file.
Next Steps
Please see the [README for Levanter](https://github.com/stanford-crfm/levanter#installing-levanter) for
details, including training with the other supported architectures (currently, [Backpacks](http://backpackmodels.science/) and MosaicML's [MPT](https://www.mosaicml.com/blog/mpt-7b)),
as well as for training on TPUs and GPUs.
Haliax Tutorials
We have two Colab tutorials for Haliax. These are a great way to get started with Haliax:
* [Introduction to Haliax with Transformers](https://colab.research.google.com/drive/1TiTcQQ4V5mopbgCu1SVl-oqJtXn7rFnC?usp=sharing)
* [Scaling Transformers in Haliax](https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz?usp=sharing), including FSDP in JAX.
Released Models
Along with the release of the code, we are releasing a few models trained using Levanter. These models are available on
the [Hugging Face Hub](https://huggingface.co/stanford-crfm) and can be used with the Hugging Face Transformers library,
in Pytorch (and, for the GPT-2-based models, Tensorflow, and JAX). We have more in development and will release them as
they become available.
- We are release a suite of music models using the [Anticipatory Music Transformer](https://johnthickstun.com/assets/pdf/anticipatory-music-transformer.pdf), a new architecture for controllable music synthesis,
trained on the [Lakh MIDI](https://colinraffel.com/projects/lmd/) corpus. The largest, 750M parameter, one is available [here](https://huggingface.co/stanford-crfm/music-large-100k).
Please see [John Thickstun](https://johnthickstun.com/)'s [blogpost](https://crfm.stanford.edu/2023/06/16/anticipatory-music-transformer.html) for more, and [a cool demo page](https://colab.research.google.com/drive/1HCQDtGFwROpHRqcmZbV0byqbxDb74YGu?usp=sharing)!
- We also have a new 1.4B parameter checkpoint of the [Backpack Model](http://backpackmodels.science/) architecture developed by [John Hewitt](https://nlp.stanford.edu/~johnhew/) and coauthors.
This model is available [here](https://huggingface.co/stanford-crfm/levanter-backpack-1b).
- [Levanter GPT](https://huggingface.co/stanford-crfm/levanter-gpt) is a 1.5B parameter GPT-2 model trained on the
[OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) corpus.
- We have a 1.4B GPT-2 model trained on [The Pile](https://pile.eleuther.ai/) corpus.
This model is available [here](https://huggingface.co/stanford-crfm/levanter-gpt-pile). This model will serve
as a common baseline for future experiments.
Future and Conclusion
This is just the beginning for Levanter. In the future, look for:
* more models on interesting problem domains,
* scaled up versions of new architectures developed here at Stanford and elsewhere,
* new training techniques, including the newly released [Sophia](https://arxiv.org/abs/2305.14342) optimizer,
* and larger models!
Levanter is still a work in progress, but we are excited to share it with the community. We hope that Levanter will be
useful to others who are interested in training foundation models using JAX and TPUs. Please join us on our journey! You
can find us on [GitHub](https://github.com/stanford-crfm/levanter), [Twitter](https://twitter.com/StanfordCRFM), or on
the (unofficial) [JAX LLM Discord](https://discord.gg/CKazXcbbBm). (And by the way, [we're hiring](https://crfm.stanford.edu/apply.html)!)
Acknowledgements
In addition to the generous support of the Google TPU Research Cloud, we would like to thank the following people for their help and support:
* John Thickstun, Sidi Lu, John Hewitt, and others for being early adopters and providing feedback. We really appreciate your patience, support, and feedback.
* Yifan Mai, Tony Lee, Jason Bolton, Ivan Zhou, and the rest of the CRFM engineering team for support and discussions.
* Roy Frostig, Sholto Douglas, Skye Wanderman-Miln, and the rest of the JAX team for help with debugging and support.
* The TRC team for support getting spun up on TPUs and for making the Cloud TPUs available to us.
* Sidd Karamcheti for support and conversations.