import warnings
import torch
import torch.distributed
try:
import fused_weight_gradient_mlp_cuda
except Exception:
warnings.warn("failed to generate the npu_matmul_add_fp32")
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
)
from megatron.core.utils import prepare_input_tensors_for_wgrad_compute
from megatron.core.tensor_parallel.layers import dist_all_gather_func, dist_reduce_scatter_func
from mindspeed.core.qat.w4a16_fake_quantization import W4A16FakeQuantization
from mindspeed.core.qat.w8a16_fake_quantization import W8A16FakeQuantization
w4a16_fakequant_func = W4A16FakeQuantization.apply
w8a16_fakequant_func = W8A16FakeQuantization.apply
def _linear_with_grad_accumulation_and_async_qat_forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
wgrad_deferral_limit,
fakequant_func,
):
quant_weight = fakequant_func(weight, [1, 32], False)
ctx.save_for_backward(input, quant_weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.allreduce_dgrad = allreduce_dgrad
ctx.sequence_parallel = sequence_parallel
ctx.wgrad_deferral_limit = wgrad_deferral_limit
ctx.grad_output_buffer = grad_output_buffer
if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
dist_all_gather_func(all_gather_buffer, input, group=get_tensor_model_parallel_group())
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, quant_weight.t())
if bias is not None:
output = output + bias
return output
def _linear_with_grad_accumulation_and_async_qat_backward(ctx, grad_output):
input, quant_weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_output_buffer = ctx.grad_output_buffer
wgrad_deferral_limit = ctx.wgrad_deferral_limit
handle = None
wgrad_compute = True
if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output)
wgrad_compute = False
if wgrad_compute:
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
handle = dist_all_gather_func(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(quant_weight)
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(grad_output, total_input)
if ctx.allreduce_dgrad:
handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True)
if ctx.sequence_parallel:
assert not ctx.allreduce_dgrad
dim_size = list(input.size())
sub_grad_input = torch.empty(
dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
)
handle = dist_reduce_scatter_func(
sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
if quant_weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, quant_weight.main_grad)
elif quant_weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, quant_weight.main_grad)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
if hasattr(quant_weight, 'grad_added_to_main_grad'):
if getattr(quant_weight, 'zero_out_wgrad', False):
grad_weight = torch.zeros(
quant_weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
grad_weight = torch.empty(
quant_weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
quant_weight.grad_added_to_main_grad = True
else:
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None
if ctx.allreduce_dgrad:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def linear_with_grad_accumulation_and_async_w4a16_forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
wgrad_deferral_limit,
):
return _linear_with_grad_accumulation_and_async_qat_forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
wgrad_deferral_limit,
w4a16_fakequant_func,
)
def linear_with_grad_accumulation_and_async_w4a16_backward(ctx, grad_output):
return _linear_with_grad_accumulation_and_async_qat_backward(ctx, grad_output)
def linear_with_grad_accumulation_and_async_w8a16_forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
wgrad_deferral_limit,
):
return _linear_with_grad_accumulation_and_async_qat_forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
wgrad_deferral_limit,
w8a16_fakequant_func,
)
def linear_with_grad_accumulation_and_async_w8a16_backward(ctx, grad_output):
return _linear_with_grad_accumulation_and_async_qat_backward(ctx, grad_output)