# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# See LICENSE for license information.

"""Fusible operation for linear layer without bias."""

from __future__ import annotations
from collections.abc import Callable, Iterable
import contextlib
import math
from typing import Any, Optional

import torch

from ...constants import TensorUsage
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import (
    CudaRNGStatesTracker,
    gather_along_first_dim,
    reduce_scatter_along_first_dim,
)
from ...quantization import FP8GlobalStateManager, Recipe
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from ...utils import (
    canonicalize_device,
    canonicalize_dtype,
    clear_tensor_data,
    devices_match,
)
from ..gemm import general_gemm
from ..gemm import general_gemm_add
from ..op import BasicOperation, OperationContext
from .._common import (
    get_accumulate_flag_in_param,
    get_dummy_wgrads_for_params,
    get_main_grad_from_param,
    is_quantized_tensor,
    maybe_dequantize,
)


def _wait_async(handle: Optional[Any]) -> None:
    """Wait for asynchronous communication to finish, if needed"""
    if handle is not None:
        handle.wait()


def _apply_gemm_options(
    result: torch.Tensor,
    *,
    alpha: Optional[float],
    beta: Optional[float],
    accumulate: bool,
    out: Optional[torch.Tensor],
) -> torch.Tensor:
    """Apply the small subset of TE GEMM output options used by BasicLinear."""
    if alpha not in (None, 1.0):
        result = result * alpha
    if out is None:
        return result
    if accumulate:
        if beta not in (None, 1.0):
            out.mul_(beta)
        out.add_(result)
    else:
        out.copy_(result)
    return out


