# Untitled

unknown

plain_text

a month ago

5.8 kB

2

Indexable

Never

# Motivation - **Increasing parameters for fixed compute**: - **Sub-quadratic self-attention**: using optimal patch sizes, it can reduces the quadratic self-attention cost to $O(N^\frac{4}{3})$ by decomposing long sequences into two shorter sequences. - **Per-patch feedforward layers**: using large feedforward layers per-patch rather than per-position, enabling much larger and more expressive models for the same cost. Assuming a patch size of $P$, instead of using the same feedforward layer with $m$ parameters $P$ times, MegaByte uses a single layer with $mP$ parameters once for the same cost. - **Parallelism in decoding**: generating representation for patches in parallel. MegaByte models with 1.5B parameters can generate sequences 40% faster than a standard 350M [[Attention Is All You Need|Transformer]]. - **End-to-end training**: Tokenization complicates pre-processing, multi-modal modeling, and transfer to new domain while hiding useful structure from the model. - **Re-use of Established components**: Increasing the likelihood that the model will inherit the desirable scaling properties of transformers. - **Model's expressiveness**: Using only the global model will resemble a decoder version of [[Vision Transformer|ViT]] which would have a joint distribution over the patch $p(x_{t+1},\dots,x_{t+P}|x_{0\ldots t})$ which has an output space of size $256^P$ that is only tractable for very small patches. On the other hand, factoring the join distribution into conditionally independent distributions $p(x_{t+1}|x_{0\ldots t})\ldots p(x_{t+P}|x_{0\ldots t})$ would greatly limits the model's expressive power. # MegaByte Transformer ## Components ![[IMG_0188.jpeg|center|400]] ### Patch Embedder - Inputs a discrete byte sequence: $x_{0\ldots T}\in \mathbb{R}^V$ - Embeds each element (byte) with a global lookup table $E^\text{global-embed}\in \mathbb{R}^{V\times D_G}$ and adds positional embeddings: $$h_t^\text{embed}=E^\text{global-embed}_{x_t}+E_t^\text{pos}\qquad t\in [0\ldots T]$$ - Chunks it into patches of length $P$, to create a sequence of $K=\frac{T}{P}$ patch embeddings with dimension $P\cdot D_G$. To allow autoregressive modeling, the patch sequence is padded to start with a trainable padding embedding ($E^\text{global-pad}\in \mathbb{R}^{P\times D_G}$), and the last patch is removed from the input: $$h_k^\text{global-in}=\left\{ \begin{array}{rcl} E^\text{global-pad}\quad \ \ & \text{if } k=0 \\ h^\text{embed}_{((k-1)\cdot P):(k\cdot P)} & k\in [1,\ldots,K) \end{array}\right.$$ ### Global Model Contextualizes patch representations by performing self-attention over previous patches: - Inputs a sequence of patches: $h_k^\text{global-in}\in\mathbb{R}^{K\times(P\cdot D_G)}$ - Performs self-attention over previous patches (decoder-only layers): $$h^\text{global-out}_{0:K}=\text{transformer}^\text{global}(h_{0:K}^\text{global-in})$$ - To prepare the output for the local model: - Each patch in the output sequence is reshaped into sequences of length $P$ and dimension $D_G$, where position $p$ uses dimensions $p \cdot D_G$ to $(p+1) \cdot D_G$. - Each position is then projected to the dimension of the local model with a matrix $w^{GL}\in \mathbb{R}^{D_G\times G_L}$ where $D_L$ is the local model dimension. - The projected position is combined with input bytes embedded with a local lookup table $E^\text{local-embed}\in \mathbb{R}^{V\times D_L}$ . - The local byte embeddings is offset by one with a trainable local padding embedding ($E^\text{local-pad} \in \mathbb{R}^{D_L}$) $$h_{k,p}^\text{local-in}=w^{GL}h^\text{global-out}_{k,(p\cdot D_G):((p+1)\cdot D_G)} + \left\{ \begin{array}{rcl} E^\text{local-pad}\quad \ \ \ & \text{if } p=0 \\ E^\text{local-embed}_{x_{(k\cdot P+p-1)}} & p\in [1,\ldots,P) \end{array}\right.$$ ### Local Model - Inputs a contextualized patch representation: $h^\text{local-in}\in\mathbb{R}^{K\times P\times D_L}$ - Operates on a single patch $k$ to autoregressively predicts the next patch: $$h^\text{local-out}_{k,0:P}=\text{transformer}^\text{local}(h^\text{local-in}_{k,0:P})$$ - Computes the probability distribution over the vocabulary at each position, $t=k\cdot P+p$ : $$p(x_t|x_{0:t})=\text{softmax}(E^\text{local-embed}h_{k,p}^{local-out})_{x_t}$$ ## Variations and Extensions ### Convolutional Patch Encoder - Limitation: Chunking sequences is not translation invariant, so byte sequences may receive a different representation depending on their position in the patch. - Mitigation: Augmenting the patch encoder with causal convolutional layers (used a stack of convolutional layers with filter sizes of 3, 5, and 7). ### Cross-Patch Attention - Limitation: Local model depends on the global model for long-range information. - Mitigation: Allowing the local model to condition on $r$ elements from the previous patch, similar to [[Transformer-XL]] but differs by being fully differentiable. ### Strided Inference - Limitation: Per-token loss within each patch would increase towards the end of the patch, as the prediction relies more on the weaker local model. - Mitigation: Using **strided inference**, the sequence is predicted with two forward passes of the full model, whose inputs are offset by $p/2$ positions from each other. # Efficient Analysis ## Training Efficiency - **Attention**: $O(\frac{T^2}{P^2}+TP)$ - **Feedforward Layers**: $2T(\frac{m_g}{P}+m_l)$ FLOPS, or $\frac{2Tm_g}{P}$ when $m_g\gg m_l$ ## Generation Efficiency A standard [[Attention Is All You Need|transformer]] with $L_\text{local}+L_\text{global}$ number of layers and patch size $P$ requires a sequence of $O(P\cdot L_\text{local}+P \cdot L_\text{global})$ serial operations during generation, while MegaByte requires only $O(P\cdot L_\text{local}+L_\text{global})$ serial operations.