[Paper Express] Data Selection for Language Models via Importance Resampling (DSIR)
README. Data Selection (DS) aims to select a given number of samples from a large, unlabeled dataset for training a capable model in a target domain. In the case of training langauge models, practical DS methods need to efficiently select from raw text corpus containing trillions of tokens. This paper, Data Selection for Language Models via Importance Resampling, proposes such a DS technique that uses importance resampling to obtain samples obeying the target distribution. I will give you a working knowledge of their methods and results in 3 minutes.
Data Selection with Importance Resampling (DSIR)
The problem is formulated as selecting \(K\) training examples from a raw dataset with \(N \gg K\) samples following distribution \(q\) to match a desired target distribution \(p\) given unlabeled target samples. The DSIR framework consists of three steps:
- Learn proxy distributions on “Bag of hashed n-grams” features: This consists of two steps: (1) Featurize a subset of the raw data and the entire target dataset. The featuried data are 10,000-dimensional vectors. Each dimension contains the count of ngrams that are hashed into the hash bucket associated with that dimension. (2) Fit bag-of-ngrams (BoN) model on the featurized data. The BoN model has 10,000 parameters, each associated with a hash bucket. The learned parameter is simply the average count of ngrams that are hashed into that bucket over the dataset (the Maximum-Likelihood solution). Two BoN models, \(\hat p_{feat}\) and \(\hat q_{feat}\), are fitted on the featurized raw and target datasets, respectively.
- Compute importance weights: Use the learned BoN model to compute the importance weight of each text chunk \(z_i\): \(w_i = \frac{\hat p_{feat}(z_i)}{\hat q_{feat}(z_i)}\). Text chunks are 128 space-delimited words that pass a set of manually-set filters (e.g., word length, repeat ratio, etc. See Appendix J).
- Sample: Sample \(K\) examples without replacement from a categorical distribution with probabilities \(\frac{w_i}{\sum_{j=1}^N w_j}\). This is efficiently implemented using the Gumbel top-k trick.
Domain-Specific Continous Pretraining Experiments
DSIR is used to select data from the 825GB The Pile dataset for 8 target domains (Computer Science papers, Biomedicine, News, etc.) to continue pretraining starting from a RoBERTa. The baseline systems are random selection, discriminative selection using scores from binary classifiers (heuristic classification), and manual curation. The amount of training is equal for different data selection methods. They use F1-score in downstream datasets as performance metrics.
Performance-wise, DSIR on average improves over random selection by 1.2% and manually curated data by 0.3%. It outperforms heuristic classification by 0.9%.
Compute-wise, DSIR spends 4.5 hours on one machine with 96 CPU cores to select 100M documents from the Pile. Most time (4.36 hours) is spent on computing the importance weights.
General-Domain Training Experiments
Setting Wikipedia and books as the target distribution, DSIR can be used to select data for training general-purpose langauge models. Qualitative study shows that DSIR selects more formal texts than heuristic classification using GPT-3 and random selection. The latter two frequently select code snippets. DSIR improves GLUE performance over the baselines by ~2%, achieving 82.3% average GLUE accuracy.
Concluding Thoughts
From my reading of the paper, DSIR is a practical data selection technique that yields consistent gains when the downstream task is text classification. It is plausible as classification often relies heavily on token (n-gram) identities. It remains to be shown whether the simple bag-of-ngrams model captures essential information for generative language models. It may also be interesting to experiment with the vision-equivalent of bag-of-ngrams for vision or vision-language tasks.
Comments ()