"Cramming" even harder

Cropped first page of "Cramming: Training a Language Model on a Single GPU in One Day"

Epistemic status: homework assignment. I wrote this in an hour or so (probably frantically, minutes before the deadline) and it is not up to my usual quality standards (prof had no critical feedback but that’s a very low bar). It may or may not be useful to anyone, but there’s a reason this site is at unoriginal.blog. These are mainly published so I can cite them in other articles. Corrections are especially welcome here—send to contact@unoriginal.blog.

We are to analyze a paper from ICML 2023. Because I’m lazy, the paper I chose to analyze was one I’d already read, “Cramming: Training a Language Model on a Single GPU in One Day” by Jonas Geiping and Tom Goldstein.

Summary

First, this paper poses the challenge of “cramming”: training a language model (specifically, the paper considers variants of BERT, which are still state of the art in encoder-only NLP models) from scratch, of arbitrary size and architecture, with any training data, on a single GPU, for 24 hours or less, to achieve the highest score possible on the GLUE benchmark of natural language understanding (with minimal fine-tuning on sample GLUE tasks beforehand).

The paper goes on to attempt the “cramming” challenge. The authors fairly describe this as “tying our hands behind our back”: the original BERT took almost two orders of magnitude more FLOPs to train than can be extracted from a single GPU within the time limit. They even add the further constraint of vanilla/naive PyTorch implementation “to limit our gains from the ‘software lottery’” (in which an idea would win not on the strength of its algorithm but in a specific pattern compilers, etc. could optimize). Nevertheless, the authors suggest “the goal of achieving BERT-like performance with modest training resources would have seemed unthinkable in 2018, and yet with modern advances and transformer training techniques this may now be possible.”

The paper validates this statement by, within the “cramming” rules, training a BERT variant with only a few changes:

  • The first models were trained on a custom dataset of books and English Wikipedia under the assumption that when bottlenecked by compute and not training data, only the highest-quality data should be used. Ths was true in theory, but empirically later iterations using a subset of the Pile performed better on GLUE.
  • The dataset is sorted such that shorter sentences are trained on first (empirically the best strategy).
  • Sequences from the training data are packed into 128-token sequences separated by a special token (instead of the BERT status quo of just padding short sequences).
  • Biases are not added to linear layers, decoders, or after the QKV projection step.
  • Layer pre-normalization is done instead of post-normalization.
  • Dropout is removed for more updates (in one epoch, the model won’t overfit anyway).
  • Aggressive “triangular” learning rate that starts high and increases until ~75% through.
  • Gradient clipping to stabilize this aggressive learning schedule.
  • Tiny batch sizes as memory is a constraint and we’re not exactly worried about parallelism.

In the end, a model using these modifications was able to achieve 80.1 on GLUE compared to the original BERT’s 80.5. Pretty impressive!

Strengths

This paper has many strengths:

  • The “cramming” approach democratizes high-performance language model training, allowing those with a single (beefy) machine to compete.
  • The “crammed” model approaches the original BERT’s GLUE score with much less compute, serving as a sort of ablation study of BERT itself and demonstrating possible efficiency gains.
  • Through extensive testing and ablation studies, the paper pinpoints the specific factors influencing model and training performance in constrained environments.
  • The transparency in sharing detailed negative results is a boon for the research community, saving time for others by indicating unfruitful paths and reinforcing the integrity and openness of the scientific process.
  • Some techniques used in the paper would not scale to BERTs trained on large clusters, but many tricks such as packing sequences together during training could be used at a larger scale as well. The paper provides a good summary of these optimization techniques.

Weaknesses

Nevertheless, I can see a few weaknesses in the paper. There are some avenues to improving efficiency that have already become very popular in the “local LLMs” community which could be considered here, and others based on some very recent research.

Quantization

One of these techniques is quantization. In the appendix of the “Cramming” paper, they specify they use “automated mixed precision for standard 16- and 32-bit floating point precision over full 32-bit float, scaled 16-bit and pure bfloat16.” In other words, they use the de facto default for NVIDIA GPUs (read: throughout ML). I am unsure by their phrasing whether they did not try the other techniques or if they did to no avail; the quote is from the “Negative Results” appendix but does not explicitly mention an attempt at using scaled 16-bit or pure bfloat16 during training like other results mentioned in this section.

Extra credit: potential quantization methods

Either way, even if (b)float16 worked great once tried, there is still room for improvement. Google recently ran an LLM training job across more than 50,000 TPUs, and one of the technologies that made this feasible was AQT (Accurate Quantized Training), which (allegedly) allows training in int8 and int4 with very little quantization loss. AQT is open source and can be dropped into any JAX computation. (It would not quite be that easy for the PyTorch BERT-like implementation used in the “Cramming” paper, but there are JAX implementations of BERT as well which could be tweaked as described in the paper, plus AQT.) Rumor has it Google is investigating int4 training next, so there is even further room for improvement. Int4 is already supported by AQT, but another option could leverage another ICLR’23 paper, “Accurate Neural Training with 4-bit Matrix Multiplications at Standard Formats”. Finally, I found another paper implementing 4-bit quantization for BERT specifically with minimal effect on GLUE scores that could be a good starting point.

