from logging import getLogger
from typing import Any, Mapping
import copy
import torch
import torch_npu
from megatron.core import mpu
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from torch import nn
from mindspeed_mm.models.ae import AEModel
from mindspeed_mm.models.diffusion import DiffusionModel
from mindspeed_mm.models.predictor import PredictModel
from mindspeed_mm.models.text_encoder import TextEncoder
logger = getLogger(__name__)
class SoRADPOModel(nn.Module):
"""
The hyper model wraps multiple models required in reinforcement learning into a single model,
maintaining the original distributed perspective unchanged.
"""
def __init__(self, config):
super().__init__()
args = get_args()
self.config = core_transformer_config_from_args(args)
self.task = getattr(config, "task", "t2v")
self.pre_process = mpu.is_pipeline_first_stage()
self.post_process = mpu.is_pipeline_last_stage()
self.input_tensor = None
self.share_embeddings_and_output_weights = False
self._model_provider(config)
def _model_provider(self, config):
"""Builds the model."""
print_rank_0("building SoRADPOModel related modules ...")
args = get_args()
if mpu.get_pipeline_model_parallel_rank() == 0:
self.load_video_features = config.load_video_features
self.load_text_features = config.load_text_features
if not self.load_video_features:
print_rank_0(f"init AEModel....")
self.ae = AEModel(config.ae).eval()
self.ae.requires_grad_(False)
if not self.load_text_features:
print_rank_0(f"init TextEncoder....")
self.text_encoder = TextEncoder(config.text_encoder).eval()
self.text_encoder.requires_grad_(False)
self.diffusion = DiffusionModel(config.diffusion).get_model()
predictor_config_ref = copy.deepcopy(config.predictor)
predictor_config = copy.deepcopy(config.predictor)
self.reference = PredictModel(predictor_config_ref).get_model().eval()
self.reference.requires_grad_(False)
self.actor = PredictModel(predictor_config).get_model()
print_rank_0("finish building SoRADPOModel related modules ...")
return None
def set_input_tensor(self, input_tensor):
self.input_tensor = input_tensor
self.actor.set_input_tensor(input_tensor)
def forward(self, video, video_lose, prompt_ids, video_mask=None, prompt_mask=None, **kwargs):
"""
video: high-scoring raw video tensors, or ae encoded latent
video_lose: low-scoring raw video tensors, or ae encoded latent
prompt_ids: tokenized input_ids, or encoded hidden states
video_mask: mask for video/image
prompt_mask: mask for prompt(text)
"""
args = get_args()
if self.pre_process:
with torch.no_grad():
i2v_results = None
if self.load_video_features:
latents = video
latents_lose = video_lose
else:
if self.task == "t2v":
latents, _ = self.ae.encode(video)
latents_lose, _ = self.ae.encode(video_lose)
elif self.task == "i2v":
latents, i2v_results = self.ae.encode(video, **kwargs)
latents_lose, _ = self.ae.encode(video_lose, **kwargs)
else:
raise NotImplementedError(f"Task {self.task} is not Implemented!")
if self.load_text_features:
prompts = prompt_ids
prompt_mask = prompt_mask
else:
prompts, prompt_mask = self.text_encoder.encode(prompt_ids, prompt_mask, **kwargs)
noised_latents_win, noise, timesteps = self.diffusion.q_sample(latents, model_kwargs=kwargs, mask=video_mask)
noised_latents_lose, _, _ = self.diffusion.q_sample(latents_lose, noise=noise, t=timesteps, model_kwargs=kwargs, mask=video_mask)
noised_latents = torch.cat((noised_latents_win, noised_latents_lose), dim=0)
noise = torch.cat((noise, noise), dim=0)
timesteps = timesteps.repeat(2)
latents = torch.cat((latents, latents_lose), dim=0)
if isinstance(prompts, list):
prompt = [torch.cat((prompt, prompt), dim=0) for prompt in prompts]
prompt_mask = [torch.cat((mask, mask), dim=0) for mask in prompt_mask]
else:
prompt = torch.cat((prompts, prompts), dim=0)
prompt_mask = torch.cat((prompt_mask, prompt_mask), dim=0)
if i2v_results is not None:
for k, v in i2v_results.items():
kwargs[k] = torch.cat((v, v), dim=0)
predictor_input_latent, predictor_timesteps, predictor_prompt = noised_latents, timesteps, prompt
predictor_video_mask, predictor_prompt_mask = video_mask, prompt_mask
predictor_input_latent_ref, predictor_timesteps_ref, predictor_prompt_ref = noised_latents, timesteps, prompt
predictor_video_mask_ref, predictor_prompt_mask_ref = video_mask, prompt_mask
kwargs_ref = kwargs
score = kwargs["score"]
score_lose = kwargs["score_lose"]
else:
if not hasattr(self.actor, "pipeline_set_prev_stage_tensor"):
raise ValueError(f"PP has not been implemented for {self.actor_cls} yet. ")
kwargs_ref = copy.deepcopy(kwargs)
ori_keys = kwargs_ref.keys()
predictor_input_list, training_loss_input_list, score_list = self.actor.pipeline_set_prev_stage_tensor(
self.input_tensor, extra_kwargs=kwargs)
new_keys = kwargs.keys()
extra_keys = new_keys - ori_keys
for key in extra_keys:
if isinstance(kwargs[key], torch.Tensor):
extra_input_list = kwargs.pop(key)
kwargs[key], kwargs_ref[key] = torch.chunk(extra_input_list, 2, dim=0)
else:
kwargs_ref[key] = kwargs[key]
predictor_input_latent_list, predictor_timesteps_list, predictor_prompt_list, _, predictor_prompt_mask_list = predictor_input_list
predictor_input_latent, predictor_input_latent_ref = torch.chunk(predictor_input_latent_list, 2, dim=0)
predictor_timesteps, predictor_timesteps_ref = torch.chunk(predictor_timesteps_list, 2, dim=0)
predictor_prompt, predictor_prompt_ref = torch.chunk(predictor_prompt_list, 2, dim=0)
predictor_video_mask, predictor_video_mask_ref = None, None
predictor_prompt_mask, predictor_prompt_mask_ref = torch.chunk(predictor_prompt_mask_list, 2, dim=0)
latents, noised_latents, timesteps, noise, video_mask = training_loss_input_list
score, score_lose = score_list
with torch.no_grad():
refer_output = self.reference(
predictor_input_latent_ref,
timestep=predictor_timesteps_ref,
prompt=predictor_prompt_ref,
video_mask=predictor_video_mask_ref,
prompt_mask=predictor_prompt_mask_ref,
**kwargs_ref,
)
actor_output = self.actor(
predictor_input_latent,
timestep=predictor_timesteps,
prompt=predictor_prompt,
video_mask=predictor_video_mask,
prompt_mask=predictor_prompt_mask,
**kwargs,
)
if self.post_process:
if isinstance(refer_output, tuple):
refer_output = refer_output[0]
if isinstance(actor_output, tuple):
actor_output = actor_output[0]
output = torch.cat((actor_output, refer_output), dim=0)
return [output, score, score_lose, latents, noised_latents, timesteps, noise]
output = []
for index, _ in enumerate(actor_output):
output.append(torch.cat((actor_output[index], refer_output[index]), dim=0))
output = output + [score, score_lose]
return self.actor.pipeline_set_next_stage_tensor(
input_list=[latents, noised_latents, timesteps, noise, video_mask],
output_list=output,
extra_kwargs=kwargs)
def state_dict_for_save_checkpoint(self, prefix="", keep_vars=False):
"""Customized state_dict"""
state_dict = self.actor.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
"""Customized load."""
if not isinstance(state_dict, Mapping):
raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
missing_keys, unexpected_keys = self.actor.load_state_dict(state_dict, False)
if missing_keys is not None:
logger.info(f"Actor missing keys in state_dict: {missing_keys}.")
if unexpected_keys is not None:
logger.info(f"Actor unexpected key(s) in state_dict: {unexpected_keys}.")
missing_keys_ref, unexpected_keys_ref = self.reference.load_state_dict(state_dict, False)
if missing_keys_ref is not None:
logger.info(f"Reference missing keys in state_dict: {missing_keys_ref}.")
if unexpected_keys_ref is not None:
logger.info(f"Reference unexpected key(s) in state_dict: {unexpected_keys_ref}.")
return None