// Copyright (c) 2025 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"

namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
constexpr int64_t MXFP_DIVISOR_SIZE = 64LL;
constexpr int64_t MXFP_MULTI_BASE_SIZE = 2LL;
constexpr int64_t NUM_TWO = 2LL;
constexpr int64_t NUM_ONE = 1LL;
constexpr int64_t DIM_2 = 2LL;
constexpr int64_t DIM_1 = 1LL;
constexpr int64_t DIM_0 = 0LL;
constexpr int64_t DIM_3 = 3LL;
constexpr int64_t DIM_4 = 4LL;
constexpr int64_t DIM_5 = 5LL;
constexpr int64_t FLOAT8_E5M2 = 35LL;
constexpr int64_t FLOAT8_E4M3FN = 36LL;
constexpr int64_t HIFLOAT8 = 34LL;

void create_new_tensor(at::Tensor &y, size_t dim_m, size_t dim_n, c10::TensorOptions options)
{
    auto output_size = op_infer::array_to_small_vector({dim_m, dim_n});
    y = npu_preparation::apply_tensor_without_format(output_size, options);
}

void create_new_tensor_batch(at::Tensor &y, size_t batch, size_t dim_m, size_t dim_n,
                             const c10::TensorOptions &options)
{
    auto output_size = op_infer::array_to_small_vector({batch, dim_m, dim_n});
    y = npu_preparation::apply_tensor_without_format(output_size, options);
}

bool is_transpose_last_two_dims(const at::Tensor &tensor)
{
    if (tensor.dim() < DIM_2) {
        return false;
    }
    auto sizes = tensor.sizes();
    auto strides = tensor.strides();
    int64_t last_dim = tensor.dim() - DIM_1;
    int64_t penultimate_dim = tensor.dim() - DIM_2;
    if (strides[penultimate_dim] != NUM_ONE || strides[last_dim] != sizes[penultimate_dim]) {
        return false;
    }
    // Match CANN's transposed contiguous-view check: batch strides must also follow the swapped last-two dims.
    int64_t expected_stride = sizes[last_dim] * sizes[penultimate_dim];
    for (int64_t batch_dim = tensor.dim() - DIM_3; batch_dim >= DIM_0; --batch_dim) {
        if (strides[batch_dim] != expected_stride) {
            return false;
        }
        expected_stride *= sizes[batch_dim];
    }
    return true;
}

int64_t infer_nz_logical_n(const at::Tensor &weight_scale, bool is_mx_quant)
{
    TORCH_CHECK(weight_scale.dim() >= DIM_2, "The dim of weight_scale[0] should be greater than or equal to 2, "
                "but got ", weight_scale.dim(), OPS_ERROR(ErrCode::PARAM));
    if (is_mx_quant) {
        TORCH_CHECK(weight_scale.dim() == DIM_4, "The dim of weight_scale[0] should be equal to 4 in MX quant mode, "
                    "but got ", weight_scale.dim(), OPS_ERROR(ErrCode::PARAM));
        // Runtime callers normalize MX weight to the non-transposed logical layout before this op.
        // Thus 4D MX weightScale is [E, ceil(K/64), N, 2], and logical N is dim2.
        return weight_scale.size(DIM_2);
    }
    return weight_scale.sizes().back();
}

