import copy
from copy import deepcopy
from collections.abc import Mapping
import torch
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
import mindspeed.megatron_adaptor
from megatron.core.enums import ModelType
from megatron.training import get_args
from megatron.training.checkpointing import load_checkpoint
from megatron.training.training import get_model
from mindspeed_mm.utils.transformer_model_config import get_model_config
from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.generation_mixin import GenerationMixin
from mindspeed_mm.models.text_encoder import Tokenizer
from mindspeed_mm.models.reward_model import Qwen2VLRewardModelBT
from mindspeed_mm.data.data_utils.reward_preprocess import fill_data_template, clean_examples
from mindspeed_mm.data.data_utils.func_utils.convert import reward_setting_processor
class VideoAlignPipeline(GenerationMixin):
def __init__(self, model_config, preprocess_params, norm_param=None):
self.model_config = model_config
self.preprocess_params = preprocess_params
self.norm_param = norm_param
self.device = torch.cuda.current_device()
self.dtype = torch.bfloat16 if getattr(model_config.text_decoder, 'bf16', False) else torch.float32
model_preprocess_params = copy.deepcopy(preprocess_params)
self.video_reader, self.video_processor, self.tokenizer, self.processor, self.model_args = reward_setting_processor(
model_preprocess_params)
model_type = ModelType.encoder_or_decoder
self.model = get_model(self.model_provider, model_type, wrap_with_ddp=False)
load_checkpoint(self.model, None, None, strict=True)
self.model = self.model[0].to(self.device, self.dtype)
def model_provider(self, pre_process=True, post_process=True):
"""Builds the model."""
vlm_config = deepcopy(self.model_config)
vlm_config.pre_process = pre_process
vlm_config.post_process = post_process
vlm_config.reward_process = True
if vlm_config.image_encoder and vlm_config.text_decoder:
vlm_config.image_encoder.vision_encoder = get_model_config(vlm_config.image_encoder.vision_encoder)
vlm_config.image_encoder.vision_projector = get_model_config(vlm_config.image_encoder.vision_projector)
vlm_config.text_decoder = get_model_config(vlm_config.text_decoder)
model = Qwen2VLRewardModelBT(config=vlm_config, extra_config=self.model_args)
model.freeze(freeze_image_encoder=getattr(vlm_config.image_encoder.vision_encoder, 'freeze', False),
freeze_image_projection=getattr(vlm_config.image_encoder.vision_projector, 'freeze', False),
freeze_text_decoder=getattr(vlm_config.text_decoder, 'freeze', False))
return model
def _norm(self, reward):
if self.norm_param is None:
return reward
else:
reward['VQ'] = (reward['VQ'] - self.norm_param['VQ_mean']) / self.norm_param['VQ_std']
reward['MQ'] = (reward['MQ'] - self.norm_param['MQ_mean']) / self.norm_param['MQ_std']
reward['TA'] = (reward['TA'] - self.norm_param['TA_mean']) / self.norm_param['TA_std']
return reward
def _prepare_input(self, data):
"""
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
if isinstance(data, Mapping):
def mapping_key(k):
if k == 'video_grid_thw':
return 'image_grid_thw'
elif k == 'pixel_values_videos':
return 'pixel_values'
else:
return k
return type(data)({mapping_key(k): self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = {"device": self.device}
return data.to(**kwargs)
return data
def _prepare_inputs(self, inputs):
"""
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
inputs = self._prepare_input(inputs)
if len(inputs) == 0:
raise ValueError
return inputs
def prepare_batch(self, data_folder, batch_data):
setting_fps = self.preprocess_params.get('fps', None)
video_max_pixels = self.preprocess_params.get('video_max_pixels', None)
eval_dim = self.preprocess_params.get('eval_dim', None)
sample_nframe = self.preprocess_params.get('sample_nframe', None)
sample_type = self.preprocess_params.get('sample_type', None)
prompt_template_type = self.preprocess_params.get('prompt_template_type', None)
video_inputs = []
chat_datas = [
fill_data_template(data_folder=data_folder,
relative_path=video_path,
prompt=prompt,
fps=setting_fps if setting_fps else fps,
max_pixels=video_max_pixels,
num_frames=num_frames,
eval_dim=eval_dim,
sample_nframe=sample_nframe,
sample_type=sample_type,
prompt_template_type=prompt_template_type)
for (video_path, prompt, fps, num_frames) in batch_data
]
clean_chat_datas = []
for chat_data in chat_datas:
clean_chat_data = clean_examples(chat_data)
clean_chat_datas.append(clean_chat_data)
video_input = self.video_processor(self.video_reader(clean_chat_data[0]['content'][0]["video"]))
video_inputs.append(video_input)
batch = self.processor(
text=self.processor.apply_chat_template(clean_chat_datas, tokenize=False, add_generation_prompt=True),
images=None,
videos=video_inputs,
padding=True,
return_tensors="pt",
videos_kwargs={"do_rescale": True},
)
batch = self._prepare_inputs(batch)
batch['pixel_values'] = batch['pixel_values'].to(self.dtype)
return batch
def __call__(self, data_folder, batch_data, use_norm=False, model_reward_return=False):
batch = self.prepare_batch(data_folder=data_folder, batch_data=batch_data)
rewards = self.model(
return_dict=True,
**batch
)
if model_reward_return:
return rewards
rewards = [{'VQ': reward[0].item(), 'MQ': reward[1].item(), 'TA': reward[2].item()} for reward in rewards]
for i in range(len(rewards)):
if use_norm:
rewards[i] = self._norm(rewards[i])
rewards[i]['Overall'] = rewards[i]['VQ'] + rewards[i]['MQ'] + rewards[i]['TA']
return rewards