RegMix: Data Mixture as Regression for Language Model Pre-training

Community Article Published July 11, 2024

Still following your human intuition to mix corpora from different sources for pre-training ๐Ÿง ?

Everyone says that data mixture has a big impact on model performance, but how - and why๐Ÿ•ต๏ธ?

Did you know that web corpora are actually highly impactful for downstream tasks ๐Ÿ†?

Check out our preprint "RegMix: Data Mixture as Regression for Language Model Pre-training" ๐Ÿ“„

๐Ÿ”ฌ In this paper, we've proposed an automatic data mixture method RegMix that achieves a 6.3% improvement over human selection on the widely used HellaSwag benchmark - and it only needs a 2% extra training FLOPs! ๐Ÿ“ˆ

image/gif

Data Mixture is Important, but Challenging

๐Ÿค–๐Ÿ“š Large Language Models (LLMs) are powered by vast, diverse datasets from the Internet, including academic papers, books, and various online sources (Gao et al. 2020). As LLMs grow in scale and complexity, the composition of their training data becomes increasingly crucial. The importance of data mixture was recognized early on by the creators of GPT-3, one of the pioneering LLMs. They deliberately chose to upsample Wikipedia content due to its perceived high quality.

๐Ÿงฉ The challenge: As the volume and diversity of data used in LLM pre-training continue to expand, the task of determining the ideal data mixture becomes increasingly complex. And the manual approach to data selection may result in suboptimal choices.

๐Ÿ”ฌ Key research question: How can we decide on a high-performing data mixture for training LLMs in a scalable and automatic manner?

Gao et al. 2020. The Pile: An 800GB Dataset of Diverse Text for Language Modeling, https://arxiv.org/abs/2101.00027

image/png

Core idea: small to large generalization

๐Ÿ’กWith the challenge of selecting the optimal data mixture in mind, our core idea is straightforward: train and identify the best-performing small-scale models using different data mixtures, and then directly generalize those findings to large-scale model training.

image/png

RegMix: Data Mixture as Regression

Concretely, our method RegMix treats data mixture selection as a regression task. Here's how it works:

  1. Train some small-scale proxy models on various data mixtures for few tokens ๐Ÿฃ
  2. Fit a regression model using these results ๐Ÿ“ˆ
  3. Use the regression model to predict the best mixture for large-scale training ๐Ÿ”ฎ
  4. Train the large-scale model on this optimized mixture ๐Ÿš€

The procedure of small-scale proxy model training requires only ~2% of the computational cost (in FLOPs) of the final large-scale model training.

To visualize the procedure, we provide a concrete example using Hacker News, GitHub, and PhilPapers as the training domain. The validation loss on StackExchange is used as the target metric to optimize during the proxy model training phase.

image/png

Regression Works Well Across Model Scales

๐ŸŽ๏ธ What's particularly exciting about RegMix is its efficiency. It allows you to explore a vast space of potential mixtures (even with 40+ domains) by training only a small number of models.

Specifically, training models on 1M models with 1B tokens can predict the performance of 256x 1M models trained on unseen data mixtures with 98.45% correlation.

Moreover, RegMix can automatically identify the best-performing data mixture among 64x 1B models with 25B tokens before actually training them๐Ÿ’ก๐Ÿ’ฐ.

image/png

Insight 1: Data mixture significantly impacts downstream performance

We experiment with 64 models, each with 1B parameters trained on different data mixtures, and evaluate their performance across various benchmarks. The results show that data mixture significantly impacts downstream performance - up to 14.6% difference on some tasks! ๐Ÿ˜ฎ

image/png

Insight 2: Web corpora benefits downstream performance the most

Web corpora like CommonCrawl ๐ŸŒ surprisingly show the strongest positive correlation with downstream performance for language models, even more than curated sources like Wikipedia! ๐Ÿ“š This pattern holds across most web domains, suggesting the diversity of CommonCrawl drives today's LM success. ๐Ÿš€

Moreover, whether it's gaming sites like IGN ๐ŸŽฎ or YouTube ๐Ÿ“บ, they exhibit similar patterns. But http://patents.google.com ๐Ÿ“„ and http://springer.com ๐Ÿ“— seem to follow different trends.

image/png

Insight 3: Domain interactions are challenging for humans to understand

Domain interactions are complex and often counterintuitive, highlighting the need for automated approaches like RegMix. ๐Ÿงฉ

For example, the PhilPapers domain appears to provide gains for all other domains under linear regression modeling, which challenges intuitive human understanding. ๐Ÿคฏ๐Ÿ“š So, what is PhilPapers? It is a database for philosophy โ€ฆ

image/png

RegMix considers the token availability

๐Ÿ”‘Previous data mixture methods struggle to balance token availability and usefulness. However, RegMix can easily control token availability by controlling the simulation space - especially considering the 4 epoch practise by Niklas et al. 2023.

๐Ÿ”ฌFor example, you can easily set the maximum weight of HackerNews to 12% in the simulation if you can afford to repeat it for 4 epochs and its token count is 3% compared to your expected training tokens.

Niklas et al. 2023. Scaling Data-Constrained Language Models, https://arxiv.org/abs/2305.16264

RegMix is already applied in 14B model

๐Ÿ”ฌ While our current paper was conducted to models under 1B parameters due to computational limitations, we successfully applied the same data mixture approach in our Sailor paper (Dou et al. 2024).

๐Ÿš€ Notably, we discovered that the optimal data mixing strategy identified using 0.5B proxy model demonstrated impressive scalability, performing effectively across models up to 14B parameters! ๐Ÿ’ช

Dou et al. 2024. Sailor: Open Language Models for South-East Asia, https://arxiv.org/abs/2404.03608 You can also find the paper at https://huggingface.co/papers/2404.03608

image/png

Try RegMix on your dataset

We also provide an instruction on how to apply the RegMix method to your dataset, and please try it and leave comments here!

image/png