from dataclasses import dataclass
from typing import Any, Dict, Literal, Sequence
import torch
from transformers import DataCollatorForSeq2Seq
from megatron.training import get_args
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]], repeat=1) -> Dict[str, torch.Tensor]:
"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n * repeat (for hyper model) examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
if get_args().stage == "dpo":
repeat = 2
for _ in range(repeat):
self._concat(concatenated_features, features)
return super().__call__(concatenated_features)
@staticmethod
def _concat(concatenated_features, features):
for key in ("chosen", "rejected"):
for feature in features:
target_feature = {
"input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
}
concatenated_features.append(target_feature)