# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# See LICENSE for license information.

import os
from typing import (
    Any,
    Callable,
    Dict,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)

import torch
import torch_npu

from ..constants import (
    FP8BwdTensorIdx,
    FP8FwdTensorIdx,
    ParallelMode,
    TensorUsage,
    dist_group_type,
)
from ..distributed import (
    DummyHandle,
    _fsdp_gather_tensors,
    _fsdp_scatter_tensors,
    gather_along_dim,
    get_distributed_world_size,
    in_fp8_activation_recompute_phase,
    is_fp8_activation_recompute_enabled,
    reduce_scatter_along_dim,
)
from ..jit import no_torch_dynamo
from ..ops.fused.overlap import CommOverlapOps
from ..ops.gemm import (
    general_gemm,
    general_gemm_add,
)
from ..quantization import FP8GlobalStateManager
from ..quantized_tensor import (
    QuantizedTensorStorage,
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)
from ..tensor.utils import clear_columnwise_cache
from ..utils import (
    cast_if_needed,
    divide,
    get_default_init_method,
    init_method_constant,
    requires_grad,
)
from .base import (
    TransformerEngineBaseModule,
    setup_dummy_wgrad,
)


class _LayerNormLinearNonTensorArgs(NamedTuple):
    is_first_microbatch: bool
    fp8: bool
    eps: float
    input_quantizer: Quantizer
    weight_quantizer: Quantizer
    output_quantizer: Quantizer
    grad_input_quantizer: Quantizer
    grad_weight_quantizer: Quantizer
    grad_output_quantizer: Quantizer
    fused_wgrad_accumulation: bool
    cpu_offloading: bool
    tp_group: torch.distributed.group
    tp_size: int
    sequence_parallel: bool
    activation_dtype: torch.dtype
    tensor_parallel: bool
    parallel_mode: Optional[str]
    is_grad_enabled: bool
    fp8_output: bool
    module: "LayerNormLinear"
    skip_fp8_weight_update: bool
    save_origin_input: bool
    overlap_ag_fprop: bool
    overlap_rs_dgrad: bool
    overlap_rs_fprop: bool
    overlap_ag_dgrad: bool
    normalization: str
    zero_centered_gamma: bool
    return_layernorm_output: bool
    return_layernorm_output_gathered: bool
    fsdp_group: Optional[Any]
    is_fsdp2: bool

    @property
    def ub_overlap_ag(self):
        return self.overlap_ag_dgrad


