"""Fusible operation for linear layer without bias."""
from __future__ import annotations
from collections.abc import Callable, Iterable
import contextlib
import math
from typing import Any, Optional
import torch
from ...constants import TensorUsage
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import (
CudaRNGStatesTracker,
gather_along_first_dim,
reduce_scatter_along_first_dim,
)
from ...quantization import FP8GlobalStateManager, Recipe
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..gemm import general_gemm
from ..gemm import general_gemm_add
from ..op import BasicOperation, OperationContext
from .._common import (
get_accumulate_flag_in_param,
get_dummy_wgrads_for_params,
get_main_grad_from_param,
is_quantized_tensor,
maybe_dequantize,
)
def _wait_async(handle: Optional[Any]) -> None:
"""Wait for asynchronous communication to finish, if needed"""
if handle is not None:
handle.wait()
def _apply_gemm_options(
result: torch.Tensor,
*,
alpha: Optional[float],
beta: Optional[float],
accumulate: bool,
out: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply the small subset of TE GEMM output options used by BasicLinear."""
if alpha not in (None, 1.0):
result = result * alpha
if out is None:
return result
if accumulate:
if beta not in (None, 1.0):
out.mul_(beta)
out.add_(result)
else:
out.copy_(result)
return out
class BasicLinear(BasicOperation):
"""Apply linear transformation: :math:`y = x A^T`."""
def __init__(
self,
in_features: int,
out_features: int,
*,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
accumulate_into_main_grad: bool = False,
userbuffers_options: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
self.in_features: int = in_features
self.out_features: int = out_features
device = canonicalize_device(device)
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
self.tensor_parallel_mode: Optional[str]
self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup]
self.tensor_parallel_size: int
self.sequence_parallel: bool
self.local_in_features: int
self.local_out_features: int
(
self.tensor_parallel_mode,
self.tensor_parallel_group,
self.tensor_parallel_size,
self.sequence_parallel,
self.local_in_features,
self.local_out_features,
) = self._canonicalize_tensor_parallelism(
mode=tensor_parallel_mode,
process_group=tensor_parallel_group,
sequence_parallel=sequence_parallel,
in_features=in_features,
out_features=out_features,
)
self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters()
if self._with_quantized_weight:
self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
weight = torch.empty(
self.local_out_features,
self.local_in_features,
device=device,
dtype=dtype,
)
weight = torch.nn.Parameter(weight)
self.weight: torch.nn.Parameter
self.register_parameter("weight", weight)
self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]]
self._rng_state_tracker_function = rng_state_tracker_function
if weight.device.type != "meta":
self.reset_parameters()
self._accumulate_into_main_grad: bool = accumulate_into_main_grad
self._userbuffers_options: Optional[dict[str, Any]] = userbuffers_options
@classmethod
def _canonicalize_tensor_parallelism(
cls,
*,
mode: Optional[str],
process_group: Optional[torch.distributed.ProcessGroup],
sequence_parallel: bool,
in_features: int,
out_features: int,
) -> tuple[
Optional[str],
Optional[torch.distributed.ProcessGroup],
int,
bool,
int,
int,
]:
"""Check configuration for tensor parallelism."""
if mode is None:
group_size = 1
else:
group_size = torch.distributed.get_world_size(process_group)
if group_size == 1:
mode = None
process_group = None
sequence_parallel = False
local_in_features = in_features
local_out_features = out_features
if mode is None:
pass
elif mode == "column":
if out_features % group_size != 0:
raise ValueError(
"Invalid configuration for tensor parallelism "
f"({mode=}, {out_features=}, {group_size=})"
)
local_out_features //= group_size
elif mode == "row":
if in_features % group_size != 0:
raise ValueError(
"Invalid configuration for tensor parallelism "
f"({mode=}, {in_features=}, {group_size=})"
)
local_in_features //= group_size
else:
raise ValueError(
"Supported modes for tensor parallelism are "
f'`None`, "row", and "column" (got {mode=})'
)
return (
mode,
process_group,
group_size,
sequence_parallel,
local_in_features,
local_out_features,
)
def num_quantizers(self, mode: str) -> int:
if mode == "forward":
return 2
if mode == "backward":
return 1
return 0
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""
weight = self.weight
device = weight.device
if device.type == "meta":
device = canonicalize_device(None)
if is_quantized_tensor(weight):
weight = torch.empty(
weight.size(),
dtype=weight.dtype,
device=device,
)
elif not devices_match(weight.device, device):
weight = torch.empty_like(weight, device=device)
init_context = contextlib.nullcontext()
if self._rng_state_tracker_function is not None:
init_context = self._rng_state_tracker_function().fork()
with init_context:
torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
if self._with_quantized_weight:
quantizer = self.get_quantizer("forward", 1)
if quantizer is None:
raise RuntimeError(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"within quantized_model_init, but the forward pass was not "
"performed within autocast."
)
quantizer.set_usage(
rowwise=True,
columnwise=torch.is_grad_enabled(),
)
quantizer.internal = False
with torch.no_grad():
weight = quantizer(weight)
if not isinstance(weight, torch.nn.Parameter):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
if self.weight.device.type == "meta":
self.reset_parameters()
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
super().pre_fuser_forward(requires_grad=requires_grad)
if FP8GlobalStateManager.is_fp8_enabled():
weight_requires_grad = requires_grad and self.weight.requires_grad
columnwise_usage = weight_requires_grad
if FP8GlobalStateManager.get_fp8_recipe().backward_override is not None:
columnwise_usage = False
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad)
grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe)
input_quantizer = self.get_quantizer("forward", 0)
grad_output_quantizer = self.get_quantizer("backward", 0)
if input_quantizer is not None:
input_quantizer.internal = True
if not (self.tensor_parallel_mode == "column" and self.sequence_parallel):
input_quantizer.optimize_for_gemm = True
if grad_output_quantizer is not None:
grad_output_quantizer.internal = True
if not (self.tensor_parallel_mode == "row" and self.sequence_parallel):
grad_output_quantizer.optimize_for_gemm = True
weight_quantizer = self.get_quantizer("forward", 1)
weight = getattr(self, "weight", None)
if weight_quantizer is not None:
weight_quantizer.internal = not (
FP8GlobalStateManager.with_fp8_parameters()
or getattr(self, "_with_quantized_weight", False)
or is_quantized_tensor(weight)
)
if recipe is not None:
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon
if getattr(self, "sequence_parallel", False):
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
if tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
elif tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
if weight_quantizer is not None and is_quantized_tensor(weight):
if weight._quantizer is not None:
weight_quantizer.set_usage(
rowwise=weight._quantizer.rowwise_usage,
columnwise=weight._quantizer.columnwise_usage,
)
weight.update_quantizer(weight_quantizer.copy())
@staticmethod
def _functional_forward(
input: torch.Tensor,
weight: torch.Tensor,
*,
alpha: float = 1.0,
bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
out: Optional[torch.Tensor] = None,
beta: Optional[float] = None,
accumulate_into_out: bool = False,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
with_quantized_compute: bool = False,
backward_override: Optional[str] = None,
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None,
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Functional API for forward pass."""
if dtype is None:
if out is not None and isinstance(out, torch.Tensor):
dtype = out.dtype
elif weight is not None and isinstance(weight, torch.Tensor):
dtype = weight.dtype
else:
raise ValueError(
"Could not infer dtype from weight nor out and dtype was not provided"
)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
if out is not None and out.dtype != dtype:
raise ValueError(f"Output tensor has invalid dtype (expected {dtype}, got {out.dtype})")
x_local = input
x = None
x_async = None
with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(
rowwise=True,
columnwise=weight_requires_grad and backward_override is None,
)
if with_x_all_gather:
input_quantizer.set_usage(columnwise=False)
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
async_op=True,
quantizer=input_quantizer,
)
else:
if not is_quantized_tensor(x_local):
x_local = input_quantizer(x_local)
x = x_local
else:
x_local = maybe_dequantize(x_local, dtype)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
async_op=True,
)
else:
x = x_local
w = weight
if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not is_quantized_tensor(w):
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(
rowwise=True,
columnwise=input_requires_grad and backward_override is None,
)
w = weight_quantizer(w)
y = out
if y is None:
if not with_quantized_compute:
output_quantizer = None
if tensor_parallel_mode == "row":
output_quantizer = None
elif is_quantized_tensor(y):
if not with_quantized_compute:
raise ValueError("Output tensor is quantized, but quantized compute is not enabled")
if tensor_parallel_mode == "row":
raise ValueError(
"Output tensor is quantized, "
"but row tensor parallelism does not support quantized output"
)
if output_quantizer is None:
output_quantizer = getattr(y, "_quantizer", None)
if output_quantizer is None:
raise ValueError("Output tensor is quantized, but quantizer was not provided")
else:
output_quantizer = None
if output_quantizer is not None:
if not isinstance(output_quantizer, Float8Quantizer):
raise RuntimeError(
"Attempting to generate quantized output tensor with unsupported quantizer"
)
output_quantizer.set_usage(rowwise=True, columnwise=False)
if accumulate_into_out:
if y is None:
raise ValueError(
"Attempted to accumulate into output tensor without providing output tensor"
)
if tensor_parallel_mode == "row":
raise ValueError(
"Accumulating into output tensor is not supported with row tensor parallelism"
)
_wait_async(x_async)
x_async = None
y = general_gemm(
x,
w,
usage_a=TensorUsage.LHS,
usage_b=TensorUsage.RHS_TRANS,
out_dtype=dtype,
bias=bias,
)
y = _apply_gemm_options(
y,
alpha=alpha,
beta=beta,
accumulate=accumulate_into_out,
out=out,
)
if tensor_parallel_mode == "row":
if sequence_parallel:
y, _ = reduce_scatter_along_first_dim(y, tensor_parallel_group)
else:
torch.distributed.all_reduce(y, group=tensor_parallel_group)
if input_requires_grad:
if (
w is not weight
and with_quantized_compute
and is_quantized_tensor(w)
and backward_override is None
):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
if weight_requires_grad:
if (
with_quantized_compute
and is_quantized_tensor(x_local)
and backward_override is None
):
if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather):
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None
return y, x_local, w
@staticmethod
def _functional_backward(
grad_output: torch.Tensor,
input: Optional[torch.Tensor],
weight: Optional[torch.Tensor],
*,
grad_input_alpha: Optional[float] = None,
input_requires_grad: bool = True,
grad_weight_alpha: Optional[float] = None,
weight_requires_grad: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
grad_weight: Optional[torch.Tensor] = None,
grad_weight_beta: Optional[float] = None,
accumulate_into_grad_weight: bool = False,
grad_input: Optional[torch.Tensor] = None,
grad_input_beta: Optional[float] = None,
accumulate_into_grad_input: bool = False,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
with_quantized_compute: bool = False,
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
grad_output_quantizer: Optional[Quantizer] = None,
grad_input_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Functional API for backward pass."""
if dtype is None:
if isinstance(weight, torch.Tensor):
dtype = weight.dtype
elif isinstance(grad_output, torch.Tensor):
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
dy_local = grad_output
dy = None
dy_async = None
with_dy_all_gather = tensor_parallel_mode == "row" and sequence_parallel
if with_quantized_compute:
if grad_output_quantizer is None:
raise ValueError("Missing quantizer for grad output tensor")
grad_output_quantizer.set_usage(
rowwise=input_requires_grad,
columnwise=weight_requires_grad,
)
if with_dy_all_gather:
dy, dy_async = gather_along_first_dim(
dy_local,
tensor_parallel_group,
async_op=True,
quantizer=grad_output_quantizer,
)
else:
if not is_quantized_tensor(dy_local):
dy_local = grad_output_quantizer(dy_local)
else:
dy_local.update_usage(
rowwise_usage=input_requires_grad,
columnwise_usage=weight_requires_grad,
)
dy = dy_local
else:
dy_local = maybe_dequantize(dy_local, dtype)
if with_dy_all_gather:
dy, dy_async = gather_along_first_dim(
dy_local,
tensor_parallel_group,
async_op=True,
)
else:
dy = dy_local
x = None
x_async = None
if weight_requires_grad:
if input is None:
raise ValueError("Input tensor is required to compute weight grad")
x_local = input
with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=False, columnwise=True)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
async_op=True,
quantizer=input_quantizer,
)
else:
if is_quantized_tensor(x_local):
x_local.update_usage(columnwise_usage=True)
else:
x_local = input_quantizer(x_local)
x = x_local
else:
x_local = maybe_dequantize(x_local, dtype)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
async_op=True,
)
else:
x = x_local
dx = None
dx_async = None
if input_requires_grad:
if weight is None:
raise ValueError("Weight tensor is required to compute input grad")
w = weight
if with_quantized_compute:
if is_quantized_tensor(w):
w.update_usage(columnwise_usage=True)
else:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
w = maybe_dequantize(w, dtype)
_wait_async(dy_async)
dy_async = None
dx = grad_input
if dx is None:
if not with_quantized_compute:
grad_input_quantizer = None
if tensor_parallel_mode == "column":
grad_input_quantizer = None
elif is_quantized_tensor(dx):
if not with_quantized_compute:
raise ValueError(
"Grad input tensor is quantized, but quantized compute is not enabled"
)
if tensor_parallel_mode == "column":
raise ValueError(
"Grad input tensor is quantized, "
"but column tensor parallelism does not support quantized grad input"
)
if grad_input_quantizer is None:
grad_input_quantizer = getattr(dx, "_quantizer", None)
if grad_input_quantizer is None:
raise ValueError(
"Grad input tensor is quantized, but quantizer was not provided"
)
else:
grad_input_quantizer = None
if grad_input_quantizer is not None:
if not isinstance(grad_input_quantizer, Float8Quantizer):
raise RuntimeError(
"Attempting to generate quantized grad input tensor "
"with unsupported quantizer"
)
if accumulate_into_grad_input:
if dx is None:
raise ValueError(
"Attempted to accumulate into grad input tensor "
"without providing grad input tensor"
)
if tensor_parallel_mode == "column":
raise ValueError(
"Accumulating into grad input tensor "
"is not supported with column tensor parallelism"
)
dx = general_gemm(
dy,
w,
usage_a=TensorUsage.LHS,
usage_b=TensorUsage.RHS,
out_dtype=dtype,
)
dx = _apply_gemm_options(
dx,
alpha=grad_input_alpha,
beta=grad_input_beta,
accumulate=accumulate_into_grad_input,
out=grad_input,
)
if tensor_parallel_mode == "column":
if sequence_parallel:
dx, dx_async = reduce_scatter_along_first_dim(
dx,
tensor_parallel_group,
async_op=True,
)
else:
dx_async = torch.distributed.all_reduce(
dx,
group=tensor_parallel_group,
async_op=True,
)
dw = None
if weight_requires_grad:
_wait_async(x_async)
_wait_async(dy_async)
x_async = None
dy_async = None
dw = grad_weight
dw_dtype = dtype
if dw is None:
if accumulate_into_grad_weight:
raise ValueError(
"Attempted to accumulate into grad weight tensor "
"without providing grad weight tensor"
)
else:
dw_dtype = dw.dtype
if (
accumulate_into_grad_weight
and grad_weight_alpha in (None, 1.0)
and grad_weight_beta in (None, 1.0)
):
general_gemm_add(
dw,
dy,
x,
usage_a=TensorUsage.LHS_TRANS,
usage_b=TensorUsage.RHS,
out_dtype=dw_dtype,
)
else:
dw = general_gemm(
dy,
x,
usage_a=TensorUsage.LHS_TRANS,
usage_b=TensorUsage.RHS,
out_dtype=dw_dtype,
)
dw = _apply_gemm_options(
dw,
alpha=grad_weight_alpha,
beta=grad_weight_beta,
accumulate=accumulate_into_grad_weight,
out=grad_weight,
)
_wait_async(dy_async)
_wait_async(x_async)
_wait_async(dx_async)
return dx, dw
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
**kwargs: Any,
) -> torch.Tensor:
input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override
else:
backward_override = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("npu")
else:
dtype = self.weight.dtype
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=self.weight,
dtype=dtype,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel,
with_quantized_compute=with_quantized_compute,
backward_override=backward_override,
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
if ctx.requires_grad:
if backward_override == "high_precision":
saved_input = input_ if weight_requires_grad else None
saved_weight = self.weight if input_requires_grad else None
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
ctx.save_for_backward(saved_input, saved_weight)
ctx.with_quantized_compute = with_quantized_compute and backward_override is None
ctx.backward_override = backward_override
ctx.input_quantizer = input_quantizer
ctx.weight_quantizer = weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad
return output
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
(x_local, w) = ctx.saved_tensors
accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad:
weight_param = self.weight
main_grad = get_main_grad_from_param(weight_param, op_label="BasicLinear")
accumulate_into_main_grad = get_accumulate_flag_in_param(weight_param)
grad_weight = main_grad.detach()
else:
accumulate_into_main_grad = False
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=w,
input_requires_grad=ctx.input_requires_grad,
weight_requires_grad=ctx.weight_requires_grad,
dtype=ctx.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel,
with_quantized_compute=ctx.with_quantized_compute,
input_quantizer=ctx.input_quantizer,
weight_quantizer=ctx.weight_quantizer,
grad_output_quantizer=ctx.grad_output_quantizer,
grad_input_quantizer=ctx.grad_input_quantizer,
)
clear_tensor_data(x_local)
if accumulate_into_main_grad:
grad_weight = get_dummy_wgrads_for_params([self.weight])[0]
return grad_input, [grad_weight]