Post
1800
Good folks at
@nvidia
have released exciting new research on normalized Transformers (nGPT) for faster and more efficient language modeling!
Here is what they are proposing:
1. Remove all normalization layers, like RMSNorm or LayerNorm, from the standard Transformer architecture.
2. Normalize all matrices along their embedding dimension after each training step. This includes input and output embeddings, attention matrices (Q, K, V), output projection matrices, and MLP matrices.
3. Replace the standard residual connections with normalized update equations using learnable eigen learning rates for the attention and MLP blocks.
4. Change the softmax scaling factor in the attention mechanism from 1/sqrt of d_k to sqrt of d_k.
5. Implement rescaling and optional normalization of query (q) and key (k) vectors in the attention mechanism using learnable scaling factors.
6. Rescale the intermediate states of the MLP block using learnable scaling factors.
7. Implement rescaling of the output logits using learnable scaling factors.
8. Remove weight decay and learning rate warmup from the optimization process.
9. Initialize the eigen learning rates and scaling factors with appropriate values as specified in the paper.
10. During training, treat all vectors and matrices as residing on a unit hypersphere, interpreting matrix-vector multiplications as cosine similarities.
11. Implement the update equations for the hidden states using the normalized outputs from attention and MLP blocks, controlled by the eigen learning rates.
12. After each forward pass, normalize all parameter matrices to ensure they remain on the unit hypersphere.
13. Use the Adam optimizer without weight decay for training the model.
14. When computing loss, apply the learnable scaling factor to the logits before the softmax operation.
15. During inference, follow the same normalization and scaling procedures as in training.
Excited to see how it scales to larger models and datasets!
Here is what they are proposing:
1. Remove all normalization layers, like RMSNorm or LayerNorm, from the standard Transformer architecture.
2. Normalize all matrices along their embedding dimension after each training step. This includes input and output embeddings, attention matrices (Q, K, V), output projection matrices, and MLP matrices.
3. Replace the standard residual connections with normalized update equations using learnable eigen learning rates for the attention and MLP blocks.
4. Change the softmax scaling factor in the attention mechanism from 1/sqrt of d_k to sqrt of d_k.
5. Implement rescaling and optional normalization of query (q) and key (k) vectors in the attention mechanism using learnable scaling factors.
6. Rescale the intermediate states of the MLP block using learnable scaling factors.
7. Implement rescaling of the output logits using learnable scaling factors.
8. Remove weight decay and learning rate warmup from the optimization process.
9. Initialize the eigen learning rates and scaling factors with appropriate values as specified in the paper.
10. During training, treat all vectors and matrices as residing on a unit hypersphere, interpreting matrix-vector multiplications as cosine similarities.
11. Implement the update equations for the hidden states using the normalized outputs from attention and MLP blocks, controlled by the eigen learning rates.
12. After each forward pass, normalize all parameter matrices to ensure they remain on the unit hypersphere.
13. Use the Adam optimizer without weight decay for training the model.
14. When computing loss, apply the learnable scaling factor to the logits before the softmax operation.
15. During inference, follow the same normalization and scaling procedures as in training.
Excited to see how it scales to larger models and datasets!