Not only do accelerators such as TPUs (Google’s focus, for obvious reasons) have high throughput of these smaller datatypes, recent GPUs do too. NVIDIA has supported int8 forever, int4 since Turing in 2018, and float8 since Hopper in 2022. Hell, Turing onwards even have int1! Given NVIDIA’s quasi-monopoly on GPU ML, it seems feasible to apply quantized training even in the consumer GPU “cramming” setting.

Given that memory bandwidth is usually the bottleneck during training, int4 might quadruple the amount of training data that can be “crammed”! (That’s a conservative estimate; if the paper’s automatic mixed-precision training has some parameters as 32-bit, int4 is even more of an improvement.)

Data quality

The paper makes a decent effort to compare the effect of different sources of training data. (It turns out when your model trains in 24 hours you can afford to train lots of variations.) They try both a Wikipedia and book-based corpus we’d expect to be high-quality if more narrow in scope, and a subset of the Pile, a very common source that’s an aggregation of 22 sources from law to website scrapes. However, I feel there are some more obvious low-hanging fruit in terms of data sources that should be explored. For one, very high-quality synthetic data can be generated by a large language model. The recent papers “Textbooks Are All You Need” and “Textbooks Are All You Need II” and the associated models phi-1 and phi-1.5 demonstrate how a synthetic and/or very highly filtered dataset leads to tiny, but very capable models. Another example of this idea is TinyStories. As one reader described Textbooks Are All You Need:

What they did was basically this:

  1. started with The Stack (a 3 TB collection of code) and text from StackOverflow
  2. used a LLM to select 6B “high-quality” tokens from (1)
  3. used GPT-3.5 to generate 1B tokens of text similar to textbooks
  4. trained a small (1.3B parameter) model (“phi-1”) on (2) and (3)
  5. used GPT-3.5 to generate text similar to textbook exercises
  6. fine-tuned phi-1 on (5)
  7. tested phi-1 on HumanEval to evaluate its programming ability

The results were pretty good, better than models 10x the size trained on 100x the data. So, it seems that scaling up isn’t the only thing that matters, and data quality can be more important than data quantity or parameter count.

Going by the listed OpenAI API prices, running GPT-3.5 on The Stack to evaluate quality would’ve been maybe ~$6M. What the authors did instead was:

  1. Use GPT-4 to evaluate a small fraction of it.
  2. Use a much smaller code-specific model to generate embeddings.
  3. Use a classifier to predict which embeddings are from what GPT-4 evaluates as good content.

Some have worried that the reason these papers achieve such good results is contamination of the training data with benchmark answers. This may be true, but if not, there may be large gains from either having a GPT model ramble until you have 24 hours worth of grammatically correct, long-range-dependency-full, high-quality tokens or using the classifier scheme mentioned above.

Positional encoding

The “Cramming” authors “implement scaled sinusoidal positional embeddings… finding incremental benefits over learned or unscaled sinusoidal embedding.” Scaled sinusoidal embeddings are soooo 2019. I’m sort of surprised they didn’t at least attempt relative positional encoding a la T5, which was well known at the time the paper was written.

There is an even bigger improvement possible over that, though. Since late 2022, the “standard” has been Rotary Position Embedding (RoPE). (RoPE had technically been around since 2021, but it only made it to the English NLP world more recently.) As explained in this wonderful article by EleutherAI,

…the intuition behind RoPE is that we can represent the token embeddings as complex numbers and their positions as pure rotations that we apply to them. If we shift both the query and key by the same amount, changing absolute position but not relative position, this will lead both representations to be additionally rotated in the same manner … thus the angle between them will remain unchanged and thus the dot product will also remain unchanged. By exploiting the nature of rotations, the dot product used in self-attention will have the property we are looking for, preserving relative positional information while discarding absolute position.

As RoPE stores the positional information without adding sin(thetam) and cos(thetam) terms to every token (it just rotates the tokens, effectively applying a multiplicative factor), using it would result in slightly lower memory bandwidth and fewer computations when “cramming”. Empirically, RoPE can make models converge around 10% faster than learned sinusoidal embedding or relative positional encoding. (I have not seen it directly compared to unscaled or scaled sinusoidal embedding, but I believe RoPE would still be better.)

Appendix: AI Use

I was encouraged to use LLMs to help write this. I did end up using Anthropic’s Claude and OpenAI’s GPT-4 via API for a few things:

  • To choose the paper to analyze, I fed a list of all ICML’23 papers and their abstracts into GPT-4 and had it filter down to papers involving optimization of LLMs. I chose the first one I’d heard about before.
  • I had GPT-4 summarize the paper before I read it, then actually read it. A few of its assertions were wrong.
  • I sometimes asked GPT-4 to define ML jargon I was unfamiliar with.
  • “Strengths” were generated by GPT-4 (fed this paper minus the Strengths section and a copy of the “Cramming” paper), but obviously so—they had that bureaucratic ChatGPT tone—so I rewrote them with GPT-4’s as a guide.
  • When I was done with the assignment, I fed it and the assignment instructions to GPT-4 to see if I was missing any of the requirements.