import logging
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor import DTensor
from mindspeed.fsdp.utils.str_match import module_name_match
logger = logging.getLogger(__name__)
def get_chunkloss_forward_fn():
def chunkloss_forward(self, hidden_states: torch.Tensor, loss_func: callable, labels=None):
if isinstance(self.weight, DTensor):
w = self.weight.to_local()
if self.bias is not None:
if not isinstance(self.bias, DTensor):
raise TypeError(
f"Expected bias to be a DTensor when weight is a DTensor, "
f"but got bias of type {type(self.bias)}."
)
b = self.bias.to_local()
else:
b = None
else:
w = self.weight
b = self.bias
loss = loss_func(hidden_states, w, b, labels)
return loss
return chunkloss_forward
def apply_chunkloss_module(module):
chunkloss_forward = get_chunkloss_forward_fn()
module.forward = types.MethodType(chunkloss_forward, module)
def get_chunkloss_module(modules, plan):
matched_modules = []
for name, module in modules.named_modules():
if module_name_match(plan.apply_module, name):
matched_modules.append(module)
if len(matched_modules) == 0:
raise RuntimeError(f'[CHUNKLOSS] No module named {plan}.')
elif len(matched_modules) > 1:
raise RuntimeError(f'[CHUNKLOSS] gets more than one module named {plan}.')
chunkloss_module = matched_modules[0]
if not isinstance(chunkloss_module, torch.nn.Linear):
raise ValueError(f"Chunk loss configuration error for module '{chunkloss_module}': "
f"expected torch.nn.Linear, got {repr(type(chunkloss_module))}. ")
return chunkloss_module