from functools import wraps
from typing import Optional, Callable
import torch
from torch.nn.parameter import Parameter
from megatron.core import parallel_state, ModelParallelConfig
from megatron.core.tensor_parallel import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region
)
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.layers import (
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce,
ColumnParallelLinear,
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
VocabParallelEmbedding,
)
from megatron.legacy.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.training import get_args
import torch.nn.functional as F
from megatron.core.tensor_parallel.mappings import (
reduce_scatter_to_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
)
from mindspeed_llm.training.utils import get_actual_seq_len_list, set_actual_seq_len_list
try:
import bitsandbytes as bnb
except ImportError:
bnb = None
def vocab_embedding_init_func(
self,
num_embeddings: int,
embedding_dim: int,
*,
init_method: Callable,
config: ModelParallelConfig,
reduce_scatter_embeddings: bool = False,
skip_weight_param_allocation: bool = False,
):
"""Patch for legacy norm."""
super(VocabParallelEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.reduce_scatter_embeddings = reduce_scatter_embeddings
self.tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
(
self.vocab_start_index,
self.vocab_end_index,
) = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, parallel_state.get_tensor_model_parallel_rank(), self.tensor_model_parallel_size
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
self.deterministic_mode = config.deterministic_mode
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
else:
self.weight = None
args = get_args()
if parallel_state.is_pipeline_first_stage() and args.embed_layernorm:
norm = MixedFusedLayerNorm(args.hidden_size)
self.norm = norm
def vocab_parallel_embedding_forward(self, input_, weight=None):
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
if self.tensor_model_parallel_size > 1:
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
masked_input = input_.clone() - self.vocab_start_index
masked_input *= ~input_mask
else:
masked_input = input_
output_parallel = F.embedding(masked_input, weight)
if self.tensor_model_parallel_size > 1:
output_parallel *= ~input_mask[..., None]
if self.reduce_scatter_embeddings:
output_parallel = output_parallel.transpose(0, 1).contiguous()
output = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output = reduce_from_tensor_model_parallel_region(output_parallel)
args_ = get_args()
if hasattr(self, 'norm'):
output = self.norm(output)
return output * args_.embedding_multiplier_scale if args_.embedding_multiplier_scale else output
class SegmentedColumnParallelLinear(ColumnParallelLinear):
def __int__(self):
super(ColumnParallelLinear, self).__init__()
def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional): weight tensor to use, compulsory when
skip_weight_param_allocation is True.
Returns:
- output
- bias
"""
args_ = get_args()
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to ColumnParallelLinear forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
else:
expected_shape = (self.output_size_per_partition, self.input_size)
if weight.shape != expected_shape:
raise RuntimeError(
f"supplied weight's shape is {tuple(weight.shape)}, "
f"not {expected_shape} as expected"
)
if self.config._cpu_offloading_context is not None:
if self.config._cpu_offloading_context.inside_context:
if self.config.cpu_offloading:
raise ValueError("CPU Offloading cannot be enabled while using non-TE modules")
bias = self.bias if not self.skip_bias_add else None
if (
self.allreduce_dgrad
or self.sequence_parallel
or self.explicit_expert_comm
or self.disable_grad_reduce
):
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
if self.config.defer_embedding_wgrad_compute:
self.embedding_activation_buffer.append(input_parallel)
if not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad
weight = torch.split(weight, weight.shape[0] // args_.output_layer_slice_num, dim=0)
output_parallel = []
for i in range(args_.output_layer_slice_num):
output_parallel.append(self._forward_impl(
input=input_parallel,
weight=weight[i],
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=allreduce_dgrad,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=self.grad_output_buffer
if self.config.defer_embedding_wgrad_compute
else None,
allreduce_dgrad=allreduce_dgrad,
))
output_parallel = torch.cat(output_parallel, dim=2)
if self.gather_output:
if self.sequence_parallel:
raise ValueError("Cannot gather output when sequence parallel is enabled")
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def checkpoint_forward_wrapper(fn):
"""
Fixes TypeError caused by Megatron CheckpointFunction not supporting list[Tensor] inputs.
Background:
In hybrid_cp_algo + attention_mask_type=general mode, attention_mask is converted to
list[Tensor], triggering CheckpointFunction native limitation.
Solution:
Flatten the list into independent Tensors during forward to bypass the check; restore
the structure via closure during recomputation to ensure logical consistency.
"""
def wrapper(ctx, run_function, distribute_saved_activations, *args):
ctx.actual_seq_len_list = get_actual_seq_len_list()
flat_args = []
arg_structure = []
for arg in args:
if isinstance(arg, (list, tuple)) and len(arg) > 0 and torch.is_tensor(arg[0]):
arg_structure.append(("seq", len(arg), type(arg)))
flat_args.extend(arg)
else:
arg_structure.append(("item", 1, None))
flat_args.append(arg)
ctx.custom_arg_structure = arg_structure
def wrapped_run_func(*rebuilt_flat_args):
original_args = []
idx = 0
for stype, length, seq_type in ctx.custom_arg_structure:
if stype == "seq":
seq = rebuilt_flat_args[idx : idx + length]
original_args.append(seq_type(seq))
else:
original_args.append(rebuilt_flat_args[idx])
idx += length
return run_function(*original_args)
return fn(ctx, wrapped_run_func, distribute_saved_activations, *flat_args)
return wrapper
def checkpoint_backward_wrapper(fn):
def wrapper(ctx, *args):
set_actual_seq_len_list(ctx.actual_seq_len_list)
return fn(ctx, *args)
return wrapper
class LinearNoTP(torch.nn.Linear):
def __init__(
self,
input_size,
output_size,
config,
**kwargs,
):
super().__init__(
input_size,
output_size,
bias=kwargs.get('bias', True),
dtype=config.params_dtype,
)
self.config = config
current_seed = torch.random.initial_seed()
torch.manual_seed(123)
torch.nn.init.xavier_uniform_(self.weight)
torch.random.manual_seed(current_seed)
setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
setattr(self.weight, 'all_reduce', True)
self._register_load_state_dict_pre_hook(
lambda state_dict, prefix, *args, **kwargs: state_dict.pop(
f'{prefix}_extra_state', None)
)
def forward(self, input_):
if hasattr(self.weight, "quant_state"):
output = bnb.matmul_4bit(input_, self.weight.t(), self.weight.quant_state, bias=self.bias)
elif get_args().fp8:
from mindspeed.te.pytorch.fp8.recipes import matmul_fp8
output = matmul_fp8(input_, self.weight)
if self.bias is not None:
output = output + self.bias
else:
output = torch.matmul(input_, self.weight.t())
if self.bias is not None:
output = output + self.bias
return output
def _save_to_state_dict(self, destination, prefix, keep_vars):
args = get_args()
if args.qlora_save_dequantize and getattr(self.weight, "quant_state", None) is not None:
self.weight = torch.nn.Parameter(bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state))
super()._save_to_state_dict(destination, prefix, keep_vars)
if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if any(['bitsandbytes' in i for i in state_dict.keys()]):
qs_dict = {}
for k, v in state_dict.items():
key = k.replace(prefix, "")
if key != '_extra_state':
qs_dict[key] = v
self.weight = bnb.nn.Params4bit.from_prequantized(
data=qs_dict.get('weight'),
quantized_stats={key.replace('weight.', ''): qs_dict[key] for key in qs_dict if key != 'weight'},
requires_grad=False,
device='npu')
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)