from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoConfig
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.core import mpu
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.models.common.chunkloss import chunk_loss, calculate_lm_loss
from mindspeed_mm.models.common.communications import cal_split_sizes, split_forward_gather_backward
from mindspeed_mm.models.transformers.modelhub import ModelHub
class TransformersModel(MultiModalModule):
def __init__(self, config) -> None:
super().__init__(config=config)
args = get_args()
hf_path = args.mm.model.init_from_hf_path
trust_remote_code = args.trust_remote_code
self.config = core_transformer_config_from_args(args)
self.transformer_config = AutoConfig.from_pretrained(hf_path, trust_remote_code=trust_remote_code)
model_cls = ModelHub.build(config, self.transformer_config)
self._set_loss_cfg(args)
if callable(getattr(model_cls, 'overwrite_transformer_config', None)):
self.transformer_config = model_cls.overwrite_transformer_config(self.transformer_config)
if args.init_model_with_meta_device:
self.model = model_cls._from_config(self.transformer_config).float()
for m in self.model.modules():
if getattr(m, "_is_hf_initialized", False):
m._is_hf_initialized = False
else:
self.model = model_cls.from_pretrained(
hf_path,
config=self.transformer_config,
dtype=torch.float32,
low_cpu_mem_usage=True,
device_map="cpu",
trust_remote_code=trust_remote_code
)
print_rank_0("> load model successfully")
self.model.train()
if callable(getattr(self.model, 'freeze', None)):
self.model.freeze(config)
self.model.use_cache = False
def compute_language_model_loss_cp(self, logits: Tensor, labels: Tensor, ignore_index: int = -100) -> Tensor:
args = get_args()
token_nums = None
logits = logits.permute(0, 2, 1).contiguous()
if args.context_parallel_algo == "ulysses_cp_algo":
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
token_nums = (shift_labels > -1).sum(dim=1)
split_sizes = cal_split_sizes(shift_labels.shape[-1], mpu.get_context_parallel_world_size())
shift_labels = split_forward_gather_backward(
shift_labels,
mpu.get_context_parallel_group(),
dim=1,
grad_scale="down",
split_sizes=split_sizes
)
else:
raise NotImplementedError("Only support ulysses_cp_algo now")
loss = F.cross_entropy(logits, shift_labels, reduction='none', ignore_index=ignore_index)
loss = loss * (shift_labels > -1)
if args.calculate_per_token_loss:
return loss.sum(), token_nums.sum()
elif args.calculate_per_sample_loss:
batch_mean_loss = loss.sum(dim=1) / token_nums
loss = batch_mean_loss.mean()
token_nums = token_nums.mean()
elif args.calculate_token_loss:
token_nums = torch.sum(token_nums)
loss = loss.sum() / token_nums
else:
raise NotImplementedError("Unsupported loss type now")
return loss, token_nums
def compute_language_model_loss(self, logits: Tensor, labels: Tensor, ignore_index: int = -100, **kwargs) -> Tensor:
args = get_args()
loss = None
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
loss_mask = shift_labels > -1
if args.calculate_per_sample_loss:
logits = logits.permute(0, 2, 1).contiguous()
loss = F.cross_entropy(logits, shift_labels, reduction='none', ignore_index=ignore_index)
batch_mean_loss = loss.sum(dim=1) / (shift_labels > -1).sum(dim=1)
loss = batch_mean_loss.mean()
elif args.calculate_per_token_loss:
shift_labels = shift_labels.view(-1)
logits = logits.view(-1, logits.shape[-1])
loss = F.cross_entropy(logits, shift_labels, reduction='none', ignore_index=ignore_index)
loss = torch.sum(loss.view(-1) * loss_mask.view(-1))
elif args.calculate_token_loss:
shift_labels = shift_labels.view(-1)
logits = logits.view(-1, logits.shape[-1])
loss = F.cross_entropy(logits, shift_labels, reduction='none', ignore_index=ignore_index)
loss_weight = (labels != -100).float()
shift_weights = loss_weight[..., 1:].contiguous()
shift_weights = shift_weights.view(-1)
shift_weights = shift_weights.to(logits.device)
shift_weights_sum = shift_weights.sum()
torch.distributed.all_reduce(shift_weights_sum, op=torch.distributed.ReduceOp.AVG)
loss = loss * shift_weights
loss = loss.sum() / shift_weights_sum
elif args.calculate_square_loss:
shift_labels = shift_labels.view(-1)
logits = logits.view(-1, logits.shape[-1])
loss = F.cross_entropy(logits, shift_labels, reduction='none', ignore_index=ignore_index)
loss_weight = (labels != -100).sum(dim=-1).float()
loss_weight = 1 / loss_weight.sqrt()
loss_weight = torch.where(labels != -100, loss_weight.unsqueeze(1), 0.0)
shift_weights = loss_weight[..., 1:].contiguous()
shift_weights = shift_weights.view(-1)
shift_weights = shift_weights.to(logits.device)
shift_weights_sum = shift_weights.sum()
torch.distributed.all_reduce(shift_weights_sum, op=torch.distributed.ReduceOp.AVG)
loss = loss * shift_weights
loss = loss.sum() / shift_weights_sum
else:
shift_labels = shift_labels.view(-1)
logits = logits.view(-1, logits.shape[-1])
loss = F.cross_entropy(logits, shift_labels, ignore_index=ignore_index)
return loss, loss_mask
def forward(
self,
input_ids: torch.Tensor,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
*args, **kwargs
) -> torch.Tensor:
loss_dict = {}
if self.loss_compute_mode == "chunk":
loss_ctx, loss_mask = self.build_loss_ctx(labels, chunk_size=self.loss_chunk_size)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
cache_position=cache_position,
use_cache=False,
loss_ctx=loss_ctx,
**kwargs
)
loss_dict["loss"] = outputs.loss
loss_dict["loss_mask"] = loss_mask
else:
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
cache_position=cache_position,
use_cache=False,
**kwargs
)
logits = outputs.logits.contiguous().float()
if mpu.get_context_parallel_world_size() > 1:
loss, token_nums = self.compute_language_model_loss_cp(logits, labels)
loss_dict["loss"] = loss
loss_dict["token_nums"] = token_nums
else:
loss, loss_mask = self.compute_language_model_loss(logits, labels, **kwargs)
loss_dict["loss"] = loss
loss_dict["loss_mask"] = loss_mask
return loss_dict
def fully_shard(
self,
process_group,
fsdp2_config_path,
**kwargs
):
if hasattr(self.model, 'fully_shard') and callable(getattr(self.model, 'fully_shard')):
return self.model.fully_shard(
process_group=process_group,
fsdp2_config_path=fsdp2_config_path,
**kwargs
)
return False
def build_loss_ctx(
self,
labels,
ignore_index=-100,
chunk_size=1024,
):
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
loss_mask = shift_labels > -1
args = get_args()
if args.calculate_per_sample_loss:
alpha = loss_mask.sum(1) * loss_mask.shape[0]
reduction = "none"
elif args.calculate_per_token_loss:
alpha = torch.tensor(1)
reduction = "sum"
elif args.calculate_token_loss:
raise NotImplementedError(f"Chunk loss not support token_loss now")
elif args.calculate_square_loss:
raise NotImplementedError(f"Chunk loss not support square_loss now")
else:
alpha = loss_mask.sum()
reduction = "sum"
if mpu.get_context_parallel_world_size() > 1:
if args.context_parallel_algo == "ulysses_cp_algo":
split_gather_sizes = cal_split_sizes(shift_labels.shape[1], mpu.get_context_parallel_world_size())
shift_labels = split_forward_gather_backward(
shift_labels,
mpu.get_context_parallel_group(),
dim=-1,
grad_scale="down",
split_sizes=split_gather_sizes
)
else:
raise NotImplementedError("Only support ulysses_cp_algo now")
chunk_labels = torch.split(shift_labels, chunk_size, dim=1)
loss_ctx_kwargs = [
{
"shift_labels": chunk_labels[i],
"ignore_index": ignore_index,
"reduction": reduction,
"alpha": alpha,
}
for i in range(len(chunk_labels))
]
def loss_ctx(hidden_states, head_weight, head_bias):
return chunk_loss(
hidden_states,
head_weight,
head_bias,
loss_forward=calculate_lm_loss,
loss_kwargs_chunks=loss_ctx_kwargs,
chunk_size=chunk_size
)
return loss_ctx, loss_mask
def _set_loss_cfg(self, args):
loss_cfg = getattr(args.mm.model, "loss_cfg", None)
self.loss_compute_mode = "default"
self.loss_chunk_size = 1024
if loss_cfg is not None:
self.loss_compute_mode = getattr(loss_cfg, "compute_mode", "default")
if self.loss_compute_mode == "default":
pass
elif self.loss_compute_mode == "chunk":
self.loss_chunk_size = getattr(loss_cfg, "chunk_size", 1024)
else:
raise NotImplementedError(f"Unrecognized loss_compute_mode: {self.loss_compute_mode}.")