Language Mixers IV: Text Compression and Memory Models

This page focuses on information compression, specifically how we can achieve better compression via new language model architectures and how this can be used for dataset augmentation and large-context models.

Introduction

In work detailed in a blog post and paper, a major finding was that one can modify a new causal language model architecture termed the ‘masked mixer’ (which is essentially a transformer with self-attention replaced by masked 1D convolutions) to effectively autoencode inputs with high compression, that is, one can train small models in reasonable time to be able to regenerate (with some error) a sequence of 512 or more tokens using the embedding of only one token with excellent generalization properties. It was found that using masked mixers for encoder and decoder allows for far greater input autoencoding accuracy than a transformer encoder/decoder pair for a given model size ad compute budget during training, leading to the ability to compress inputs using new and potentially more efficient methods than what has been possible using the transformer.

It may be wondered why text compression ability is important: even if large language models achieve 2x or 4x the compression of commonly used algorithms like gzip, they are thousands of times more computationally expensive to actually run and thus not preferred today (although this may change in the future). The answer is that effective compression from many to one token allows one to design architectures that have interesting properties such as extended context windows or meta-context guidance, which we will investigate under the broad name of ‘memory’.

We will first tweak the autoencoder mixer architecture to try to obtain optimal text compression in fixed compute budgets before using this information to attempt to test whether one can obtain better text compression ratios than the current best methods (which are large causal language modeling transformers). We will conclude by examining the suitability of autoencoding embeddings for extending generative language model context with sub-quadratic complexity using encoder embeddings.

Transformer-based autoencoders train and scale poorly

In previous work evidence was presented for the idea that masked mixers were far better autoencoders than transformers, with the primary large-scale data evidence being the following: if one trains a $d_m=1024$ masked mixer autoencoder with $n_l=8$ layers in the encoder and decoder, one reaches a far lower CEL than the compute- and memory- transformer model with $d_m=512$ (transformers contain far more activations per $d_m$ due to their $K, Q, V$ projections). One may object that this could be mostly due to differences in the amount of information available in a 512-dimensional vector compared to a 1024-dimensional vector, although that was not found to be the case on the much smaller TinyStories dataset where equivalent-dimensional transformers far underperformed their masked mixer counterparts despite requiring around double the compute and device memory.

We are in a position to now further explore the training efficiencies for transformer-based versus mixer-based autoencoders. Perhaps the first question one might have when observing the very large performance gap between transformer and mixer is if there is some implementation error in the transformer architecture. The original implementation used pre-built Llama model blocks from the Huggingface transformers library, and implemented loops through these blocks for encoder and decoder while using the same vector broadcasting and tensor rearrangements, word-token embedding, and language modeling heads transformations as the masked mixer counterpart, feeding the positional information to each layer. It may be wondered if a much simpler implementation would perform better, where the encoder and decoder are each LlamaModel transformer implementations, and we also include an attention mask on the decoder. From the figure below, we can see that the transformer autoencoder with pre-built encoder and decoder is actually somewhat worse than the original modular architecture when trained on the FineMath 4+ dataset.

prebuilt vs module transformer autoencoder

The other main oustanding question is whether the increased masked mixer autoencoder training efficiency might be due to the larger embedding dimension in that model versus the transformer, at least for larger and more diverse datasets like finemath 4+ or fineweb-edu. This is primarily a scaling question with respect to increases in the model $d_m$ (and thus the embedding dimension in this case), so one can obtain a more general idea of the differences in masked mixer versus transformers for autoencoder training efficiency by comparing the losses achieved as one scales the model $d_m$ for a given training dataset.

From the following figure, it can be appreciated that indeed transformers are far less efficient to train as autoencoders than masked mixers for multiple $d_m$ values, providing evidence for the idea that differences in autoencoding abilities between these models are not due to those differences but are instead model intrinsic (specifically attention-intrinsic). For context, we have a total of $n = 200,000 * 128 * 512 \approx 13.1 * 10^9$ tokens trained at 200k steps for each model in the following figure.

transformer versus mixer autoencoders

It is apparent that transformers scale badly in terms of samples per model, as apparent by the negative asymptotic slope of the transformer training curves being far smaller than that of the masked mixer. Transformers-based autoencoders also scale poorly in terms of scaling the embedding size or equivalently the model width, which is apparent as doubling the $d_m$ of a transformer autoencoder twice gives decreasing loss achieved with each doubling, whereas the opposite is true for masked mixer autoencoders.

None of this is particularly surprising given the results and theory outlined in the masked mixer introductory paper. One major finding there is that transformers are relatively inefficient to train for tasks requiring retention of information present in the input, either in a single or multiple output embeddings.

Why transformers struggle in this autoencoding paradigm

From the last section we observed that transformers make far worse autoencoders (with the architecture presented at least) than masked mixers. The next question to ask is why this is: why would transformers be so much worse than masked mixers, given that in previous work has shown more modest differences in information retention between transformers and masked mixers? More specifically, why would attention lead to poor autoencoding in this architecture?

Some reflection will provide a reasonable hypothesis: recall that attention may be expressed as follows:

\[A(Q, K, V) = \mathrm{softmax} \left( \frac{QK^T}{\sqrt(d_k)} \right)V\]

where the $Q, K, V$ are matrices of packed vectors projected from each token embedding. To make these operations clearer from the perspective of token indices, we can ignore the $d_k$ scaling factor and express this equation as follows: the attention between the query projection for token $i$ and key and value projections at index $j$, for ${i, j \in n}$ is

\[A(q_i, k_j, v_j) = \frac{\exp \left( (q_i \cdot k_j) v_j \right)}{ \sum_n \exp \left( (q_i \cdot k_n) v_n \right)}\]

Now consider what we have done by repeating the input embedding for all tokens in the decoder: as the projection weight matrices $W_k, W_q, W_v$ are identical for all tokens, the necessarily we have the following:

\[k_i = k_j \\ q_i = q_j \\ v_i = v_j \forall i, j\]

and thus $q_i \cdot k_j = q_i \cdot k_l \; \; \forall i, j, l$. Therefore we have $A(q_i, k_j, v_j) = A(q_i, k_l, v_l) \; \forall i, j, l$ such that output activations from the attention layer are identical across all token embeddings.

Given this observation, it is not hard to see that this identicality will persist for more than one layer as each feedforward module following attention is identical. But this is not the whole picture: as implemented, Llama-style transformers apply positional encoding (RoPE in this case) before self-attention such that the embeddings at each position are actually unique, assuming that the positional encoding is itself unique for the token indices we are training on (on this page it always will be). Thus is is not strictly correct to point to identical activations due to self-attention as being the cause of the poor transformer training for repeat-embedding autoencoders, but one might wonder whether perhaps transformers are less well-suited to autoencoding with repeated embeddings relative to masked mixers.

One indirect way we can test this is as follows: if a transformer were significantly worse for decoding repeated embeddings than a masked mixer, we would expect for an autoencoder with a mixer encoder and transformer decoder to perform worse than an autoencoder with a mixer encoder and decoder, or an autoencoder with a transformer encoder and mixer decoder. As shown in the following figure, this is indeed what is found (although an optimized masked mixer autoencoder is more efficient than either compound architecture).

transformer versus mixer autoencoders

Given some evidence for our idea, how would one go about injecting an encoder’s embedding to a transformer decoder while avoiding the identical attention problem? One simple but effective way to do this is to ‘unroll’ the embedding by projecting unique subsets of the encoder’s embedding into each of the decoder’s input positions. A relatively simple but flexible method to get unique subsets of a vector is to use a sliding window, where we project from a certain number of contiguous elements of a vector in our ‘window’ and shift the elements projected from at each of the decoder’s indices, keeping the projection weights identical across all input windows. This requires an embedding vector satisfying $d_m \geq n_{ctx}$ for every decoder index vector to be unique, but we can simply add a projection to enlarge the $d_m$ of the embedding as required if necessary.

For cases where we want to project from $n_d$ elements and $d_m < n_d + n_{ctx} - 1$, or in other words where our window slides off our embedding vector to make all decoder inputs, we can simply wrap our window around to the first index of the embedding, concatenate, and project accordingly. For the case where $n_ctx=4, d_m=6, n_d = 4$, the following diagram illustrates the projection and wrapping process:

mixer autoencoder efficiencies

