# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# See LICENSE for license information.

import os
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch_npu

from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, ParallelMode, TensorUsage, dist_group_type
from ..ops.basic.npu_activation import (
    ACTIVATION_FWD,
    ACTIVATION_BWD,
    GLU_VARIANTS,
)
from ..distributed import (
    DummyHandle,
    _fsdp_gather_tensors,
    _fsdp_scatter_tensors,
    gather_along_dim,
    get_distributed_world_size,
    in_fp8_activation_recompute_phase,
    is_fp8_activation_recompute_enabled,
    reduce_scatter_along_dim,
)
from ..jit import no_torch_dynamo
from ..ops.fused.overlap import CommOverlapOps
from ..ops.gemm import matmul_add
from ..quantization import FP8GlobalStateManager
from ..quantized_tensor import QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved
from ..tensor.utils import clear_columnwise_cache
from ..utils import cast_if_needed, divide, get_default_init_method, init_method_constant, requires_grad
from ._common import WeightGradStore
from .base import TransformerEngineBaseModule


_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
    "tensor_model_parallel": False,
    "partition_dim": -1,
    "partition_stride": 1,
}


def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    setattr(tensor, "tensor_model_parallel", is_parallel)
    setattr(tensor, "partition_dim", dim)
    setattr(tensor, "partition_stride", stride)


class _LayerNormMLPNonTensorArgs(NamedTuple):
    is_first_microbatch: bool
    fp8: bool
    wgrad_store: WeightGradStore
    eps: float
    fc1_input_quantizer: Optional[Quantizer]
    fc1_weight_quantizer: Optional[Quantizer]
    fc1_output_quantizer: Optional[Quantizer]
    fc1_grad_input_quantizer: Optional[Quantizer]
    fc1_grad_weight_quantizer: Optional[Quantizer]
    fc1_grad_output_quantizer: Optional[Quantizer]
    fc2_input_quantizer: Optional[Quantizer]
    fc2_weight_quantizer: Optional[Quantizer]
    fc2_output_quantizer: Optional[Quantizer]
    fc2_grad_input_quantizer: Optional[Quantizer]
    fc2_grad_weight_quantizer: Optional[Quantizer]
    fc2_grad_output_quantizer: Optional[Quantizer]
    fuse_wgrad_accumulation: bool
    tp_group: Optional[dist_group_type]
    tp_size: int
    sequence_parallel: bool
    activation_dtype: torch.dtype
    tensor_parallel: bool
    set_parallel_mode: bool
    is_grad_enabled: bool
    return_layernorm_output: bool
    return_layernorm_output_gathered: bool
    zero_centered_gamma: bool
    normalization: str
    activation: str
    activation_params: tuple
    overlap_ag_fprop: bool
    overlap_rs_fprop: bool
    overlap_ag_dgrad: bool
    overlap_rs_dgrad: bool
    module: "LayerNormMLP"
    skip_fp8_weight_update: Optional[bool]
    fp8_output: bool
    fsdp_group: Optional[Any]
    is_fsdp2: bool


def _apply_norm(inputmat, ln_weight, ln_bias, eps, normalization, zero_centered_gamma, activation_dtype=None):
    if normalization == "LayerNorm":
        if zero_centered_gamma:
            gamma = 1.0 + ln_weight
        else:
            gamma = ln_weight
        if activation_dtype is not None:
            gamma = cast_if_needed(gamma, activation_dtype)
            if ln_bias is not None:
                ln_bias = cast_if_needed(ln_bias, activation_dtype)
        ln_out = F.layer_norm(inputmat, [inputmat.shape[-1]], weight=gamma, bias=ln_bias, eps=eps)
        mu = inputmat.mean(dim=-1, keepdim=True)
        rsigma = 1.0 / torch.sqrt((inputmat - mu).pow(2).mean(dim=-1, keepdim=True) + eps)
    elif normalization == "RMSNorm":
        if zero_centered_gamma:
            gamma = 1.0 + ln_weight
        else:
            gamma = ln_weight
        if activation_dtype is not None:
            gamma = cast_if_needed(gamma, activation_dtype)
        ln_out, rrsigma = torch_npu.npu_rms_norm(inputmat, gamma, epsilon=eps)
        mu = None
        rsigma = rrsigma
    else:
        raise ValueError(f"Unknown normalization: {normalization}")

    return ln_out, mu, rsigma


def _norm_bwd(dgrad, inputmat, ln_weight, mu, rsigma, zero_centered_gamma, normalization):
    n = inputmat.shape[-1]

    if normalization == "LayerNorm":
        x_hat = (inputmat - mu) * rsigma
        if zero_centered_gamma:
            dx_hat = dgrad * (1.0 + ln_weight)
        else:
            dx_hat = dgrad * ln_weight
        dvar = (dx_hat * (inputmat - mu) * (-0.5) * rsigma.pow(3)).sum(dim=-1, keepdim=True)
        dmean = (-dx_hat * rsigma).sum(dim=-1, keepdim=True) + dvar * (-2.0 / n) * (inputmat - mu).sum(dim=-1, keepdim=True)
        dx = dx_hat * rsigma + dvar * 2.0 / n * (inputmat - mu) + dmean / n
        dgamma = (dgrad * x_hat).sum(dim=tuple(range(len(dgrad.shape) - 1)))
        dbeta = dgrad.sum(dim=tuple(range(len(dgrad.shape) - 1)))
    elif normalization == "RMSNorm":
        rrms = rsigma
        x_hat = inputmat * rrms
        if zero_centered_gamma:
            dx_hat = dgrad * (1.0 + ln_weight)
        else:
            dx_hat = dgrad * ln_weight
        dvar = (dx_hat * inputmat).sum(dim=-1, keepdim=True) * (-0.5) * rrms.pow(3)
        dx = dx_hat * rrms + dvar * 2.0 * inputmat / n
        dgamma = (dgrad * x_hat).sum(dim=tuple(range(len(dgrad.shape) - 1)))
        dbeta = None
    else:
        raise ValueError(f"Unknown normalization: {normalization}")

    return dx, dgamma, dbeta


