3-minute Pitch: Learning from MBR decoding using Direct Preference Optimization
Minimum Bayes Risk (MBR) decoding generally outperforms temperature sampling and beam search. But it is expensive computationally. We can train the model on the MBR decoding outputs so that cheaper decoding methods perform on par with MBR. Google calls it "MBR fine-tuning". Our recent work introduces a more efficient method using Direct Preference Optimization. We experimented on machine translation.
MBR Decoding Refresher
MBR decoding first temperature samples \(|H|\) candidates \(c_1, c_2, ..., c_{|H|}\). For each candidate, calculate its risk: \(r_i = \sum_{j\neq i} L(c_i, c_j)\). The loss function \(L(x, y)\) measures the quality of \(x\) using \(y\) as a reference. The candidate with the minimal risk is the final MBR decoding output. Intuitively, every candidate gets "risk evaluated" by every other candidate. The candidate with the least risks wins.
Previous MBR learning approaches only use the winning candidate to train the model. We hypothesize that if the model learns to how to rank candidates as MBR, it will learn how to do MBR. To learn the ranking, we use Direct Preference Optimization (DPO).
Direct Preference Optimization (DPO)
DPO learns a preference dataset \(\mathcal{D}=\{y_w, y_l, x\}\) using regular fine-tuning. \(x\) is the input, \(y_w\) is the preferred response, and \(y_l\) is the dispreferred response. We fine-tune the model with the following DPO loss. \(\pi_{\theta}(y|x)\) is the likelihood of \(y\) given \(x\) under the current model. \(\pi_{ref}(y|x)\) is the likelihood of \(y\) given \(x\) under the reference model. Training is initialized with \(\pi_{\theta} = \pi_{ref}\).
\[\mathcal{L}_{DPO}=\beta\log\frac{\pi_{\theta}(y_w|x)}{\pi_{ref}(y_w|x)} - \beta\log\frac{\pi_{\theta}(y_l|x)}{\pi_{ref}(y_l|x)} \]
The term \( r(x,y) = \beta\frac{\pi_{\theta}(y|x)}{\pi_{ref}(y|x)}\) is known as the language model's estimated reward. DPO training maximizes the estimated reward margin over the preference dataset.
Building Preference Dataset from MBR outputs
We cast the MBR ranking into a preference dataset to employ DPO. We investigated several strategies:
- BW. Select the best and worst MBR candidates, forming 1 pair per translation.
- BMW. Add the median candidate, forming 2 pairs per translation.
- CP. Use all consecutive pairs in the ranking, forming \(|H|-1\) pairs per translation.
- CPS. Introduces a stride in the CP strategy. For stride = 2, pairs are made with candidates whose rankings differ by 2.
To summarize our method, we 1. MBR decode the original model to obtain translation candidates and their rankings; 2. build a preference dataset from these MBR ranked candidates; 3. train the model on the preference dataset with DPO loss. We name our method DPO-MBR.
Experiments and Results
We experiment with the BLOOMZ and BLOOMZ-mt models on WMT21 and WMT22 zh-en (both directions). We DPO-MBR train on previous year's validation set and test on current year's test set. Key findings are:
- +3 BLEURT in doing beam search after DPO-MBR training. This is the same gain as doing MBR before DPO-MBR.
- The BMW strategy works best overall.
- Heavy regularization (\(\beta \geq 0.7\)) is needed for DPO training.
- Learning from \(|H|=8\) MBR candidates is sufficient.
Conclusion
We introduced DPO-MBR, our MBR learning method using Direct Preference Optimization. DPO-MBR is more data efficient by utilizing the rankings in MBR decoding. We show that DPO-MBR is effective for BLOOMZ and BLOOMZ-mt in NMT. The full paper is here. If you find this interesting, let me know in the comments or shoot an email!
The first author Guangyu Yang is my supervisee during the 2022-23 Cambridge MLMI MPhil program.
Comments ()