class _LayerNormLinear(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: Union[torch.Tensor, None],
        weight: torch.Tensor,
        bias: torch.Tensor,
        non_tensor_args: _LayerNormLinearNonTensorArgs,
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        args = non_tensor_args
        out_features, in_features = weight.shape
        inp_shape = inp.shape
        inp_requires_grad = inp.requires_grad
        assert inp_shape[-1] == in_features, "GEMM not possible"
        inp = inp.view((-1, in_features))
        inputmat = inp

        inputmat = cast_if_needed(inputmat, args.activation_dtype)
        ln_weight = cast_if_needed(ln_weight, args.activation_dtype)
        if ln_bias is not None:
            ln_bias = cast_if_needed(ln_bias, args.activation_dtype)

        tp_world_size = get_distributed_world_size(args.tp_group)
        with_input_all_gather = args.parallel_mode == ParallelMode.COLUMN and args.sequence_parallel

        if args.normalization == "LayerNorm":
            if args.zero_centered_gamma:
                gamma = 1 + ln_weight
            else:
                gamma = ln_weight
            ln_out = torch.nn.functional.layer_norm(
                inputmat,
                [inputmat.shape[-1]],
                weight=gamma,
                bias=ln_bias,
                eps=args.eps,
            )
            mu = inputmat.mean(dim=-1, keepdim=True)
            var = inputmat.var(dim=-1, unbiased=False, keepdim=True)
            rsigma = torch.rsqrt(var + args.eps)
        elif args.normalization == "RMSNorm":
            if args.zero_centered_gamma:
                gamma = 1 + ln_weight
            else:
                gamma = ln_weight
            ln_out, rrsigma = torch_npu.npu_rms_norm(inputmat, gamma, epsilon=args.eps)  # pylint: disable=no-member
            mu = None
            rsigma = rrsigma
        else:
            raise ValueError(f"Unsupported normalization type: {args.normalization}")

        ln_out_return = None
        if args.return_layernorm_output or args.return_layernorm_output_gathered:
            ln_out_return = ln_out

        if args.fp8:
            backward_needs_input = args.is_grad_enabled and weight.requires_grad

            if is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase():
                args.input_quantizer.set_usage(columnwise=False)
                args.weight_quantizer.set_usage(columnwise=False)
            else:
                args.input_quantizer.set_usage(
                    rowwise=True,
                    columnwise=backward_needs_input,
                )

            if args.is_fsdp2:
                args.weight_quantizer.set_usage(columnwise=False)

            if with_input_all_gather and not args.overlap_ag_fprop:
                ln_out, _ = gather_along_dim(ln_out, args.tp_group)

            mm_inp = args.input_quantizer.quantize(ln_out)

            update_workspace = args.is_first_microbatch is None or args.is_first_microbatch
            weightmat = args.module.get_weight_workspace(
                tensor=weight,
                quantizer=args.weight_quantizer,
                cache_name=(
                    None if (args.is_first_microbatch is None or args.is_fsdp2) else "weight"
                ),
                update_workspace=update_workspace,
                skip_update_flag=args.skip_fp8_weight_update,
                workspace_dtype=args.activation_dtype,
            )
            weightmat.update_usage(rowwise_usage=True)
        else:
            mm_inp = ln_out
            weightmat = weight

            if with_input_all_gather and not args.overlap_ag_fprop:
                mm_inp, _ = gather_along_dim(mm_inp, args.tp_group)

        mm_kwargs = {
            "usage_a": TensorUsage.LHS,
            "usage_b": TensorUsage.RHS_TRANS,
            "out_dtype": args.activation_dtype,
        }

        if args.overlap_ag_fprop:
            out, mm_inp = CommOverlapOps.allgather_matmul(
                mm_inp,
                weightmat,
                bias,
                tp_world_size,
                args.tp_group,
                **mm_kwargs,
            )
        elif args.overlap_rs_fprop:
            out = CommOverlapOps.matmul_reduce_scatter(
                mm_inp,
                weightmat,
                bias,
                tp_world_size,
                args.tp_group,
                **mm_kwargs,
            )
        else:
            out = general_gemm(
                mm_inp,
                weightmat,
                bias=bias,
                **mm_kwargs,
            )

        # mm_inp.clear_wise(rowwise=True)

        if (
            args.parallel_mode == ParallelMode.ROW
            and not args.overlap_rs_fprop
            and args.tp_size > 1
        ):
            if args.sequence_parallel:
                out, _ = CommOverlapOps.reduce_scatter(
                    out,
                    args.output_quantizer,
                    tp_world_size,
                    args.tp_group,
                    use_quant=args.fp8_output and args.module.fp8_meta["recipe"].mxfp8(),
                )
            else:
                torch.distributed.all_reduce(out, group=args.tp_group)

        out = out.view(-1, *inp_shape[1:-1], out_features)

        if args.return_layernorm_output:
            if args.return_layernorm_output_gathered:
                if with_input_all_gather:
                    ln_out_return = mm_inp
                shape = list(inp_shape)
                shape[0] *= tp_world_size if with_input_all_gather else 1
                return out, ln_out_return.view(shape)
            return out, ln_out_return.view(inp_shape)

        if args.is_grad_enabled:
            ctx.args = args
            ctx.use_bias = bias is not None
            ctx.inp_shape = inp_shape
            ctx.fp8 = args.fp8
            ctx.debug = False
            ctx.sequence_parallel = args.sequence_parallel
            ctx.ub_overlap_ag = args.ub_overlap_ag
            ctx.requires_dgrad = inp_requires_grad
            ctx.requires_wgrad = weight.requires_grad
            ctx.normalization = args.normalization
            ctx.zero_centered_gamma = args.zero_centered_gamma
            ctx.return_layernorm_output = args.return_layernorm_output
            ctx.return_layernorm_output_gathered = args.return_layernorm_output_gathered
            ctx.reduce_and_update_bwd_fp8_tensors = False

            ctx.fsdp_group = args.fsdp_group
            ctx.is_fsdp2 = args.is_fsdp2
            if args.fp8:
                ctx.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage)
                # FSDP2: Don't save FP8 workspace so backward re-quantizes from all-gathered weight
                wt_save = weightmat
                if args.is_fsdp2 and weightmat is not weight:
                    wt_save = None
                ctx.fsdp_shapes = _fsdp_scatter_tensors(
                    args.fsdp_group,
                    mu,
                    rsigma,
                    weightmat if not ctx.is_weight_param_quantized else None,
                    mm_inp if weight.requires_grad else None,
                )
                tensors_to_save, tensor_objects = prepare_for_saving(
                    inputmat,
                    mm_inp,
                    wt_save,
                    weight,
                    bias,
                    ln_weight,
                    mu,
                    rsigma,
                )
            else:
                ctx.fsdp_shapes = _fsdp_scatter_tensors(
                    args.fsdp_group,
                    mu,
                    rsigma,
                    mm_inp if weight.requires_grad else None,
                )
                tensors_to_save, tensor_objects = prepare_for_saving(
                    inputmat,
                    mm_inp,
                    weight,
                    bias,
                    ln_weight,
                    mu,
                    rsigma,
                )
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects

            if args.fused_wgrad_accumulation and weight.requires_grad:
                if hasattr(weight, "__fsdp_param__"):
                    ctx.main_grad_func = weight.get_main_grad
                else:
                    ctx.main_grad_func = lambda: weight.main_grad

            if args.fp8 and requires_grad(inputmat, ln_weight, ln_bias, weight, bias):
                qstate = FP8GlobalStateManager.quantization_state
                _first_fp8_module = qstate.is_first_fp8_module
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    qstate.is_first_fp8_module = _first_fp8_module

        return out

    @staticmethod
    def backward(ctx, *grad_outputs):
        args: _LayerNormLinearNonTensorArgs = ctx.args
        saved_tensors = ctx.saved_tensors

        is_fp8 = getattr(ctx, 'fp8', False)

        if is_fp8:
            inputmat, mm_inp, weightmat, weight, bias, ln_weight, mu, rsigma = restore_from_saved(
                ctx.tensor_objects, saved_tensors
            )
        else:
            inputmat, mm_inp, weight, bias, ln_weight, mu, rsigma = restore_from_saved(
                ctx.tensor_objects, saved_tensors
            )
            weightmat = None
        ctx.tensor_objects = None

        fsdp_group = getattr(ctx, "fsdp_group", None)
        if fsdp_group is not None:
            _fsdp_gather_tensors(
                fsdp_group,
                ctx.fsdp_shapes,
                mu,
                rsigma,
                weightmat
                if is_fp8 and not getattr(ctx, "is_weight_param_quantized", False)
                else None,
                mm_inp if ctx.requires_wgrad else None,
            )

        _main_grad = (
            ctx.main_grad_func()
            if weight is not None and hasattr(ctx, 'main_grad_func') and ctx.requires_wgrad
            else None
        )

        tp_world_size = get_distributed_world_size(args.tp_group)
        dgrad = None
        wgrad = None

        if is_fp8:
            if not ctx.requires_wgrad:
                args.grad_output_quantizer.set_usage(columnwise=False)

        mm_grad, grad_bias = TransformerEngineBaseModule.grad_output_preprocess(
            ctx,
            grad_outputs[0],
            args.parallel_mode == ParallelMode.ROW,
            args.grad_output_quantizer,
        )

        handle = DummyHandle

        is_fsdp2 = getattr(ctx, "is_fsdp2", False)
        if weightmat is None and is_fsdp2 and is_fp8:
            if isinstance(weight, QuantizedTensorStorage):
                weightmat = weight
            else:
                args.weight_quantizer.set_usage(rowwise=True, columnwise=True)
                weightmat = args.weight_quantizer(weight)

        dgrad_weight = weightmat if weightmat is not None else weight

        dgrad_kwargs = {
            "usage_a": TensorUsage.LHS,
            "usage_b": TensorUsage.RHS,
            "out_dtype": args.activation_dtype,
        }
        if args.overlap_ag_dgrad:
            dgrad, mm_grad = CommOverlapOps.allgather_matmul(
                mm_grad,
                dgrad_weight,
                None,
                tp_world_size,
                args.tp_group,
                **dgrad_kwargs,
            )
        elif args.overlap_rs_dgrad:
            dgrad = CommOverlapOps.matmul_reduce_scatter(
                mm_grad,
                dgrad_weight,
                None,
                tp_world_size,
                args.tp_group,
                **dgrad_kwargs,
            )
        else:
            dgrad = general_gemm(mm_grad, dgrad_weight, **dgrad_kwargs)
            if tp_world_size > 1 and args.parallel_mode == ParallelMode.COLUMN:
                if args.sequence_parallel:
                    dgrad, handle = reduce_scatter_along_dim(
                        dgrad,
                        args.tp_group,
                        async_op=True,
                    )
                else:
                    handle = torch.distributed.all_reduce(
                        dgrad,
                        group=args.tp_group,
                        async_op=True,
                    )

        # FSDP2: clear columnwise cache derived from all-gathered weight
        if is_fsdp2 and is_fp8 and isinstance(weightmat, QuantizedTensorStorage):
            clear_columnwise_cache(weightmat)

        if ctx.requires_wgrad:
            if is_fp8:
                use_fuse_wgrad_accumulation = (
                    args.fused_wgrad_accumulation and args.module.fp8_meta["recipe"].mxfp8()
                )
            else:
                use_fuse_wgrad_accumulation = (
                    args.fused_wgrad_accumulation and weight.main_grad.dtype == torch.float32
                )

            out_dtype = (
                weight.main_grad.dtype if use_fuse_wgrad_accumulation else args.activation_dtype
            )

            if not is_fp8:
                mm_inp = mm_inp.view(-1, mm_inp.shape[-1])
                mm_grad = mm_grad.view(-1, mm_grad.shape[-1])

            wgrad_kwargs = {
                "usage_a": TensorUsage.LHS_TRANS,
                "usage_b": TensorUsage.RHS,
                "out_dtype": out_dtype,
            }

            if use_fuse_wgrad_accumulation:
                general_gemm_add(weight.main_grad, mm_grad, mm_inp, **wgrad_kwargs)
                wgrad = setup_dummy_wgrad(weight)
            else:
                wgrad = general_gemm(mm_grad, mm_inp, **wgrad_kwargs)

        if is_fp8 and ctx.reduce_and_update_bwd_fp8_tensors:
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        if ctx.use_bias and grad_bias is None:
            grad_bias = torch.reshape(grad_outputs[0], (-1, grad_outputs[0].shape[-1])).sum(dim=0)

        if handle is not None:
            handle.wait()

        if dgrad is not None:
            dgrad = dgrad.view(inputmat.shape)  # pylint: disable=no-member
        if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
            if dgrad is not None:
                dgrad = dgrad + grad_outputs[1].view_as(dgrad)

        dgamma = None
        dbeta = None
        if ctx.normalization == "LayerNorm":
            dgrad_2d = dgrad.view(-1, dgrad.shape[-1])
            inputmat_2d = inputmat.view(-1, inputmat.shape[-1])  # pylint: disable=no-member
            ln_weight_2d = ln_weight

            mean = inputmat_2d.mean(dim=-1, keepdim=True)
            var = inputmat_2d.var(dim=-1, unbiased=False, keepdim=True)
            rsigma_comp = torch.rsqrt(var + args.eps)
            x_hat = (inputmat_2d - mean) * rsigma_comp

            if ctx.zero_centered_gamma:
                dgamma = (dgrad_2d * x_hat).sum(dim=0)
                dbeta = dgrad_2d.sum(dim=0)
                dx_hat = dgrad_2d * (1 + ln_weight_2d)
            else:
                dgamma = (dgrad_2d * x_hat).sum(dim=0)
                dbeta = dgrad_2d.sum(dim=0)
                dx_hat = dgrad_2d * ln_weight_2d

            n = inputmat_2d.shape[-1]
            dvar = (dx_hat * (inputmat_2d - mean) * (-0.5) * (rsigma_comp**3)).sum(
                dim=-1, keepdim=True
            )
            dmean = (-dx_hat * rsigma_comp).sum(dim=-1, keepdim=True) + dvar * (-2.0 / n) * (
                inputmat_2d - mean
            ).sum(dim=-1, keepdim=True)
            dgrad = dx_hat * rsigma_comp + dvar * 2.0 / n * (inputmat_2d - mean) + dmean / n
            dgrad = dgrad.reshape(inputmat.shape)  # pylint: disable=no-member

        elif ctx.normalization == "RMSNorm":
            dgrad_2d = dgrad.view(-1, dgrad.shape[-1])
            inputmat_2d = inputmat.view(-1, inputmat.shape[-1])  # pylint: disable=no-member
            ln_weight_2d = ln_weight

            if rsigma is not None:
                rrms = rsigma
            else:
                variance = inputmat_2d.pow(2).mean(dim=-1, keepdim=True)
                rrms = torch.rsqrt(variance + args.eps)

            x_hat = inputmat_2d * rrms

            if ctx.zero_centered_gamma:
                dgamma = (dgrad_2d * x_hat).sum(dim=0)
                dx_hat = dgrad_2d * (1 + ln_weight_2d)
            else:
                dgamma = (dgrad_2d * x_hat).sum(dim=0)
                dx_hat = dgrad_2d * ln_weight_2d

            n = inputmat_2d.shape[-1]
            dvar = (dx_hat * inputmat_2d).sum(dim=-1, keepdim=True) * (-0.5) * (rrms**3)
            dgrad = dx_hat * rrms + dvar * 2.0 * inputmat_2d / n
            dgrad = dgrad.reshape(inputmat.shape)  # pylint: disable=no-member
            dbeta = None

        if not ctx.use_bias:
            grad_bias = None

        if ctx.requires_wgrad:
            if hasattr(ctx, 'main_grad_func') and hasattr(weight, "grad_added_to_main_grad"):
                weight.grad_added_to_main_grad = True
                if getattr(weight, "zero_out_wgrad", False):
                    wgrad = torch.zeros(
                        list(weight.main_grad.shape),
                        dtype=weight.dtype,
                        device=torch.npu.current_device(),
                        requires_grad=False,
                    )
                else:
                    wgrad = torch.empty(
                        list(weight.main_grad.shape),
                        dtype=weight.dtype,
                        device=torch.npu.current_device(),
                        requires_grad=False,
                    )
            elif hasattr(ctx, 'main_grad_func'):
                wgrad = None
        else:
            wgrad = None

        return (
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            dgamma,
            dbeta,
            wgrad,
            grad_bias,
            None,
        )

    @staticmethod
    def _create_grad_weight_placeholder(weight: torch.Tensor) -> Optional[torch.Tensor]:
        if not hasattr(weight, "grad_added_to_main_grad"):
            return None
        if getattr(weight, "zero_out_wgrad", False):
            grad_weight = torch.zeros(
                weight.main_grad.shape,
                dtype=weight.dtype,
                device=torch.npu.current_device(),
                requires_grad=False,
            )
        else:
            grad_weight = torch.empty(
                weight.main_grad.shape,
                dtype=weight.dtype,
                device=torch.npu.current_device(),
                requires_grad=False,
            )
        weight.grad_added_to_main_grad = True
        return grad_weight