This can be implemented as follows: given a linear projection layer that assumes the input is half the size of the ouput, self.projection = nn.Linear(dim//2, dim), we can replace the embedding repeat with our unrolled projections as follows:

encoder_embedding = x[:, -1, :].unsqueeze(1) # dim=[batch, token, hidden]
embedding_stack = []
# sliding window unroll over hidden dim
for i in range(self.tokenized_length):
    sliding_window = encoder_embedding[..., i:i+dim//2]
    if i+dim//2 > dim:
        residual = i+dim//2 - self.tokenized_length
        # loop around to first index
        sliding_window = torch.cat((sliding_window, encoder_embedding[..., :residual]), dim=2)
    embedding_stack.append(sliding_window)
encoder_embedding = torch.cat(embedding_stack, dim=1)
encoder_embedding = self.projection(encoder_embedding)

Note here that an implementation most faithful to our figure above would be to apply the projection at each index in the for loop before concatenation, but this is much less efficient as applying the projection to the pre-concatenated output allows us to make use of device (GPU) parallelization that is otherwise tricky to add to the loop via Pytorch primitives.

For a $d_m=512, n_l=8$ (eight layer for encoder, eight for decoder) applied to $n_{ctx}=512$ FineWeb-edu, we have the following:

transformer versus mixer autoencoders

We can also take the effort to find the optimal number of heads for our autoencoder, and the figure below shows the training efficiencies for various head sizes. It can be appreciated that eight heads are optimial or near-optimal for this autoencoder using the unrolled embedding technique.

transformer versus mixer autoencoders

Does the use of unrolled embeddings lead to masked mixer autoencoders training more efficiently? The answer is generally not at one $d_m$: for the most performant convolutional kernel sizes, the use of repeated embeddings leads to more efficient training.

transformer versus mixer autoencoders

It is interesting, therefore, that for a model with a larger embedding ($d_m=1024$) we find that this is not the case:

transformer versus mixer autoencoders

It should be noted that the masked mixers in the above figure use around half the compute and device memory as the transformers in the figure before, and cannot be compared directly for training efficiency purposes. It is interesting to note that expanding the transformer width and using a compression layer (two linear transformations that compress $d_m \to d_m/2 \to d_m$ lead to substantially worse training efficiency than using a smaller $d_m$ with no compression between encoder and decoder, despite the smaller-$d_m$ model requiring half the compute and memory to train.

transformer versus mixer autoencoders

Causal masking increases autoencoder training efficiency

To begin with, it is helpful to recall the architecture of the masked mixer-based autoencoder as presented in the work linked in the last section:

mixer autoencoder architecture

This architecture recieved little to no optimization in that work, as it was mostly presented as evidence for a hypothesis involving bijective functions and autoencoding. But now that we want to improve and elaborate upon this architecture, where would we start?

One obvious question is whether or not the convolutions really need to be masked: after all, we are generating the output in one step and are not adhering to the causal language modeling objective of next token prediction, so why would we really want to mask the model’s convolutions? Bearing this in mind, it is somewhat unexpected that removing the encoder convolution’s mask results in substantially less efficient training and even more surprising that removal of the decoder’s mask (keeping the encoder unmasked as well) results in highly unstable training with gradient norm spikes leading to rapid rises in loss during training. The following figure details the cross-entropy losses achieved by masked mixer-based autoencoders ($d_m=1024, n_{ctx}=512, n_l=8, b=128$) on the FineWeb-edu (10BT) dataset, where the evaluation is a hold-out sample from the corpora (which has been measured to contain <5% duplicate documents). Here ‘masked’ indicates a causal (left-to-right) mask implemented using a lower triangular mask on the 1D convolution weight matrix, and the right panel is simply an extended training run of the conditions in the left panel (omitting the unstable no-masked model).

mixer autoencoder efficiencies

Why would causal masking be so important to a model that does not perform causal modeling? There is usually some benefit to matching a model’s structure to any prior we know about the dataset that is being modeled, and with that perspective one could perhaps guess that enforcing causality is beneficial because the data being modeled (text) is in some way fundamentally causal as it is understood in one orientation. It is less certain why removing all causality masks leads to highly unstable training, as one may expect for a simple decrease in efficiency in this paradigm rather than exploding gradients.

Multi-headed mixer autoencoders

If the mixer is a better autoencoder encoder and decoder in this paradigm (where we regenerate all tokens of a sequence in one pass), how might one improve the mixer’s architecture for further training efficiency? One straightforward guess might be to increase the number of inter-token trainable parameters, and this increase may be achieved in a number of different ways (multiple convolutional kernels, expanded convolutions with nonlinearities, multi-headed masked mixers) but when a number of architectures were tested for causal language modeling the most performant among these was the multi-headed mixer. The linear algebraic transformations that occur in the multi-headed mixer for the case where we have $n_h=2$ heads and the total projection dim is greater than the hidden dimension may be illustrated as follows:

multi-headed autoencoders

We modify the original multi-headed mixer architecture for the case where there are $n_h=2$ heads and the projection dim is just $d_m / n_h$, and fix $n_l=8$ for both encoder and decoder, and replace the mixer autoencoder’s masked convolutions with these multi-headed convolutions.

multi-headed autoencoders

As long as we stick to the $d_m / n_h$ total projection dimension limitation, the number of inter-token parameters $p$ for a model with $n_l$ layers and $n_{ctx}$ context is

\[p = n_l (n_h * n_{ctx}^2 + 2d_m^2)\]

whereas we have $p = n_l * n_{ctx}^2$ inter-token parameters for the ‘flat’ masked mixer. Therefore we have a linear increase in inter-token parameters as the number of heads in this scenario, with a constant factor for the addition of any head. To see how to arrive at this number, observe that each head has a corresponding convolution (with weight matrix of size $n_{ctx}^2$) and the input and output projections are unchanged as the number of heads increases, in particular the output projection is size $d_m^2$ and the input $d_m*d_m *n_h / n_h$, and each head contains one 1D convolution.

For the case where we keep the concatenated projection dimension to be equal to $d_m$ as above, we have a notable increase in autoencoder training efficiency (see the figure below) relative to the ‘flat’ masked mixer which has no projections and only a single 1D-convolution. From the results below, it is clear that increasing the number of heads leads to greater training efficiency, and the gap between autoencoders with four or more heads and the flat masked mixer architecture is substantial.

decoder options

From the figure above one clearly reaches limited returns when expanding beyond four heads as there is a significant computational burden but very little efficiency benefit with an increase to 8 heads from 4, and an efficiency detriment if we increase to 16 heads from 8. It is curious that four heads are also optimal for causal transformer models of similar size with respect to loss achieved per unit of compute applied during training for this dataset as well.

Interestingly, however, the multi-headed mixer autoencoder experiences instabilities late in training manifesting in very rapidly exploding gradients for models with one or two heads. As this was observed early in training for autoencoders without causal masks, one straightforward explanation would be a problem with the multi-headed mixer’s masks. We causally mask the convolution in each head, and a quick test shows that the encoder and decoder modules are indeed causal. A relatively straighforward solution for this problem would be to add a layer or RMS normalization to the concatenated projections, or add residuals across head elements. We can avoid these architectural changes by carefully choosing a datataype, however, as explained below.

A straightforward way to address gradient overflow is to use a datatype with a wider dynamic range than fp16 (which is by definition e5m11), for example either fp32 (e8m23) or the nearly equivalently expressive bf16 (e8m7). bf16 is supported on newer hardware (A100s and H100s and RTX 30 series and above, and although it can be emulated with V100s this emulation cannot make use of the tensor cores in that GPU and is thus low-throughput) and is well supported for bf16/fp32 mixed precision training integration by Pytorch. When we train the same models using bf16/fp32 precision, we no longer observe catastrophic instabilities in multi-headed autoencoder trainings (see the above figure for an example) which supports the hypothesis that numerical overflow is the cause of training instabilities in multi-headed mixer autoencoders.

There is a substantial decrease in loss per step for causal autoencoders trained on the FineWeb as well, although we also find exploding gradients leading to a catastrophic increase in training loss for multi-headed mixer autoencoders. As the exploding gradient and loss problem appears pervasive for multi-headed masked mixer autoencoders, one can attempt to understand why this is the case. A good first candidate could be the datatype we are using for training which is fp16/fp32 mixed precision to allow for older hardware (V100) compatibility. Although this training is usually decently stable across a wide range of model, one can argue that an autoencoder of this type is inherently unstable with respect to gradient norms as all gradients flow through one vector, which is susceptible to numerical overlfow if the vector’s partial derivatives are sufficiently large.

decoder options

One can increase the number of trainable inter-token parameters in a simpler way than using multiple heads: using a convolutional kernel of size greater than unity (k=2 here denotes two kernels) scales the number of inter-token parameters as

\[p=n_l (k * n_{ctx}^2)\]

as there are simply $k$ weight maps per convolutional layer. From the figure above it is apparent that a flat masked mixer with k=8 achieves identical loss to the 4-headed mixer, but does not suffer the numerical instabilities associated with the latter. For clarity, a figure depicting how a $k=2, n_{ctx}=3, d_m=2$ layer operates is provided below: note here that kernels $k>1$ convolve not only across one hidden layer’s embedding index at a time as $k=1$ mixers do, but contain inter-token embedding parameters such that convolutional weights are trained where a token’s embedding element at position $n$ affects all other token’s embedding elements (causally, that is) at positions $\lvert n - p \rvert < k$. Effectively we have both inter-token and intra-token weights in one layer when $k>1$.

non-unitary kernel

It is also noteworthy that there is such a large difference in training efficiencies for multi-headed versus flat masked mixers for autoencoders. One can estimate the increase in training efficiency by finding the number of samples required to reach a certain loss value, and in this case one requires more than 6x the number of samples for the flat masked mixer to approach the loss achieved by the 4- or 8-headed autoencoder at 200k steps. For comparison, the difference in causal language model loss per step between flat and multi-headed mixers is very small: the flat mixer requires only around 1.05x the number of samples to match the 4-headed mixer when trained on TinyStories. If we implement a causal masked mixer while keeping the projection dimension equal to $d_m/n_h$, we find that there is very little difference in per-step loss between this model and the flat masked mixer when trained on the FineWeb-edu (10BT) dataset.

causal mixer heads

Text Compression Background

Although it may not seem to be very important to the field of artificial intelligence, text compression in many ways has been proven time and time again to be a very apt predictor of a language model’s abilities across the spectrum and has been shown to be important for everything from language generation to q/a chat capability to chain-of-thought reasoning capability.

Language compression is an old goal, and attempts to understand how to compress text and how much text can be compressed go back to the beginnings of information theory. Shannon’s seminal work focuses on the problem of compression and transmission of textual information, as does earlier work in the field from Hartley and Nyquist. There were practical reasons for that: it has long been appreciated that one needs to send much less information to regenerate a string of characters than to generate the sound of someone speaking those characters, so figuring out exactly how much information is required was and is an important problem to data transmission.

We focus on compression rather than transmission, although one can view deep learning models themselves as being somewhat similar to noisy communication channels. One of the most general methods of measuring text compression is bits per byte (bpb), which is the number of bits required for encoding a byte of input text. Often the input is assumed to be encoded in UTF-8, which uses one byte per character and makes this measure effectively the number of bits per input character if single-byte encoding predominates.

Although less well known than other model capabilities, the most effective text compression methods today are frontier large language models trained to predict each next token in a very large corpus of text. The gap between classical compression algorithms and these models is vast in the scale of information theory: perhaps the most widely used compression algorithm gzip achieves up to around 2 bits per byte, highly tuned dictionary and bit matcher decoders achieve around 1.2 bits per byte, whereas Deepseek v3 achieves 0.54 bits per byte.

The way large language models are usually used to compress text is simply by being able to predict next tokens, where the compression is simply the number of bits required to correct errors in the model’s prediction for each token. Causal language model-style compression is nearly as old as text compression itself. For example, Shannon used next character prediction by people as a way to attempt to estimate the source entropy for the English language. Shannon estimated a lower bound of around 0.6 bits per character, very similar to what we see for large language models today.

There is a clear upper bound to causal language model text compression, however: in natural languages such as English, there is a certain amount of inherent ambiguity such that no particular word necessarily follows from a given previous sequence of words. This is what Shannon referred to as ‘source entropy’, and it may be thought of as irreducible error and provides a hard lower bound on the bits-per-byte compression of a causal-style model.

With this in mind, we have a clear alternative to next token prediction-based compression. We can use our new masked mixer-based autoencoder to compress the input directly and thereby avoid the problem of source entropy alltogether. The reason for this is that our autoencoder effectively compresses the entire input into one vector and uses this vector to reconstruct that input, where the source entropy becomes irrelevant for an arbitrarily powerful autoencoder capable of capturing all necessary information in the embedding vector. In real-world applications the source entropy is clearly important to the ease of actually training a model (as we shall see for mathematical versus general text later), but in the idealized scenario of arbitrary power our only lower bound is the size of the autoencoder’s embedding.

Text Compression via Autoencoders

If we have a negative log likelihood loss $\Bbb L$, we can compute the number of bits per input byte for a given segment of text if we know the length of text in bytes $L_b$ and the number of tokens required to encode that text for a given tokenizer $L_t$.

\[\mathtt{BPB} = (L_t / L_b) * \Bbb L / \ln (2)\]

On this page we report loss as the torch implementation of CrossEntropyLoss, which is equivalent to taking the Negative Log Likelihood of a softmax-tranformed logit output. This means that we can simply substitute our CEL loss values for the negative log likelihood $\Bbb L$ values (the softmax simply transforms the model’s outputs to a probability distribution). We also make the simplifying assumption that our text is encoded in single-byte UTF-8.

We can now compare the compression achieved using masked mixer autoencoders to that obtained when using next-token-prediction models. Taking the FineMath 4+ dataset and a tokenizer that averages 2.82 characters (which equals 2.82 bytes assuming single UTF-8) and a model with a 512-dimensional embedding with an $n_{ctx}=512$ stored using 4 bits per parameter, we can calculate the amortized BPB as follows:

\[\mathtt{BPB} = \frac{n_p * b_p}{n_{ctx} * (L_b / L_t)} \\ \mathtt{BPB} = \frac{512 * 4}{512 * 2.82} \approx 1.42\]

assuming that we have zero loss after training (we actually have around $\Bbb L=0.7$). This compares disfavorably with the compression achieved by a causal language model transformer on this dataset using approximately the same compute,

\[\mathtt{BPB} = (L_t / L_b) * \Bbb L / \ln (2) \\ \mathtt{BPB} = (1/2.82) * 1.4 / \ln(2) \approx 0.72\]

A straightforward way to attempt to increase the compression in our autoencoder is to use a smaller embedding, perhaps to 128 parameters. In the following figure, we train mixer and transformer-based autoencoders using embeddings of size $128$ on Fineweb using $n_ctx=512$. The 8k-size FineWeb tokenizer averages 3.92 characters per token on that dataset, resulting in the following lower bound for compression using this model:

\[\mathtt{BPB} = \frac{n_p * b_p}{n_{ctx} * (L_b / L_t)} \\ \mathtt{BPB} = \frac{128 * 4}{512 * 3.92} \approx 0.255\]

Unfortunately after 200k steps (which equates to 13 billion tokens, or around 12 hours on a 4x A100 node) these autoencoders are no where close to achieving this compression, as their loss is far from the origin. Somewhat encouraging is the observation that these models do not exhibit exponential decay scaling in terms of loss per number of tokens trained (below right, observe the lack of linearity on the semilog axes), suggesting that if much more compute were to be applied they may be capable of effectively reaching zero loss.

causal mixer heads

Embedding-augmented causal language models

Recalling our original motivation to introduce input embeddings to remove the source entropy of text for a language model compressor, it may be wondered if one cannot combine the advantages of the autoencoder with next token prediction-based compression. The reson why this might be beneficial is as follows: it is apparent that masked mixer autoencoders require a compressed $d_m$ that is too large (in bits per $n_{ctx}$ tokens) to improve upon the compression found using next token prediction models given the relatively small amount of compute we have been applying to train these models.

The primary reason for this is that each token (which usually represents between 3 and 5 bytes of text) is greatly expanded by the embedding transformation, whereby each token becomes represented by vectors of at least 512 elements, or at least 1024 bytes in fp16. This is such a large expansion that even our subsequent many-to-one token compression abilities do not give greater total compression than causal language models.

Some reflection may convince us that there is a good reason for this: it may be conjectured that our autoencoder is less efficiently trainable than a next-token-prediction model (for our compute) because it must generate the entire context window in one forward pass, whereas the causal language model generates one next token at a time using information from all previous tokens.

With this in mind, it may be wondered if we can combine the advantages of causal modeling and autoencoding to make an encoder-decoder architecture that exhibits superior compression ability to either autoencoder or causal language model alone, give similar compute to what has been used previously.

There are a number of options for how one could introduce the encoder’s embedding into the decoder. Three straightforward candidates are concatenation in the token dimension, concatenation in the embedding dimension, or a linear combination in the embedding dimension. For embedding concatenation, we decrease the size of the token embeddings for the decoder commensurately to maintain a constant decoder size while comparing to other methods. These are illustrated in the following figure for convience.

memory decoder architectures

When we test the use of these three methods on FineMath 4+ data using a masked mixer decoder as our causal language model, we find that in general they are similarly efficient to train with the embedding concatenation method obtaining a slightly lower loss than embeddings combination or token concatenation.

memory decoder performances

One may expect for a transformer decoder to exhibit more training efficiency if given a token concatenation relative to embedding concatenation or combination, and indeed this is what we find (this time applied to the FineWeb-edu dataset):

memory decoder performances

It appears from the above results that transformers are relatively invariant to how exactly the encoder’s embedding is introduced among these three methods, so for text compression purposes we will use them interchangeably.

Can adding an encoder’s embedding lead to increased compression? To answer this we first need to know how large our embeddings are (particularly how many bytes they require) and then we can convert this to a bits-per-byte value. Suppose one trains an embedding-augmented causal model where the embedding is of dimension $n_p$, each parameter being stored using $b_p$ bits, for a context window of size $n_{ctx}$ and $L_b / L_t$ bytes of input text per token. Then we can calculate the bits per byte required to store this embedding (amortized over the input) as previously via

\[\mathtt{BPB} = \frac{n_p * b_p}{n_{ctx} * (L_b / L_t)}\]

Once this value is known, we can find the loss offset $\Bbb L_o$ that corresponds to this extra required information,

\[\Bbb L_o = \frac{\mathtt{BPB} * \ln (2)}{(L_t / L_b)}\]

and add this offset to the causal language modeling negative log likelihood loss for next token prediction to find the total loss.

\[\Bbb L = \Bbb L(O(a, \theta), y) + \Bbb L_o\]

We call this the ‘normalized loss’ for brevity. For a relatively small embedding ($d_m=64$) assuming 4 bits per parameter, and with a context window of $n_{ctx}=1024$ we have the following normalized loss (all masked mixers in the following figure are flat, and all models except the transformer->transformer use embedding concatenation introduction which uses token concatenation introduction):

memory decoder performances

There is some expected behavior here: the embedding-augmented models begin training with higher loss (the result of the offset required for storage of the 64 floats in the embedding vector) but then approach or pass the purely causal model’s loss (or equivalently its bpb compression of the input). We can see this most clearly by observing the difference between embedding-augmented and plain causal model at a given step or causal model loss value as shown in the following figure. As we expect for the embedding to allow a model to circumvent the difficulty of learning a corpus with a certain amount of language-intrinsic source entropy, we would expect for the embedding to be less useful early in training when this source entropy is a small part of the total model’s cross-entropy loss with respect to the text corpus.

memory decoder performances

It is somewhat less expected that the masked mixer decoder appears to be able to learn to use the information present in the embedding more efficiently than the transformer decoder, a pattern that is particularly apparent later in training. As the transformer encoder -> transformer decoder model exhibits the same tendancy, this could result from an alignment between encoder and decoder with respect to architecture.

We may wonder whether this near-linear loss decrease for embedding-augmented models persists as models are trained on more tokens. We find that this is indeed the case: training on more samples (with the same lr scheduler and batch size etc.) leads to the masked mixer achieving the lowest total bits per byte, assuming 4 bits per parameter for the embedding. Even assuming that one can only quantize the embedding to 8 bits per parameter, we find that the memory transformer exceeds the compression of the causal-trained transformer on this dataset.

memory decoder performances

From these results it appears that the informational content of the embedding (only 64 parameters in this case) is not saturated even after a relatively long training run such that the loss continues to decrease nearly linearly. It seems safe to assume that the embedding-augmented mixer would be a more efficient compression model for a fixed compute applied at training even if the embedding were quantized with 8 or 16 bits per parameter.

The above results are obtained for flat masked mixers, and as we observed superior training efficiencies with multi-headed mixers of the same size one may expect to find better training properties for embedding-augmented causal mixers as well. It comes a some suprise, therefore, that an embedding-augmented causal masked mixer (using embedding dimension concatentation) with four-headed mixer layers in the encoder slightly underperforms the same model without heads as shown in the figure below. Note here that $d_m=64$ denotes the embedding dimension, where use use a 256-width encoder and 1024-width decoder.

memory decoder performances

A similar result is obtained where we instead increase the number of inter-token parameters by using convolutional kernels of size four rather than one, where we see no increase in training efficiency upon doing so when the encoder’s embedding is introduced via embedding concatenation but curiously not when we use token dimension introduction for the embedding, in which case there is a notable improvement in training efficiency.

memory decoder performances

But when we measure the ability of embedding-augmented masked mixers to compress FineWeb at a larger $n_{ctx}=1024$, we see that the multi-kernel mixer training actually exhibits decreased per-step loss relative to single-kernel mixers. As for FineMath, we again observe embedding concatenation to yield slightly more efficient training after 200k steps (~13 billion tokens).

memory decoder performances

Embedding Quantization

In the section above, we assumed an 8 bit per parameter quantization would be possible with minimal loss. Is this a reasonable assumption?

There are typically two ways one can take to quantize a model: either train the model using the quantization (or a facsimile of that quantization) which is called quantization-aware training or else train the model in full accuracy (here fp16/fp32 mixed precision) and quantize after training, which is known as post-training quantization. In this case, all we care about is quantizing the activations of one particular layer (activations between the down and up compression transformations) rather than the weights and activations, making this a somewhat atypical quantization investigation. We start by observing post-training quantization accuracy.

Perhaps the simplest post-training quantization we can perform is to cast the activations in our compressed representation layer to the desired data type, assuming that both data types are floating points. There are two pytorch-native 8-bit floating point data types: float8_e5m2 (five bits for the exponent, two for the mantissa, and one for the sign) and float8_e4m3fn (four exponent, three mantissa, and one sign bit and one bit pattern for NaN overflows), which were originally introduced in this paper. The difference between these data types is effectively more precision for e4m3 versus more range for e5m2.

Embedding Datatype Loss
Float16 (E5M10) 2.26
Float16 (E8M7) 2.28
Float8 (E4M3fn) 4.86
Float8 (E5M2) 6.65

Thus we have a large increase in evaluation loss when we go from seven precision bits to three. This is somewhat surprising given that we are performing relatively mild quantization on activations of only one layer. It may be wondered whether this is typical of layers in oracle memory models or not, and we can investigate this question by performing quantization on the output activations of each linear transformation in our model, observing the resulting evaluation loss with each quantized layer. In the folling figure, we see that indeed the activations of down and up projections are by a wide margin the most sensitive of all layers in the model.

memory quantization

When we observe the distribution of activations typical in this embedding layer, we see the reason: activations for ten samples divided into 128 buckets (the most represntable by 8 bits) lie entirely within [-7, 7] and thus we only require three bits in the mantissa, but there is substantial overlap in activation values given this bucket size near the origin. In the following figure, we observe the distribution embedding activation values for ten samples, aggregated on the left and separated by color (ie all red bars correspond to activations from one sample). The right panel shows that there generally not a distinct drift in mean and variance between samples.

memory quantization

A more sophisticated quantization approach is to use the statistics of the distribution of activations to attempt to assign buckets (quantized values) such that the probability of two distinct values becoming quantized to the same value is minimized. From the above figure it appears that our data follows an approximately normal distribution, which is typical of language model activations. One successful approach to near-lossless quantization of such values is the Bitsandbytes Int8 matrix multiplication algorithm in which activations and weights are quantized to 8 bits per element and multiplied and unquantized, with large outlier parameters held out and processed at full precision. We observe little to no large outliers in our activation distribution, and thus can assume that 8 bit quantization is applied to nearly every activation element.

We can apply the Int8() approach while avoiding any loss penalty resulting from quantizing weights in addition to activations as follows: we insert a linear layer ‘probe’ into the model in between the down and up compression transformations, and assign this linear probe to have weights corresponding to the identity matrix such that the input and output of this linear layer are identical, $Px = Ix = x$. As the identity matrix contains all zeros and ones, it may be trivially quantized in 8 bits without precision degradation. We can compare the loss obtained quantizing this probe to the loss obtained quantizing the up layer to observe the relative importance of activation-only versus activation-and-weight quantization.

Embedding Datatype Loss
Float16 (E5M10) 2.13
BNB Int8() on probe 2.47
BNB Int8() on up 2.52

Int8() is clearly a far more powerful quantization method than naieve casting, but we still have a substantial performance gap present. This model (and indeed this particular embedding layer) is much more sensitive to quantization than a typical causal language model, which would normally see a <2% loss difference upon Int8() quantization from FP16 data.

If there is some difficult-to-reduce error upon post-training quantization, the usual strategy is to train a quantization-aware model. We have a particularly simple quantization goal: one layer’s activation quantized to around 8 bits per parameter. As most trainings are performed on older hardware that does not natively support the newer 8 bit datatypesin their tensor cores, we do not actually train using 8 bits per activation but instead use a fascimile of this: we add noise to Float16 activations to approximate the precision achievable using 8 bits, specifically we target the E4M3 datatype with three mantissa bits.

To perform quantization simulation, we add our embedding vector to a vector of identical size of uniform noise scaled to 1/2 the desired precision of $2^{-3}$ (ie x += torch.rand(x.shape) * 2**-4). This makes our down and up operations equivalent to the following where $q=2^{-4}$ and $\mathcal{U}_e(-q, q)$ indicates a vector with the same dimensionality as our compressed embedding $e$ is formed by sampling a uniform distribution bounded by $[-q, q)$,

\[O_{up} = W_{up} \left( W_{down}x + \mathcal{U}_e(-q, q) \right)\]

The use of noise addition to weights to estimate the information required to store those weights is an old trick in the field, and an early use of uniform noise addition to weights in order to estimate the number of bits required to store those weights is found in Sejnowski and Rosenberg. In the linked papers, the authors sought to understand the contribution of each weight (layer) to the model in question by adding noise at various magnitudes and observing the inference accuracy, and the ability to re-train the ‘damaged’ model. In the latter, they estimate the number of bits required per parameter based on this noise amount combined with knowledge of the range of values present. We modify those approaches for quantization-aware training by injecting noise upon each forward pass (rather than only once) in the activation rather than weights and training from scratch, rather than re-training a previously-trained model after one-shot noise addition.

At first glance it might seem strange to add noise to simulate quantization, but doing so simply decreases the effective precision of the embedding vector as given a noised output one can only assign an approximate value. As quantization decreases the effective precision (and range, but that is not as relevant to this implementation) in a similar way, we only have to scale the noise appropriately for the target quantization to achieve near-zero loss gap between noised and quantized implementations. We use a factor of 1/2 the desired quantization precision as the maximum distance any point is away from the closest quantized value is precisely half the distance between quantized values.

We first note that there is minimal (<0.1%) difference between loss achieved using our noise-added model versus a standard transformer-based embedding-augmented model after 100k training steps, indicating that our noised quantization-aware model trains as efficiently as our unquantization-aware model. When we evaluate this, we observe the following:

Embedding Datatype Loss
Float16 (E5M10) 2.56
BNB Int8() 2.55
Float8 (E4M3fn) 2.58
Float8 (E5M2) 2.65

We observe very little loss difference between the unquantized and Float8 E4M3-quantized model, which is what we wanted. We also find no loss decrease upon bitsandbytes Int8() probe quantization, and only a modest loss increase with Float8 E5M2. When we observe the distribution of embedding activations, we find that there is less clustering around the origin but notable we still observe no outliers. We thus expect to be able to quantize to 6 bits per activation using E3M2 with the same amount of loss as E5M2.

memory quantization

After 200k training steps, we have the following:

Embedding Datatype Loss
Float16 (E5M10) 2.424
BNB Int8() 2.437
Float8 (E4M3fn) 2.672
Float8 (E5M2) 3.253

and the distribution for activations is as follows:

memory quantization

The approximately normal distribution present in these activations would be expected to account for the gap between BNB Int8() and FLoat8 E4M3fn accuracy.

memory quantization

When we compare the distribution of this quantization-aware model to the non-quantized-aware model, we see that the mean is still centered around the origin but the variance is larger, and the distribution is notably flattened. This is to be expected, as activation values must be pushed farther apart in order to remain distinct upon addition of the uniform noise.

memory quantization

When we observe the sensitivity of each layer in a noise-induced quantization-aware trained model compared to our original full-precision model, we find something interesting: not only are activations of the layer to which we added the noise (down) no longer sensitive to 8-bit quantization, but also neither are the other layers that were found to be relatively quantization-sensitive (early transformer layers in the decoder and encoder, up projection, etc.). This suggests that injection of noise into an arbitrary model layer’s activations is capable of imparting quantization insensitivity on all layers.

memory quantization

This is also true of the weight layers: noise injection into the compressed embedding layer activations imports quantization insensitivity to both encoder and decoder weight layers as well as the compression layer weights.

memory quantization

We can hypothesize that the converse of the last statement may also be true: if noise is injected into another layer, it may result in quantization insensitivity in our vector of interest (the compressed layer activations). Rather than injecting noise directly into our compressed embedding layer, we inject before the compression (down) transformation and add a residual across the injection point. As shown in the following figure, we do indeed see quantization insensitivity in our layer of interest as well as all others in this model.

memory quantization

The real benefit of this approach is that it empirically results in greater quantization insensitivity in the compressed embedding, while also leading to slightly lower loss after training: the following figure shows that Bitsandbytes Int8 quantization now results in less than a 0.1% loss increase from the Float16 reference.

Embedding Datatype Loss
Float16 (E5M10) 2.4203
BNB Int8() 2.4228
Float8 (E4M3fn) 2.4687
Float8 (E5M2) 2.6070

When we re-train noise-injected (at the compressed embedding unless otherwise noted) models and compare to a re-trained full-precision model, we find that there is a non-obvious relationship between increased noise magnitude and training efficiency. Most notably, there is little to no difference between training efficencies of the QAT model with $2^{-4}$ uniform noise injected at the compressed embedding, a model with $2^{-2}$ noise injected inside a residual before compression, or the full-precision model.

memory qat model training

To conclude, we find that embeddings can indeed be quantized to 8 bits per parameter with no real loss difference (using BitsandBytes Int8()) from the unquantized embedding using noise-injected QAT models. As this QAT method imparts minimal increase in evaluation loss compared to full-precision training, we conclude that the compressed embeddings of a QAT model are quantizable without loss, with little to no training efficiency difference from the full-precision model training detailed previously.

Estimation of language entropy with embedding-augmented models

We seen that the introduction of an (oracle) embedding into a causal decoder allows for more efficient language compression than either a standard one-pass encoder-decoder or causal decoder-only architecture. It is particularly noteworthy that embedding-augmented models exhibit different asymptotic training characteristics relative to causal decoder models: for this section, it is important to observe that causal model loss during training is well approxiated by a log-linear relationship, where a constant decrease in loss occurs upon a fixed magnitude change in training inputs (ie number of tokens or training steps). For example, if we may see a 0.4 CEL decrease upon doubling of the number of input steps, we would expect to see another 0.4 CEL decrease upon another step doubling.

Now consider the following question: how effective can language models possibly be? We can think about this in terms of both how much incompressible information language exhibits, or in terms of the intrinsic entropy present in langauge itself. To see that language does have non-zero intrinsic entropy, for any given word on this page observe that you could replace the rest of the text with an equally grammatically valid (and factually valid etc.) unique completion.

One straightforward way to estimate this intrinsic language entropy is to train large causal decoders on more and more tokens using a cross-entropy loss metric, and observe this loss after the training has converged. But even supposing that we had an arbitrarily large model capable of training on finite hardware, this approach is not computationally feasible: as long as loss curves remain log-linear, not only does convergence never actually occur but we require an exponential amount of data and compute to achieve a linear increase in estimation accuracy.

With the observation that embedding-augmented models experience linear rather than exponential decay during training, it is clear that these models are much more suitable for intrinsic language entropy estimation. The approach is straightforward, train sufficiently large encoder-augmented embedding models with smaller and smaller embedding sizes and observe the size required for convergence to 0 CEL. The smallest embedding size required for the model to achieve zero loss (on hold-out data) is the langauge entropy estimation.

Combining embeddings with causal decoders allows for token-specific entropy estimation

When large language models are trained, the current approach is to feed a large number (on the order of trillions) of tokens to these models, and doing so requires a very large amount of compute to process many forward and backward passes on this data. The paradigm of language model prediction ability scaling with the input data size has been observed since 2016 or so, and is likely to be observed further into the future as well. But the finite amount of language data available means that input data cannot be scaled without reaching a limit. More recently it has been observed that models trained on filtered data (selected for factual accuracy, informational content etc.) perform substantially better on downstream tasks such as those present in benchmarks used to measure large language models (MMLU etc).

We approach the filtering of data at a very granular level, at the token rather than passage. Given a certain passage (which we assume can be tokenized to fit in a given context window), which tokens are ‘more important’ to be trained on and which are less? This is a somewhat subjective question as it depends on what one wants a language model to do, but from the perspective of training a model to minimize language entropy there is a single answer: as no model can surpass the entropy of a conditional next token, training should be performed such that the model’s weights are not modified in order to attempt to do so. It is clear that attempting to train a model past each token’s intrinsic entropy is impossible, and in practical terms is a waste of compute and data.

How can one perform this entropy-aware training? The first step is to be able to estimate what the entropy of each (conditional) token is, and happily we have the perfect model to be able to do so in an efficient manner, at least in relative terms compared to the other tokens in the given corpus.

Once the relative token entropy is estimated, the second step is to incorporate this information into the training algorithm such that the model is only marginally modified to fit the high-entropy tokens, while low-entropy tokens are more strongly fit. This can be done by simply assigning cross-entropy loss weights to be the complement (1-x) of our relative entropy values such that larger loss weights are assigned to tokens with lower entropy. The idea here being that at the start of training, models predict all tokens with high entropy (see the cross-entropy loss at the start of training). Tokens that have high conditional entropy require less modification of this initial model state than tokens of low entropy, and thus smaller steps in the model’s weights for these tokens relative to low-entropy tokens result in the model reaching the intrinsic entropy loss value for both tokens, assuming that model weight modification scaling is proportional to the scaling of loss per token.

Taking a step back, does it make sense to decrease the changes made to a model with respect to high- and low-entropy tokens? One approach to language modeling is to simply train on everything you can get your hands on, with the idea that a model can ‘soak up’ the data and will perform better than when trained on curated data. This is an inefficient way to train models, however, as it has been shown numerous times that models trained on filtered data far outperform models trained on unfiltered data. The reason for this is that it is not inaccurate to think of a model as a sponge that can indeed ‘soak up’ the training data, but that this sponge is finite in size and can only soak up so much given a fixed amount of compute (or data). In this analogy, we want the model to attempt to learn the aspects of a dataset that indeed learnable, rather than the ones that are fundamentally not such as token prediction where the tokens contain large intrinsic entropy.

Memory Model Introduction

The ability to compress information from a sequence of tokens into one embedding in an efficient manner has another utility: we can use these embeddings to provide exended context to a model without increasing its inference computation. Extensive research has been performed on methods to reduce the amount of computation and memory required to train and perform inference on a model applied to $n$ tokens, and this problem has been particularly relevant to recent advances in code generation, mathematical problem solving, and other domains benefitting from chain-of-thought test-time compute scaling. In this paradigm, the performance of a model scales with the number of tokens one generates (before generating a final answer) such that inference compute and memory become significant bottlenecks.

On the compute side, transformers scale with $O(n^2)$ for $n$ tokens (assuming $K, V$ caching) making very long context window inference prohibitevly expensive unless one has access to large GPU clusters. This is true regardless of whether one uses a mixture of experts or attention innovations such as multi-headed latent attention, as these only introduce constant factors to the scaling complexity. Memory requirements are $O(n)$ during inference if $K, V$ caching is used and $O(n^2)$ during training as caching cannot be used as one must backpropegate across all activations.

Decreasing both memory and compute growth is an active area of current research, with most efforts aligned with attempts to use sub-quadratic complexity attention or attention alternatives. Here we take a different approach: our generative model is still $O(n^2)$ compute (and $O(n)$ memory if caching is implemented) for inference regardless of mixer or transformer use, but we provide embeddings representing entire sequences as inputs to that model in certain indices rather than only embeddings representing tokens. For the case where we have one encoder model and one decoder of a fixed length, the compute required is $O(n)$ with length (a slight abuse of notation as this is only valid up to the length $n_{ctx}^2$) as one simply uses more of the decoder’s embedding indices for sequences rather than tokens, and similarly the memory scaling is $O(1)$ again only up to $n_{ctx}^2$.

The idea of compressing input sequences into embeddings that take the place of transformer token embeddings is not new, and was explored in various forms (see particularly the recurrent memory transformer). Such models were shown to be able to perform simple information retrieval (needle-in-haystack style) on the compressed inputs but little more than that, and certainly do not retain most information of the input. The insight here is that as we have seen that transformers are quite inefficient with respect to capturing input information in encodings but masked mixers are not, we can use masked mixer encoders instead to greatly increase the amount of information that is stored in each embedding.

The architecture we will experiment with here is mostly similar to the embedding-augmented causal langauge model architecture implemented above, where we use the token dimension concatenation to maximize the number of embeddings we can provide to the decoder model. We begin by testing the ability of a memory embedding to provide information to allow for an increase in decoder causal modeling accuracy where this embedding is formed from all tokens in the sequence, both previous and future, which we refer to as an ‘oracle’ memory. For clarity, a transformer-based architecture used for such oracle memories is as follows:

memory decoder architectures

After testing the ability of oracle memories to capture necessary information of the input to effectively minimize causal language modeling loss, we then explore models where the memories are only of past tokens.

The notable difference between this architecture and the token concatenation-based autoencoder introduced above is that we not longer care about compressing the embedding fed from the encoder to decoder. This is because if one uses token concatenation, each token in the decoder is converted to a vector of dimension $d_m$ such that it is natural to supply the encoder’s embedding as a vector of that same dimension. This also allows us to provide embeddings of $n$ encoded text sequences as $n$ embeddings, taking the place of $n$ tokens of the decoder. It is clear to see that this is much more efficient in terms of decoder input space than embeding concatenation, and avoid sthe problems of input ambuguity present when performing linear combinations in the embedding dimension.

Oracle memories can be nearly perfect with limited training

The first question we can ask is as follows: can a causal decoder make use of the information present in an autoencoder’s embedding? Essentially the question we want to address here is whether the information present in the embedding passed from encoder to decoder in the autoencoding architecture (where all tokens are generated in one forward pass) is beneficial for the case where we train a causal decoder (where each forward pass yields only one token). We can test this by observing training efficiencies for a causal masked mixer and an embedding-augmented masked mixer, where the latter may be implemented by training an autoencoder, discarding the decoder, adding a causal decoder, and freezing the encoder’s parameters during training. To save memory and compute during training, one could simply save the embeddings generated by this autoencoder to storage and reference them when training and evaluating the decoder, but we choose the former implementation to avoid a significant storage overhead. From the figure below, we see that the answer is yes: including an autoencoder’s embedding leads to substantially lower cross-entropy loss for the causal decoder.

What makes a frozen encoder effective? Does an embedding from an autoencoder that has been trained more effectively (to a lower cross-entropy loss, that is) leads to more efficient training than using an embedding from a less-trained autoencoder, and are autoencoders more effective encoders than next-token-prediction (which we refer to as causal language model, CLM)-trained decoders?

From the figure below, we see that the answer to the first question is yes as there is a monotonic decrease in memory model loss as the autoencoder encoder’s loss decreases, although not to the degree such that a perfect encoder is capable of resulting in near-zero causal loss after our short training run (which is the case for optimized trainable encoders as we will shortly see). The latter observation likely indicates that the information learned during autoencoding (for a one-pass decoding) is fundamentally different from information that is useful for next token prediction.

The answer to the second question is that autoencoder encoders are indeed far more effective encoders than next token prediction-trained models, to such an extent that a rather poor autoencoder encoder (that only reaches a CEL of 3.5) far outperforms a causally trained model’s embedding that reaches a lower CEL of 2.6. Note that these models are all roughly compute-matched, such that one should accurately compare the causal embedding model with the most efficiently trained autoencoder. This is further illustration of the finding elsewhere that causal language models are not particularly effective information retention models, but rather tend to filter input information.

memory decoder architectures

Given that we have seen more effective encoders resulting in lower causal decoder model loss, it may be assumed that the same would be true if one fixes the encoder and uses a more powerful decoder, say by doubling the number of layers from 8 to 16 (which for strictly causal decoders results in a drop in FineWeb evaluation loss from ~2.9 to ~2.65). Somewhat surprisingly, however, this is not the case: the same layer doubling in the decoder leads to no significant change in loss for a memory model given the optimial frozen (CEL=0.3) encoder used above. This counterintuitive result suggests that the decoder’s power is not a limiting factor for frozen memory models with effective encoders, but rather the number of samples passed to the model is more important.

memory decoder architectures

Now that we have seen that casual decoders can indeed make use of encoded information, we can investigate the question of efficiency: is it better to train both encoder and causal decoder simultanously, or use an encoder from a previously-trained autoencoder and train the causal decoder separately? As shown in the left figure below, at least for a non-optimized encoder the answer is that training both simultaneously is far more efficient.

memory decoder architectures

As is the case for memory models designed for information compression (ie with very small embeddings), a multi-headed mixer showed no benefit over the flat masked mixer early in its training run. Instead, we see a precipitous drop in cross-entropy loss when increasing the convolutional kernel to four (from one) to the point that our $d_m=1024$ memory model is able to effectively perform a $512 \to 1$ compression on token embeddings after training for half a day on two H100s with very little error resulting from this compression.

memory decoder architectures

To conclude, the answer to our initial question in this section is yes: decoders (transformers or masked mixers) are certainly able to be trained to use information present in an encoder’s embedding,

Memory model training efficiency

Now that we have seen that decoders are indeed capable of learning to use practically all information present in an encoder, we can proceed with training memory models wherre encoders store information from previous tokens, not the token to be predicted.

A little reflection can convince one that if it were efficiently trainable, the use of such memory embeddings would be extremely useful both for increasing a model’s effective context window without increasing its inference computation as well as for cases where one would want to cache a model’s previous state as is common in multi-turn conversation or agentic workflows.

Thus the question remains whether or not a memory model is actually efficiently trainable, with an upper bound being the per-step loss achieved by ‘full-context’ models, meaning models that do not separate the input into chunks and form embeddings of these chunks but train on the entire input at once. We start by training relatively small memory models in which both encoder and decoder are trainable.

The experimental setup to address this question is as follows: we first obtain tokenized inputs of a certain length (here 1024 total tokens, potentially including padding) and then divide these into four chunks of 256 tokens each. Our encoder then forms embeddings of the first three chunks, and the decoder uses either zero (for the first chunk), one (for the second), two (for the third chunk) or three previous embeddings (for the fourth chunk) as it predicts each next token in that chunk. We compare this memory model training efficiency to the same architecture but with no memory embeddings added as well as full-context models all trained on the same dataset (full-context models of course do not use chunked inputs).

We use smaller encoders than decoders ($d_m/2$ of the decoder to be precise) in order to make the memory models more similar to the full-context models in terms of memory and compute required per training step than is the case for full-size encoders, at least to within around 20% or so the throughput.

In the following figure (note the semilog axes), observe that transformers and masked mixers are both substantially benefitted by the addition of memory embeddings as expected. What is more surprising is that the memory mixer is nearly identical in per-step training efficiency to the full-context model albeit with a ~25% lower throughput. The full-context model uses right instead of left padding as mixers don’t use attention masks (left-padded training is much less efficient for full-context but not memory models).

This is not the case for transformers, however, such that the full-context version of that model remains substantially more efficient to train than the memory-augmented version. The per-step loss difference between full-context transformer and memory mixer halves as training proceeds, making it likely with more training the memory mixer would equal or surpass the full-context transformer in terms of loss at a given training step. In the following figure, each model is trained on $128 * 1024 * 200000 \approx 26*10^9$ tokens.

memory decoder architectures

It may be wondered whether it is beneficial to fix the positions of memory embeddings and token embeddings or else allow the indices of the start of token embeddings to vary. The difference between the fixed-position and variable-position embedding implementation may be depicted as follows:

fixed vs variable position

Masked mixers effectively use fixed, absolute positional encodings such that is is natural to use fixed-position embeddings. But as this is not the case for transformers, such that it is useful to compare the training efficiencies between fixed and variable position embeddings. As shown in the following figure, there is a rather small increase in efficiency using fixed positional encodings for transformers.

fixed vs variable results

A separation between encoder and decoder allows for memory- and compute-efficient training

It may also be wondered how these encoder-decoder memory models compare with decoder-only-style memory models with respect to training efficiency. A notable example of this is the recurrent memory transformer architecture in which a decoder model reserves one or more embeddings as memory ‘tokens’. For causal language modeling, this means that these decoders are tasked with both encoding (in the sense of storing information in the memory embeddings) as well as decoding, in the sense of using embeddings of tokens as well as sequences of tokens to generate individual tokens.

To show the difference between the encoder-decoder memory models as defined above (which we can think of as ‘parallel’ memory models) and recurrent memory models, the following diagram illustrates how each model type processes inputs of up to three chunks.

memory decoder architectures

It is apparent that the encoder-decoder memory model exhibits a number of notably beneficial features compared to the decoder-only recurrent model both for training and inference. If both encoder and decoder are trainable, the total memory would be approximately equivalent to the recurrent model if the encoder were smaller than the decoder but the encoding can occur in parallel, avoiding the high sequential computational cost of back-propegation through time. In a similar manner, using parallel encoders also allows one to avoid the known difficulties associated with normalizing gradients accross many decoder segments typical of BPTT and recurrent memory transformers in particular. The same parallelizations are possible during inference, meaning that a long prompt may be processed much more quickly using the parellelizable memory model than the recurrent version.

We test the training efficiency of parallel and recurrent memory models by comparing losses achieved on FineWeb inputs of up to 1024 total tokens with chunk sizes of 256 tokens, meaning up to three parallel memory tokens or a depth-four decoder stack for the recurrent model. We use one memory embedding for the recurrent model and perform back-propegation through time for the entire recurrence. In the figure below, we see that parellelized memory models are slightly more efficiently trainable in this paradigm, but that this effect is rather subtle and in the case of transformers likely not significant. The difference in transformer versus masked mixer losses per step are mostly accounted for by superior performance mixers exhibit for small-context (here 256 tokens) training windows.

memory decoder architectures

If the memory models contain trainable encoders, these two architectures are very similar in memory and compute requirements for a given input and model size. This is because these models form gradients on all $n$ tokens of their extended context, which occurs for RMTs via back-propegation through time and for memory models via gradients on encoders. In the case of RMTs, it was shown to be necessary to perform this back-progegation through time in order to maintain training efficiency, and additionally other approaches that only back-propegate in local chunks (ie transformer-xl) exhibit worse performance.

Recalling that recurrent memory models combine encoder and decoder tasks into one architectural unit, this is not particularly surpising: clearly training an encoder is not efficient if one does not backpropegate an encoder’s gradients to model blocks on token indices that are actually required for information retention. Separating encoders from decoders would be expected to largely ameliorate this issue: instead, we can train an encoder first on all necessary token chunks and then use this model to form the memory embeddings that may be used to train the decoder without requiring gradient flow to the encoder. The fundamental idea here is that one may separate the encoding (information-saving) function from the decoding (information-discriminating) function in order to achieve very large memory savings during training.

Is a memory model with a frozen encoder efficiently triainable? It would not be of much use to train using a frozen encoder if the decoder was not able to learn to use the encoder’s information efficiently in the first place. We can test this by comparing per-step losses of frozen encoder models to no-memory and trainable encoder-based memory models. The following figure (left) shows the training losses achieved using a frozen encoder with an architecture matched to the memory model decoder where both encoders achieve an autoencoder CEL of <0.3 on 512 tokens.

memory decoder architectures

To display the differences between frozen and unfrozen memory model encoder training efficiencies more effectively, the right panel shows the proportion of memory loss (1.0 being equal to the memory models, 0.0 being equal to the loss of no-memory model at each step) achieved by mixer and transformer frozen memory models. For both architectures, we see that the difference between frozen and trainable memory encoder decreases as training proceeds, but it is also apparent that mixers are more readily capable of using frozen encoder information compared to transformers. This is notably not due to the encoder itself, as the transformer encoder used here achieved a slightly lower evaluation cross-entropy loss ($\Bbb L =0.289$) compared to the mixer encoder ($ \Bbb L = 0.292$) on the same dataset.

memory decoder architectures

As we have seen in the text compression work that even trainable encoders are much more efficiently trainable if their architecture matches that of the decoder (meaning that we pair mixer encoders with mixer decoders and transformer encoders with transformer decoders), it may be wondered if a similar phenomenon would here such that the transformer decoder would be less capable of using the information present in the mixer’s embedding. We find that there is actually an increase in per-step loss for a transformer memory model if it is given the mixer embedding compared to a model with no memory at all, suggesting that this transformer decoder is more or less comopletely incapable of using a frozen mixer encoder’s information, at least early in training.

memory decoder architectures

Autoencoders and memory models do not learn trivial encodings

Thus far we have seen curiously diverging training efficiencies with architectural changes for the case where the encoder’s embedding is as large or larger than the context window, versus the case where an encoder’s embedding is significantly smaller than the context window. For example, consider the widely different effect of using more mixer heads or a $k>1$ convolution for large embeddings (where this leads to much more efficient training) compared to small embeddings (where it leads to a decrease in training efficiency). Another example is the sharp drop in efficiency in both autoencoders and memory models as one decreases the encoder embedding size past the $n_{ctx}$ boundary.

Why would this occur? One hypothesis is that large-embedding models simply form trivial autoencodings that are easy to train and that this is assisted by the architectural modifications we have explored above, whereas it is impossible for small-embedding models to form trivial autoencodings. What is signified by a ‘trivial’ autoencoding is one in which the information from input token indices are passed to the output (or at least decoder) such that nothing is actually learned about the actual data distribution (educational text or mathematical text in this case).

A good example of a trivial autoencoding is for the case where the model’s context window is equal to the embedding dimension, $n_{ctx} = d_m$, and the model learns to represent each single input token as in a single embedding element. On this page we typically use a tokenizer of size 8000 and embeddings encoded in 16 bits per parameter. Clearly each embedding element can encode a token fairly accurately (for all tokenizers up to size $2^{16}=65536$), so a powerful autoencoder might simply learn this encoding.

Testing for this one specific trivial encoding is not difficult, but one could imagine many other forms of equally distribution-free autoencoding such that it is difficult to directly test for all such encodings. Happily there is a somewhat direct way one can test for all trivial autoencodings as they are defined above: we can observe the loss (compression) for the model in question first on tokens that are drawn from a uniform random distribution and compare this to in-distribution loss. If nothing is learned about the actual data distribution, these values will be approximately equivalent.

Generating a uniform distribution on all possible tokens of a tokenizer is fairly straightforward: we can just check the size of an input and reassign the input as a tensor of random integers of the same shape.

input_ids = torch.randint(1, self.n_vocab, input_ids.shape)

For reference, we can decode these random inputs and confirm that they are indeed far outside the distribution of FineWeb-edu data, and the following shows that this is indeed the case.

adequate mot smart receive ruralgment wonvis requestusaloney |lessictues Pl legislationager guarduresiverse.comulin minutesí excessive ign-G blue pictures Environment kit hoursCE task enhanceuff oral Cast<|reserved_special_token_147|> individual.Cil Glick examined changing awayolesplace wid sector twenty du tox covered White<|reserved_special_token_13|> famouses influen.e

Does loss on these random tokens mirror loss on in-distriution data for large-embedding models, either autoencoders or memory models? The answer for both is no: across all models tested, the loss for these random strings is much larger than the in-disribution loss and indeed exceeds the untrained model loss (which is typically around 9.5). This is strong evidence against these models forming a trivial autoencoding as defined above.

random loss

Why does the untrained model have a loss of around 9.5? Untrained models (with either uniform or Kaiming normal weight initialization) typically exhibit activations that approximate Gaussian distributions, which is observed for vision models as well as for language models. As we are sampling tokens $n$ from a uniform distribution, we can compute the average cross-entropy loss between a normal distribution $\mathcal{N}$ over the tokenizer size (here $\vert t \vert = 8000$) and all possible token indices,

\[\frac{1}{n} \sum_n \Bbb L \left( \mathcal{N}((|t|, 0, 1), n \right) = 9.501\]

Thus the loss we observe for our untrained model is nearly equivalent to the loss obtained if we assume that the activations of the output layer are approximately normal.

We can also observe the generalization of a given model by comparing the loss achieved on in-distribution versus marginally out-of-distribution data. We use FineMath as our marginally out-of-distribution dataset for models trained on the FineWeb, and FineWeb for models trained on FineMath. We have already observed good generalization for in-distribution data for most models on this page (there is <5% duplication between train and eval datasets for either FineWeb or FineMath but very little difference in train loss versus test loss).

in and ood

These results tell us that near-distribution generalization scales in a very similar manner between autoencoders and memory models. Curiously, however, masked mixer-based models of both types tend to be somewhat better generalizers than transformer models, as shown in the following figure.

mixer generalization

Thus neither memory models nor one-pass autoencoders learn trivial encodings, regardless of whether masked mixer or transformer architectures are used. It is natural to wonder next whether these models are even capable to learning a trivial encoding at all. As we observe nearly similar generalization properties for mixers and transformers, we may be free to pick either architecture and test the ability of a model that otherwise learns non-trivial autoencodings to learn a trivial autoencoding by simply training on uniform random tokens used earlier for evaluation.

We employ an efficiently-trainable autoencoder architecture ($n_{ctx}=512, d_m=1024, k=8$ with $n_l=8$ for both encoder and decoder and repeated embeddings in the decoder) that is capable of reaching ~0.6 CEL on FineMath or ~1.5 on FineWeb after training on 13 billion tokens. We use bf16/fp32 mixed precision training after encountering numerical instabilities when using fp16/fp32 mixed precision training using this dataset.

As shown in the following figure, this autoencoder experiences virtually no loss minimization and thus does not learn a trivial autoencoding on these random tokens.

random train

How much information do embeddings contain?

We have seen that one can train autoencoders to compress information from many tokens into the embedding of one, and that this compression is non-trivial and reflects the distribution of the training data. This means that the embeddings of these autoencoders (the encoder’s embedding passed to the decoder) is capable of near-lossless 512:1 compression with respect to token embeddings.

It may be wondered whether this is at all remarkable: after all, one can train a language model in many ways such that it is possible that many types of training will result in similar information compression characteristics. Does this compression result if we train models objectives that are more common in the field, such as causal modeling to predict next tokens or noise-contrastive estimation for retrieval?

To restate, this section will address the following question: how much information is contained in embeddings from models trained for different tasks? It is often assumed that large models trained with lots of compute will be best at most tasks, but is this true if the task requires much or all of the information in the input to be compressed into an embedding?

Our experimental design is as follows: we load the weights of the encoder model of whichever type we want, discard the word-token embedding and language modeling head (and decoder if relevant), freeze the blocks, and train a decoder (and word token embedding transformation) to minimize the autoencoding loss on the dataset that the encoder was trained on (here FineWeb 10BT). With one decoder architecture and fixed-compute encoders, we can get an idea of how much information is present in these embeddings if we can train the decoder to convergence, or at least sufficiently close to convergence. We train on the same context window used for encoder training ($n_{ctx}=512$) and the decoder is of size ($d_m=1024, n_l=8$). We repeat the embedding unless otherwise noted.

We first consider these questions for masked mixers. As a positive control, we can confirm that this approach is capable of recovering nearly all the information present in an autoencoder’s embedding (below left) when the decoder recieves an unrolled projection of the embedding. This shows that our decoder-only training approach is sufficient to recover most of the information present in the embedding.

In the following figure, we find that causal (next-token) training results in a small increase in informational content of a model’s embedding (here second-to-last token embedding as the last token is untrained) compared to an untrained model, and that an embedding (at the same position) from a noise contrastive encoding retrieval model has somewhat increased informational content relative to the causal model. This indicates that the retrieval training process increases information retention, as this model was first trained for causal modeling before InfoNCE-based retrieval training was applied.

mixer information recovery

We see slightly less information recovery from a frozen transformer’s encoder compared to what was observed for mixers, and the same general small increase in information retention for causal models compared to untrained ones. Curiously there is actually a small decrease in embedding information for retrieval models relative to causal models, which may provide some basis for the finding that mixers are generally more efficient for retrieval.

mixer information recovery

For both model architectures, retrieval and causal embeddings contain only a small fraction of the information compared to the autoencoder embedding. This is not a subtle difference, and it cannot be reasonable argued that extended training of the decoder would result in any other conclusion.

The natural question to ask is how much information these cross-entropy loss values represent. The answer depends on our definition of information: one could define information as a Hamming metric on the tokenized output and target (input) tokens, such that the information present in the embedding is a measure of the proportion of correct tokens predicted.

Alternatively, we can define information retention using the cross-entropy as the fraction of cross-entropy loss the model reaches over the loss of an ‘informationless’ model. In this definition we want to understand what the cross-entropy losses would be for a model with perfect information and a model with no information, and normalize our obtained losses by these values. A model with perfect information in its encoder will clearly obtain zero cross-entropy loss (assuming an arbitrarily powerful decoder). The distribution with the least Shannon information is the uniform ($\mathbf U$) distribution by definition, so we can compute the cross-entropy loss corresponding to an informationless model by simply assuming that the model exhibits a uniform distribution $\mathcal{U} \sim [0, 1)$ over token activations. As our tokenizer is of size 8000, we find the following for $n$ tokens:

\[H(p_0, q) = \frac{1}{n} \sum_{n} \Bbb L \left( \mathcal{U}(|t|), t \right) = 9.03\]

where $t$ is sampled from the input distribution, or equivalently any distribution given that the reference is uniform, such that we have a range of $\Bbb L \in [0, 9.03]$ for our tokenizer. We can therefore define the embedding information as the complement of the fraction of our cross-entropy loss

\[I_e = 1 - \frac{H(p, q)}{H(p_0, q)} = 1 - \frac{- \sum_x q(x) \log (p(x))}{- \sum_x q_0(x) \log (p(x))}\]

which for our tokenizer simplifies to

\[I_e = 1 - \frac{H(p, q)}{9.03}\]

For mixers, we have the following:

Encoder Model Loss Input Information (%)
Autoencoder (validation) encoder 0.435 95.2
Autoencoder encoder 1.528 83.1
Untrained 5.937 34.3
Causal Trained 5.815 35.6
Causal -> Retrieval Trained 5.594 38.1
Autoencoder -> Retrieval Trained 5.846 35.3

and for transformers,

Encoder Model Loss Input Information (%)
Autoencoder (validation) encoder 2.924 67.6
Autoencoder encoder 2.935 67.5
Untrained 6.643 26.4
Causal Trained 6.214 31.1
Causal -> Retrieval Trained 6.380 29.3

By this metric, therefore, we observe that causal language and retrieval model training result in small increases in information retention, on the scale of 1-4%, compared to untrained models but that autoencoder training results in an order of magnitude larger information retention increase. We conclude that causal models do not retain most input information (and indeed barely more than untrained models) and somewhat suprisingly neither do retrieval models, whereas autoencoders do.

This is a notably different conclusion from another study using similar techniques to measure informational content in large causal transformers by Morris and colleagues. There, the authors found that one can achieve at least somewhat accurate inversion of models using output logits of a next token predicted after a hidden prompt is fed to the model. We note that this is likely due to a difference in scale: there, the authors were interested in regenerating prompts rather than entire text segments, and accordingly train decoders using a context window of 64 tokens rather than the 512 tokens used here. Most models in that work are furthermore much larger than those considered here, and the dataset considered is much more restricted (sytem prompts rather than arbitrary text). Here and Elsewhere it was observed that information retention in an embedding is highly dependent on context window size with smaller contexts being much easier to retain information from. In this light, the finding that causal models struggle to retain most information of arbitrary text with much larger context window is perhaps unsurprising.

Oracle memories are compressed even if they don’t need to be

We can also use this method to determine the informational content of the embedding-augmented ‘oracle’ memory models introduced in an earlier section on this page. Recall that these models combine an encoder with a causal language modeling decoder, and for large-dimensional (ie $d_m \geq n_{ctx}$) transformers and mixers with some mild architectural constraints these models approach 0 loss with limited training budgets. This begs the question: how much information is contained in the embedding generated by the encoder? One estimate is as follows: given that the decoder-only models achieve a CEL of ~2.6 on this dataset, so we achieve a bits-per-byte compression of

\[\mathtt{BPB} = (L_t/L_b) \Bbb L / \ln(2) = (1/3.92) * 2.60 / \ln(2) \approx 0.957\]

with the decoder alone. With the encoder, we have a compression (disregarding the encoder) of

\[\mathtt{BPB} = (L_t/L_b) \Bbb L / \ln(2) = (1/3.92) * 0.1 / \ln(2) \approx 0.036\]

meaning that the encoder is responsible for approximately 0.921 bits per byte, which is not very remarkable given that the encoder’s amortized memory for these large models results in an amortized 8 bits per byte extra. This is not nearly enough to accurately compress the 512 token context window, however, as shown below:

mixer information recovery

Thus the large-dimensional oracle memory embeddings contain more input information than causal model embeddings and untrained models, but still only exhibit retention of a fraction of the total information in the input. Recall previous results showing that that this relatively low-information embedding results in better next token prediction than a frozen high-information autoencoder embedding when paired with a causal decoder. As the decoder is fed all previous tokens at each forward pass, this suggests that a small amount of input information is necessary to provide next token information when paried with this previous token information.

How much information does this memory model encoder embedding contain compared to its capacity in terms of bits per byted? After training our decoder, we have a BPB compression of

\[\mathtt{BPB} = (L_t/L_b) \Bbb L / \ln(2) = (1/3.92) * 4.93 / \ln(2) \approx 1.81\]

whereas the amortized bits per input byte of the embedding ($d_m=1024, n_{ctx}=512$) assuming no compression is

\[\mathtt{BPB} = \frac{n_p * b_p}{n_{ctx} * (L_b / L_t)} = \frac{1024 * 8}{512 * 3.92} \approx 4.08\]

Thus the encoder has learned to compress the input by a factor of 1.81:4.08, even though it would not have to in the sense that an uncompressed embedding contains 4.08 bits per input byte of information which would be sufficient for allowing the decoder to achieve zero loss.

Equivalently we have an input of 512 tokens each containing 3.92 bytes, we have an input of 2007 bytes and thus our encoder contains around 2007 bytes * 1.81 bits/byte = 3633 bits of information per context window. This is much smaller than the (uncompressed) 1024 parameters * 8 bits/parameter = 8096 bits present in the encoder’s embedding, assuming 8 bits per parameter quantization.