# coding=utf-8
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) 2024, Bytedance Inc. All rights reserved.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

# pylint: disable=redefined-builtin

import warnings

import torch
import torch.distributed

try:
    import fused_weight_gradient_mlp_cuda
except Exception:
    warnings.warn("failed to generate the npu_matmul_add_fp32")


from megatron.core.parallel_state import (
    get_global_memory_buffer,
    get_tensor_model_parallel_group,
    get_tensor_model_parallel_world_size,
)
from megatron.core.utils import prepare_input_tensors_for_wgrad_compute
from megatron.core.tensor_parallel.layers import dist_all_gather_func, dist_reduce_scatter_func
from mindspeed.core.qat.w4a16_fake_quantization import W4A16FakeQuantization
from mindspeed.core.qat.w8a16_fake_quantization import W8A16FakeQuantization

w4a16_fakequant_func = W4A16FakeQuantization.apply
w8a16_fakequant_func = W8A16FakeQuantization.apply


def _linear_with_grad_accumulation_and_async_qat_forward(
    ctx,
    input,
    weight,
    bias,
    gradient_accumulation_fusion,
    allreduce_dgrad,
    sequence_parallel,
    grad_output_buffer,
    wgrad_deferral_limit,
    fakequant_func,
):
    quant_weight = fakequant_func(weight, [1, 32], False)
    ctx.save_for_backward(input, quant_weight)

    ctx.use_bias = bias is not None
    ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
    ctx.allreduce_dgrad = allreduce_dgrad
    ctx.sequence_parallel = sequence_parallel
    ctx.wgrad_deferral_limit = wgrad_deferral_limit
    ctx.grad_output_buffer = grad_output_buffer

    if sequence_parallel:
        world_size = get_tensor_model_parallel_world_size()
        dim_size = list(input.size())
        dim_size[0] = dim_size[0] * world_size

        all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
        dist_all_gather_func(all_gather_buffer, input, group=get_tensor_model_parallel_group())
        total_input = all_gather_buffer
    else:
        total_input = input

    output = torch.matmul(total_input, quant_weight.t())
    if bias is not None:
        output = output + bias
    return output


def _linear_with_grad_accumulation_and_async_qat_backward(ctx, grad_output):
    input, quant_weight = ctx.saved_tensors
    use_bias = ctx.use_bias
    grad_output_buffer = ctx.grad_output_buffer
    wgrad_deferral_limit = ctx.wgrad_deferral_limit

    handle = None
    wgrad_compute = True
    if grad_output_buffer is not None:
        if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
            grad_output_buffer.append(grad_output)
            wgrad_compute = False

    if wgrad_compute:
        if ctx.sequence_parallel:
            world_size = get_tensor_model_parallel_world_size()
            dim_size = list(input.size())
            dim_size[0] = dim_size[0] * world_size

            all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
            handle = dist_all_gather_func(
                all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
            )

            # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
            # gather is scheduled before the input gradient computation
            total_input = all_gather_buffer
        else:
            total_input = input
    grad_input = grad_output.matmul(quant_weight)

    if ctx.sequence_parallel and wgrad_compute:
        handle.wait()

    if wgrad_compute:
        grad_output, total_input = prepare_input_tensors_for_wgrad_compute(grad_output, total_input)

    if ctx.allreduce_dgrad:
        # Asynchronous all-reduce
        handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True)
        # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
        # all-reduce is scheduled before the weight gradient computation

    if ctx.sequence_parallel:
        assert not ctx.allreduce_dgrad
        dim_size = list(input.size())
        sub_grad_input = torch.empty(
            dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
        )
        # reduce_scatter
        handle = dist_reduce_scatter_func(
            sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
        )
        # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
        # reduce scatter is scheduled before the weight gradient computation

    if ctx.gradient_accumulation_fusion:
        if wgrad_compute:
            if quant_weight.main_grad.dtype == torch.float32:
                fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, quant_weight.main_grad)
            elif quant_weight.main_grad.dtype in (torch.float16, torch.bfloat16):
                fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, quant_weight.main_grad)
            else:
                raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")

        if hasattr(quant_weight, 'grad_added_to_main_grad'):
            # When overlap_grad_reduce is True, need to ensure that backward hooks
            # are all run on the main backprop thread to prevent deadlocks. Setup
            # dummy grad_weight tensor to prevent backward hooks from being run
            # in a background thread.
            if getattr(quant_weight, 'zero_out_wgrad', False):
                grad_weight = torch.zeros(
                    quant_weight.main_grad.shape,
                    dtype=input.dtype,
                    device=torch.cuda.current_device(),
                    requires_grad=False,
                )
            else:
                grad_weight = torch.empty(
                    quant_weight.main_grad.shape,
                    dtype=input.dtype,
                    device=torch.cuda.current_device(),
                    requires_grad=False,
                )
            quant_weight.grad_added_to_main_grad = True
        else:
            grad_weight = None
    else:
        grad_weight = grad_output.t().matmul(total_input)
    grad_bias = grad_output.sum(dim=0) if use_bias else None

    if ctx.sequence_parallel:
        handle.wait()
        # Need to return None's as gradient has to flow for all the input arguments
        # provided during forward
        return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None

    if ctx.allreduce_dgrad:
        handle.wait()

    return grad_input, grad_weight, grad_bias, None, None, None, None, None, None


def linear_with_grad_accumulation_and_async_w4a16_forward(
    ctx,
    input,
    weight,
    bias,
    gradient_accumulation_fusion,
    allreduce_dgrad,
    sequence_parallel,
    grad_output_buffer,
    wgrad_deferral_limit,
):
    return _linear_with_grad_accumulation_and_async_qat_forward(
        ctx,
        input,
        weight,
        bias,
        gradient_accumulation_fusion,
        allreduce_dgrad,
        sequence_parallel,
        grad_output_buffer,
        wgrad_deferral_limit,
        w4a16_fakequant_func,
    )


def linear_with_grad_accumulation_and_async_w4a16_backward(ctx, grad_output):
    return _linear_with_grad_accumulation_and_async_qat_backward(ctx, grad_output)


def linear_with_grad_accumulation_and_async_w8a16_forward(
    ctx,
    input,
    weight,
    bias,
    gradient_accumulation_fusion,
    allreduce_dgrad,
    sequence_parallel,
    grad_output_buffer,
    wgrad_deferral_limit,
):
    return _linear_with_grad_accumulation_and_async_qat_forward(
        ctx,
        input,
        weight,
        bias,
        gradient_accumulation_fusion,
        allreduce_dgrad,
        sequence_parallel,
        grad_output_buffer,
        wgrad_deferral_limit,
        w8a16_fakequant_func,
    )


def linear_with_grad_accumulation_and_async_w8a16_backward(ctx, grad_output):
    return _linear_with_grad_accumulation_and_async_qat_backward(ctx, grad_output)