import os
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch_npu
from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, ParallelMode, TensorUsage, dist_group_type
from ..ops.basic.npu_activation import (
ACTIVATION_FWD,
ACTIVATION_BWD,
GLU_VARIANTS,
)
from ..distributed import (
DummyHandle,
_fsdp_gather_tensors,
_fsdp_scatter_tensors,
gather_along_dim,
get_distributed_world_size,
in_fp8_activation_recompute_phase,
is_fp8_activation_recompute_enabled,
reduce_scatter_along_dim,
)
from ..jit import no_torch_dynamo
from ..ops.fused.overlap import CommOverlapOps
from ..ops.gemm import matmul_add
from ..quantization import FP8GlobalStateManager
from ..quantized_tensor import QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved
from ..tensor.utils import clear_columnwise_cache
from ..utils import cast_if_needed, divide, get_default_init_method, init_method_constant, requires_grad
from ._common import WeightGradStore
from .base import TransformerEngineBaseModule
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
"partition_dim": -1,
"partition_stride": 1,
}
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
setattr(tensor, "tensor_model_parallel", is_parallel)
setattr(tensor, "partition_dim", dim)
setattr(tensor, "partition_stride", stride)
class _LayerNormMLPNonTensorArgs(NamedTuple):
is_first_microbatch: bool
fp8: bool
wgrad_store: WeightGradStore
eps: float
fc1_input_quantizer: Optional[Quantizer]
fc1_weight_quantizer: Optional[Quantizer]
fc1_output_quantizer: Optional[Quantizer]
fc1_grad_input_quantizer: Optional[Quantizer]
fc1_grad_weight_quantizer: Optional[Quantizer]
fc1_grad_output_quantizer: Optional[Quantizer]
fc2_input_quantizer: Optional[Quantizer]
fc2_weight_quantizer: Optional[Quantizer]
fc2_output_quantizer: Optional[Quantizer]
fc2_grad_input_quantizer: Optional[Quantizer]
fc2_grad_weight_quantizer: Optional[Quantizer]
fc2_grad_output_quantizer: Optional[Quantizer]
fuse_wgrad_accumulation: bool
tp_group: Optional[dist_group_type]
tp_size: int
sequence_parallel: bool
activation_dtype: torch.dtype
tensor_parallel: bool
set_parallel_mode: bool
is_grad_enabled: bool
return_layernorm_output: bool
return_layernorm_output_gathered: bool
zero_centered_gamma: bool
normalization: str
activation: str
activation_params: tuple
overlap_ag_fprop: bool
overlap_rs_fprop: bool
overlap_ag_dgrad: bool
overlap_rs_dgrad: bool
module: "LayerNormMLP"
skip_fp8_weight_update: Optional[bool]
fp8_output: bool
fsdp_group: Optional[Any]
is_fsdp2: bool
def _apply_norm(inputmat, ln_weight, ln_bias, eps, normalization, zero_centered_gamma, activation_dtype=None):
if normalization == "LayerNorm":
if zero_centered_gamma:
gamma = 1.0 + ln_weight
else:
gamma = ln_weight
if activation_dtype is not None:
gamma = cast_if_needed(gamma, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
ln_out = F.layer_norm(inputmat, [inputmat.shape[-1]], weight=gamma, bias=ln_bias, eps=eps)
mu = inputmat.mean(dim=-1, keepdim=True)
rsigma = 1.0 / torch.sqrt((inputmat - mu).pow(2).mean(dim=-1, keepdim=True) + eps)
elif normalization == "RMSNorm":
if zero_centered_gamma:
gamma = 1.0 + ln_weight
else:
gamma = ln_weight
if activation_dtype is not None:
gamma = cast_if_needed(gamma, activation_dtype)
ln_out, rrsigma = torch_npu.npu_rms_norm(inputmat, gamma, epsilon=eps)
mu = None
rsigma = rrsigma
else:
raise ValueError(f"Unknown normalization: {normalization}")
return ln_out, mu, rsigma
def _norm_bwd(dgrad, inputmat, ln_weight, mu, rsigma, zero_centered_gamma, normalization):
n = inputmat.shape[-1]
if normalization == "LayerNorm":
x_hat = (inputmat - mu) * rsigma
if zero_centered_gamma:
dx_hat = dgrad * (1.0 + ln_weight)
else:
dx_hat = dgrad * ln_weight
dvar = (dx_hat * (inputmat - mu) * (-0.5) * rsigma.pow(3)).sum(dim=-1, keepdim=True)
dmean = (-dx_hat * rsigma).sum(dim=-1, keepdim=True) + dvar * (-2.0 / n) * (inputmat - mu).sum(dim=-1, keepdim=True)
dx = dx_hat * rsigma + dvar * 2.0 / n * (inputmat - mu) + dmean / n
dgamma = (dgrad * x_hat).sum(dim=tuple(range(len(dgrad.shape) - 1)))
dbeta = dgrad.sum(dim=tuple(range(len(dgrad.shape) - 1)))
elif normalization == "RMSNorm":
rrms = rsigma
x_hat = inputmat * rrms
if zero_centered_gamma:
dx_hat = dgrad * (1.0 + ln_weight)
else:
dx_hat = dgrad * ln_weight
dvar = (dx_hat * inputmat).sum(dim=-1, keepdim=True) * (-0.5) * rrms.pow(3)
dx = dx_hat * rrms + dvar * 2.0 * inputmat / n
dgamma = (dgrad * x_hat).sum(dim=tuple(range(len(dgrad.shape) - 1)))
dbeta = None
else:
raise ValueError(f"Unknown normalization: {normalization}")
return dx, dgamma, dbeta
class _LayerNormMLP(torch.autograd.Function):
@staticmethod
def forward(
ctx,
ln_weight,
ln_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias,
inp,
args: _LayerNormMLPNonTensorArgs,
):
inputmat = inp.reshape(-1, inp.shape[-1])
inputmat = cast_if_needed(inputmat, args.activation_dtype)
ln_out, mu, rsigma = _apply_norm(inputmat, ln_weight, ln_bias, args.eps, args.normalization, args.zero_centered_gamma, args.activation_dtype)
if args.return_layernorm_output:
ln_out_return = ln_out.reshape(inp.shape)
if args.return_layernorm_output_gathered and args.tensor_parallel and args.sequence_parallel:
ln_out_return, _ = gather_along_dim(ln_out_return, args.tp_group)
mm_inp = ln_out
if args.tensor_parallel and args.sequence_parallel and args.set_parallel_mode and not args.overlap_ag_fprop:
mm_inp, _ = gather_along_dim(mm_inp, args.tp_group)
if args.fp8:
return _LayerNormMLP.fp8_forward(ctx, ln_weight, ln_bias, fc1_weight, fc1_bias, fc2_weight, fc2_bias, inp, inputmat, ln_out, mm_inp, mu, rsigma, args)
fc1_out = torch.matmul(mm_inp, fc1_weight.t())
if fc1_bias is not None:
fc1_out = fc1_out + fc1_bias
act_fn = ACTIVATION_FWD[args.activation]
if args.activation == "clamped_swiglu" and args.activation_params:
act_out = act_fn(fc1_out, *args.activation_params)
else:
act_out = act_fn(fc1_out)
fc2_out = torch.matmul(act_out, fc2_weight.t())
if fc2_bias is not None:
fc2_out = fc2_out + fc2_bias
tp_world_size = get_distributed_world_size(args.tp_group)
if args.tensor_parallel and args.set_parallel_mode and not args.overlap_rs_fprop and tp_world_size > 1:
if args.sequence_parallel:
fc2_out, _ = reduce_scatter_along_dim(fc2_out, args.tp_group)
else:
torch.distributed.all_reduce(fc2_out, group=args.tp_group)
if args.is_grad_enabled:
ctx.args = args
ctx.inp_shape = inp.shape
ctx.fp8 = args.fp8
ctx.sequence_parallel = args.sequence_parallel
ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = fc1_weight.requires_grad or fc2_weight.requires_grad
ctx.fsdp_group = args.fsdp_group
ctx.is_fsdp2 = args.is_fsdp2
ctx.fsdp_shapes = _fsdp_scatter_tensors(
args.fsdp_group,
mu,
rsigma,
mm_inp,
fc1_out,
act_out,
)
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias,
ln_weight,
ln_bias,
mm_inp,
fc1_out,
act_out,
mu,
rsigma,
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
if args.return_layernorm_output:
return fc2_out.reshape(inp.shape[0], inp.shape[1], fc2_out.shape[-1]), ln_out_return
return fc2_out.reshape(inp.shape[0], inp.shape[1], fc2_out.shape[-1])
@staticmethod
def backward(ctx, grad_output, *rest):
if ctx.args.fp8:
return _LayerNormMLP.fp8_backward(ctx, grad_output, rest)
args = ctx.args
inputmat, fc1_weight, fc1_bias, fc2_weight, fc2_bias, ln_weight, ln_bias, mm_inp, fc1_out, act_out, mu, rsigma = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
)
ctx.tensor_objects = None
fsdp_group = getattr(ctx, "fsdp_group", None)
if fsdp_group is not None:
_fsdp_gather_tensors(
fsdp_group,
ctx.fsdp_shapes,
mu,
rsigma,
mm_inp,
fc1_out,
act_out,
)
grad_output = grad_output.reshape(-1, grad_output.shape[-1])
if args.return_layernorm_output and len(rest) > 0 and rest[0] is not None:
grad_ln_out = rest[0].reshape(-1, rest[0].shape[-1])
else:
grad_ln_out = None
tp_world_size = get_distributed_world_size(args.tp_group)
if args.tensor_parallel and args.set_parallel_mode and args.sequence_parallel and not ctx.args.overlap_ag_dgrad:
grad_output, _ = gather_along_dim(grad_output, args.tp_group)
d_fc2 = grad_output
fc2_dgrad = torch.matmul(d_fc2, fc2_weight)
if args.wgrad_store is not None and args.wgrad_store.delay_wgrad_compute():
fc2_bias_grad_delayed = None
if fc2_bias is not None:
fc2_bias_grad_delayed = d_fc2.sum(dim=0)
def fc2_wgrad_fn(act, dy):
if args.fuse_wgrad_accumulation and fc2_weight.main_grad.dtype == torch.float32:
matmul_add(fc2_weight.main_grad, act, dy)
return _LayerNormMLP._create_grad_weight_placeholder(fc2_weight), fc2_bias_grad_delayed
return torch.matmul(dy.t(), act), fc2_bias_grad_delayed
args.wgrad_store.put([act_out, d_fc2], fc2_wgrad_fn)
fc2_wgrad = None
else:
if args.fuse_wgrad_accumulation and fc2_weight.main_grad.dtype == torch.float32:
matmul_add(fc2_weight.main_grad, act_out, d_fc2)
fc2_wgrad = _LayerNormMLP._create_grad_weight_placeholder(fc2_weight)
else:
fc2_wgrad = torch.matmul(d_fc2.t(), act_out)
fc2_bias_grad = None
if fc2_bias is not None:
fc2_bias_grad = d_fc2.sum(dim=0)
act_bwd_fn = ACTIVATION_BWD[args.activation]
if args.activation == "clamped_swiglu" and args.activation_params:
dact = act_bwd_fn(fc1_out, fc2_dgrad, *args.activation_params)
else:
dact = act_bwd_fn(fc1_out, fc2_dgrad)
fc1_dgrad = torch.matmul(dact, fc1_weight)
handle = None
if args.tensor_parallel and args.set_parallel_mode and tp_world_size > 1:
if args.sequence_parallel and not args.overlap_rs_dgrad:
fc1_dgrad, handle = reduce_scatter_along_dim(
fc1_dgrad, args.tp_group, async_op=True
)
elif not args.sequence_parallel:
handle = torch.distributed.all_reduce(fc1_dgrad, group=args.tp_group, async_op=True)
if args.wgrad_store is not None and args.wgrad_store.delay_wgrad_compute():
fc1_bias_grad_delayed = None
if fc1_bias is not None:
fc1_bias_grad_delayed = dact.sum(dim=0)
def fc1_wgrad_fn(x, dy):
if args.fuse_wgrad_accumulation and fc1_weight.main_grad.dtype == torch.float32:
matmul_add(fc1_weight.main_grad, x, dy)
return _LayerNormMLP._create_grad_weight_placeholder(fc1_weight), fc1_bias_grad_delayed
return torch.matmul(dy.t(), x), fc1_bias_grad_delayed
args.wgrad_store.put([mm_inp, dact], fc1_wgrad_fn)
fc1_wgrad = None
else:
if args.fuse_wgrad_accumulation and fc1_weight.main_grad.dtype == torch.float32:
matmul_add(fc1_weight.main_grad, mm_inp, dact)
fc1_wgrad = _LayerNormMLP._create_grad_weight_placeholder(fc1_weight)
else:
fc1_wgrad = torch.matmul(dact.t(), mm_inp)
fc1_bias_grad = None
if fc1_bias is not None:
fc1_bias_grad = dact.sum(dim=0)
if grad_ln_out is not None:
fc1_dgrad = fc1_dgrad + grad_ln_out
dx, dgamma, dbeta = _norm_bwd(fc1_dgrad, inputmat, ln_weight, mu, rsigma, args.zero_centered_gamma, args.normalization)
if handle is not None:
handle.wait()
return dgamma, dbeta, fc1_wgrad, fc1_bias_grad, fc2_wgrad, fc2_bias_grad, dx.reshape(ctx.inp_shape), None
@staticmethod
def fp8_forward(ctx, ln_weight, ln_bias, fc1_weight, fc1_bias, fc2_weight, fc2_bias, inp, inputmat, ln_out, mm_inp, mu, rsigma, args):
backward_needs_fc1_input = args.is_grad_enabled and fc1_weight.requires_grad
backward_needs_fc1_weight = args.is_grad_enabled and ln_out.requires_grad
backward_needs_fc2_input = args.is_grad_enabled and fc2_weight.requires_grad
backward_needs_fc2_weight = args.is_grad_enabled and True
if is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase():
args.fc1_input_quantizer.set_usage(columnwise=False)
args.fc1_weight_quantizer.set_usage(columnwise=False)
args.fc2_input_quantizer.set_usage(columnwise=False)
args.fc2_weight_quantizer.set_usage(columnwise=False)
else:
if not backward_needs_fc1_input:
args.fc1_input_quantizer.set_usage(columnwise=False)
if not backward_needs_fc1_weight:
args.fc1_weight_quantizer.set_usage(columnwise=False)
if not backward_needs_fc2_input:
args.fc2_input_quantizer.set_usage(columnwise=False)
if not backward_needs_fc2_weight:
args.fc2_weight_quantizer.set_usage(columnwise=False)
if args.is_fsdp2:
args.fc1_weight_quantizer.set_usage(columnwise=False)
args.fc2_weight_quantizer.set_usage(columnwise=False)
fc1_input = args.fc1_input_quantizer.quantize(mm_inp)
update_workspace = args.is_first_microbatch is None or args.is_first_microbatch
fc1_weightmat = args.module.get_weight_workspace(
tensor=fc1_weight,
quantizer=args.fc1_weight_quantizer,
cache_name=(None if (args.is_first_microbatch is None or args.is_fsdp2) else "fc1_weight"),
update_workspace=update_workspace,
skip_update_flag=args.skip_fp8_weight_update,
workspace_dtype=args.activation_dtype,
)
fc1_weightmat.update_usage(rowwise_usage=True)
fc1_mm_kwargs = {
"usage": TensorUsage.LHS,
"usage_b": TensorUsage.RHS_TRANS,
"out_dtype": args.activation_dtype,
}
if args.overlap_ag_fprop:
fc1_out, fc1_input = fc1_input.allgather_matmul(
fc1_weightmat, fc1_bias, get_distributed_world_size(args.tp_group), args.tp_group, **fc1_mm_kwargs
)
else:
fc1_out = fc1_input.matmul(fc1_weightmat, **fc1_mm_kwargs)
if fc1_bias is not None:
fc1_out = fc1_out + fc1_bias
fc1_input.clear_wise(rowwise=True)
act_fn = ACTIVATION_FWD[args.activation]
if args.activation == "clamped_swiglu" and args.activation_params:
act_out = act_fn(fc1_out, *args.activation_params)
else:
act_out = act_fn(fc1_out)
fc2_input = args.fc2_input_quantizer.quantize(act_out)
fc2_weightmat = args.module.get_weight_workspace(
tensor=fc2_weight,
quantizer=args.fc2_weight_quantizer,
cache_name=(None if (args.is_first_microbatch is None or args.is_fsdp2) else "fc2_weight"),
update_workspace=update_workspace,
skip_update_flag=args.skip_fp8_weight_update,
workspace_dtype=args.activation_dtype,
)
fc2_weightmat.update_usage(rowwise_usage=True)
fc2_mm_kwargs = {
"usage": TensorUsage.LHS,
"usage_b": TensorUsage.RHS_TRANS,
"out_dtype": args.activation_dtype,
}
if args.overlap_rs_fprop:
fc2_out = fc2_input.matmul_reduce_scatter(
fc2_weightmat, fc2_bias, get_distributed_world_size(args.tp_group), args.tp_group, **fc2_mm_kwargs
)
else:
fc2_out = fc2_input.matmul(fc2_weightmat, **fc2_mm_kwargs)
if fc2_bias is not None:
fc2_out = fc2_out + fc2_bias
fc2_input.clear_wise(rowwise=True)
tp_world_size = get_distributed_world_size(args.tp_group)
if args.tensor_parallel and args.set_parallel_mode and not args.overlap_rs_fprop and tp_world_size > 1:
if args.sequence_parallel:
fc2_out, _ = CommOverlapOps.reduce_scatter(
fc2_out,
args.fc2_output_quantizer,
tp_world_size,
args.tp_group,
use_quant=args.fp8_output and args.module.fp8_meta["recipe"].mxfp8()
)
else:
torch.distributed.all_reduce(fc2_out, group=args.tp_group)
if args.is_grad_enabled:
ctx.args = args
ctx.inp_shape = inp.shape
ctx.fp8 = args.fp8
ctx.sequence_parallel = args.sequence_parallel
ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = fc1_weight.requires_grad or fc2_weight.requires_grad
ctx.is_weight_param_quantized = isinstance(fc1_weight, QuantizedTensorStorage)
ctx.fsdp_group = args.fsdp_group
ctx.is_fsdp2 = args.is_fsdp2
ctx.fsdp_shapes = _fsdp_scatter_tensors(
args.fsdp_group,
mu,
rsigma,
fc1_weightmat if not ctx.is_weight_param_quantized else None,
fc2_weightmat if not ctx.is_weight_param_quantized else None,
fc1_input,
fc1_out,
fc2_input,
act_out,
)
fc1_wt_save = fc1_weightmat
if args.is_fsdp2 and fc1_weightmat is not fc1_weight:
fc1_wt_save = None
fc2_wt_save = fc2_weightmat
if args.is_fsdp2 and fc2_weightmat is not fc2_weight:
fc2_wt_save = None
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
fc1_wt_save,
fc1_weight,
fc1_bias,
fc2_wt_save,
fc2_weight,
fc2_bias,
ln_weight,
ln_bias,
fc1_input,
fc1_out,
fc2_input,
act_out,
mu,
rsigma,
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
if requires_grad(inp, fc1_weight, fc2_weight):
qstate = FP8GlobalStateManager.quantization_state
_first_fp8_module = qstate.is_first_fp8_module
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
qstate.is_first_fp8_module = _first_fp8_module
if args.return_layernorm_output:
return fc2_out.reshape(inp.shape[0], inp.shape[1], fc2_out.shape[-1]), ln_out.reshape(inp.shape)
return fc2_out.reshape(inp.shape[0], inp.shape[1], fc2_out.shape[-1])
@staticmethod
def fp8_backward(ctx, grad_output, rest):
args = ctx.args
inputmat, fc1_weight_fp8, fc1_weight, fc1_bias, fc2_weight_fp8, fc2_weight, fc2_bias, ln_weight, ln_bias, fc1_input, fc1_out, fc2_input, act_out, mu, rsigma = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
)
ctx.tensor_objects = None
is_fsdp2 = getattr(ctx, "is_fsdp2", False)
fsdp_group = getattr(ctx, "fsdp_group", None)
is_weight_param_quantized = getattr(ctx, "is_weight_param_quantized", False)
if fsdp_group is not None:
_fsdp_gather_tensors(
fsdp_group,
ctx.fsdp_shapes,
mu,
rsigma,
fc1_weight_fp8 if not is_weight_param_quantized else None,
fc2_weight_fp8 if not is_weight_param_quantized else None,
fc1_input,
fc1_out,
fc2_input,
act_out,
)
grad_output = grad_output.reshape(-1, grad_output.shape[-1])
if args.return_layernorm_output and len(rest) > 0 and rest[0] is not None:
grad_ln_out = rest[0].reshape(-1, rest[0].shape[-1])
else:
grad_ln_out = None
tp_world_size = get_distributed_world_size(args.tp_group)
if not ctx.requires_dgrad:
args.fc2_grad_output_quantizer.set_usage(rowwise=False)
if not ctx.requires_wgrad:
args.fc2_grad_output_quantizer.set_usage(columnwise=False)
if args.tensor_parallel and args.set_parallel_mode and args.sequence_parallel and not args.overlap_ag_dgrad:
grad_output, _ = gather_along_dim(grad_output, args.tp_group)
fc2_grad_output_quantizer = args.fc2_grad_output_quantizer
fc2_grad_output = fc2_grad_output_quantizer.quantize(grad_output)
if fc2_weight_fp8 is None and is_fsdp2:
if isinstance(fc2_weight, QuantizedTensorStorage):
fc2_weight_fp8 = fc2_weight
else:
fc2_weight_quantizer = args.fc2_weight_quantizer
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_fp8 = fc2_weight_quantizer(fc2_weight)
fc2_dgrad_kwargs = {
"usage": TensorUsage.LHS,
"usage_b": TensorUsage.RHS,
"out_dtype": args.activation_dtype,
}
if args.overlap_ag_dgrad:
fc2_dgrad, fc2_grad_output = fc2_grad_output.allgather_matmul(
fc2_weight_fp8, None, tp_world_size, args.tp_group, **fc2_dgrad_kwargs
)
elif args.overlap_rs_dgrad:
fc2_dgrad = fc2_grad_output.matmul_reduce_scatter(
fc2_weight_fp8, None, tp_world_size, args.tp_group, **fc2_dgrad_kwargs
)
else:
fc2_dgrad = fc2_grad_output.matmul(fc2_weight_fp8, **fc2_dgrad_kwargs)
if is_fsdp2 and isinstance(fc2_weight_fp8, QuantizedTensorStorage):
clear_columnwise_cache(fc2_weight_fp8)
if args.is_first_microbatch is not None:
accumulate_fc2_wgrad = (
args.fuse_wgrad_accumulation
and not args.is_first_microbatch
and args.module.fp8_meta["recipe"].mxfp8()
)
else:
accumulate_fc2_wgrad = (
args.fuse_wgrad_accumulation
and args.module.fp8_meta["recipe"].mxfp8()
)
if accumulate_fc2_wgrad:
fc2_out_dtype = fc2_weight.main_grad.dtype
else:
fc2_out_dtype = args.activation_dtype
fc2_wgrad_kwargs = {
"usage": TensorUsage.LHS_TRANS,
"usage_b": TensorUsage.RHS,
"out_dtype": fc2_out_dtype,
}
if args.wgrad_store is not None and args.wgrad_store.delay_wgrad_compute():
fc2_bias_grad_delayed = None
if fc2_bias is not None:
fc2_bias_grad_delayed = grad_output.sum(dim=0)
def fc2_wgrad_fn(x, dy):
if accumulate_fc2_wgrad:
dy.matmul_add(fc2_weight.main_grad, x, **fc2_wgrad_kwargs)
return _LayerNormMLP._create_grad_weight_placeholder(fc2_weight), fc2_bias_grad_delayed
return dy.matmul(x, **fc2_wgrad_kwargs), fc2_bias_grad_delayed
args.wgrad_store.put([fc2_input, fc2_grad_output], fc2_wgrad_fn)
fc2_wgrad = None
else:
if accumulate_fc2_wgrad:
fc2_grad_output.matmul_add(fc2_weight.main_grad, fc2_input, **fc2_wgrad_kwargs)
fc2_wgrad = _LayerNormMLP._create_grad_weight_placeholder(fc2_weight)
else:
fc2_wgrad = fc2_grad_output.matmul(fc2_input, **fc2_wgrad_kwargs)
fc2_bias_grad = None
if fc2_bias is not None:
fc2_bias_grad = grad_output.sum(dim=0)
act_bwd_fn = ACTIVATION_BWD[args.activation]
if args.activation == "clamped_swiglu" and args.activation_params:
dact = act_bwd_fn(fc1_out, fc2_dgrad, *args.activation_params)
else:
dact = act_bwd_fn(fc1_out, fc2_dgrad)
if not ctx.requires_dgrad:
args.fc1_grad_output_quantizer.set_usage(rowwise=False)
if not ctx.requires_wgrad:
args.fc1_grad_output_quantizer.set_usage(columnwise=False)
fc1_grad_output = args.fc1_grad_output_quantizer.quantize(dact)
if fc1_weight_fp8 is None and is_fsdp2:
if isinstance(fc1_weight, QuantizedTensorStorage):
fc1_weight_fp8 = fc1_weight
else:
fc1_weight_quantizer = args.fc1_weight_quantizer
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc1_weight_fp8 = fc1_weight_quantizer(fc1_weight)
fc1_dgrad_kwargs = {
"usage": TensorUsage.LHS,
"usage_b": TensorUsage.RHS,
"out_dtype": args.activation_dtype,
}
if args.overlap_rs_dgrad:
fc1_dgrad = fc1_grad_output.matmul_reduce_scatter(
fc1_weight_fp8, None, tp_world_size, args.tp_group, **fc1_dgrad_kwargs
)
else:
fc1_dgrad = fc1_grad_output.matmul(fc1_weight_fp8, **fc1_dgrad_kwargs)
if is_fsdp2 and isinstance(fc1_weight_fp8, QuantizedTensorStorage):
clear_columnwise_cache(fc1_weight_fp8)
handle = None
if args.tensor_parallel and args.set_parallel_mode and tp_world_size > 1:
if args.sequence_parallel and not args.overlap_rs_dgrad:
fc1_dgrad, handle = reduce_scatter_along_dim(
fc1_dgrad, args.tp_group, async_op=True
)
elif not args.sequence_parallel:
handle = torch.distributed.all_reduce(fc1_dgrad, group=args.tp_group, async_op=True)
if args.is_first_microbatch is not None:
accumulate_fc1_wgrad = (
args.fuse_wgrad_accumulation
and not args.is_first_microbatch
and args.module.fp8_meta["recipe"].mxfp8()
)
else:
accumulate_fc1_wgrad = (
args.fuse_wgrad_accumulation
and args.module.fp8_meta["recipe"].mxfp8()
)
if accumulate_fc1_wgrad:
fc1_out_dtype = fc1_weight.main_grad.dtype
else:
fc1_out_dtype = args.activation_dtype
fc1_wgrad_kwargs = {
"usage": TensorUsage.LHS_TRANS,
"usage_b": TensorUsage.RHS,
"out_dtype": fc1_out_dtype,
}
if args.wgrad_store is not None and args.wgrad_store.delay_wgrad_compute():
fc1_bias_grad_delayed = None
if fc1_bias is not None:
fc1_bias_grad_delayed = dact.sum(dim=0)
def fc1_wgrad_fn(x, dy):
if accumulate_fc1_wgrad:
dy.matmul_add(fc1_weight.main_grad, x, **fc1_wgrad_kwargs)
return _LayerNormMLP._create_grad_weight_placeholder(fc1_weight), fc1_bias_grad_delayed
return dy.matmul(x, **fc1_wgrad_kwargs), fc1_bias_grad_delayed
args.wgrad_store.put([fc1_input, fc1_grad_output], fc1_wgrad_fn)
fc1_wgrad = None
else:
if accumulate_fc1_wgrad:
fc1_grad_output.matmul_add(fc1_weight.main_grad, fc1_input, **fc1_wgrad_kwargs)
fc1_wgrad = _LayerNormMLP._create_grad_weight_placeholder(fc1_weight)
else:
fc1_wgrad = fc1_grad_output.matmul(fc1_input, **fc1_wgrad_kwargs)
fc1_bias_grad = None
if fc1_bias is not None:
fc1_bias_grad = dact.sum(dim=0)
if grad_ln_out is not None:
fc1_dgrad = fc1_dgrad + grad_ln_out
dx, dgamma, dbeta = _norm_bwd(fc1_dgrad, inputmat, ln_weight, mu, rsigma, args.zero_centered_gamma, args.normalization)
if ctx.reduce_and_update_bwd_fp8_tensors:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
if handle is not None:
handle.wait()
return dgamma, dbeta, fc1_wgrad, fc1_bias_grad, fc2_wgrad, fc2_bias_grad, dx.reshape(ctx.inp_shape), None
@staticmethod
def _create_grad_weight_placeholder(weight):
if not hasattr(weight, "grad_added_to_main_grad"):
return None
if getattr(weight, "zero_out_wgrad", False):
grad_weight = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.npu.current_device(),
requires_grad=False,
)
else:
grad_weight = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.npu.current_device(),
requires_grad=False,
)
weight.grad_added_to_main_grad = True
return grad_weight
class LayerNormMLP(TransformerEngineBaseModule):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
return_layernorm_output: bool = False,
return_layernorm_output_gathered: bool = True,
zero_centered_gamma: bool = False,
normalization: str = "LayerNorm",
activation: str = "gelu",
activation_params: Optional[tuple] = None,
init_method: Optional[Callable] = None,
output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
rng_tracker_name: Optional[str] = None,
bias: bool = True,
params_dtype: Optional[torch.dtype] = None,
return_bias: bool = False,
device: Union[torch.device, str] = "npu",
ub_overlap_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: Optional[str] = None,
normalization_eps: Optional[float] = None,
):
super(LayerNormMLP, self).__init__(name)
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma
self.normalization = normalization
self.activation = activation
self.activation_params = activation_params or ()
self.symmetric_ar_type = symmetric_ar_type
if normalization_eps is not None:
self.eps = normalization_eps
else:
self.eps = eps
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.tensor_parallel = self.tp_size > 1
self.set_parallel_mode = self.tensor_parallel
size_per_partition = divide(self.ffn_hidden_size, self.tp_size) if self.tensor_parallel else self.ffn_hidden_size
if activation in GLU_VARIANTS:
fc1_out_features = 2 * size_per_partition
else:
fc1_out_features = size_per_partition
self.overlap_ag_fprop = self.set_parallel_mode and self.sequence_parallel and ub_overlap_ag
self.overlap_rs_dgrad = self.set_parallel_mode and self.sequence_parallel and ub_overlap_rs_dgrad
self.overlap_rs_fprop = self.set_parallel_mode and self.sequence_parallel and ub_overlap_rs
self.overlap_ag_dgrad = self.set_parallel_mode and self.sequence_parallel and ub_overlap_ag
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
ln_weight_tensor = torch.ones(hidden_size, device=device, dtype=params_dtype)
self.register_parameter(
"ln_weight",
torch.nn.Parameter(ln_weight_tensor),
init_fn=init_method_constant(1.0),
)
if zero_centered_gamma:
torch.nn.init.zeros_(self.ln_weight)
if normalization != "RMSNorm":
ln_bias_tensor = torch.zeros(hidden_size, device=device, dtype=params_dtype)
self.register_parameter(
"layer_norm_bias",
torch.nn.Parameter(ln_bias_tensor),
init_fn=init_method_constant(0.0),
)
else:
self.layer_norm_bias = None
fc1_weight_tensor = torch.empty(
fc1_out_features, hidden_size, device=device, dtype=params_dtype
)
self.register_parameter(
"fc1_weight",
torch.nn.Parameter(fc1_weight_tensor),
init_fn=init_method or get_default_init_method(),
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT,
)
fc1_bias_tensor = None
if self.use_bias:
fc1_bias_tensor = torch.empty(
fc1_out_features, device=device, dtype=params_dtype
)
self.register_parameter(
"fc1_bias",
torch.nn.Parameter(fc1_bias_tensor),
init_fn=init_method_constant(0.0),
)
else:
self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)
fc2_weight_tensor = torch.empty(
hidden_size, size_per_partition, device=device, dtype=params_dtype
)
self.register_parameter(
"fc2_weight",
torch.nn.Parameter(fc2_weight_tensor),
init_fn=output_layer_init_method or init_method or get_default_init_method(),
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=FP8FwdTensorIdx.GEMM2_WEIGHT,
)
fc2_bias_tensor = None
if self.use_bias:
fc2_bias_tensor = torch.empty(
hidden_size, device=device, dtype=params_dtype
)
self.register_parameter(
"fc2_bias",
torch.nn.Parameter(fc2_bias_tensor),
init_fn=init_method_constant(0.0),
)
else:
self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device)
if with_fp8_params:
self.init_fp8_metadata(num_gemms=2)
self.reset_parameters(defer_init=device == "meta")
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in ("fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias"):
param.skip_backward_post_hook = True
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
fp8_output: Optional[bool] = False,
fp8_grad: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
is_grad_enabled = torch.is_grad_enabled()
inp = self.prepare_forward(inp, num_gemms=2)
skip_fp8_weight_update = None
if is_grad_enabled:
ln_mlp_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
ln_mlp_fn = _LayerNormMLP.forward
autograd_ctx = [None]
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers
non_tensor_args = _LayerNormMLPNonTensorArgs(
is_first_microbatch=is_first_microbatch,
fp8=self.fp8,
wgrad_store=self.wgrad_store,
eps=self.eps,
fc1_input_quantizer=fc1_input_quantizer,
fc1_weight_quantizer=fc1_weight_quantizer,
fc1_output_quantizer=fc1_output_quantizer,
fc1_grad_input_quantizer=fc1_grad_input_quantizer,
fc1_grad_weight_quantizer=fc1_grad_weight_quantizer,
fc1_grad_output_quantizer=fc1_grad_output_quantizer,
fc2_input_quantizer=fc2_input_quantizer,
fc2_weight_quantizer=fc2_weight_quantizer,
fc2_output_quantizer=fc2_output_quantizer,
fc2_grad_input_quantizer=fc2_grad_input_quantizer,
fc2_grad_weight_quantizer=fc2_grad_weight_quantizer,
fc2_grad_output_quantizer=fc2_grad_output_quantizer,
fuse_wgrad_accumulation=self.fuse_wgrad_accumulation,
tp_group=self.tp_group,
tp_size=self.tp_size,
sequence_parallel=self.sequence_parallel,
activation_dtype=self.activation_dtype,
tensor_parallel=self.tensor_parallel,
set_parallel_mode=self.set_parallel_mode,
is_grad_enabled=is_grad_enabled,
return_layernorm_output=self.return_layernorm_output,
return_layernorm_output_gathered=self.return_layernorm_output_gathered,
zero_centered_gamma=self.zero_centered_gamma,
normalization=self.normalization,
activation=self.activation,
activation_params=self.activation_params,
overlap_ag_fprop=self.overlap_ag_fprop,
overlap_rs_fprop=self.overlap_rs_fprop,
overlap_ag_dgrad=self.overlap_ag_dgrad,
overlap_rs_dgrad=self.overlap_rs_dgrad,
module=self,
skip_fp8_weight_update=skip_fp8_weight_update,
fp8_output=fp8_output,
fsdp_group=self.fsdp_group,
is_fsdp2=self.is_fsdp2,
)
try:
out = ln_mlp_fn(
*autograd_ctx,
self.ln_weight,
self.layer_norm_bias,
self.fc1_weight,
self.fc1_bias if self.apply_bias else None,
self.fc2_weight,
self.fc2_bias if self.apply_bias else None,
inp,
non_tensor_args,
)
finally:
self.end_forward()
if self.return_layernorm_output:
return out
return out
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
if not defer_init:
setattr(self.ln_weight, "sequence_parallel", self.sequence_parallel)
if self.layer_norm_bias is not None:
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
set_tensor_model_parallel_attributes(
tensor=self.fc1_weight,
is_parallel=True,
dim=0,
stride=1,
)
set_tensor_model_parallel_attributes(
tensor=self.fc2_weight,
is_parallel=True,
dim=1,
stride=1,
)
if self.use_bias:
set_tensor_model_parallel_attributes(
self.fc1_bias, True, 0, 1
)
setattr(self.fc2_bias, "sequence_parallel", self.sequence_parallel)
def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
if not self.fp8:
return [None] * 12
fc1_input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
fc1_input_quantizer.internal = True
if not (self.set_parallel_mode and self.sequence_parallel):
fc1_input_quantizer.optimize_for_gemm = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_INPUT]
fc2_input_quantizer.internal = True
if not (self.set_parallel_mode and self.sequence_parallel):
fc2_input_quantizer.optimize_for_gemm = True
fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers()
fc1_output_quantizer = None
fc2_output_quantizer = None
if fp8_output:
fc1_output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
fc2_output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_OUTPUT]
fc1_grad_input_quantizer = None
fc1_grad_weight_quantizer = None
fc1_grad_output_quantizer = None
fc2_grad_input_quantizer = None
fc2_grad_weight_quantizer = None
fc2_grad_output_quantizer = None
if is_grad_enabled:
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
fc1_grad_output_quantizer.internal = True
if not (self.set_parallel_mode and self.sequence_parallel):
fc1_grad_output_quantizer.optimize_for_gemm = True
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT2]
fc2_grad_output_quantizer.internal = True
if not (self.set_parallel_mode and self.sequence_parallel):
fc2_grad_output_quantizer.optimize_for_gemm = True
if fp8_grad:
fc1_grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
fc2_grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT2]
return (
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
)
def _get_weight_quantizers(self):
if not self.fp8:
return [None, None]
fc1_weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer]
def _get_weight_tensors(self):
unfused_weights = [self.fc1_weight, self.fc2_weight]
from ..quantized_tensor import QuantizedTensor
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
pass
else:
import warnings
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() if isinstance(w, QuantizedTensor) else w for w in unfused_weights]
return unfused_weights
def need_backward_dw(self):
return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()
def backward_dw(self):
if not self.need_backward_dw():
return
(fc2_wgrad, fc2_bias_grad, *_), tensor_list_fc2 = self.wgrad_store.pop()
(fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop()
if self.use_bias:
if self.fc2_bias.grad is None and fc2_bias_grad is not None:
self.fc2_bias.grad = fc2_bias_grad.to(self.fc2_bias.dtype)
if self.fc1_bias.grad is None and fc1_bias_grad is not None:
self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype)
if not self.fuse_wgrad_accumulation:
self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype)
self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
del fc2_wgrad, fc2_bias_grad, fc1_wgrad, fc1_bias_grad