from functools import wraps
from webbrowser import get
import os
import warnings
from typing import List, Optional, Callable, Any
import torch
import torch_npu
import torch.distributed
import torch.nn.functional as F
try:
import fused_weight_gradient_mlp_cuda
except Exception:
warnings.warn("failed to generate the npu_matmul_add_fp32")
from megatron.core import parallel_state, tensor_parallel, mpu
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
)
from megatron.core.tensor_parallel.layers import set_tensor_model_parallel_attributes
from megatron.core.parallel_state import get_tensor_model_parallel_world_size
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.tensor_parallel.utils import VocabUtility, divide, split_tensor_along_last_dim
from megatron.core.utils import prepare_input_tensors_for_wgrad_compute
from megatron.core.tensor_parallel.layers import dist_all_gather_func, dist_reduce_scatter_func
from .w4a16_fake_quantization import W4A16FakeQuantization
w4a16_fakequant_func = W4A16FakeQuantization.apply
def linear_with_grad_accumulation_and_async_w4a16_forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
wgrad_deferral_limit,
):
weight = w4a16_fakequant_func(weight, [1, 32], False)
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.allreduce_dgrad = allreduce_dgrad
ctx.sequence_parallel = sequence_parallel
ctx.wgrad_deferral_limit = wgrad_deferral_limit
ctx.grad_output_buffer = grad_output_buffer
if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
dist_all_gather_func(all_gather_buffer, input, group=get_tensor_model_parallel_group())
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
return output
def linear_with_grad_accumulation_and_async_w4a16_backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_output_buffer = ctx.grad_output_buffer
wgrad_deferral_limit = ctx.wgrad_deferral_limit
wgrad_compute = True
if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output)
wgrad_compute = False
if wgrad_compute:
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(
dim_size, input.dtype, "mpu"
)
handle = dist_all_gather_func(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, total_input
)
if ctx.allreduce_dgrad:
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
if ctx.sequence_parallel:
assert not ctx.allreduce_dgrad
dim_size = list(input.size())
sub_grad_input = torch.empty(
dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
)
handle = dist_reduce_scatter_func(
sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
if hasattr(weight, 'grad_added_to_main_grad'):
if getattr(weight, 'zero_out_wgrad', False):
grad_weight = torch.zeros(
weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
grad_weight = torch.empty(
weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
weight.grad_added_to_main_grad = True
else:
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None
if ctx.allreduce_dgrad:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None