mail@pastecode.io avatarunknown
a month ago
5.8 kB
# Motivation
- **Increasing parameters for fixed compute**:
	- **Sub-quadratic self-attention**: using optimal patch sizes, it can reduces the quadratic self-attention cost to $O(N^\frac{4}{3})$ by decomposing long sequences into two shorter sequences.
	- **Per-patch feedforward layers**: using large feedforward layers per-patch rather than per-position, enabling much larger and more expressive models for the same cost. Assuming a patch size of $P$, instead of using the same feedforward layer with $m$ parameters $P$ times, MegaByte uses a single layer with $mP$ parameters once for the same cost.
- **Parallelism in decoding**: generating representation for patches in parallel. MegaByte models with 1.5B parameters can generate sequences 40% faster than a standard 350M [[Attention Is All You Need|Transformer]].
- **End-to-end training**: Tokenization complicates pre-processing, multi-modal modeling, and transfer to new domain while hiding useful structure from the model.
- **Re-use of Established components**: Increasing the likelihood that the model will inherit the desirable scaling properties of transformers.
- **Model's expressiveness**:  Using only the global model will resemble a decoder version of [[Vision Transformer|ViT]] which would have a joint distribution over the patch $p(x_{t+1},\dots,x_{t+P}|x_{0\ldots t})$ which has an output space of size $256^P$ that is only tractable for very small patches. On the other hand, factoring the join distribution into conditionally independent distributions $p(x_{t+1}|x_{0\ldots t})\ldots p(x_{t+P}|x_{0\ldots t})$ would greatly limits the model's expressive power.

# MegaByte Transformer
## Components
### Patch Embedder
- Inputs a discrete byte sequence: $x_{0\ldots T}\in \mathbb{R}^V$
- Embeds each element (byte) with a global lookup table $E^\text{global-embed}\in \mathbb{R}^{V\times D_G}$ and adds positional embeddings: 
  $$h_t^\text{embed}=E^\text{global-embed}_{x_t}+E_t^\text{pos}\qquad t\in [0\ldots T]$$
- Chunks it into patches of length $P$, to create a sequence of $K=\frac{T}{P}$ patch embeddings with dimension $P\cdot D_G$. To allow autoregressive modeling, the patch sequence is padded to start with a trainable padding embedding ($E^\text{global-pad}\in \mathbb{R}^{P\times D_G}$), and the last patch is removed from the input:
  $$h_k^\text{global-in}=\left\{ \begin{array}{rcl} E^\text{global-pad}\quad \ \ & \text{if } k=0 \\ h^\text{embed}_{((k-1)\cdot P):(k\cdot P)} & k\in [1,\ldots,K) \end{array}\right.$$

### Global Model
Contextualizes patch representations by performing self-attention over previous patches:
- Inputs a sequence of patches: $h_k^\text{global-in}\in\mathbb{R}^{K\times(P\cdot D_G)}$
- Performs self-attention over previous patches (decoder-only layers):
-  To prepare the output for the local model:
	- Each patch in the output sequence is reshaped into sequences of length $P$ and dimension $D_G$, where position $p$ uses dimensions $p \cdot D_G$ to $(p+1) \cdot D_G$. 
	- Each position is then projected to the dimension of the local model with a matrix $w^{GL}\in \mathbb{R}^{D_G\times G_L}$ where $D_L$ is the local model dimension. 
	- The projected position is combined with input bytes embedded with a local lookup table $E^\text{local-embed}\in \mathbb{R}^{V\times D_L}$ . 
	- The local byte embeddings is offset by one with a trainable local padding embedding ($E^\text{local-pad} \in \mathbb{R}^{D_L}$)
	$$h_{k,p}^\text{local-in}=w^{GL}h^\text{global-out}_{k,(p\cdot D_G):((p+1)\cdot D_G)} + \left\{ \begin{array}{rcl} E^\text{local-pad}\quad \ \ \ & \text{if } p=0 \\ E^\text{local-embed}_{x_{(k\cdot P+p-1)}} & p\in [1,\ldots,P) \end{array}\right.$$

### Local Model
- Inputs a contextualized patch representation: $h^\text{local-in}\in\mathbb{R}^{K\times P\times D_L}$
- Operates on a single patch $k$ to autoregressively predicts the next patch:
- Computes the probability distribution over the vocabulary at each position, $t=k\cdot P+p$ :

## Variations and Extensions
### Convolutional Patch Encoder
- Limitation: Chunking sequences is not translation invariant, so byte sequences may receive a different representation depending on their position in the patch.
- Mitigation: Augmenting the patch encoder with causal convolutional layers (used a stack of convolutional layers with filter sizes of 3, 5, and 7).

### Cross-Patch Attention
- Limitation: Local model depends on the global model for long-range information.
- Mitigation: Allowing the local model to condition on $r$ elements from the previous patch, similar to [[Transformer-XL]] but differs by being fully differentiable.

### Strided Inference
- Limitation: Per-token loss within each patch would increase towards the end of the patch, as the prediction relies more on the weaker local model.
- Mitigation: Using **strided inference**, the sequence is predicted with two forward passes of the full model, whose inputs are offset by $p/2$ positions from each other.

# Efficient Analysis
## Training Efficiency
- **Attention**: $O(\frac{T^2}{P^2}+TP)$
- **Feedforward Layers**: $2T(\frac{m_g}{P}+m_l)$ FLOPS, or $\frac{2Tm_g}{P}$ when $m_g\gg m_l$ 

## Generation Efficiency 
A standard [[Attention Is All You Need|transformer]] with $L_\text{local}+L_\text{global}$ number of layers and patch size $P$ requires a sequence of $O(P\cdot L_\text{local}+P \cdot L_\text{global})$ serial operations during generation, while MegaByte requires only $O(P\cdot L_\text{local}+L_\text{global})$ serial operations.