import os
from typing import (
Any,
Callable,
Dict,
NamedTuple,
Optional,
Tuple,
Union,
)
import torch
import torch_npu
from ..constants import (
FP8BwdTensorIdx,
FP8FwdTensorIdx,
ParallelMode,
TensorUsage,
dist_group_type,
)
from ..distributed import (
DummyHandle,
_fsdp_gather_tensors,
_fsdp_scatter_tensors,
gather_along_dim,
get_distributed_world_size,
in_fp8_activation_recompute_phase,
is_fp8_activation_recompute_enabled,
reduce_scatter_along_dim,
)
from ..jit import no_torch_dynamo
from ..ops.fused.overlap import CommOverlapOps
from ..ops.gemm import (
general_gemm,
general_gemm_add,
)
from ..quantization import FP8GlobalStateManager
from ..quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
from ..tensor.utils import clear_columnwise_cache
from ..utils import (
cast_if_needed,
divide,
get_default_init_method,
init_method_constant,
requires_grad,
)
from .base import (
TransformerEngineBaseModule,
setup_dummy_wgrad,
)
class _LayerNormLinearNonTensorArgs(NamedTuple):
is_first_microbatch: bool
fp8: bool
eps: float
input_quantizer: Quantizer
weight_quantizer: Quantizer
output_quantizer: Quantizer
grad_input_quantizer: Quantizer
grad_weight_quantizer: Quantizer
grad_output_quantizer: Quantizer
fused_wgrad_accumulation: bool
cpu_offloading: bool
tp_group: torch.distributed.group
tp_size: int
sequence_parallel: bool
activation_dtype: torch.dtype
tensor_parallel: bool
parallel_mode: Optional[str]
is_grad_enabled: bool
fp8_output: bool
module: "LayerNormLinear"
skip_fp8_weight_update: bool
save_origin_input: bool
overlap_ag_fprop: bool
overlap_rs_dgrad: bool
overlap_rs_fprop: bool
overlap_ag_dgrad: bool
normalization: str
zero_centered_gamma: bool
return_layernorm_output: bool
return_layernorm_output_gathered: bool
fsdp_group: Optional[Any]
is_fsdp2: bool
@property
def ub_overlap_ag(self):
return self.overlap_ag_dgrad
class _LayerNormLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor,
bias: torch.Tensor,
non_tensor_args: _LayerNormLinearNonTensorArgs,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
args = non_tensor_args
out_features, in_features = weight.shape
inp_shape = inp.shape
inp_requires_grad = inp.requires_grad
assert inp_shape[-1] == in_features, "GEMM not possible"
inp = inp.view((-1, in_features))
inputmat = inp
inputmat = cast_if_needed(inputmat, args.activation_dtype)
ln_weight = cast_if_needed(ln_weight, args.activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, args.activation_dtype)
tp_world_size = get_distributed_world_size(args.tp_group)
with_input_all_gather = args.parallel_mode == ParallelMode.COLUMN and args.sequence_parallel
if args.normalization == "LayerNorm":
if args.zero_centered_gamma:
gamma = 1 + ln_weight
else:
gamma = ln_weight
ln_out = torch.nn.functional.layer_norm(
inputmat,
[inputmat.shape[-1]],
weight=gamma,
bias=ln_bias,
eps=args.eps,
)
mu = inputmat.mean(dim=-1, keepdim=True)
var = inputmat.var(dim=-1, unbiased=False, keepdim=True)
rsigma = torch.rsqrt(var + args.eps)
elif args.normalization == "RMSNorm":
if args.zero_centered_gamma:
gamma = 1 + ln_weight
else:
gamma = ln_weight
ln_out, rrsigma = torch_npu.npu_rms_norm(inputmat, gamma, epsilon=args.eps)
mu = None
rsigma = rrsigma
else:
raise ValueError(f"Unsupported normalization type: {args.normalization}")
ln_out_return = None
if args.return_layernorm_output or args.return_layernorm_output_gathered:
ln_out_return = ln_out
if args.fp8:
backward_needs_input = args.is_grad_enabled and weight.requires_grad
if is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase():
args.input_quantizer.set_usage(columnwise=False)
args.weight_quantizer.set_usage(columnwise=False)
else:
args.input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input,
)
if args.is_fsdp2:
args.weight_quantizer.set_usage(columnwise=False)
if with_input_all_gather and not args.overlap_ag_fprop:
ln_out, _ = gather_along_dim(ln_out, args.tp_group)
mm_inp = args.input_quantizer.quantize(ln_out)
update_workspace = args.is_first_microbatch is None or args.is_first_microbatch
weightmat = args.module.get_weight_workspace(
tensor=weight,
quantizer=args.weight_quantizer,
cache_name=(
None if (args.is_first_microbatch is None or args.is_fsdp2) else "weight"
),
update_workspace=update_workspace,
skip_update_flag=args.skip_fp8_weight_update,
workspace_dtype=args.activation_dtype,
)
weightmat.update_usage(rowwise_usage=True)
else:
mm_inp = ln_out
weightmat = weight
if with_input_all_gather and not args.overlap_ag_fprop:
mm_inp, _ = gather_along_dim(mm_inp, args.tp_group)
mm_kwargs = {
"usage_a": TensorUsage.LHS,
"usage_b": TensorUsage.RHS_TRANS,
"out_dtype": args.activation_dtype,
}
if args.overlap_ag_fprop:
out, mm_inp = CommOverlapOps.allgather_matmul(
mm_inp,
weightmat,
bias,
tp_world_size,
args.tp_group,
**mm_kwargs,
)
elif args.overlap_rs_fprop:
out = CommOverlapOps.matmul_reduce_scatter(
mm_inp,
weightmat,
bias,
tp_world_size,
args.tp_group,
**mm_kwargs,
)
else:
out = general_gemm(
mm_inp,
weightmat,
bias=bias,
**mm_kwargs,
)
if (
args.parallel_mode == ParallelMode.ROW
and not args.overlap_rs_fprop
and args.tp_size > 1
):
if args.sequence_parallel:
out, _ = CommOverlapOps.reduce_scatter(
out,
args.output_quantizer,
tp_world_size,
args.tp_group,
use_quant=args.fp8_output and args.module.fp8_meta["recipe"].mxfp8(),
)
else:
torch.distributed.all_reduce(out, group=args.tp_group)
out = out.view(-1, *inp_shape[1:-1], out_features)
if args.return_layernorm_output:
if args.return_layernorm_output_gathered:
if with_input_all_gather:
ln_out_return = mm_inp
shape = list(inp_shape)
shape[0] *= tp_world_size if with_input_all_gather else 1
return out, ln_out_return.view(shape)
return out, ln_out_return.view(inp_shape)
if args.is_grad_enabled:
ctx.args = args
ctx.use_bias = bias is not None
ctx.inp_shape = inp_shape
ctx.fp8 = args.fp8
ctx.debug = False
ctx.sequence_parallel = args.sequence_parallel
ctx.ub_overlap_ag = args.ub_overlap_ag
ctx.requires_dgrad = inp_requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.normalization = args.normalization
ctx.zero_centered_gamma = args.zero_centered_gamma
ctx.return_layernorm_output = args.return_layernorm_output
ctx.return_layernorm_output_gathered = args.return_layernorm_output_gathered
ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.fsdp_group = args.fsdp_group
ctx.is_fsdp2 = args.is_fsdp2
if args.fp8:
ctx.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage)
wt_save = weightmat
if args.is_fsdp2 and weightmat is not weight:
wt_save = None
ctx.fsdp_shapes = _fsdp_scatter_tensors(
args.fsdp_group,
mu,
rsigma,
weightmat if not ctx.is_weight_param_quantized else None,
mm_inp if weight.requires_grad else None,
)
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
mm_inp,
wt_save,
weight,
bias,
ln_weight,
mu,
rsigma,
)
else:
ctx.fsdp_shapes = _fsdp_scatter_tensors(
args.fsdp_group,
mu,
rsigma,
mm_inp if weight.requires_grad else None,
)
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
mm_inp,
weight,
bias,
ln_weight,
mu,
rsigma,
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
if args.fused_wgrad_accumulation and weight.requires_grad:
if hasattr(weight, "__fsdp_param__"):
ctx.main_grad_func = weight.get_main_grad
else:
ctx.main_grad_func = lambda: weight.main_grad
if args.fp8 and requires_grad(inputmat, ln_weight, ln_bias, weight, bias):
qstate = FP8GlobalStateManager.quantization_state
_first_fp8_module = qstate.is_first_fp8_module
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
qstate.is_first_fp8_module = _first_fp8_module
return out
@staticmethod
def backward(ctx, *grad_outputs):
args: _LayerNormLinearNonTensorArgs = ctx.args
saved_tensors = ctx.saved_tensors
is_fp8 = getattr(ctx, 'fp8', False)
if is_fp8:
inputmat, mm_inp, weightmat, weight, bias, ln_weight, mu, rsigma = restore_from_saved(
ctx.tensor_objects, saved_tensors
)
else:
inputmat, mm_inp, weight, bias, ln_weight, mu, rsigma = restore_from_saved(
ctx.tensor_objects, saved_tensors
)
weightmat = None
ctx.tensor_objects = None
fsdp_group = getattr(ctx, "fsdp_group", None)
if fsdp_group is not None:
_fsdp_gather_tensors(
fsdp_group,
ctx.fsdp_shapes,
mu,
rsigma,
weightmat
if is_fp8 and not getattr(ctx, "is_weight_param_quantized", False)
else None,
mm_inp if ctx.requires_wgrad else None,
)
_main_grad = (
ctx.main_grad_func()
if weight is not None and hasattr(ctx, 'main_grad_func') and ctx.requires_wgrad
else None
)
tp_world_size = get_distributed_world_size(args.tp_group)
dgrad = None
wgrad = None
if is_fp8:
if not ctx.requires_wgrad:
args.grad_output_quantizer.set_usage(columnwise=False)
mm_grad, grad_bias = TransformerEngineBaseModule.grad_output_preprocess(
ctx,
grad_outputs[0],
args.parallel_mode == ParallelMode.ROW,
args.grad_output_quantizer,
)
handle = DummyHandle
is_fsdp2 = getattr(ctx, "is_fsdp2", False)
if weightmat is None and is_fsdp2 and is_fp8:
if isinstance(weight, QuantizedTensorStorage):
weightmat = weight
else:
args.weight_quantizer.set_usage(rowwise=True, columnwise=True)
weightmat = args.weight_quantizer(weight)
dgrad_weight = weightmat if weightmat is not None else weight
dgrad_kwargs = {
"usage_a": TensorUsage.LHS,
"usage_b": TensorUsage.RHS,
"out_dtype": args.activation_dtype,
}
if args.overlap_ag_dgrad:
dgrad, mm_grad = CommOverlapOps.allgather_matmul(
mm_grad,
dgrad_weight,
None,
tp_world_size,
args.tp_group,
**dgrad_kwargs,
)
elif args.overlap_rs_dgrad:
dgrad = CommOverlapOps.matmul_reduce_scatter(
mm_grad,
dgrad_weight,
None,
tp_world_size,
args.tp_group,
**dgrad_kwargs,
)
else:
dgrad = general_gemm(mm_grad, dgrad_weight, **dgrad_kwargs)
if tp_world_size > 1 and args.parallel_mode == ParallelMode.COLUMN:
if args.sequence_parallel:
dgrad, handle = reduce_scatter_along_dim(
dgrad,
args.tp_group,
async_op=True,
)
else:
handle = torch.distributed.all_reduce(
dgrad,
group=args.tp_group,
async_op=True,
)
if is_fsdp2 and is_fp8 and isinstance(weightmat, QuantizedTensorStorage):
clear_columnwise_cache(weightmat)
if ctx.requires_wgrad:
if is_fp8:
use_fuse_wgrad_accumulation = (
args.fused_wgrad_accumulation and args.module.fp8_meta["recipe"].mxfp8()
)
else:
use_fuse_wgrad_accumulation = (
args.fused_wgrad_accumulation and weight.main_grad.dtype == torch.float32
)
out_dtype = (
weight.main_grad.dtype if use_fuse_wgrad_accumulation else args.activation_dtype
)
if not is_fp8:
mm_inp = mm_inp.view(-1, mm_inp.shape[-1])
mm_grad = mm_grad.view(-1, mm_grad.shape[-1])
wgrad_kwargs = {
"usage_a": TensorUsage.LHS_TRANS,
"usage_b": TensorUsage.RHS,
"out_dtype": out_dtype,
}
if use_fuse_wgrad_accumulation:
general_gemm_add(weight.main_grad, mm_grad, mm_inp, **wgrad_kwargs)
wgrad = setup_dummy_wgrad(weight)
else:
wgrad = general_gemm(mm_grad, mm_inp, **wgrad_kwargs)
if is_fp8 and ctx.reduce_and_update_bwd_fp8_tensors:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
if ctx.use_bias and grad_bias is None:
grad_bias = torch.reshape(grad_outputs[0], (-1, grad_outputs[0].shape[-1])).sum(dim=0)
if handle is not None:
handle.wait()
if dgrad is not None:
dgrad = dgrad.view(inputmat.shape)
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
if dgrad is not None:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgamma = None
dbeta = None
if ctx.normalization == "LayerNorm":
dgrad_2d = dgrad.view(-1, dgrad.shape[-1])
inputmat_2d = inputmat.view(-1, inputmat.shape[-1])
ln_weight_2d = ln_weight
mean = inputmat_2d.mean(dim=-1, keepdim=True)
var = inputmat_2d.var(dim=-1, unbiased=False, keepdim=True)
rsigma_comp = torch.rsqrt(var + args.eps)
x_hat = (inputmat_2d - mean) * rsigma_comp
if ctx.zero_centered_gamma:
dgamma = (dgrad_2d * x_hat).sum(dim=0)
dbeta = dgrad_2d.sum(dim=0)
dx_hat = dgrad_2d * (1 + ln_weight_2d)
else:
dgamma = (dgrad_2d * x_hat).sum(dim=0)
dbeta = dgrad_2d.sum(dim=0)
dx_hat = dgrad_2d * ln_weight_2d
n = inputmat_2d.shape[-1]
dvar = (dx_hat * (inputmat_2d - mean) * (-0.5) * (rsigma_comp**3)).sum(
dim=-1, keepdim=True
)
dmean = (-dx_hat * rsigma_comp).sum(dim=-1, keepdim=True) + dvar * (-2.0 / n) * (
inputmat_2d - mean
).sum(dim=-1, keepdim=True)
dgrad = dx_hat * rsigma_comp + dvar * 2.0 / n * (inputmat_2d - mean) + dmean / n
dgrad = dgrad.reshape(inputmat.shape)
elif ctx.normalization == "RMSNorm":
dgrad_2d = dgrad.view(-1, dgrad.shape[-1])
inputmat_2d = inputmat.view(-1, inputmat.shape[-1])
ln_weight_2d = ln_weight
if rsigma is not None:
rrms = rsigma
else:
variance = inputmat_2d.pow(2).mean(dim=-1, keepdim=True)
rrms = torch.rsqrt(variance + args.eps)
x_hat = inputmat_2d * rrms
if ctx.zero_centered_gamma:
dgamma = (dgrad_2d * x_hat).sum(dim=0)
dx_hat = dgrad_2d * (1 + ln_weight_2d)
else:
dgamma = (dgrad_2d * x_hat).sum(dim=0)
dx_hat = dgrad_2d * ln_weight_2d
n = inputmat_2d.shape[-1]
dvar = (dx_hat * inputmat_2d).sum(dim=-1, keepdim=True) * (-0.5) * (rrms**3)
dgrad = dx_hat * rrms + dvar * 2.0 * inputmat_2d / n
dgrad = dgrad.reshape(inputmat.shape)
dbeta = None
if not ctx.use_bias:
grad_bias = None
if ctx.requires_wgrad:
if hasattr(ctx, 'main_grad_func') and hasattr(weight, "grad_added_to_main_grad"):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
list(weight.main_grad.shape),
dtype=weight.dtype,
device=torch.npu.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
list(weight.main_grad.shape),
dtype=weight.dtype,
device=torch.npu.current_device(),
requires_grad=False,
)
elif hasattr(ctx, 'main_grad_func'):
wgrad = None
else:
wgrad = None
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
wgrad,
grad_bias,
None,
)
@staticmethod
def _create_grad_weight_placeholder(weight: torch.Tensor) -> Optional[torch.Tensor]:
if not hasattr(weight, "grad_added_to_main_grad"):
return None
if getattr(weight, "zero_out_wgrad", False):
grad_weight = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.npu.current_device(),
requires_grad=False,
)
else:
grad_weight = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.npu.current_device(),
requires_grad=False,
)
weight.grad_added_to_main_grad = True
return grad_weight
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
"partition_dim": -1,
"partition_stride": 1,
}
def set_tensor_model_parallel_attributes(
tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int
) -> None:
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
setattr(tensor, "tensor_model_parallel", is_parallel)
setattr(tensor, "partition_dim", dim)
setattr(tensor, "partition_stride", stride)
class LayerNormLinear(TransformerEngineBaseModule):
def __init__(
self,
in_features: int,
out_features: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
normalization: str = "LayerNorm",
return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
return_layernorm_output_gathered: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "npu",
ub_overlap_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
save_original_input: bool = False,
name: Optional[str] = None,
) -> None:
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.normalization = normalization
assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = (
return_layernorm_output_gathered if return_layernorm_output else False
)
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.parallel_mode = parallel_mode
assert self.parallel_mode in [None, "column", "row"], (
f"parallel_mode {parallel_mode} not supported"
)
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
if init_method is None:
init_method = get_default_init_method()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.overlap_ag_fprop = (
self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag
)
self.overlap_rs_dgrad = (
self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad
)
self.overlap_rs_fprop = (
self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs
)
self.overlap_ag_dgrad = (
self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag
)
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
assert parameters_split is None
self.eps = eps
layer_norm_weight = torch.nn.Parameter(
torch.empty(self.in_features, device=device, dtype=params_dtype)
)
self.register_parameter(
"layer_norm_weight",
layer_norm_weight,
init_fn=init_method_constant(float(not self.zero_centered_gamma)),
)
if self.normalization != "RMSNorm":
layer_norm_bias = torch.nn.Parameter(
torch.empty(self.in_features, device=device, dtype=params_dtype)
)
self.register_parameter(
"layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0)
)
else:
self.layer_norm_bias = None
weight_tensor = torch.empty(
self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
)
bias_tensor = None
if self.use_bias:
bias_tensor = torch.empty(
self.out_features,
device=device,
dtype=params_dtype,
)
self.weight_names = ["weight"]
self.bias_names = ["bias"]
self.parameter_split_sizes = [out_features]
if sum(self.parameter_split_sizes) != out_features:
raise ValueError(
f"Trying to split weight buffer ({out_features=}) "
f"with split sizes {self.parameter_split_sizes}"
)
if self.parallel_mode == "column":
for i, size in enumerate(self.parameter_split_sizes):
if size % self.tp_size != 0:
raise RuntimeError(
f"Attempting to distribute a parameter with out_features={size} "
f"between {self.tp_size} tensor-parallel processes"
)
self.parameter_split_sizes[i] = size // self.tp_size
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
offset += split_size
split_end = offset
is_subview = (split_start, split_end) != (0, self.out_features)
if is_subview and with_fp8_params:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
self.register_parameter(
self.weight_names[i],
torch.nn.Parameter(weight_tensor[split_start:split_end]),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT,
)
if self.use_bias:
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
offset += split_size
split_end = offset
self.register_parameter(
self.bias_names[i],
torch.nn.Parameter(bias_tensor[split_start:split_end]),
init_fn=init_method_constant(0.0),
)
else:
for _name in self.bias_names:
b = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, _name, b)
if with_fp8_params:
self.init_fp8_metadata()
self.reset_parameters(defer_init=device == "meta")
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
else:
self.gemm_bias_unfused_add = False
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
if not defer_init:
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
for weight in self.weight_names:
set_tensor_model_parallel_attributes(
tensor=getattr(self, weight),
is_parallel=True,
dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
for bias_name in self.bias_names:
if self.parallel_mode == "row":
setattr(
getattr(self, bias_name),
"sequence_parallel",
self.sequence_parallel,
)
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bias_name), True, 0, 1)
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
fp8_output: Optional[bool] = False,
fp8_grad: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
is_grad_enabled = torch.is_grad_enabled()
inp = self.prepare_forward(inp, allow_non_contiguous=False)
skip_fp8_weight_update = None
weight_tensor = getattr(self, self.weight_names[0])
bias_tensor = getattr(self, self.bias_names[0]) if self.use_bias else None
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = _LayerNormLinearNonTensorArgs(
is_first_microbatch=is_first_microbatch,
fp8=self.fp8,
eps=self.eps,
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
grad_input_quantizer=grad_input_quantizer,
grad_weight_quantizer=grad_weight_quantizer,
grad_output_quantizer=grad_output_quantizer,
fused_wgrad_accumulation=self.fuse_wgrad_accumulation,
cpu_offloading=False,
tp_group=self.tp_group,
tp_size=self.tp_size,
sequence_parallel=self.sequence_parallel,
activation_dtype=self.activation_dtype,
tensor_parallel=self.tp_size > 1,
parallel_mode=self.parallel_mode,
is_grad_enabled=is_grad_enabled,
fp8_output=fp8_output,
module=self,
skip_fp8_weight_update=skip_fp8_weight_update,
save_origin_input=self.save_original_input,
overlap_ag_fprop=self.overlap_ag_fprop,
overlap_rs_dgrad=self.overlap_rs_dgrad,
overlap_rs_fprop=self.overlap_rs_fprop,
overlap_ag_dgrad=self.overlap_ag_dgrad,
normalization=self.normalization,
zero_centered_gamma=self.zero_centered_gamma,
return_layernorm_output=self.return_layernorm_output,
return_layernorm_output_gathered=self.return_layernorm_output_gathered,
fsdp_group=self.fsdp_group,
is_fsdp2=self.is_fsdp2,
)
try:
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
finally:
self.end_forward()
if self.return_layernorm_output:
out, ln_out = out
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_bias:
if self.return_layernorm_output:
return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out
return out, cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_layernorm_output:
return out, ln_out
return out
def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
if not self.fp8:
return [None] * 6
grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
input_quantizer.internal = True
if not (self.parallel_mode == "column" and self.sequence_parallel):
input_quantizer.optimize_for_gemm = True
(weight_quantizer,) = self._get_weight_quantizers()
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
grad_output_quantizer.internal = True
if not (self.parallel_mode == "row" and self.sequence_parallel):
grad_output_quantizer.optimize_for_gemm = True
if fp8_grad:
grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
return (
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
def _get_weight_quantizers(self):
if not self.fp8:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
__all__ = ["LayerNormLinear"]