class _LayerNormMLP(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        ln_weight,
        ln_bias,
        fc1_weight,
        fc1_bias,
        fc2_weight,
        fc2_bias,
        inp,
        args: _LayerNormMLPNonTensorArgs,
    ):
        inputmat = inp.reshape(-1, inp.shape[-1])
        inputmat = cast_if_needed(inputmat, args.activation_dtype)

        ln_out, mu, rsigma = _apply_norm(inputmat, ln_weight, ln_bias, args.eps, args.normalization, args.zero_centered_gamma, args.activation_dtype)

        if args.return_layernorm_output:
            ln_out_return = ln_out.reshape(inp.shape)
            if args.return_layernorm_output_gathered and args.tensor_parallel and args.sequence_parallel:
                ln_out_return, _ = gather_along_dim(ln_out_return, args.tp_group)

        mm_inp = ln_out
        if args.tensor_parallel and args.sequence_parallel and args.set_parallel_mode and not args.overlap_ag_fprop:
            mm_inp, _ = gather_along_dim(mm_inp, args.tp_group)

        if args.fp8:
            return _LayerNormMLP.fp8_forward(ctx, ln_weight, ln_bias, fc1_weight, fc1_bias, fc2_weight, fc2_bias, inp, inputmat, ln_out, mm_inp, mu, rsigma, args)

        fc1_out = torch.matmul(mm_inp, fc1_weight.t())
        if fc1_bias is not None:
            fc1_out = fc1_out + fc1_bias

        act_fn = ACTIVATION_FWD[args.activation]
        if args.activation == "clamped_swiglu" and args.activation_params:
            act_out = act_fn(fc1_out, *args.activation_params)
        else:
            act_out = act_fn(fc1_out)

        fc2_out = torch.matmul(act_out, fc2_weight.t())
        if fc2_bias is not None:
            fc2_out = fc2_out + fc2_bias

        tp_world_size = get_distributed_world_size(args.tp_group)
        if args.tensor_parallel and args.set_parallel_mode and not args.overlap_rs_fprop and tp_world_size > 1:
            if args.sequence_parallel:
                fc2_out, _ = reduce_scatter_along_dim(fc2_out, args.tp_group)
            else:
                torch.distributed.all_reduce(fc2_out, group=args.tp_group)

        if args.is_grad_enabled:
            ctx.args = args
            ctx.inp_shape = inp.shape
            ctx.fp8 = args.fp8
            ctx.sequence_parallel = args.sequence_parallel
            ctx.requires_dgrad = inp.requires_grad
            ctx.requires_wgrad = fc1_weight.requires_grad or fc2_weight.requires_grad

            ctx.fsdp_group = args.fsdp_group
            ctx.is_fsdp2 = args.is_fsdp2
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                args.fsdp_group,
                mu,
                rsigma,
                mm_inp,
                fc1_out,
                act_out,
            )

            tensors_to_save, tensor_objects = prepare_for_saving(
                inputmat,
                fc1_weight,
                fc1_bias,
                fc2_weight,
                fc2_bias,
                ln_weight,
                ln_bias,
                mm_inp,
                fc1_out,
                act_out,
                mu,
                rsigma,
            )
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects

        if args.return_layernorm_output:
            return fc2_out.reshape(inp.shape[0], inp.shape[1], fc2_out.shape[-1]), ln_out_return

        return fc2_out.reshape(inp.shape[0], inp.shape[1], fc2_out.shape[-1])

    @staticmethod
    def backward(ctx, grad_output, *rest):
        if ctx.args.fp8:
            return _LayerNormMLP.fp8_backward(ctx, grad_output, rest)

        args = ctx.args
        inputmat, fc1_weight, fc1_bias, fc2_weight, fc2_bias, ln_weight, ln_bias, mm_inp, fc1_out, act_out, mu, rsigma = (
            restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
        )
        ctx.tensor_objects = None

        fsdp_group = getattr(ctx, "fsdp_group", None)
        if fsdp_group is not None:
            _fsdp_gather_tensors(
                fsdp_group,
                ctx.fsdp_shapes,
                mu,
                rsigma,
                mm_inp,
                fc1_out,
                act_out,
            )

        grad_output = grad_output.reshape(-1, grad_output.shape[-1])
        if args.return_layernorm_output and len(rest) > 0 and rest[0] is not None:
            grad_ln_out = rest[0].reshape(-1, rest[0].shape[-1])
        else:
            grad_ln_out = None

        tp_world_size = get_distributed_world_size(args.tp_group)

        if args.tensor_parallel and args.set_parallel_mode and args.sequence_parallel and not ctx.args.overlap_ag_dgrad:
            grad_output, _ = gather_along_dim(grad_output, args.tp_group)

        d_fc2 = grad_output
        fc2_dgrad = torch.matmul(d_fc2, fc2_weight)

        if args.wgrad_store is not None and args.wgrad_store.delay_wgrad_compute():
            fc2_bias_grad_delayed = None
            if fc2_bias is not None:
                fc2_bias_grad_delayed = d_fc2.sum(dim=0)
            def fc2_wgrad_fn(act, dy):
                if args.fuse_wgrad_accumulation and fc2_weight.main_grad.dtype == torch.float32:
                    matmul_add(fc2_weight.main_grad, act, dy)
                    return _LayerNormMLP._create_grad_weight_placeholder(fc2_weight), fc2_bias_grad_delayed
                return torch.matmul(dy.t(), act), fc2_bias_grad_delayed
            args.wgrad_store.put([act_out, d_fc2], fc2_wgrad_fn)
            fc2_wgrad = None
        else:
            if args.fuse_wgrad_accumulation and fc2_weight.main_grad.dtype == torch.float32:
                matmul_add(fc2_weight.main_grad, act_out, d_fc2)
                fc2_wgrad = _LayerNormMLP._create_grad_weight_placeholder(fc2_weight)
            else:
                fc2_wgrad = torch.matmul(d_fc2.t(), act_out)

        fc2_bias_grad = None
        if fc2_bias is not None:
            fc2_bias_grad = d_fc2.sum(dim=0)

        act_bwd_fn = ACTIVATION_BWD[args.activation]
        if args.activation == "clamped_swiglu" and args.activation_params:
            dact = act_bwd_fn(fc1_out, fc2_dgrad, *args.activation_params)
        else:
            dact = act_bwd_fn(fc1_out, fc2_dgrad)

        fc1_dgrad = torch.matmul(dact, fc1_weight)

        handle = None
        if args.tensor_parallel and args.set_parallel_mode and tp_world_size > 1:
            if args.sequence_parallel and not args.overlap_rs_dgrad:
                fc1_dgrad, handle = reduce_scatter_along_dim(
                    fc1_dgrad, args.tp_group, async_op=True
                )
            elif not args.sequence_parallel:
                handle = torch.distributed.all_reduce(fc1_dgrad, group=args.tp_group, async_op=True)

        if args.wgrad_store is not None and args.wgrad_store.delay_wgrad_compute():
            fc1_bias_grad_delayed = None
            if fc1_bias is not None:
                fc1_bias_grad_delayed = dact.sum(dim=0)
            def fc1_wgrad_fn(x, dy):
                if args.fuse_wgrad_accumulation and fc1_weight.main_grad.dtype == torch.float32:
                    matmul_add(fc1_weight.main_grad, x, dy)
                    return _LayerNormMLP._create_grad_weight_placeholder(fc1_weight), fc1_bias_grad_delayed
                return torch.matmul(dy.t(), x), fc1_bias_grad_delayed
            args.wgrad_store.put([mm_inp, dact], fc1_wgrad_fn)
            fc1_wgrad = None
        else:
            if args.fuse_wgrad_accumulation and fc1_weight.main_grad.dtype == torch.float32:
                matmul_add(fc1_weight.main_grad, mm_inp, dact)
                fc1_wgrad = _LayerNormMLP._create_grad_weight_placeholder(fc1_weight)
            else:
                fc1_wgrad = torch.matmul(dact.t(), mm_inp)

        fc1_bias_grad = None
        if fc1_bias is not None:
            fc1_bias_grad = dact.sum(dim=0)

        if grad_ln_out is not None:
            fc1_dgrad = fc1_dgrad + grad_ln_out

        dx, dgamma, dbeta = _norm_bwd(fc1_dgrad, inputmat, ln_weight, mu, rsigma, args.zero_centered_gamma, args.normalization)

        if handle is not None:
            handle.wait()

        return dgamma, dbeta, fc1_wgrad, fc1_bias_grad, fc2_wgrad, fc2_bias_grad, dx.reshape(ctx.inp_shape), None

    @staticmethod
    def fp8_forward(ctx, ln_weight, ln_bias, fc1_weight, fc1_bias, fc2_weight, fc2_bias, inp, inputmat, ln_out, mm_inp, mu, rsigma, args):
        backward_needs_fc1_input = args.is_grad_enabled and fc1_weight.requires_grad
        backward_needs_fc1_weight = args.is_grad_enabled and ln_out.requires_grad
        backward_needs_fc2_input = args.is_grad_enabled and fc2_weight.requires_grad
        backward_needs_fc2_weight = args.is_grad_enabled and True

        if is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase():
            args.fc1_input_quantizer.set_usage(columnwise=False)
            args.fc1_weight_quantizer.set_usage(columnwise=False)
            args.fc2_input_quantizer.set_usage(columnwise=False)
            args.fc2_weight_quantizer.set_usage(columnwise=False)
        else:
            if not backward_needs_fc1_input:
                args.fc1_input_quantizer.set_usage(columnwise=False)
            if not backward_needs_fc1_weight:
                args.fc1_weight_quantizer.set_usage(columnwise=False)
            if not backward_needs_fc2_input:
                args.fc2_input_quantizer.set_usage(columnwise=False)
            if not backward_needs_fc2_weight:
                args.fc2_weight_quantizer.set_usage(columnwise=False)

        if args.is_fsdp2:
            args.fc1_weight_quantizer.set_usage(columnwise=False)
            args.fc2_weight_quantizer.set_usage(columnwise=False)

        fc1_input = args.fc1_input_quantizer.quantize(mm_inp)

        update_workspace = args.is_first_microbatch is None or args.is_first_microbatch
        fc1_weightmat = args.module.get_weight_workspace(
            tensor=fc1_weight,
            quantizer=args.fc1_weight_quantizer,
            cache_name=(None if (args.is_first_microbatch is None or args.is_fsdp2) else "fc1_weight"),
            update_workspace=update_workspace,
            skip_update_flag=args.skip_fp8_weight_update,
            workspace_dtype=args.activation_dtype,
        )
        fc1_weightmat.update_usage(rowwise_usage=True)

        fc1_mm_kwargs = {
            "usage": TensorUsage.LHS,
            "usage_b": TensorUsage.RHS_TRANS,
            "out_dtype": args.activation_dtype,
        }

        if args.overlap_ag_fprop:
            fc1_out, fc1_input = fc1_input.allgather_matmul(
                fc1_weightmat, fc1_bias, get_distributed_world_size(args.tp_group), args.tp_group, **fc1_mm_kwargs
            )
        else:
            fc1_out = fc1_input.matmul(fc1_weightmat, **fc1_mm_kwargs)
            if fc1_bias is not None:
                fc1_out = fc1_out + fc1_bias

        fc1_input.clear_wise(rowwise=True)

        act_fn = ACTIVATION_FWD[args.activation]
        if args.activation == "clamped_swiglu" and args.activation_params:
            act_out = act_fn(fc1_out, *args.activation_params)
        else:
            act_out = act_fn(fc1_out)

        fc2_input = args.fc2_input_quantizer.quantize(act_out)

        fc2_weightmat = args.module.get_weight_workspace(
            tensor=fc2_weight,
            quantizer=args.fc2_weight_quantizer,
            cache_name=(None if (args.is_first_microbatch is None or args.is_fsdp2) else "fc2_weight"),
            update_workspace=update_workspace,
            skip_update_flag=args.skip_fp8_weight_update,
            workspace_dtype=args.activation_dtype,
        )
        fc2_weightmat.update_usage(rowwise_usage=True)

        fc2_mm_kwargs = {
            "usage": TensorUsage.LHS,
            "usage_b": TensorUsage.RHS_TRANS,
            "out_dtype": args.activation_dtype,
        }

        if args.overlap_rs_fprop:
            fc2_out = fc2_input.matmul_reduce_scatter(
                fc2_weightmat, fc2_bias, get_distributed_world_size(args.tp_group), args.tp_group, **fc2_mm_kwargs
            )
        else:
            fc2_out = fc2_input.matmul(fc2_weightmat, **fc2_mm_kwargs)
            if fc2_bias is not None:
                fc2_out = fc2_out + fc2_bias

        fc2_input.clear_wise(rowwise=True)

        tp_world_size = get_distributed_world_size(args.tp_group)
        if args.tensor_parallel and args.set_parallel_mode and not args.overlap_rs_fprop and tp_world_size > 1:
            if args.sequence_parallel:
                fc2_out, _ = CommOverlapOps.reduce_scatter(
                    fc2_out,
                    args.fc2_output_quantizer,
                    tp_world_size,
                    args.tp_group,
                    use_quant=args.fp8_output and args.module.fp8_meta["recipe"].mxfp8()
                )
            else:
                torch.distributed.all_reduce(fc2_out, group=args.tp_group)

        if args.is_grad_enabled:
            ctx.args = args
            ctx.inp_shape = inp.shape
            ctx.fp8 = args.fp8
            ctx.sequence_parallel = args.sequence_parallel
            ctx.requires_dgrad = inp.requires_grad
            ctx.requires_wgrad = fc1_weight.requires_grad or fc2_weight.requires_grad
            ctx.is_weight_param_quantized = isinstance(fc1_weight, QuantizedTensorStorage)

            ctx.fsdp_group = args.fsdp_group
            ctx.is_fsdp2 = args.is_fsdp2
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                args.fsdp_group,
                mu,
                rsigma,
                fc1_weightmat if not ctx.is_weight_param_quantized else None,
                fc2_weightmat if not ctx.is_weight_param_quantized else None,
                fc1_input,
                fc1_out,
                fc2_input,
                act_out,
            )

            # FSDP2: Don't save FP8 workspace so backward re-quantizes from all-gathered weight
            fc1_wt_save = fc1_weightmat
            if args.is_fsdp2 and fc1_weightmat is not fc1_weight:
                fc1_wt_save = None
            fc2_wt_save = fc2_weightmat
            if args.is_fsdp2 and fc2_weightmat is not fc2_weight:
                fc2_wt_save = None

            tensors_to_save, tensor_objects = prepare_for_saving(
                inputmat,
                fc1_wt_save,
                fc1_weight,
                fc1_bias,
                fc2_wt_save,
                fc2_weight,
                fc2_bias,
                ln_weight,
                ln_bias,
                fc1_input,
                fc1_out,
                fc2_input,
                act_out,
                mu,
                rsigma,
            )
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects

            if requires_grad(inp, fc1_weight, fc2_weight):
                qstate = FP8GlobalStateManager.quantization_state
                _first_fp8_module = qstate.is_first_fp8_module
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    qstate.is_first_fp8_module = _first_fp8_module

        if args.return_layernorm_output:
            return fc2_out.reshape(inp.shape[0], inp.shape[1], fc2_out.shape[-1]), ln_out.reshape(inp.shape)

        return fc2_out.reshape(inp.shape[0], inp.shape[1], fc2_out.shape[-1])

    @staticmethod
    def fp8_backward(ctx, grad_output, rest):
        args = ctx.args
        inputmat, fc1_weight_fp8, fc1_weight, fc1_bias, fc2_weight_fp8, fc2_weight, fc2_bias, ln_weight, ln_bias, fc1_input, fc1_out, fc2_input, act_out, mu, rsigma = (
            restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
        )
        ctx.tensor_objects = None

        is_fsdp2 = getattr(ctx, "is_fsdp2", False)
        fsdp_group = getattr(ctx, "fsdp_group", None)
        is_weight_param_quantized = getattr(ctx, "is_weight_param_quantized", False)
        if fsdp_group is not None:
            _fsdp_gather_tensors(
                fsdp_group,
                ctx.fsdp_shapes,
                mu,
                rsigma,
                fc1_weight_fp8 if not is_weight_param_quantized else None,
                fc2_weight_fp8 if not is_weight_param_quantized else None,
                fc1_input,
                fc1_out,
                fc2_input,
                act_out,
            )

        grad_output = grad_output.reshape(-1, grad_output.shape[-1])
        if args.return_layernorm_output and len(rest) > 0 and rest[0] is not None:
            grad_ln_out = rest[0].reshape(-1, rest[0].shape[-1])
        else:
            grad_ln_out = None

        tp_world_size = get_distributed_world_size(args.tp_group)

        if not ctx.requires_dgrad:
            args.fc2_grad_output_quantizer.set_usage(rowwise=False)
        if not ctx.requires_wgrad:
            args.fc2_grad_output_quantizer.set_usage(columnwise=False)

        if args.tensor_parallel and args.set_parallel_mode and args.sequence_parallel and not args.overlap_ag_dgrad:
            grad_output, _ = gather_along_dim(grad_output, args.tp_group)

        fc2_grad_output_quantizer = args.fc2_grad_output_quantizer
        fc2_grad_output = fc2_grad_output_quantizer.quantize(grad_output)

        if fc2_weight_fp8 is None and is_fsdp2:
            if isinstance(fc2_weight, QuantizedTensorStorage):
                fc2_weight_fp8 = fc2_weight
            else:
                fc2_weight_quantizer = args.fc2_weight_quantizer
                fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
                fc2_weight_fp8 = fc2_weight_quantizer(fc2_weight)

        fc2_dgrad_kwargs = {
            "usage": TensorUsage.LHS,
            "usage_b": TensorUsage.RHS,
            "out_dtype": args.activation_dtype,
        }

        if args.overlap_ag_dgrad:
            fc2_dgrad, fc2_grad_output = fc2_grad_output.allgather_matmul(
                fc2_weight_fp8, None, tp_world_size, args.tp_group, **fc2_dgrad_kwargs
            )
        elif args.overlap_rs_dgrad:
            fc2_dgrad = fc2_grad_output.matmul_reduce_scatter(
                fc2_weight_fp8, None, tp_world_size, args.tp_group, **fc2_dgrad_kwargs
            )
        else:
            fc2_dgrad = fc2_grad_output.matmul(fc2_weight_fp8, **fc2_dgrad_kwargs)

        # FSDP2: clear columnwise cache derived from FC2 all-gathered weight
        if is_fsdp2 and isinstance(fc2_weight_fp8, QuantizedTensorStorage):
            clear_columnwise_cache(fc2_weight_fp8)

        if args.is_first_microbatch is not None:
            accumulate_fc2_wgrad = (
                args.fuse_wgrad_accumulation
                and not args.is_first_microbatch
                and args.module.fp8_meta["recipe"].mxfp8()
            )
        else:
            accumulate_fc2_wgrad = (
                args.fuse_wgrad_accumulation
                and args.module.fp8_meta["recipe"].mxfp8()
            )

        if accumulate_fc2_wgrad:
            fc2_out_dtype = fc2_weight.main_grad.dtype
        else:
            fc2_out_dtype = args.activation_dtype

        fc2_wgrad_kwargs = {
            "usage": TensorUsage.LHS_TRANS,
            "usage_b": TensorUsage.RHS,
            "out_dtype": fc2_out_dtype,
        }

        if args.wgrad_store is not None and args.wgrad_store.delay_wgrad_compute():
            fc2_bias_grad_delayed = None
            if fc2_bias is not None:
                fc2_bias_grad_delayed = grad_output.sum(dim=0)
            def fc2_wgrad_fn(x, dy):
                if accumulate_fc2_wgrad:
                    dy.matmul_add(fc2_weight.main_grad, x, **fc2_wgrad_kwargs)
                    return _LayerNormMLP._create_grad_weight_placeholder(fc2_weight), fc2_bias_grad_delayed
                return dy.matmul(x, **fc2_wgrad_kwargs), fc2_bias_grad_delayed
            args.wgrad_store.put([fc2_input, fc2_grad_output], fc2_wgrad_fn)
            fc2_wgrad = None
        else:
            if accumulate_fc2_wgrad:
                fc2_grad_output.matmul_add(fc2_weight.main_grad, fc2_input, **fc2_wgrad_kwargs)
                fc2_wgrad = _LayerNormMLP._create_grad_weight_placeholder(fc2_weight)
            else:
                fc2_wgrad = fc2_grad_output.matmul(fc2_input, **fc2_wgrad_kwargs)

        fc2_bias_grad = None
        if fc2_bias is not None:
            fc2_bias_grad = grad_output.sum(dim=0)

        act_bwd_fn = ACTIVATION_BWD[args.activation]
        if args.activation == "clamped_swiglu" and args.activation_params:
            dact = act_bwd_fn(fc1_out, fc2_dgrad, *args.activation_params)
        else:
            dact = act_bwd_fn(fc1_out, fc2_dgrad)

        if not ctx.requires_dgrad:
            args.fc1_grad_output_quantizer.set_usage(rowwise=False)
        if not ctx.requires_wgrad:
            args.fc1_grad_output_quantizer.set_usage(columnwise=False)

        fc1_grad_output = args.fc1_grad_output_quantizer.quantize(dact)

        if fc1_weight_fp8 is None and is_fsdp2:
            if isinstance(fc1_weight, QuantizedTensorStorage):
                fc1_weight_fp8 = fc1_weight
            else:
                fc1_weight_quantizer = args.fc1_weight_quantizer
                fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
                fc1_weight_fp8 = fc1_weight_quantizer(fc1_weight)

        fc1_dgrad_kwargs = {
            "usage": TensorUsage.LHS,
            "usage_b": TensorUsage.RHS,
            "out_dtype": args.activation_dtype,
        }

        if args.overlap_rs_dgrad:
            fc1_dgrad = fc1_grad_output.matmul_reduce_scatter(
                fc1_weight_fp8, None, tp_world_size, args.tp_group, **fc1_dgrad_kwargs
            )
        else:
            fc1_dgrad = fc1_grad_output.matmul(fc1_weight_fp8, **fc1_dgrad_kwargs)

        # FSDP2: clear columnwise cache derived from FC1 all-gathered weight
        if is_fsdp2 and isinstance(fc1_weight_fp8, QuantizedTensorStorage):
            clear_columnwise_cache(fc1_weight_fp8)

        handle = None
        if args.tensor_parallel and args.set_parallel_mode and tp_world_size > 1:
            if args.sequence_parallel and not args.overlap_rs_dgrad:
                fc1_dgrad, handle = reduce_scatter_along_dim(
                    fc1_dgrad, args.tp_group, async_op=True
                )
            elif not args.sequence_parallel:
                handle = torch.distributed.all_reduce(fc1_dgrad, group=args.tp_group, async_op=True)

        if args.is_first_microbatch is not None:
            accumulate_fc1_wgrad = (
                args.fuse_wgrad_accumulation
                and not args.is_first_microbatch
                and args.module.fp8_meta["recipe"].mxfp8()
            )
        else:
            accumulate_fc1_wgrad = (
                args.fuse_wgrad_accumulation
                and args.module.fp8_meta["recipe"].mxfp8()
            )

        if accumulate_fc1_wgrad:
            fc1_out_dtype = fc1_weight.main_grad.dtype
        else:
            fc1_out_dtype = args.activation_dtype

        fc1_wgrad_kwargs = {
            "usage": TensorUsage.LHS_TRANS,
            "usage_b": TensorUsage.RHS,
            "out_dtype": fc1_out_dtype,
        }

        if args.wgrad_store is not None and args.wgrad_store.delay_wgrad_compute():
            fc1_bias_grad_delayed = None
            if fc1_bias is not None:
                fc1_bias_grad_delayed = dact.sum(dim=0)
            def fc1_wgrad_fn(x, dy):
                if accumulate_fc1_wgrad:
                    dy.matmul_add(fc1_weight.main_grad, x, **fc1_wgrad_kwargs)
                    return _LayerNormMLP._create_grad_weight_placeholder(fc1_weight), fc1_bias_grad_delayed
                return dy.matmul(x, **fc1_wgrad_kwargs), fc1_bias_grad_delayed
            args.wgrad_store.put([fc1_input, fc1_grad_output], fc1_wgrad_fn)
            fc1_wgrad = None
        else:
            if accumulate_fc1_wgrad:
                fc1_grad_output.matmul_add(fc1_weight.main_grad, fc1_input, **fc1_wgrad_kwargs)
                fc1_wgrad = _LayerNormMLP._create_grad_weight_placeholder(fc1_weight)
            else:
                fc1_wgrad = fc1_grad_output.matmul(fc1_input, **fc1_wgrad_kwargs)

        fc1_bias_grad = None
        if fc1_bias is not None:
            fc1_bias_grad = dact.sum(dim=0)

        if grad_ln_out is not None:
            fc1_dgrad = fc1_dgrad + grad_ln_out

        dx, dgamma, dbeta = _norm_bwd(fc1_dgrad, inputmat, ln_weight, mu, rsigma, args.zero_centered_gamma, args.normalization)

        if ctx.reduce_and_update_bwd_fp8_tensors:
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

        if handle is not None:
            handle.wait()

        return dgamma, dbeta, fc1_wgrad, fc1_bias_grad, fc2_wgrad, fc2_bias_grad, dx.reshape(ctx.inp_shape), None

    @staticmethod
    def _create_grad_weight_placeholder(weight):
        if not hasattr(weight, "grad_added_to_main_grad"):
            return None
        if getattr(weight, "zero_out_wgrad", False):
            grad_weight = torch.zeros(
                weight.main_grad.shape,
                dtype=weight.dtype,
                device=torch.npu.current_device(),
                requires_grad=False,
            )
        else:
            grad_weight = torch.empty(
                weight.main_grad.shape,
                dtype=weight.dtype,
                device=torch.npu.current_device(),
                requires_grad=False,
            )
        weight.grad_added_to_main_grad = True
        return grad_weight


