#include <vector>
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
constexpr int64_t DIM_ZERO = 0;
constexpr int64_t DIM_ONE = 1;
constexpr int64_t DIM_TWO = 2;
constexpr int64_t SHAPE_SIZE_TWO = 2;
const int64_t B4_NUMS_IN_B8 = 2;
static bool is_nz_format(const at::Tensor &x2)
{
auto* npu_impl = torch_npu::NPUBridge::GetNpuStorageImpl(x2);
TORCH_CHECK(npu_impl != nullptr, "Invalid NPU tensor implementation.", OPS_ERROR(ErrCode::PTR));
const torch_npu::NPUStorageDesc &tensor_desc = npu_impl->npu_desc_;
return tensor_desc.npu_format_ == ACL_FORMAT_FRACTAL_NZ;
}
at::Tensor npu_dual_level_quant_matmul(const at::Tensor &x1, const at::Tensor &x2, const at::Tensor &x1_level0_scale,
const at::Tensor &x2_level0_scale, const at::Tensor &x1_level1_scale,
const at::Tensor &x2_level1_scale, const c10::optional<at::Tensor> &bias,
int64_t output_dtype)
{
TORCH_CHECK(c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950,
"This interface is supported only on the Ascend950 platform and after.", OPS_ERROR(ErrCode::PARAM));
auto x1_dim_num = x1.dim();
TORCH_CHECK(x1_dim_num == SHAPE_SIZE_TWO, "The shape dimension of x1 is only supported to be 2, but it is ",
x1_dim_num, OPS_ERROR(ErrCode::PARAM));
auto x2_dim_num = x2.dim();
TORCH_CHECK(x2_dim_num == SHAPE_SIZE_TWO, "The shape dimension of x2 is only supported to be 2, but it is ",
x2_dim_num, OPS_ERROR(ErrCode::PARAM));
bool is_weight_nz = is_nz_format(x2);
TORCH_CHECK(is_weight_nz, "x2 only supports NZ format.", OPS_ERROR(ErrCode::PARAM));
auto x1_k_dim = x1.size(DIM_ONE) * B4_NUMS_IN_B8;
auto x2_k_dim = x2.size(DIM_ONE) * B4_NUMS_IN_B8;
TORCH_CHECK(x1_k_dim == x2_k_dim, "The k of x1 and x2 should be equal. but x1_k_dim is ", x1_k_dim,
", x2_k_dim is ", x2_k_dim, OPS_ERROR(ErrCode::PARAM));
auto out_dim_num = x1_dim_num;
auto output_size = op_infer::array_to_small_vector(x1.sizes());
output_size[DIM_ZERO] = x1.size(DIM_ZERO);
output_size[DIM_ONE] = x2.size(DIM_ZERO);
const at::Tensor &bias_v = bias.value_or(at::Tensor());
int64_t l0_group_size = 512;
int64_t l1_group_size = 32;
TensorWrapper x1_wrapper =
make_wrapper(x1, c10::optional<int64_t>(static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1)));
TensorWrapper x1_level1_scale_wrapper =
make_wrapper(x1_level1_scale, c10::optional<int64_t>(static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0)));
TensorWrapper x2_level1_scale_wrapper =
make_wrapper(x2_level1_scale, c10::optional<int64_t>(static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0)));
c10::TensorOptions options;
aclDataType output_acltype = c10_npu::GetAclDataType(output_dtype);
options = x1.options().dtype(npu_preparation::convert_to_scalar_type(output_acltype));
at::Tensor result = npu_preparation::apply_tensor_without_format(output_size, options);
bool transpose_x1 = false;
bool transpose_x2 = true;
EXEC_NPU_CMD(aclnnDualLevelQuantMatmulWeightNz, x1_wrapper, x2, x1_level0_scale, x2_level0_scale,
x1_level1_scale_wrapper, x2_level1_scale_wrapper, bias_v, transpose_x1, transpose_x2, l0_group_size,
l1_group_size, result);
return result;
}
}