A case study in reproducibility of evaluation with RewardBench
A very technical deep dive. Lessons in evaluating, using, and training reward models on open-source infrastructure.
Recently, we added an inference command line interface (CLI) to RewardBench that made it easy for people to compute agreement of a reward model with any preference dataset. The workflow is as simple as:
pip install rewardbench rewardbench --model={yourmodel} --dataset={yourdataset} --batch_size=8
The goal is to provide a simple way to get more people evaluating their reward models in the oldest way out there: preference agreement. Preference agreement is the percentage of time that a reward model agrees with a the chosen and rejected labels associated with two compltions to a prompt.
We had no idea if people were really using this. A few days ago, we were greeted with a scary issue about results being variable with the RewardBench CLI on the UltraFeedback Dataset:
Eg: Running
rewardbench --model=PKU-Alignment/beaver-7b-v1.0-cost --dataset=allenai/ultrafeedback_binarized_cleaned --split=test_gen --chat_template=raw --save_all --batch_size=1
rewardbench --model=PKU-Alignment/beaver-7b-v1.0-cost --dataset=allenai/ultrafeedback_binarized_cleaned --split=test_gen --chat_template=raw --save_all --batch_size=2
rewardbench --model=PKU-Alignment/beaver-7b-v1.0-cost --dataset=allenai/ultrafeedback_binarized_cleaned --split=test_gen --chat_template=raw --save_all --batch_size=3
results in these different scores:
1. 0.429
2. 0.476
3. 0.480
This made it seem like something important was broken, so it was sort of a “drop everything to fix it” kind of day We’ve done a bunch of testing on this, and found the variance is pretty much fundamental to the open reward models out there — i.e. it is within the confines of model.forward(**)
rather than from our code. It effects most models, but there are a lot of reasons why (and they’re mostly out of our control).
In summary, we found that:
DeBERTa’s batch implementation on HuggingFace is not implemented, but it doesn’t throw a warning or error.
Truncation within tokenizers, especially when needing to add EOS tokens for classification, can cause problems.
All models having a fundamental amount of variance (and arguably being a little broken in Transformers).
Most popular reward models being implemented wrong, in a way that results will be unclear.
Chat templates being wrong induces further variance.
Bigger batch sizes may have slightly less variance in inference stability (if you know Cuda, please tell me if this is right).
DeBERTa issues
Looking at our simple pipeline for minimally running inference on raw text, we can check if the underlying model is deterministic under batches.
Let’s create a copy of samples going into the tokenizer:
samples_2 = [samples[0],]*5
Then, tokenize the samples (which will be identical):
inputs = self.tokenizer(
samples_2,
truncation=truncation,
max_length=max_length,
padding=padding,
# return_special_tokens_mask=True,
return_tensors="pt",
).to("cuda")
> {'input_ids': tensor([[ 1, 2569, 4873, ..., 4873, 1504, 2],
[ 1, 2569, 4873, ..., 4873, 1504, 2],
[ 1, 2569, 4873, ..., 4873, 1504, 2],
[ 1, 2569, 4873, ..., 4873, 1504, 2],
[ 1, 2569, 4873, ..., 4873, 1504, 2]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], device='cuda:0')}
And run inference. The numbers should be identical:
self.model(**inputs)
> SequenceClassifierOutput(loss={'logits': tensor([[-1.9756],
[-2.7129],
[-2.5938],
[-2.7422],
[-2.5684]], device='cuda:0', dtype=torch.float16,
grad_fn=<AddmmBackward0>)}, logits=tensor([[-1.9756],
[-2.7129],
[-2.5938],
[-2.7422],
[-2.5684]], device='cuda:0', dtype=torch.float16,
grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
Ultimately, the batching within the model forward is wrong. Consider the case of passing in 5 repeated entities directly into the model. A normal model output looks like this:
GPTNeoXRewardModelOutput(logits=tensor([[-1.1133],
[-1.1133],
[-1.1133],
[-1.1133],
[-1.1133]], device='cuda:0', dtype=torch.float16,
grad_fn=<AddmmBackward0>))
We didn’t test if this goes beyond AutoModelForSequenceClassification, but my colleagues seem to think it does.
Potential unrelated issue with truncation
In testing this, we were looking at the raw token outputs of many of the tokenizer configurations to make sure it was as expected. One case that we found is the tokenizer incorrectly truncating off the last tokens of the sequence, which will make it so reward models implemented with HuggingFace won’t know where to predict the reward. In HuggingFace AutoModelForSequenceClassification
the reward prediction is done on the first EOS token (which can be tied to the padding token) that does not have a 0 in the attention_mask
. In this case, with the reward model RLHFlow/RewardModel-Mistral-7B-for-DPA-v1, we found that the last tokens were incorrect out of the tokenizer with Truncation=True
.
tensor([[ 2, 2, 2, ..., 873, 28723, 2],
[ 2, 2, 2, ..., 23798, 28723, 2],
[ 1, 1, 733, ..., 21824, 5020, 970],
[ 1, 1, 733, ..., 272, 10725, 473],
[ 2, 2, 2, ..., 22447, 28723, 2]], device='cuda:0')
For context:
2 is the padding token (also EOS token index). All sequences should have one of those at the end of the sequence to then predict reward.
1 is the BOS token, which every sequence should have to start with. In this case, the tokens are being truncated from the right and then not adding the EOS token even though we set
tokenizer.add_eos_token = True
Every model in is “a little broken”
If you go down the rabbit hole of comparing LLM outputs on different devices and with minor changes to the configuration, you learn that it’s very hard to get exactly the same outputs. In our case, we found a few transformer issues that very closely align with the problems we’re seeing.
For example, for most of the modern models, the positional embeddings make it tricky. Not running with fp16 for quantization helps, but doesn’t solve it (we tested it to confirm this).
We are aware of this phenomenon on all (or nearly all) models that contain rotary position embeddings (Llama, Llama2, Falcon, GPTNeoX, ...). Running things in
fp32
helps avoid this problem, but that is far from a good solution.
Here’s a good summary of the situation, in another issue on what looked like a KV caching bug, but was just numerical weirdness.
How are reward models particularly implemented wrong?
When people train a reward model with HuggingFace, they’re normally using the AutoModelForSequenceClassification
class. This class outputs a score per each class it is trained on — i.e. a multi-class predictor. For reward models, we only want to train and predict on one class. For reward models, this looks like model.config.num_classes = 1
. Many, many reward models have more than one class.
Some models, such as RLHFlow/RewardModel-Mistral-7B-for-DPA-v1
are trained differently to do fine-grained RLHF and should have more than one class. Others have custom code to make sure the output is handled correctly.
At the end of the day, check the models you grab from the hub before you trust them. Here’s an example of a vanilla RM that seems to be implemented correctly.
Adding a feature to make this clearer: margin of scores
Models where the mean of the chosen and rejected scores are closer together are more sensitive to sources of noise changing the classification results. To observe this effect, we’ve added logging of the mean chosen and rejected rewards, along with the margin between them. The hypothesis is that better reward models will have better margins and be less sensitive to inference noise (which we can partially see in the distributions in the paper on the RewardBench data). See the Starling models below, compared to other models — there is a clear separation of the chosen and rejected rewards.
Given the focus of this issue was on batch size variance, I re-ran some of our experiments on the test set of UltraFeedback to observe margins between a model that is strong and implemented correctly and an older OpenAssistant reward model (which has the wrong number of classes for AutoModelForSequenceClassification
).
To improve the stability of inference, we made two minor changes to token preprocessing. We found this reduced the variance of outputs.
Truncation = False (because of the weirdness we saw), more on Truncation here: https://huggingface.co/docs/transformers/en/pad_truncation
tokenizer.padding_side = ‘left’, to avoid weird AutoModelForSequenceClassification behavior.
# Batch size 1
Results: 0.7644670050761422, on 985 prompts
Mean chosen: 4.246033551366196, std: 3.2116607020053767
Mean rejected: 1.6019037817940494, std: 3.062244882625771
Mean margin: 2.644129769572147
# Batch size 5
Results: 0.7705583756345178, on 985 prompts
Mean chosen: 4.248798556255205, std: 3.2116428887517916
Mean rejected: 1.6112177001643302, std: 3.057724844634685
Mean margin: 2.63758085609087
# Batch size 20
Results: 0.7756345177664975, on 985 prompts
Mean chosen: 4.23749801713198, std: 3.21121643868046
Mean rejected: 1.6067456318037159, std: 3.054211084773242
Mean margin: 2.6307523853282637
OpenAssistant/oasst-rm-2.1-pythia-1.4b-epoch-2.5
# Batch size 1
Results: 0.5898477157360406, on 985 prompts
Mean chosen: 1.0486571950960886, std: 3.2982734634650694
Mean rejected: 0.22475792066699962, std: 2.9019193261387204
Mean margin: 0.823899274429089
# Batch size 5
Results: 0.5989847715736041, on 985 prompts
Mean chosen: 1.0520252711881841, std: 3.29494203957563
Mean rejected: 0.2284763636322796, std: 2.9001652644001243
Mean margin: 0.8235489075559045
# Batch size 20
Results: 0.6, on 985 prompts
Mean chosen: 1.0458758029840924, std: 3.296012226103798
Mean rejected: 0.22225384010276214, std: 2.897595056971502
Mean margin: 0.8236219628813303
PKU-Alignment/beaver-7b-v1.0-cost with correct chat template
# Batch size 1
Results: 0.4517766497461929, on 985 prompts
Mean chosen: 3.610776872199199, std: 3.541165682239221
Mean rejected: 3.7702910912218432, std: 3.5245946555000787
Mean margin: -0.15951421902264434
# Batch size 3
Results: 0.45989847715736043, on 985 prompts
Mean chosen: 3.608702903592647, std: 3.5389827447119897
Mean rejected: 3.766082879855548, std: 3.5319931323583234
Mean margin: -0.15737997626290104
# Batch size 5
Results: 0.4558375634517767, on 985 prompts
Mean chosen: 3.6083962261374225, std: 3.544413085756003
Mean rejected: 3.769583555889614, std: 3.53038180513649
Mean margin: -0.16118732975219108
# Batch size 20
Results: 0.46091370558375633, on 985 prompts
Mean chosen: 3.610106167091331, std: 3.5459553079842574
Mean rejected: 3.7643620176363717, std: 3.5255408296954722
Mean margin: -0.15425585054504085
PKU-Alignment/beaver-7b-v1.0-cost with the incorrect chat template as in the original GitHub issue, was with batch sizes 1, 2, and 3. The numbers I have are also slightly different, which could be due to different PyTorch and machine versions. This is also without the changes to truncation and padding.
# Batch size 1
Results: 0.46598984771573604, on 985 prompts
Mean chosen: 2.681750867814582, std: 3.2592924729848964
Mean rejected: 2.769996768932052, std: 3.193354047852257
Mean margin: -0.08824590111747006
# Batch size 2
Results: 0.4781725888324873, on 985 prompts
Mean chosen: 2.67659324316809, std: 3.257516704246201
Mean rejected: 2.7611247355562782, std: 3.190987615210259
Mean margin: -0.08453149238818793
# Batch size 3
Results: 0.47918781725888326, on 985 prompts
Mean chosen: 2.6791953678663614, std: 3.261352735964595
Mean rejected: 2.763327192897119, std: 3.191630759537142
Mean margin: -0.0841318250307577
# Batch size 5
Results: 0.467005076142132, on 985 prompts
Mean chosen: 2.677452583119349, std: 3.262681034870804
Mean rejected: 2.7661390672480395, std: 3.1932779168537717
Mean margin: -0.08868648412869061
# Batch size 20
Results: 0.4649746192893401, on 985 prompts
Mean chosen: 2.6717060611937855, std: 3.262159265332147
Mean rejected: 2.7622602046443725, std: 3.191878860191202
Mean margin: -0.09055414345058693
Without the fixes (truncation and padding size), there is more variance in the results, but is still small.
# Batch size 1
Results: 0.43147208121827413, on 985 prompts
Mean chosen: 2.6370499509240166, std: 3.2039855146820213
Mean rejected: 2.7235231719041235, std: 3.1652781950390882
Mean margin: -0.08647322098010687
# Batch size 5
Results: 0.47715736040609136, on 985 prompts
Mean chosen: 2.6371255923043653, std: 3.2056168015074675
Mean rejected: 2.7238458967450914, std: 3.1643515098948725
Mean margin: -0.08672030444072588
In the end, we’re pretty okay with our implementation. We’ve merged some small changes to make this easier to track down (such as our own pipeline abstraction rather than HuggingFace’s), but it’s unavoidable that there will be 1-3% error on most results when using a accuracy based loss.
Extra findings
In the end, there are a ton of little things that can bite you on your path to reproducibility. Hopefully this helps you mitigate future issues. Other things that could come up are below, but please keep evaluating your reward models!
Don’t use
Accelerator.prepare(dataloader)
with evaluation sets, as it can add or remove a certain number of datapoints from the dataset. https://github.com/huggingface/accelerate/issues/2316How
AutoModelForSequenceClassification
selects the index to return https://github.com/huggingface/transformers/blob/b6c9f47fd6f911450024c52e382e544e5d04387a/src/transformers/models/llama/modeling_llama.py#L1372If you force the model config to add and EOS token (which the SeqClassifiers should predict reward on), sometimes you end up with two EOS tokens. This shouldn’t be a big deal, but could cause small problems.
reward_pipe.tokenizer.add_eos_token = True