class LayerNormMLP(TransformerEngineBaseModule):
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        return_layernorm_output: bool = False,
        return_layernorm_output_gathered: bool = True,
        zero_centered_gamma: bool = False,
        normalization: str = "LayerNorm",
        activation: str = "gelu",
        activation_params: Optional[tuple] = None,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        rng_tracker_name: Optional[str] = None,
        bias: bool = True,
        params_dtype: Optional[torch.dtype] = None,
        return_bias: bool = False,
        device: Union[torch.device, str] = "npu",
        ub_overlap_ag: bool = False,
        ub_overlap_rs: bool = False,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
        ub_name: Optional[str] = None,
        delay_wgrad_compute: bool = False,
        symmetric_ar_type: Optional[str] = None,
        name: Optional[str] = None,
        normalization_eps: Optional[float] = None,
    ):
        super(LayerNormMLP, self).__init__(name)

        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
        self.get_rng_state_tracker = get_rng_state_tracker
        self.rng_tracker_name = rng_tracker_name
        self.return_layernorm_output = return_layernorm_output
        self.return_layernorm_output_gathered = return_layernorm_output_gathered
        self.zero_centered_gamma = zero_centered_gamma
        self.normalization = normalization
        self.activation = activation
        self.activation_params = activation_params or ()
        self.symmetric_ar_type = symmetric_ar_type

        if normalization_eps is not None:
            self.eps = normalization_eps
        else:
            self.eps = eps

        self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.tensor_parallel = self.tp_size > 1
        self.set_parallel_mode = self.tensor_parallel

        size_per_partition = divide(self.ffn_hidden_size, self.tp_size) if self.tensor_parallel else self.ffn_hidden_size

        if activation in GLU_VARIANTS:
            fc1_out_features = 2 * size_per_partition
        else:
            fc1_out_features = size_per_partition

        self.overlap_ag_fprop = self.set_parallel_mode and self.sequence_parallel and ub_overlap_ag
        self.overlap_rs_dgrad = self.set_parallel_mode and self.sequence_parallel and ub_overlap_rs_dgrad
        self.overlap_rs_fprop = self.set_parallel_mode and self.sequence_parallel and ub_overlap_rs
        self.overlap_ag_dgrad = self.set_parallel_mode and self.sequence_parallel and ub_overlap_ag

        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

        ln_weight_tensor = torch.ones(hidden_size, device=device, dtype=params_dtype)
        self.register_parameter(
            "ln_weight",
            torch.nn.Parameter(ln_weight_tensor),
            init_fn=init_method_constant(1.0),
        )
        if zero_centered_gamma:
            torch.nn.init.zeros_(self.ln_weight)

        if normalization != "RMSNorm":
            ln_bias_tensor = torch.zeros(hidden_size, device=device, dtype=params_dtype)
            self.register_parameter(
                "layer_norm_bias",
                torch.nn.Parameter(ln_bias_tensor),
                init_fn=init_method_constant(0.0),
            )
        else:
            self.layer_norm_bias = None

        fc1_weight_tensor = torch.empty(
            fc1_out_features, hidden_size, device=device, dtype=params_dtype
        )
        self.register_parameter(
            "fc1_weight",
            torch.nn.Parameter(fc1_weight_tensor),
            init_fn=init_method or get_default_init_method(),
            get_rng_state_tracker=get_rng_state_tracker,
            fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT,
        )

        fc1_bias_tensor = None
        if self.use_bias:
            fc1_bias_tensor = torch.empty(
                fc1_out_features, device=device, dtype=params_dtype
            )
            self.register_parameter(
                "fc1_bias",
                torch.nn.Parameter(fc1_bias_tensor),
                init_fn=init_method_constant(0.0),
            )
        else:
            self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)

        fc2_weight_tensor = torch.empty(
            hidden_size, size_per_partition, device=device, dtype=params_dtype
        )
        self.register_parameter(
            "fc2_weight",
            torch.nn.Parameter(fc2_weight_tensor),
            init_fn=output_layer_init_method or init_method or get_default_init_method(),
            get_rng_state_tracker=get_rng_state_tracker,
            fp8_meta_index=FP8FwdTensorIdx.GEMM2_WEIGHT,
        )

        fc2_bias_tensor = None
        if self.use_bias:
            fc2_bias_tensor = torch.empty(
                hidden_size, device=device, dtype=params_dtype
            )
            self.register_parameter(
                "fc2_bias",
                torch.nn.Parameter(fc2_bias_tensor),
                init_fn=init_method_constant(0.0),
            )
        else:
            self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device)

        if with_fp8_params:
            self.init_fp8_metadata(num_gemms=2)

        self.reset_parameters(defer_init=device == "meta")

        if self.wgrad_store.delay_wgrad_compute():
            for name, param in self.named_parameters():
                if name in ("fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias"):
                    param.skip_backward_post_hook = True

    @no_torch_dynamo()
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
        fp8_output: Optional[bool] = False,
        fp8_grad: Optional[bool] = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        is_grad_enabled = torch.is_grad_enabled()

        inp = self.prepare_forward(inp, num_gemms=2)

        skip_fp8_weight_update = None

        if is_grad_enabled:
            ln_mlp_fn = _LayerNormMLP.apply
            autograd_ctx = []
        else:
            ln_mlp_fn = _LayerNormMLP.forward
            autograd_ctx = [None]

        quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
        (
            fc1_input_quantizer,
            fc1_weight_quantizer,
            fc1_output_quantizer,
            fc1_grad_input_quantizer,
            fc1_grad_weight_quantizer,
            fc1_grad_output_quantizer,
            fc2_input_quantizer,
            fc2_weight_quantizer,
            fc2_output_quantizer,
            fc2_grad_input_quantizer,
            fc2_grad_weight_quantizer,
            fc2_grad_output_quantizer,
        ) = quantizers

        non_tensor_args = _LayerNormMLPNonTensorArgs(
            is_first_microbatch=is_first_microbatch,
            fp8=self.fp8,
            wgrad_store=self.wgrad_store,
            eps=self.eps,
            fc1_input_quantizer=fc1_input_quantizer,
            fc1_weight_quantizer=fc1_weight_quantizer,
            fc1_output_quantizer=fc1_output_quantizer,
            fc1_grad_input_quantizer=fc1_grad_input_quantizer,
            fc1_grad_weight_quantizer=fc1_grad_weight_quantizer,
            fc1_grad_output_quantizer=fc1_grad_output_quantizer,
            fc2_input_quantizer=fc2_input_quantizer,
            fc2_weight_quantizer=fc2_weight_quantizer,
            fc2_output_quantizer=fc2_output_quantizer,
            fc2_grad_input_quantizer=fc2_grad_input_quantizer,
            fc2_grad_weight_quantizer=fc2_grad_weight_quantizer,
            fc2_grad_output_quantizer=fc2_grad_output_quantizer,
            fuse_wgrad_accumulation=self.fuse_wgrad_accumulation,
            tp_group=self.tp_group,
            tp_size=self.tp_size,
            sequence_parallel=self.sequence_parallel,
            activation_dtype=self.activation_dtype,
            tensor_parallel=self.tensor_parallel,
            set_parallel_mode=self.set_parallel_mode,
            is_grad_enabled=is_grad_enabled,
            return_layernorm_output=self.return_layernorm_output,
            return_layernorm_output_gathered=self.return_layernorm_output_gathered,
            zero_centered_gamma=self.zero_centered_gamma,
            normalization=self.normalization,
            activation=self.activation,
            activation_params=self.activation_params,
            overlap_ag_fprop=self.overlap_ag_fprop,
            overlap_rs_fprop=self.overlap_rs_fprop,
            overlap_ag_dgrad=self.overlap_ag_dgrad,
            overlap_rs_dgrad=self.overlap_rs_dgrad,
            module=self,
            skip_fp8_weight_update=skip_fp8_weight_update,
            fp8_output=fp8_output,
            fsdp_group=self.fsdp_group,
            is_fsdp2=self.is_fsdp2,
        )

        try:
            out = ln_mlp_fn(
                *autograd_ctx,
                self.ln_weight,
                self.layer_norm_bias,
                self.fc1_weight,
                self.fc1_bias if self.apply_bias else None,
                self.fc2_weight,
                self.fc2_bias if self.apply_bias else None,
                inp,
                non_tensor_args,
            )
        finally:
            self.end_forward()

        if self.return_layernorm_output:
            return out
        return out

    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            setattr(self.ln_weight, "sequence_parallel", self.sequence_parallel)
            if self.layer_norm_bias is not None:
                setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)

            set_tensor_model_parallel_attributes(
                tensor=self.fc1_weight,
                is_parallel=True,
                dim=0,
                stride=1,
            )
            set_tensor_model_parallel_attributes(
                tensor=self.fc2_weight,
                is_parallel=True,
                dim=1,
                stride=1,
            )

            if self.use_bias:
                set_tensor_model_parallel_attributes(
                    self.fc1_bias, True, 0, 1
                )
                setattr(self.fc2_bias, "sequence_parallel", self.sequence_parallel)

    def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
        if not self.fp8:
            return [None] * 12

        fc1_input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
        fc1_input_quantizer.internal = True
        if not (self.set_parallel_mode and self.sequence_parallel):
            fc1_input_quantizer.optimize_for_gemm = True

        fc2_input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_INPUT]
        fc2_input_quantizer.internal = True
        if not (self.set_parallel_mode and self.sequence_parallel):
            fc2_input_quantizer.optimize_for_gemm = True

        fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers()

        fc1_output_quantizer = None
        fc2_output_quantizer = None
        if fp8_output:
            fc1_output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
            fc2_output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_OUTPUT]

        fc1_grad_input_quantizer = None
        fc1_grad_weight_quantizer = None
        fc1_grad_output_quantizer = None
        fc2_grad_input_quantizer = None
        fc2_grad_weight_quantizer = None
        fc2_grad_output_quantizer = None

        if is_grad_enabled:
            fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
            fc1_grad_output_quantizer.internal = True
            if not (self.set_parallel_mode and self.sequence_parallel):
                fc1_grad_output_quantizer.optimize_for_gemm = True

            fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT2]
            fc2_grad_output_quantizer.internal = True
            if not (self.set_parallel_mode and self.sequence_parallel):
                fc2_grad_output_quantizer.optimize_for_gemm = True

            if fp8_grad:
                fc1_grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
                fc2_grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT2]

        return (
            fc1_input_quantizer,
            fc1_weight_quantizer,
            fc1_output_quantizer,
            fc1_grad_input_quantizer,
            fc1_grad_weight_quantizer,
            fc1_grad_output_quantizer,
            fc2_input_quantizer,
            fc2_weight_quantizer,
            fc2_output_quantizer,
            fc2_grad_input_quantizer,
            fc2_grad_weight_quantizer,
            fc2_grad_output_quantizer,
        )

    def _get_weight_quantizers(self):
        if not self.fp8:
            return [None, None]
        fc1_weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
        fc1_weight_quantizer.internal = True
        fc2_weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_WEIGHT]
        fc2_weight_quantizer.internal = True
        return [fc1_weight_quantizer, fc2_weight_quantizer]

    def _get_weight_tensors(self):
        unfused_weights = [self.fc1_weight, self.fc2_weight]
        from ..quantized_tensor import QuantizedTensor
        if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
            if self.fp8:
                pass
            else:
                import warnings
                warnings.warn(
                    "You are using quantized weights without quantized compute. "
                    "Please make sure this is intentional."
                )
                unfused_weights = [w.dequantize() if isinstance(w, QuantizedTensor) else w for w in unfused_weights]
        return unfused_weights

    def need_backward_dw(self):
        return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()

    def backward_dw(self):
        if not self.need_backward_dw():
            return
        (fc2_wgrad, fc2_bias_grad, *_), tensor_list_fc2 = self.wgrad_store.pop()
        (fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop()
        if self.use_bias:
            if self.fc2_bias.grad is None and fc2_bias_grad is not None:
                self.fc2_bias.grad = fc2_bias_grad.to(self.fc2_bias.dtype)
            if self.fc1_bias.grad is None and fc1_bias_grad is not None:
                self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype)
        if not self.fuse_wgrad_accumulation:
            self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype)
            self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
        del fc2_wgrad, fc2_bias_grad, fc1_wgrad, fc1_bias_grad