class BasicLinear(BasicOperation):
    """Apply linear transformation: :math:`y = x A^T`."""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
        tensor_parallel_mode: Optional[str] = None,
        tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
        sequence_parallel: bool = False,
        rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
        accumulate_into_main_grad: bool = False,
        userbuffers_options: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()

        # Weight tensor dimensions
        self.in_features: int = in_features
        self.out_features: int = out_features

        # Weight tensor attributes
        device = canonicalize_device(device)
        dtype = canonicalize_dtype(dtype)
        if dtype not in (torch.float32, torch.float16, torch.bfloat16):
            raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")

        # Tensor parallel configuration
        self.tensor_parallel_mode: Optional[str]
        self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup]
        self.tensor_parallel_size: int
        self.sequence_parallel: bool
        self.local_in_features: int
        self.local_out_features: int
        (
            self.tensor_parallel_mode,
            self.tensor_parallel_group,
            self.tensor_parallel_size,
            self.sequence_parallel,
            self.local_in_features,
            self.local_out_features,
        ) = self._canonicalize_tensor_parallelism(
            mode=tensor_parallel_mode,
            process_group=tensor_parallel_group,
            sequence_parallel=sequence_parallel,
            in_features=in_features,
            out_features=out_features,
        )

        # Initialize recipe state if needed for natively quantized weight
        self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters()
        if self._with_quantized_weight:
            self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())

        # Initialize parameters if needed
        weight = torch.empty(
            self.local_out_features,
            self.local_in_features,
            device=device,
            dtype=dtype,
        )
        weight = torch.nn.Parameter(weight)
        self.weight: torch.nn.Parameter
        self.register_parameter("weight", weight)
        self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]]
        self._rng_state_tracker_function = rng_state_tracker_function
        if weight.device.type != "meta":
            self.reset_parameters()

        # Whether to accumulate weight gradient into main_grad
        self._accumulate_into_main_grad: bool = accumulate_into_main_grad

        # Userbuffers options
        self._userbuffers_options: Optional[dict[str, Any]] = userbuffers_options

    @classmethod
    def _canonicalize_tensor_parallelism(
        cls,
        *,
        mode: Optional[str],
        process_group: Optional[torch.distributed.ProcessGroup],
        sequence_parallel: bool,
        in_features: int,
        out_features: int,
    ) -> tuple[
        Optional[str],
        Optional[torch.distributed.ProcessGroup],
        int,
        bool,
        int,
        int,
    ]:
        """Check configuration for tensor parallelism."""

        # Tensor-parallel group size
        if mode is None:
            group_size = 1
        else:
            group_size = torch.distributed.get_world_size(process_group)

        # Disable tensor parallelism if not needed
        if group_size == 1:
            mode = None
            process_group = None
            sequence_parallel = False

        # Determine local tensor dims
        local_in_features = in_features
        local_out_features = out_features
        if mode is None:
            pass
        elif mode == "column":
            # Distribute output tensor
            if out_features % group_size != 0:
                raise ValueError(
                    "Invalid configuration for tensor parallelism "
                    f"({mode=}, {out_features=}, {group_size=})"
                )
            local_out_features //= group_size
        elif mode == "row":
            # Distribute input tensor
            if in_features % group_size != 0:
                raise ValueError(
                    "Invalid configuration for tensor parallelism "
                    f"({mode=}, {in_features=}, {group_size=})"
                )
            local_in_features //= group_size
        else:
            raise ValueError(
                "Supported modes for tensor parallelism are "
                f'`None`, "row", and "column" (got {mode=})'
            )

        return (
            mode,
            process_group,
            group_size,
            sequence_parallel,
            local_in_features,
            local_out_features,
        )

    def num_quantizers(self, mode: str) -> int:
        if mode == "forward":
            return 2
        if mode == "backward":
            return 1
        return 0

    def reset_parameters(self) -> None:
        """Initialize parameter buffers and values"""

        # Parameter device
        weight = self.weight
        device = weight.device
        if device.type == "meta":
            device = canonicalize_device(None)

        # Allocate buffer if needed
        if is_quantized_tensor(weight):
            weight = torch.empty(
                weight.size(),
                dtype=weight.dtype,
                device=device,
            )
        elif not devices_match(weight.device, device):
            weight = torch.empty_like(weight, device=device)

        # Initialize values
        init_context = contextlib.nullcontext()
        if self._rng_state_tracker_function is not None:
            init_context = self._rng_state_tracker_function().fork()
        with init_context:
            torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))

        # Quantize if needed
        if self._with_quantized_weight:
            quantizer = self.get_quantizer("forward", 1)
            if quantizer is None:
                raise RuntimeError(
                    "Tried to quantize weight with deferred initialization "
                    "due to meta device, but no quantizer was available. "
                    "This is most likely because the weight was initialized "
                    "within quantized_model_init, but the forward pass was not "
                    "performed within autocast."
                )
            quantizer.set_usage(
                rowwise=True,
                columnwise=torch.is_grad_enabled(),
            )
            quantizer.internal = False
            with torch.no_grad():
                weight = quantizer(weight)

        # Save updated parameter
        if not isinstance(weight, torch.nn.Parameter):
            weight = torch.nn.Parameter(weight)
        self.weight = weight

    def pre_first_fuser_forward(self) -> None:
        super().pre_first_fuser_forward()
        if self.weight.device.type == "meta":
            self.reset_parameters()

    def pre_fuser_forward(self, *, requires_grad: bool) -> None:
        super().pre_fuser_forward(requires_grad=requires_grad)
        if FP8GlobalStateManager.is_fp8_enabled():
            # Configure quantizer usages
            weight_requires_grad = requires_grad and self.weight.requires_grad
            columnwise_usage = weight_requires_grad
            if FP8GlobalStateManager.get_fp8_recipe().backward_override is not None:
                columnwise_usage = False
            input_quantizer = self.get_quantizer("forward", 0)
            weight_quantizer = self.get_quantizer("forward", 1)
            grad_output_quantizer = self.get_quantizer("backward", 0)
            input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
            weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad)
            grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)

    def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
        super().reset_recipe_state(recipe=recipe)

        # Configure input/grad output quantizers
        # Note: These tensors are only used internally. If there is no
        # tensor-parallel communication, they are only used for GEMM.
        input_quantizer = self.get_quantizer("forward", 0)
        grad_output_quantizer = self.get_quantizer("backward", 0)
        if input_quantizer is not None:
            input_quantizer.internal = True
            if not (self.tensor_parallel_mode == "column" and self.sequence_parallel):
                input_quantizer.optimize_for_gemm = True
        if grad_output_quantizer is not None:
            grad_output_quantizer.internal = True
            if not (self.tensor_parallel_mode == "row" and self.sequence_parallel):
                grad_output_quantizer.optimize_for_gemm = True

        # Configure weight quantizer
        # Note: This function may be called in base class constructor,
        # before basic linear attrs have been set.
        weight_quantizer = self.get_quantizer("forward", 1)
        weight = getattr(self, "weight", None)
        if weight_quantizer is not None:
            # Determine if quantized weight is exposed as parameter
            weight_quantizer.internal = not (
                FP8GlobalStateManager.with_fp8_parameters()
                or getattr(self, "_with_quantized_weight", False)
                or is_quantized_tensor(weight)
            )

        # Recipe-specific configuration
        # Note: This function may be called in base class constructor,
        # before any basic linear attrs have been set.
        if recipe is not None:
            if recipe.float8_current_scaling():
                input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
                input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
                weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
                weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon
                grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
                grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon
                if getattr(self, "sequence_parallel", False):
                    tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
                    if tensor_parallel_mode == "column":
                        input_quantizer.with_amax_reduction = True
                        input_quantizer.amax_reduction_group = self.tensor_parallel_group
                    elif tensor_parallel_mode == "row":
                        grad_output_quantizer.with_amax_reduction = True
                        grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group

        # Update quantizer in quantized weight tensor
        if weight_quantizer is not None and is_quantized_tensor(weight):
            if weight._quantizer is not None:
                # Preserve existing usages in weight tensor. Even if a
                # usage is currently unnecessary, the weight tensor
                # may be used elsewhere.
                weight_quantizer.set_usage(
                    rowwise=weight._quantizer.rowwise_usage,
                    columnwise=weight._quantizer.columnwise_usage,
                )
            weight.update_quantizer(weight_quantizer.copy())

    @staticmethod
    def _functional_forward(
        input: torch.Tensor,  # pylint: disable=redefined-builtin
        weight: torch.Tensor,
        *,
        alpha: float = 1.0,
        bias: Optional[torch.Tensor] = None,
        device: Optional[torch.device] = None,  # pylint: disable=unused-argument
        dtype: Optional[torch.dtype] = None,
        out: Optional[torch.Tensor] = None,
        beta: Optional[float] = None,
        accumulate_into_out: bool = False,
        tensor_parallel_mode: Optional[str] = None,
        tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
        sequence_parallel: bool = False,
        with_quantized_compute: bool = False,
        backward_override: Optional[str] = None,
        input_quantizer: Optional[Quantizer] = None,
        weight_quantizer: Optional[Quantizer] = None,
        output_quantizer: Optional[Quantizer] = None,
        input_requires_grad: bool = True,
        weight_requires_grad: bool = True,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        """Functional API for forward pass."""

        # Check datatype
        if dtype is None:
            if out is not None and isinstance(out, torch.Tensor):
                dtype = out.dtype
            elif weight is not None and isinstance(weight, torch.Tensor):
                dtype = weight.dtype
            else:
                raise ValueError(
                    "Could not infer dtype from weight nor out and dtype was not provided"
                )
        if dtype not in (torch.float32, torch.float16, torch.bfloat16):
            raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
        if out is not None and out.dtype != dtype:
            raise ValueError(f"Output tensor has invalid dtype (expected {dtype}, got {out.dtype})")

        # Check input tensor
        x_local = input
        x = None
        x_async = None
        with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel
        if with_quantized_compute:
            if input_quantizer is None:
                raise ValueError("Missing quantizer for input tensor")
            input_quantizer.set_usage(
                rowwise=True,
                columnwise=weight_requires_grad and backward_override is None,
            )
            if with_x_all_gather:
                input_quantizer.set_usage(columnwise=False)
                x, x_async = gather_along_first_dim(
                    x_local,
                    tensor_parallel_group,
                    async_op=True,
                    quantizer=input_quantizer,
                )
            else:
                if not is_quantized_tensor(x_local):
                    x_local = input_quantizer(x_local)
                x = x_local
        else:
            x_local = maybe_dequantize(x_local, dtype)

            if with_x_all_gather:
                x, x_async = gather_along_first_dim(
                    x_local,
                    tensor_parallel_group,
                    async_op=True,
                )
            else:
                x = x_local

        # Check weight tensor
        w = weight
        if not with_quantized_compute:
            w = maybe_dequantize(w, dtype)
        elif with_quantized_compute and not is_quantized_tensor(w):
            if weight_quantizer is None:
                raise ValueError("Missing quantizer for weight tensor")
            weight_quantizer.set_usage(
                rowwise=True,
                columnwise=input_requires_grad and backward_override is None,
            )
            w = weight_quantizer(w)

        # Check output tensor
        y = out
        if y is None:
            if not with_quantized_compute:
                output_quantizer = None
            if tensor_parallel_mode == "row":
                output_quantizer = None
        elif is_quantized_tensor(y):
            if not with_quantized_compute:
                raise ValueError("Output tensor is quantized, but quantized compute is not enabled")
            if tensor_parallel_mode == "row":
                raise ValueError(
                    "Output tensor is quantized, "
                    "but row tensor parallelism does not support quantized output"
                )
            if output_quantizer is None:
                output_quantizer = getattr(y, "_quantizer", None)
            if output_quantizer is None:
                raise ValueError("Output tensor is quantized, but quantizer was not provided")
        else:
            output_quantizer = None
        if output_quantizer is not None:
            if not isinstance(output_quantizer, Float8Quantizer):
                raise RuntimeError(
                    "Attempting to generate quantized output tensor with unsupported quantizer"
                )
            output_quantizer.set_usage(rowwise=True, columnwise=False)

        # Check if accumulating into output tensor
        if accumulate_into_out:
            if y is None:
                raise ValueError(
                    "Attempted to accumulate into output tensor without providing output tensor"
                )
            if tensor_parallel_mode == "row":
                raise ValueError(
                    "Accumulating into output tensor is not supported with row tensor parallelism"
                )

        # Synchronize communication for input
        _wait_async(x_async)
        x_async = None

        # Perform GEMM
        y = general_gemm(
            x,
            w,
            usage_a=TensorUsage.LHS,
            usage_b=TensorUsage.RHS_TRANS,
            out_dtype=dtype,
            bias=bias,
        )
        y = _apply_gemm_options(
            y,
            alpha=alpha,
            beta=beta,
            accumulate=accumulate_into_out,
            out=out,
        )

        # Reduce tensor-parallel output if needed
        if tensor_parallel_mode == "row":
            if sequence_parallel:
                y, _ = reduce_scatter_along_first_dim(y, tensor_parallel_group)
            else:
                torch.distributed.all_reduce(y, group=tensor_parallel_group)

        # Prepare weight tensor for backward pass
        if input_requires_grad:
            if (
                w is not weight
                and with_quantized_compute
                and is_quantized_tensor(w)
                and backward_override is None
            ):
                w.update_usage(rowwise_usage=False, columnwise_usage=True)
        else:
            w = None

        # Prepare input tensor for backward pass
        if weight_requires_grad:
            if (
                with_quantized_compute
                and is_quantized_tensor(x_local)
                and backward_override is None
            ):
                if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather):
                    # FP8 does not support all-gather of transpose data
                    x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
        else:
            x_local = None

        return y, x_local, w

    @staticmethod
    def _functional_backward(
        grad_output: torch.Tensor,
        input: Optional[torch.Tensor],  # pylint: disable=redefined-builtin
        weight: Optional[torch.Tensor],
        *,
        grad_input_alpha: Optional[float] = None,
        input_requires_grad: bool = True,
        grad_weight_alpha: Optional[float] = None,
        weight_requires_grad: bool = True,
        device: Optional[torch.device] = None,  # pylint: disable=unused-argument
        dtype: Optional[torch.dtype] = None,
        grad_weight: Optional[torch.Tensor] = None,
        grad_weight_beta: Optional[float] = None,
        accumulate_into_grad_weight: bool = False,
        grad_input: Optional[torch.Tensor] = None,
        grad_input_beta: Optional[float] = None,
        accumulate_into_grad_input: bool = False,
        tensor_parallel_mode: Optional[str] = None,
        tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
        sequence_parallel: bool = False,
        with_quantized_compute: bool = False,
        input_quantizer: Optional[Quantizer] = None,
        weight_quantizer: Optional[Quantizer] = None,
        grad_output_quantizer: Optional[Quantizer] = None,
        grad_input_quantizer: Optional[Quantizer] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Functional API for backward pass."""

        # Check datatype
        if dtype is None:
            if isinstance(weight, torch.Tensor):
                dtype = weight.dtype
            elif isinstance(grad_output, torch.Tensor):
                dtype = grad_output.dtype
        dtype = canonicalize_dtype(dtype)
        if dtype not in (torch.float32, torch.float16, torch.bfloat16):
            raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")

        # Check grad output tensor
        dy_local = grad_output
        dy = None
        dy_async = None
        with_dy_all_gather = tensor_parallel_mode == "row" and sequence_parallel
        if with_quantized_compute:
            if grad_output_quantizer is None:
                raise ValueError("Missing quantizer for grad output tensor")
            grad_output_quantizer.set_usage(
                rowwise=input_requires_grad,
                columnwise=weight_requires_grad,
            )
            if with_dy_all_gather:
                dy, dy_async = gather_along_first_dim(
                    dy_local,
                    tensor_parallel_group,
                    async_op=True,
                    quantizer=grad_output_quantizer,
                )
            else:
                if not is_quantized_tensor(dy_local):
                    dy_local = grad_output_quantizer(dy_local)
                else:
                    dy_local.update_usage(
                        rowwise_usage=input_requires_grad,
                        columnwise_usage=weight_requires_grad,
                    )
                dy = dy_local
        else:
            dy_local = maybe_dequantize(dy_local, dtype)

            if with_dy_all_gather:
                dy, dy_async = gather_along_first_dim(
                    dy_local,
                    tensor_parallel_group,
                    async_op=True,
                )
            else:
                dy = dy_local

        # Check input tensor
        x = None
        x_async = None
        if weight_requires_grad:
            if input is None:
                raise ValueError("Input tensor is required to compute weight grad")
            x_local = input
            with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel
            if with_quantized_compute:
                if input_quantizer is None:
                    raise ValueError("Missing quantizer for input tensor")
                input_quantizer.set_usage(rowwise=False, columnwise=True)
                if with_x_all_gather:
                    x, x_async = gather_along_first_dim(
                        x_local,
                        tensor_parallel_group,
                        async_op=True,
                        quantizer=input_quantizer,
                    )
                else:
                    if is_quantized_tensor(x_local):
                        x_local.update_usage(columnwise_usage=True)
                    else:
                        x_local = input_quantizer(x_local)
                    x = x_local
            else:
                x_local = maybe_dequantize(x_local, dtype)

                if with_x_all_gather:
                    x, x_async = gather_along_first_dim(
                        x_local,
                        tensor_parallel_group,
                        async_op=True,
                    )
                else:
                    x = x_local

        # Compute grad input
        dx = None
        dx_async = None
        if input_requires_grad:
            # Check weight tensor
            if weight is None:
                raise ValueError("Weight tensor is required to compute input grad")
            w = weight
            if with_quantized_compute:
                if is_quantized_tensor(w):
                    w.update_usage(columnwise_usage=True)
                else:
                    if weight_quantizer is None:
                        raise ValueError("Missing quantizer for weight tensor")
                    weight_quantizer.set_usage(columnwise=True)
                    w = weight_quantizer(w)
            else:
                w = maybe_dequantize(w, dtype)

            # Synchronize tensor-parallel communication
            _wait_async(dy_async)
            dy_async = None

            # Check grad input tensor
            dx = grad_input
            if dx is None:
                if not with_quantized_compute:
                    grad_input_quantizer = None
                if tensor_parallel_mode == "column":
                    grad_input_quantizer = None
            elif is_quantized_tensor(dx):
                if not with_quantized_compute:
                    raise ValueError(
                        "Grad input tensor is quantized, but quantized compute is not enabled"
                    )
                if tensor_parallel_mode == "column":
                    raise ValueError(
                        "Grad input tensor is quantized, "
                        "but column tensor parallelism does not support quantized grad input"
                    )
                if grad_input_quantizer is None:
                    grad_input_quantizer = getattr(dx, "_quantizer", None)
                if grad_input_quantizer is None:
                    raise ValueError(
                        "Grad input tensor is quantized, but quantizer was not provided"
                    )
            else:
                grad_input_quantizer = None
            if grad_input_quantizer is not None:
                if not isinstance(grad_input_quantizer, Float8Quantizer):
                    raise RuntimeError(
                        "Attempting to generate quantized grad input tensor "
                        "with unsupported quantizer"
                    )

            # Check if accumulating into grad input tensor
            if accumulate_into_grad_input:
                if dx is None:
                    raise ValueError(
                        "Attempted to accumulate into grad input tensor "
                        "without providing grad input tensor"
                    )
                if tensor_parallel_mode == "column":
                    raise ValueError(
                        "Accumulating into grad input tensor "
                        "is not supported with column tensor parallelism"
                    )

            # Perform dgrad GEMM
            dx = general_gemm(
                dy,
                w,
                usage_a=TensorUsage.LHS,
                usage_b=TensorUsage.RHS,
                out_dtype=dtype,
            )
            dx = _apply_gemm_options(
                dx,
                alpha=grad_input_alpha,
                beta=grad_input_beta,
                accumulate=accumulate_into_grad_input,
                out=grad_input,
            )

            # Reduce tensor-parallel grad input if needed
            if tensor_parallel_mode == "column":
                if sequence_parallel:
                    dx, dx_async = reduce_scatter_along_first_dim(
                        dx,
                        tensor_parallel_group,
                        async_op=True,
                    )
                else:
                    dx_async = torch.distributed.all_reduce(
                        dx,
                        group=tensor_parallel_group,
                        async_op=True,
                    )

        # Compute grad weight
        dw = None
        if weight_requires_grad:
            # Synchronize tensor-parallel communication
            _wait_async(x_async)
            _wait_async(dy_async)
            x_async = None
            dy_async = None

            # Check grad weight tensor
            dw = grad_weight
            dw_dtype = dtype
            if dw is None:
                if accumulate_into_grad_weight:
                    raise ValueError(
                        "Attempted to accumulate into grad weight tensor "
                        "without providing grad weight tensor"
                    )
            else:
                dw_dtype = dw.dtype

            # Perform wgrad GEMM
            if (
                accumulate_into_grad_weight
                and grad_weight_alpha in (None, 1.0)
                and grad_weight_beta in (None, 1.0)
            ):
                general_gemm_add(
                    dw,
                    dy,
                    x,
                    usage_a=TensorUsage.LHS_TRANS,
                    usage_b=TensorUsage.RHS,
                    out_dtype=dw_dtype,
                )
            else:
                dw = general_gemm(
                    dy,
                    x,
                    usage_a=TensorUsage.LHS_TRANS,
                    usage_b=TensorUsage.RHS,
                    out_dtype=dw_dtype,
                )
                dw = _apply_gemm_options(
                    dw,
                    alpha=grad_weight_alpha,
                    beta=grad_weight_beta,
                    accumulate=accumulate_into_grad_weight,
                    out=grad_weight,
                )

        # Clean up and return grads
        _wait_async(dy_async)
        _wait_async(x_async)
        _wait_async(dx_async)
        return dx, dw

    def op_forward(
        self,
        ctx: OperationContext,
        input_: torch.Tensor,
        *,
        prev_op_grad_output_quantizer: Optional[Quantizer],
        next_op_input_quantizer: Optional[Quantizer],
        **kwargs: Any,
    ) -> torch.Tensor:
        # Check which grads are required
        input_requires_grad = ctx.requires_grad
        weight_requires_grad = ctx.requires_grad and self.weight.requires_grad

        # Quantizers
        input_quantizer = self.get_quantizer("forward", 0)
        weight_quantizer = self.get_quantizer("forward", 1)
        output_quantizer = next_op_input_quantizer
        grad_output_quantizer = self.get_quantizer("backward", 0)
        grad_input_quantizer = prev_op_grad_output_quantizer
        with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
        if with_quantized_compute:
            backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override
        else:
            backward_override = None

        # Get autocast dtype if needed
        if torch.is_autocast_enabled():
            dtype = torch.get_autocast_dtype("npu")
        else:
            dtype = self.weight.dtype

        # Linear forward
        output, x_local, w = BasicLinear._functional_forward(
            input=input_,
            weight=self.weight,
            dtype=dtype,
            tensor_parallel_mode=self.tensor_parallel_mode,
            tensor_parallel_group=self.tensor_parallel_group,
            sequence_parallel=self.sequence_parallel,
            with_quantized_compute=with_quantized_compute,
            backward_override=backward_override,
            input_quantizer=input_quantizer,
            weight_quantizer=weight_quantizer,
            output_quantizer=output_quantizer,
            input_requires_grad=input_requires_grad,
            weight_requires_grad=weight_requires_grad,
        )

        # Save state for backward pass
        if ctx.requires_grad:
            if backward_override == "high_precision":
                saved_input = input_ if weight_requires_grad else None
                saved_weight = self.weight if input_requires_grad else None
            else:
                saved_input = x_local
                saved_weight = w
            if is_cpu_offload_enabled():
                # No special CPU offloading logic is needed for weights. saved_weight is
                # either self.weight (nn.Parameter, auto-excluded from offload) or a
                # workspace freshly created each forward pass.
                mark_activation_offload(saved_input)
            ctx.save_for_backward(saved_input, saved_weight)
            ctx.with_quantized_compute = with_quantized_compute and backward_override is None
            ctx.backward_override = backward_override
            ctx.input_quantizer = input_quantizer
            ctx.weight_quantizer = weight_quantizer
            ctx.grad_output_quantizer = grad_output_quantizer
            ctx.grad_input_quantizer = grad_input_quantizer
            ctx.dtype = dtype
            ctx.input_requires_grad = input_requires_grad
            ctx.weight_requires_grad = weight_requires_grad

        return output

    def op_backward(
        self,
        ctx: OperationContext,
        grad_output: torch.Tensor,
    ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
        # Saved tensors from forward pass
        (x_local, w) = ctx.saved_tensors

        # Megatron-LM wgrad fusion
        # Note: Get grad tensor from param so we can accumulate
        # directly into it.
        accumulate_into_main_grad = self._accumulate_into_main_grad
        grad_weight = None
        if ctx.weight_requires_grad and accumulate_into_main_grad:
            weight_param = self.weight
            main_grad = get_main_grad_from_param(weight_param, op_label="BasicLinear")
            accumulate_into_main_grad = get_accumulate_flag_in_param(weight_param)
            grad_weight = main_grad.detach()
        else:
            accumulate_into_main_grad = False

        # Linear backward pass
        grad_input, grad_weight = BasicLinear._functional_backward(
            grad_output=grad_output,
            input=x_local,
            weight=w,
            input_requires_grad=ctx.input_requires_grad,
            weight_requires_grad=ctx.weight_requires_grad,
            dtype=ctx.dtype,
            grad_weight=grad_weight,
            accumulate_into_grad_weight=accumulate_into_main_grad,
            tensor_parallel_mode=self.tensor_parallel_mode,
            tensor_parallel_group=self.tensor_parallel_group,
            sequence_parallel=self.sequence_parallel,
            with_quantized_compute=ctx.with_quantized_compute,
            input_quantizer=ctx.input_quantizer,
            weight_quantizer=ctx.weight_quantizer,
            grad_output_quantizer=ctx.grad_output_quantizer,
            grad_input_quantizer=ctx.grad_input_quantizer,
        )

        # Clear input tensor if possible
        clear_tensor_data(x_local)

        # Megatron-LM wgrad fusion
        # Note: Return dummy tensor for grad weight if needed.
        if accumulate_into_main_grad:
            grad_weight = get_dummy_wgrads_for_params([self.weight])[0]

        return grad_input, [grad_weight]