#include <torch/extension.h>
#include <torch_npu/csrc/framework/utils/RandomOpAdapter.h>
#include <torch_npu/csrc/framework/utils/OpAdapter.h>
#include <torch_npu/csrc/core/npu/NPUFormat.h>
#include <torch_npu/csrc/include/ops.h>
#include "inc/aclnn_common.h"
const static int MIN_DIM = 2;
const static int X_MAX_DIM = 8;
using npu_preparation = at_npu::native::OpPreparation;
namespace op_infer {
constexpr int SIZE = 8;
c10::SmallVector<int64_t, SIZE> array_to_small_vector(c10::IntArrayRef shape)
{
c10::SmallVector<int64_t, SIZE> small_shape;
for (size_t i = 0; i < shape.size(); ++i) {
small_shape.emplace_back(shape[i]);
}
return small_shape;
}
}
at::Tensor npu_ffn(const at::Tensor &x, const at::Tensor &weight1, const at::Tensor &weight2,
std::string activation, c10::optional<at::Tensor> expert_tokens, c10::optional<at::Tensor> expert_tokens_index,
const c10::optional<at::Tensor> &bias1, const c10::optional<at::Tensor> &bias2,
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset,
const c10::optional<at::Tensor> &deq_scale1, const c10::optional<at::Tensor> &deq_scale2,
const c10::optional<at::Tensor> &antiquant_scale1, const c10::optional<at::Tensor> &antiquant_scale2,
const c10::optional<at::Tensor> &antiquant_offset1, const c10::optional<at::Tensor> &antiquant_offset2,
c10::optional<int64_t> inner_precise, c10::optional<at::ScalarType> output_dtype)
{
auto weight1_dim_num = weight1.dim();
auto weight2_dim_num = weight2.dim();
auto x_dim_num = x.dim();
TORCH_CHECK(x_dim_num >= MIN_DIM && x_dim_num <= X_MAX_DIM, "x shape dims should be 2~8, but it is ", x_dim_num);
auto x_k_dim = x.size(x.dim() - 1);
auto wight1_k_dim = weight1.size(weight1.dim() - 2);
TORCH_CHECK(x_k_dim == wight1_k_dim, "The k of x and weight should be equal. but x_k_dim is ",
x_k_dim, ", wight1_k_dim is ", wight1_k_dim);
TORCH_CHECK(!(expert_tokens.has_value() && expert_tokens_index.has_value()),
"expert_tokens and expert_tokens_index should not have the value simultaneously.");
char *activation_ptr = const_cast<char *>(activation.data());
const at::Tensor &bias1_real = bias1.value_or(at::Tensor());
const at::Tensor &bias2_real = bias2.value_or(at::Tensor());
const at::Tensor &scale_real = scale.value_or(at::Tensor());
const at::Tensor &offset_real = offset.value_or(at::Tensor());
const at::Tensor &deq_scale1_real = deq_scale1.value_or(at::Tensor());
const at::Tensor &deq_scale2_real = deq_scale2.value_or(at::Tensor());
const at::Tensor &antiquant_scale1_real = antiquant_scale1.value_or(at::Tensor());
const at::Tensor &antiquant_scale2_real = antiquant_scale2.value_or(at::Tensor());
const at::Tensor &antiquant_offset1_real = antiquant_offset1.value_or(at::Tensor());
const at::Tensor &antiquant_offset2_real = antiquant_offset2.value_or(at::Tensor());
auto output_size = op_infer::array_to_small_vector(x.sizes());
output_size[x.dim() - 1] = weight2.size(weight2.dim() - 1);
c10::TensorOptions options = x.options().dtype(x.scalar_type());
if (x.scalar_type() == at::kChar && weight1.scalar_type() == at::kChar && weight2.scalar_type() == at::kChar) {
options = x.options().dtype(output_dtype.value_or(at::kHalf));
}
at::Tensor result = at::empty(output_size, options);
int64_t inner_precise_val = inner_precise.has_value() ? inner_precise.value() : 0;
bool tokens_index_flag = expert_tokens_index.has_value();
const at::Tensor &expert_tokens_real = expert_tokens.has_value() ? expert_tokens.value() :
expert_tokens_index.value_or(at::Tensor());
ACLNN_CMD(aclnnFFNV3, x, weight1, weight2, expert_tokens_real, bias1_real, bias2_real,
scale_real, offset_real, deq_scale1_real, deq_scale2_real, antiquant_scale1_real, antiquant_scale2_real,
antiquant_offset1_real, antiquant_offset2_real, activation_ptr, inner_precise_val, tokens_index_flag, result);
return result;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("npu_ffn", &npu_ffn, "npu_ffn");
}