_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
    "tensor_model_parallel": False,
    "partition_dim": -1,
    "partition_stride": 1,
}


def set_tensor_model_parallel_attributes(
    tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int
) -> None:
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    setattr(tensor, "tensor_model_parallel", is_parallel)
    setattr(tensor, "partition_dim", dim)
    setattr(tensor, "partition_stride", stride)


class LayerNormLinear(TransformerEngineBaseModule):  # pylint: disable=abstract-method
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        normalization: str = "LayerNorm",
        return_bias: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        parallel_mode: Optional[str] = None,
        return_layernorm_output: bool = False,
        return_layernorm_output_gathered: bool = False,
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
        zero_centered_gamma: bool = False,
        device: Union[torch.device, str] = "npu",
        ub_overlap_ag: bool = False,
        ub_overlap_rs: bool = False,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_name: Optional[str] = None,
        delay_wgrad_compute: bool = False,
        symmetric_ar_type: Optional[str] = None,
        save_original_input: bool = False,
        name: Optional[str] = None,
    ) -> None:
        super().__init__(name)

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.normalization = normalization
        assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = self.use_bias and not return_bias
        self.return_layernorm_output = return_layernorm_output
        self.return_layernorm_output_gathered = (
            return_layernorm_output_gathered if return_layernorm_output else False
        )
        self.zero_centered_gamma = zero_centered_gamma
        self.symmetric_ar_type = symmetric_ar_type
        self.save_original_input = save_original_input

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)

        self.parallel_mode = parallel_mode
        assert self.parallel_mode in [None, "column", "row"], (
            f"parallel_mode {parallel_mode} not supported"
        )

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        if init_method is None:
            init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

        self.overlap_ag_fprop = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag
        )
        self.overlap_rs_dgrad = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad
        )
        self.overlap_rs_fprop = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs
        )
        self.overlap_ag_dgrad = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag
        )

        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

        assert parameters_split is None

        self.eps = eps
        layer_norm_weight = torch.nn.Parameter(
            torch.empty(self.in_features, device=device, dtype=params_dtype)
        )
        self.register_parameter(
            "layer_norm_weight",
            layer_norm_weight,
            init_fn=init_method_constant(float(not self.zero_centered_gamma)),
        )
        if self.normalization != "RMSNorm":
            layer_norm_bias = torch.nn.Parameter(
                torch.empty(self.in_features, device=device, dtype=params_dtype)
            )
            self.register_parameter(
                "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0)
            )
        else:
            self.layer_norm_bias = None

        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
        if self.use_bias:
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )

        self.weight_names = ["weight"]
        self.bias_names = ["bias"]
        self.parameter_split_sizes = [out_features]

        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )

        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            is_subview = (split_start, split_end) != (0, self.out_features)
            if is_subview and with_fp8_params:
                raise RuntimeError(
                    "Splitting QuantizedTensor into multiple params is not supported"
                )

            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT,
            )

        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for _name in self.bias_names:
                b = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, _name, b)

        if with_fp8_params:
            self.init_fp8_metadata()

        self.reset_parameters(defer_init=device == "meta")

        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
            if self.normalization != "RMSNorm":
                setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)

            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            if self.use_bias:
                for bias_name in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(
                            getattr(self, bias_name),
                            "sequence_parallel",
                            self.sequence_parallel,
                        )
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias_name), True, 0, 1)

    @no_torch_dynamo()
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
        fp8_output: Optional[bool] = False,
        fp8_grad: Optional[bool] = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        is_grad_enabled = torch.is_grad_enabled()

        inp = self.prepare_forward(inp, allow_non_contiguous=False)

        skip_fp8_weight_update = None

        weight_tensor = getattr(self, self.weight_names[0])
        bias_tensor = getattr(self, self.bias_names[0]) if self.use_bias else None

        quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
        (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            grad_input_quantizer,
            grad_weight_quantizer,
            grad_output_quantizer,
        ) = quantizers

        if is_grad_enabled:
            fwd_fn = _LayerNormLinear.apply
            autograd_ctx = []
        else:
            fwd_fn = _LayerNormLinear.forward
            autograd_ctx = [None]

        non_tensor_args = _LayerNormLinearNonTensorArgs(
            is_first_microbatch=is_first_microbatch,
            fp8=self.fp8,
            eps=self.eps,
            input_quantizer=input_quantizer,
            weight_quantizer=weight_quantizer,
            output_quantizer=output_quantizer,
            grad_input_quantizer=grad_input_quantizer,
            grad_weight_quantizer=grad_weight_quantizer,
            grad_output_quantizer=grad_output_quantizer,
            fused_wgrad_accumulation=self.fuse_wgrad_accumulation,
            cpu_offloading=False,
            tp_group=self.tp_group,
            tp_size=self.tp_size,
            sequence_parallel=self.sequence_parallel,
            activation_dtype=self.activation_dtype,
            tensor_parallel=self.tp_size > 1,
            parallel_mode=self.parallel_mode,
            is_grad_enabled=is_grad_enabled,
            fp8_output=fp8_output,
            module=self,
            skip_fp8_weight_update=skip_fp8_weight_update,
            save_origin_input=self.save_original_input,
            overlap_ag_fprop=self.overlap_ag_fprop,
            overlap_rs_dgrad=self.overlap_rs_dgrad,
            overlap_rs_fprop=self.overlap_rs_fprop,
            overlap_ag_dgrad=self.overlap_ag_dgrad,
            normalization=self.normalization,
            zero_centered_gamma=self.zero_centered_gamma,
            return_layernorm_output=self.return_layernorm_output,
            return_layernorm_output_gathered=self.return_layernorm_output_gathered,
            fsdp_group=self.fsdp_group,
            is_fsdp2=self.is_fsdp2,
        )

        try:
            out = fwd_fn(
                *autograd_ctx,
                inp,
                self.layer_norm_weight,
                self.layer_norm_bias,
                weight_tensor,
                bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
                non_tensor_args,
            )
        finally:
            self.end_forward()

        if self.return_layernorm_output:
            out, ln_out = out

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            if self.return_layernorm_output:
                return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        if self.return_layernorm_output:
            return out, ln_out
        return out

    def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
        if not self.fp8:
            return [None] * 6
        grad_input_quantizer = None
        grad_weight_quantizer = None
        grad_output_quantizer = None
        output_quantizer = None
        input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
        input_quantizer.internal = True
        if not (self.parallel_mode == "column" and self.sequence_parallel):
            input_quantizer.optimize_for_gemm = True
        (weight_quantizer,) = self._get_weight_quantizers()
        if fp8_output:
            output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
        if is_grad_enabled:
            grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
            grad_output_quantizer.internal = True
            if not (self.parallel_mode == "row" and self.sequence_parallel):
                grad_output_quantizer.optimize_for_gemm = True
            if fp8_grad:
                grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
        return (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            grad_input_quantizer,
            grad_weight_quantizer,
            grad_output_quantizer,
        )

    def _get_weight_quantizers(self):
        if not self.fp8:
            return [None]
        weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
        weight_quantizer.internal = True
        return [weight_quantizer]


__all__ = ["LayerNormLinear"]