from logging import getLogger
import torch
from torch import nn
from megatron.core import mpu
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.models.predictor.dits.lumina_mgpt2 import ChameleonForConditionalGeneration
logger = getLogger(__name__)
class Lumina(MultiModalModule):
def __init__(self, config):
super().__init__(config=config)
args = get_args()
self.config = core_transformer_config_from_args(args)
if not isinstance(config, dict):
config = config.to_dict()
self.model = ChameleonForConditionalGeneration(**config["predictor"])
self.z_loss_weight = config.pop("z_loss_weight", 1e-5)
self.dtype = config.pop("dtype", torch.bfloat16)
def forward(self, input_ids, labels, **kwargs):
max_tokens = max([len(_) for _ in input_ids])
max_tokens = min(max_tokens, self.model.max_position_embeddings)
input_ids = [_[:max_tokens] for _ in input_ids]
labels = [_[:max_tokens] for _ in labels]
input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids]
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.model.device)
labels = [label + [-100] * (max_tokens - len(label)) for label in labels]
labels = torch.tensor(labels, dtype=torch.int64, device=self.model.device)
with torch.autocast("npu", dtype=self.dtype):
result = self.model(input_ids=input_ids, labels=labels, **kwargs)
loss = result[0]
additional_loss_dict = {}
if self.z_loss_weight > 0:
logits: torch.Tensor = result[1]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
valid_mask = shift_labels >= 0
z_loss = torch.logsumexp(shift_logits, dim=-1).pow(2)[valid_mask].mean()
additional_loss_dict["z_loss"] = (z_loss, self.z_loss_weight)
return loss, additional_loss_dict
def state_dict_for_save_checkpoint(self, prefix="", keep_vars=False):
"""Customized state_dict"""
print("lumina state dict for save checkpoint")
state_dict = self.model.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict
def load_state_dict(self, state_dict, strict=False):
"""Customized load."""
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, False)
if missing_keys is not None:
logger.info(f"Missing keys in state_dict: {missing_keys}.")
if unexpected_keys is not None:
logger.info(f"Unexpected key(s) in state_dict: {unexpected_keys}.")
return None
def get_fsdp_wrap_module_list(self):
modules = [*list(self.model.model.layers), self.model.lm_head, self.model.model.embed_tokens]
return [('model', modules)]