from copy import deepcopy
from functools import partial
import math
import torch
import torch.nn.functional as F
from megatron.core.enums import ModelType
from megatron.training import get_args, print_rank_0
from megatron.training.checkpointing import load_checkpoint
from megatron.training.global_vars import set_args
from megatron.training.training import get_model
from megatron.training.utils import average_losses_across_data_parallel_group
from mindspeed_mm.models.reward_model import Qwen2VLRewardModelBT
from mindspeed_mm.tasks.finetune.lora.utils import is_enable_lora
from mindspeed_mm.tasks.rl.dpo.dpo_trainer import DPOTrainer
from mindspeed_mm.utils.transformer_model_config import get_model_config
from mindspeed_mm.data.data_utils.func_utils.convert import load_reward_tokenizer
class PartialEmbeddingUpdater:
"""Function: Update only the embeddings of special tokens, while freezing the embeddings of regular tokens."""
def __init__(self):
self.special_token_ids = None
self.orig_embeds_params = None
self.vocab_size = None
def get_model_args(self, special_token_ids, enable_partial_update):
self.special_token_ids = special_token_ids
self.enable_partial_update = enable_partial_update
def setup(self, model):
"""Pre-training Initialization: Backup the initial weights of the model's input embedding layer. """
self.device = torch.cuda.current_device()
input_embeddings = model.text_decoder.embedding.word_embeddings
self.orig_embeds_params = input_embeddings.weight.clone().detach()
self.orig_embeds_params = self.orig_embeds_params.to(self.device)
self.vocab_size = self.orig_embeds_params.shape[0]
def __call__(self, model, *kwargs):
"""After each training step, execute: restore the embedding weights of the regular tokens. """
if self.special_token_ids and self.enable_partial_update:
index_no_updates = torch.ones((self.vocab_size), dtype=torch.bool, device=self.device)
index_no_updates[self.special_token_ids] = False
with torch.no_grad():
input_embeddings = model.text_decoder.embedding.word_embeddings
input_embeddings.weight[index_no_updates] = self.orig_embeds_params[index_no_updates]
class VideoVLMRewardTrainer(DPOTrainer):
"""
A trainer class for Video Reward Model.
This class provides methods for model initialize, computing losses and metrics, and training.
"""
def __init__(
self,
train_valid_test_dataset_provider,
model_type,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults=None,
):
"""
Initializes the VideoVLMReward instance.
Sets up the instance variables for the model provider, actual micro batch size,
and initializes the VideoVLMReward model.
"""
self.partialEmbeddingUpdater = PartialEmbeddingUpdater()
super().__init__(
train_valid_test_dataset_provider,
model_type,
process_non_loss_data_func,
extra_args_provider,
args_defaults,
call_backs=[self.partialEmbeddingUpdater]
)
self.disable_dropout()
def model_provider(self, pre_process=True, post_process=True):
"""Builds the model."""
args = get_args()
print_rank_0("building VideoReward model ...")
self.loss_type = args.mm.model.loss_type
self.loss_dtype = torch.bfloat16 if args.mm.model.text_decoder.bf16 else torch.float32
vlm_config = deepcopy(args.mm.model)
data_config = deepcopy(args.mm.data)
if not isinstance(data_config, dict):
data_config = data_config.to_dict()
preprocess_param = data_config['dataset_param']['preprocess_parameters']
special_token_ids = None
token_embedding_length = None
tokenizer_module = load_reward_tokenizer(preprocess_param)
tokenizer, processor = tokenizer_module['tokenizer'], tokenizer_module['processor']
if preprocess_param['split_special_tokens']:
special_tokens = ["<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
special_token_ids = tokenizer.convert_tokens_to_ids(special_tokens)
token_embedding_length = len(tokenizer)
tokenizer_padding_side = "right"
pad_token_id = tokenizer.pad_token_id
model_args = {"special_token_ids": special_token_ids, "token_embedding_length": token_embedding_length,
"tokenizer_padding_side": tokenizer_padding_side, "pad_token_id": pad_token_id}
vlm_config.pre_process = pre_process
vlm_config.post_process = post_process
vlm_config.reward_process = True
if vlm_config.image_encoder and vlm_config.text_decoder:
vlm_config.image_encoder.vision_encoder = get_model_config(vlm_config.image_encoder.vision_encoder)
vlm_config.image_encoder.vision_projector = get_model_config(vlm_config.image_encoder.vision_projector)
vlm_config.text_decoder = get_model_config(vlm_config.text_decoder)
model = Qwen2VLRewardModelBT(config=vlm_config, extra_config=model_args)
model.freeze(freeze_image_encoder=getattr(vlm_config.image_encoder.vision_encoder, 'freeze', False),
freeze_image_projection=getattr(vlm_config.image_encoder.vision_projector, 'freeze', False),
freeze_text_decoder=getattr(vlm_config.text_decoder, 'freeze', False))
else:
raise AttributeError("image_encoder config or text_decoder config not exist!")
self.token_embedding_length = token_embedding_length
enable_partial_update = getattr(vlm_config.text_decoder, 'word_embeddings_only_update_special', False)
self.partialEmbeddingUpdater.get_model_args(special_token_ids, enable_partial_update)
return model
def disable_dropout(self):
"""
disable dropout
"""
args_ = get_args()
args_.attention_dropout = 0.0
args_.hidden_dropout = 0.0
args_.retro_encoder_hidden_dropout = 0.0
args_.retro_encoder_attention_dropout = 0.0
set_args(args_)
def _convert_A_B_to_chosen_rejected(self, rewards_A, rewards_B, scores_A, scores_B, chosen_label, label_dim=None):
"""
Inputs:
rewards_A, rewards_B: [B, N]
scores_A, scores_B: [B, N]
chosen_label: [B, N]
Outputs:
rewards_chosen, rewards_rejected: [B, N]
scores_chosen, scores_rejected: [B, N]
nontied_mask: [B, N] (preference labels that is not tied)
valid_mask: [B, N] (all valid labels)
"""
chosen_mask = (chosen_label == 1)
rejected_mask = (chosen_label != 1)
if label_dim is not None:
N = chosen_label.size(1)
chosen_mask = chosen_mask[:, label_dim].unsqueeze(1).expand(-1, N)
rejected_mask = rejected_mask[:, label_dim].unsqueeze(1).expand(-1, N)
rewards_chosen = torch.where(chosen_mask, rewards_A, rewards_B)
rewards_rejected = torch.where(rejected_mask, rewards_A, rewards_B)
scores_chosen = torch.where(chosen_mask, scores_A, scores_B)
scores_rejected = torch.where(rejected_mask, scores_A, scores_B)
nontied_mask = ((chosen_label == 1) | (chosen_label == -1)).float()
if label_dim is not None:
nontied_mask = nontied_mask[:, label_dim].unsqueeze(1).expand(-1, N)
valid_mask = (chosen_label != 22).float()
if label_dim is not None:
valid_mask = valid_mask[:, label_dim].unsqueeze(1).expand(-1, N)
return rewards_chosen, rewards_rejected, scores_chosen, scores_rejected, nontied_mask, valid_mask
@staticmethod
def get_batch(data_iterator):
"""Generate a batch."""
if data_iterator is not None:
batch = next(data_iterator)
else:
raise ValueError("Data iterator is None. Unable to retrieve batch.")
device = torch.cuda.current_device()
batch['input_ids_A'] = batch['input_ids_A'].to(device)
batch['attention_mask_A'] = batch['attention_mask_A'].to(device)
batch['pixel_values_A'] = batch['pixel_values_A'].to(device)
batch['image_grid_thw_A'] = batch['image_grid_thw_A'].to(device)
batch['input_ids_B'] = batch['input_ids_B'].to(device)
batch['attention_mask_B'] = batch['attention_mask_B'].to(device)
batch['pixel_values_B'] = batch['pixel_values_B'].to(device)
batch['image_grid_thw_B'] = batch['image_grid_thw_B'].to(device)
batch['A_scores'] = torch.tensor(batch['A_scores']).to(device)
batch['B_scores'] = torch.tensor(batch['B_scores']).to(device)
batch['chosen_label'] = torch.tensor(batch['chosen_label']).to(device)
return batch
def loss_func(self, rewards_chosen, rewards_rejected, nontied_mask, valid_mask, inputs, output_tensor):
rewards_A, rewards_B = output_tensor[0], output_tensor[1]
metrics = {}
if self.loss_type == "bt":
loss = -F.logsigmoid(rewards_chosen - rewards_rejected)
out_mask = nontied_mask
elif self.loss_type == "margin":
loss = -F.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"])
out_mask = nontied_mask
elif self.loss_type == "constant_margin":
loss = -F.logsigmoid(rewards_chosen - rewards_rejected - 0.57)
out_mask = nontied_mask
elif self.loss_type == "scaled":
loss = (-(inputs["margin"] + 0.0) * F.logsigmoid(rewards_chosen - rewards_rejected))
out_mask = nontied_mask
elif self.loss_type == "reg":
rewards = torch.stack([rewards_A, rewards_B], dim=1)
scores = torch.stack([inputs["A_scores"], inputs["B_scores"]], dim=1)
out_mask = scores != 0.0
scores = (scores - 3.0)
loss = F.mse_loss(rewards, scores, reduction="none")
elif self.loss_type == "btt":
k = 5.0
log_k = math.log(k)
log_k2_sub_1 = math.log(k ** 2 - 1)
bt_loss = -F.logsigmoid(rewards_chosen - rewards_rejected - log_k)
same_loss = -F.logsigmoid(rewards_chosen - rewards_rejected - log_k) \
- F.logsigmoid(rewards_rejected - rewards_chosen - log_k) \
- log_k2_sub_1
loss = bt_loss * nontied_mask + same_loss * (1 - nontied_mask)
out_mask = valid_mask
else:
raise NotImplementedError(f"Loss type {self.loss_type} not implemented.")
loss = loss * out_mask
loss = loss.mean()
metrics['loss'] = average_losses_across_data_parallel_group([loss])
for key in metrics.keys():
metrics[key] = average_losses_across_data_parallel_group([metrics[key]])
return loss, metrics
def forward_step(self, data_iterator, model):
batch = self.get_batch(data_iterator)
input_ids_A = batch['input_ids_A']
attention_mask_A = batch['attention_mask_A']
pixel_values_A = batch['pixel_values_A']
image_grid_thw_A = batch['image_grid_thw_A']
input_ids_B = batch['input_ids_B']
attention_mask_B = batch['attention_mask_B']
pixel_values_B = batch['pixel_values_B']
image_grid_thw_B = batch['image_grid_thw_B']
A_scores = batch['A_scores']
B_scores = batch['B_scores']
chosen_label = batch['chosen_label']
rewards_A = model(input_ids=input_ids_A, pixel_values=pixel_values_A, image_grid_thw=image_grid_thw_A,
attention_mask=attention_mask_A).to(self.loss_dtype)
rewards_B = model(input_ids=input_ids_B, pixel_values=pixel_values_B, image_grid_thw=image_grid_thw_B,
attention_mask=attention_mask_B).to(self.loss_dtype)
rewards_chosen, rewards_rejected, scores_chosen, scores_rejected, nontied_mask, valid_mask = self._convert_A_B_to_chosen_rejected(
rewards_A, rewards_B, A_scores, B_scores, chosen_label
)
batch["margin"] = scores_chosen - scores_rejected
output_tensor = torch.stack([rewards_A, rewards_B], dim=0)
return output_tensor, partial(self.loss_func, rewards_chosen, rewards_rejected, nontied_mask, valid_mask, batch)