std::tuple<at::Tensor, at::Tensor> npu_grouped_matmul_swiglu_quant_v2(
    const at::Tensor & x,
    const at::TensorList weight,
    const at::TensorList weight_scale,
    const at::Tensor & x_scale,
    const at::Tensor & group_list,
    const c10::optional<at::Tensor> & smooth_scale,
    const c10::optional<at::TensorList> weight_assist_matrix,
    const c10::optional<at::Tensor> & bias,
    c10::optional<int64_t> dequant_mode,
    c10::optional<int64_t> dequant_dtype,
    c10::optional<int64_t> quant_mode,
    c10::optional<int64_t> quant_dtype,
    c10::optional<int64_t> group_list_type,
    const c10::OptionalIntArrayRef tuning_config,
    c10::optional<int64_t> x_dtype,
    c10::optional<int64_t> weight_dtype,
    c10::optional<int64_t> weight_scale_dtype,
    c10::optional<int64_t> x_scale_dtype)
{
    TORCH_CHECK(weight.size() == NUM_ONE, "The size of weight should be 1, current size is ", weight.size(), OPS_ERROR(ErrCode::PARAM));
    TORCH_CHECK(weight_scale.size() == NUM_ONE, "The size of weight_scale should be 1, current size is ",
                weight_scale.size(), OPS_ERROR(ErrCode::PARAM));
    TORCH_CHECK(x.dim() >= DIM_2, "The x dim should greater than 2, but the actual value is ", x.dim(), OPS_ERROR(ErrCode::PARAM));
    TORCH_CHECK(!weight_scale[DIM_0].sizes().empty(), "The weight_scale[0] is empty.", OPS_ERROR(ErrCode::PARAM));

    const bool is_weight_nz = at_npu::native::custom_ops::get_npu_format(weight[DIM_0]) == ACL_FORMAT_FRACTAL_NZ ||
                              at_npu::native::custom_ops::get_npu_format(weight[DIM_0]) == ACL_FORMAT_FRACTAL_NZ_C0_16
                              || weight[DIM_0].dim() == DIM_5;
    auto x_size = x.sizes();
    int n = 0;
    bool weight_trans = is_transpose_last_two_dims(weight[DIM_0]);
    const bool is_mx_quant = weight_scale_dtype.has_value();
    const bool is_5d_nz = is_weight_nz && (weight[DIM_0].dim() == DIM_5);
    if (is_5d_nz) {
        n = static_cast<int>(infer_nz_logical_n(weight_scale[DIM_0], is_mx_quant));
    } else {
        if (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950) {
            n = static_cast<int>(weight[DIM_0].sizes()[DIM_2]);
        } else {
            n = static_cast<int>(weight_scale[DIM_0].sizes().back());
        }
    }
    int m = x_size[DIM_0];
    int k = x_size[DIM_1];

    const bool mxfp8w4_nz_input = is_weight_nz &&
                                     x.scalar_type() == at::kFloat8_e4m3fn &&
                                     weight_dtype.has_value() &&
                                     weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1);

    if (x_dtype.has_value()) {
        TORCH_CHECK(x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2)
                 || x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1)
                 || x_dtype.value() == static_cast<int64_t>(c10_npu::DType::HIFLOAT8),
                    "The optional parameter x_dtype only supports torch_npu.float4_e2m1fn_x2, torch_npu.float4_e1m2fn_x2, torch_npu.hifloat8, or None, but the actual value is ",
                    c10_npu::CustomDataTypeToString(x_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
    }
    if (weight_dtype.has_value()) {
        TORCH_CHECK(weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2)
                 || weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1)
                 || weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::HIFLOAT8),
                    "The optional parameter weight_dtype only supports torch_npu.float4_e2m1fn_x2, torch_npu.float4_e1m2fn_x2, torch_npu.hifloat8, or None, but the actual value is ",
                    c10_npu::CustomDataTypeToString(weight_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
    }

    if (!mxfp8w4_nz_input) {
        TORCH_CHECK(
            (x_dtype.has_value() && weight_dtype.has_value()) || (!x_dtype.has_value() && !weight_dtype.has_value()),
            "The optional parameter x_dtype and weight_dtype should both be torch_npu.float4_e2m1fn_x2, torch_npu.float4_e1m2fn_x2"
            "torch_npu.hifloat8, or None.",
            OPS_ERROR(ErrCode::VALUE));
    }

    if (weight_scale_dtype.has_value()) {
        TORCH_CHECK(weight_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0),
                    "The optional parameter weight_scale_dtype only supports float8_e8m0fnu or None, but the actual value is ",
                    c10_npu::CustomDataTypeToString(weight_scale_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
    }
    if (x_scale_dtype.has_value()) {
        TORCH_CHECK(x_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0),
                    "The optional parameter x_scale_dtype only supports float8_e8m0fnu or None, but the actual value is ",
                    c10_npu::CustomDataTypeToString(x_scale_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
    }
    if (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950) {
        if (dequant_dtype.has_value()) {
            TORCH_CHECK(dequant_dtype.value() == static_cast<int64_t>(c10::ScalarType::Float)
                        || dequant_dtype.value() == static_cast<int64_t>(c10::ScalarType::Char)
                        || dequant_dtype.value() == static_cast<int64_t>(c10::ScalarType::Half)
                        || dequant_dtype.value() == static_cast<int64_t>(c10::ScalarType::BFloat16),
                        "The optional parameter dequant_dtype only support torch.float32, torch.int8, torch.float16 and torch.bfloat16 ,but the actual value is ",
                        c10_npu::CustomDataTypeToString(dequant_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
        }
    }

    int64_t dequant_mode_real = dequant_mode.value_or(0);
    int64_t dequant_dtype_real = dequant_dtype.value_or(0);
    // 从torch的枚举值转化为Ge的枚举值
    const std::map<int64_t, int64_t> TorchToGeMap = {
        {6, 0},
        {5, 1},
        {15, 27}};
    auto it = TorchToGeMap.find(dequant_dtype.value_or(0));
    if (it != TorchToGeMap.end()) {
        dequant_dtype_real = it->second;
    }
    int64_t quant_mode_real = quant_mode.value_or(0);
    int64_t group_list_type_real = group_list_type.value_or(0);
    auto weight_assist_matrix_real = weight_assist_matrix.value_or(at::TensorList());
    auto tuning_config_real = tuning_config.value_or(at::IntArrayRef{});
    auto bias_real = bias.value_or(at::Tensor());
    auto smooth_scale_real = smooth_scale.value_or(at::Tensor());

    const bool mxfp4_input = x_dtype.has_value() && weight_dtype.has_value() &&
                                   (x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2) ||
                                    x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1)) &&
                                   (weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2) ||
                                    weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1));
    at::Tensor output;
    at::Tensor output_scale;
    if (!weight_scale_dtype.has_value()) {
        if (c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E5M2 || c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E4M3FN
            || c10_npu::GetAclDataType(quant_dtype.value()) == HIFLOAT8) {
            c10::TensorOptions options_output = x.options().dtype(quant_dtype.has_value()
                                                                      ? npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(quant_dtype.value()))
                                                                      : x[DIM_0].scalar_type());
            create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE, options_output);
        } else {
            output = npu_preparation::apply_tensor_without_format({m, n / MXFP_MULTI_BASE_SIZE}, c10::dtype(c10::ScalarType::Char));
        }
        output_scale = npu_preparation::apply_tensor_without_format({m}, c10::dtype(c10::ScalarType::Float));
    } else {
        if (dequant_dtype.has_value()) {
                dequant_dtype_real = static_cast<int64_t>(c10_npu::GetAclDataType(dequant_dtype.value()));
        }
        TORCH_CHECK(!weight[DIM_0].sizes().empty(), "weight[0] is empty.", OPS_ERROR(ErrCode::PARAM));
        if (!is_weight_nz) {
            TORCH_CHECK(weight[DIM_0].dim() == DIM_3, "weight[0] dim should be equal to 3, but the actual value is ",
                        weight[DIM_0].dim(), OPS_ERROR(ErrCode::PARAM));
            if (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950) {
                n = static_cast<int>(weight[DIM_0].sizes()[DIM_2]);
            } else {
                n = static_cast<int>(weight_scale[DIM_0].sizes().back());
            }
        }
        c10::TensorOptions options_output = x.options().dtype(quant_dtype.has_value()
                    ? npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(quant_dtype.value()))
                    : x[DIM_0].scalar_type());
        c10::TensorOptions options = x.options().dtype(npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(weight_scale_dtype.value())));

        if (mxfp4_input) {
            // In the non-NZ MXFP4 path, uint8 storage packs two FP4 values. For non-transposed weight,
            // n is the packed physical N from weight.shape[2], while CANN validates against logical N = n * 2.
            if (!weight_trans) {
                if (c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E5M2 || c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E4M3FN) {
                    create_new_tensor(output, m, ((n / MXFP_MULTI_BASE_SIZE) * FP4_IN_INT8), options_output);
                    create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n * FP4_IN_INT8 / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE),
                                            MXFP_MULTI_BASE_SIZE, options);
                } else {
                    create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE, options_output);
                    create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n * FP4_IN_INT8 / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE), MXFP_MULTI_BASE_SIZE, options);
                }
            } else {
                if (c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E5M2 || c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E4M3FN) {
                    create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE, options_output);
                    create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE), MXFP_MULTI_BASE_SIZE, options);
                } else {
                    create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE / NUM_TWO, options_output);
                    create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE), MXFP_MULTI_BASE_SIZE, options);
                }
            }
        } else {
            create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE, options_output);
            create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE), MXFP_MULTI_BASE_SIZE, options);
        }
    }

    TensorWrapper x_wrapper = {x,
        x_dtype.has_value() ? c10_npu::GetAclDataType(x_dtype.value())
                            : npu_preparation::convert_to_acl_data_type(x.scalar_type())};
    TensorListWrapper weight_wrapper = {weight,
        weight_dtype.has_value() ? c10_npu::GetAclDataType(weight_dtype.value())
                                 : npu_preparation::convert_to_acl_data_type(weight[0].scalar_type())};
    TensorListWrapper weight_scale_wrapper = {weight_scale,
        weight_scale_dtype.has_value() ? c10_npu::GetAclDataType(weight_scale_dtype.value())
                                : (weight_scale.empty() ? aclDataType::ACL_FLOAT
                                : npu_preparation::convert_to_acl_data_type(weight_scale[0].scalar_type()))};
    TensorWrapper x_scale_wrapper = {x_scale,
        x_scale_dtype.has_value() ? c10_npu::GetAclDataType(x_scale_dtype.value())
                                : (!x_scale.numel() ? aclDataType::ACL_FLOAT
                                : npu_preparation::convert_to_acl_data_type(x_scale.scalar_type()))};
    TensorWrapper output_wrapper = {output,
        quant_dtype.has_value() ? c10_npu::GetAclDataType(quant_dtype.value()): aclDataType::ACL_FLOAT};
    TensorWrapper output_scale_wrapper = {output_scale,
        weight_scale_dtype.has_value() ? aclDataType::ACL_FLOAT8_E8M0 : aclDataType::ACL_FLOAT};

    if (is_weight_nz) {
        static const bool is_weight_nz_available = check_aclnn_kernel_available("aclnnGroupedMatmulSwigluQuantWeightNzV2");
        TORCH_CHECK(is_weight_nz_available,
                    "Format of weight in npu_grouped_matmul is FRACTAL_NZ, current CANN version "
                    "do not support with this format. Please try to update the version of CANN."
                    + OPS_ERROR(ErrCode::PARAM));

        if (mxfp8w4_nz_input) {
            EXEC_NPU_CMD(
                aclnnGroupedMatmulSwigluQuantWeightNzV2,
                x,
                weight_wrapper,
                weight_scale_wrapper,
                weight_assist_matrix_real,
                bias_real,
                x_scale_wrapper,
                smooth_scale_real,
                group_list,
                dequant_mode_real,
                dequant_dtype_real,
                quant_mode_real,
                group_list_type_real,
                tuning_config_real,
                output,
                output_scale_wrapper);
        } else {
            at::Tensor weight_for_nz = weight[DIM_0];
            if (at_npu::native::custom_ops::get_npu_format(weight_for_nz) != ACL_FORMAT_FRACTAL_NZ) {
                weight_for_nz = weight_for_nz.clone();
                auto &desc = torch_npu::NPUBridge::GetNpuStorageImpl(weight_for_nz)->npu_desc_;
                desc.npu_format_ = ACL_FORMAT_FRACTAL_NZ;
                desc.storage_sizes_ = op_infer::array_to_small_vector(weight_for_nz.sizes());
            }
            c10::SmallVector<at::Tensor, 1> weight_nz_vec = {weight_for_nz};
            at::TensorList weight_nz_list(weight_nz_vec);
            TensorListWrapper weight_nz_wrapper = {weight_nz_list,
                weight_dtype.has_value() ? c10_npu::GetAclDataType(weight_dtype.value())
                                 : npu_preparation::convert_to_acl_data_type(weight[0].scalar_type())};
            EXEC_NPU_CMD(
                aclnnGroupedMatmulSwigluQuantWeightNzV2,
                x_wrapper,
                weight_dtype.has_value() ? weight_wrapper : weight_nz_wrapper,
                weight_scale_wrapper,
                weight_assist_matrix_real,
                bias_real,
                x_scale_wrapper,
                smooth_scale_real,
                group_list,
                dequant_mode_real,
                dequant_dtype_real,
                quant_mode_real,
                group_list_type_real,
                tuning_config_real,
                output_wrapper,
                output_scale_wrapper);
        }
    } else {
        EXEC_NPU_CMD(
            aclnnGroupedMatmulSwigluQuantV2,
            x_wrapper,
            weight_wrapper,
            weight_scale_wrapper,
            weight_assist_matrix_real,
            bias_real,
            x_scale_wrapper,
            smooth_scale_real,
            group_list,
            dequant_mode_real,
            dequant_dtype_real,
            quant_mode_real,
            group_list_type_real,
            tuning_config_real,
            output_wrapper,
            output_scale_wrapper);
    }
    return std::tuple<at::Tensor, at::Tensor>(output, output_scale);
}
}