#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpAdapter.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
const int64_t INT4_NUMS_IN_INT32 = 8;
const int64_t WEIGHT_SHAPE_SIZE = 2;
const int64_t CUBE_BLOCK_SIZE = 16;
const int64_t C0_SIZE_INT32 = 8;
void convert_to_int4_pack(const std::vector<int32_t>& weight_data, std::vector<int32_t>& weight_int4pack_data)
{
size_t n = weight_int4pack_data.size();
for (size_t i = 0; i < n; ++i) {
uint32_t a = static_cast<uint32_t>(weight_data[i * 8]);
uint32_t b = static_cast<uint32_t>(weight_data[i * 8 + 1]);
uint32_t c = static_cast<uint32_t>(weight_data[i * 8 + 2]);
uint32_t d = static_cast<uint32_t>(weight_data[i * 8 + 3]);
uint32_t e = static_cast<uint32_t>(weight_data[i * 8 + 4]);
uint32_t f = static_cast<uint32_t>(weight_data[i * 8 + 5]);
uint32_t g = static_cast<uint32_t>(weight_data[i * 8 + 6]);
uint32_t h = static_cast<uint32_t>(weight_data[i * 8 + 7]);
weight_int4pack_data[i] = (a & 0xF) | (b & 0xF) << 4 | (c & 0xF) << 8 | (d & 0xF) << 12 |
(e & 0xF) << 16 | (f & 0xF) << 20 | (g & 0xF) << 24 | (h & 0xF) << 28;
}
}
void trans_nd_to_nz(std::vector<int32_t>& weight_array, uint64_t k, uint64_t n)
{
uint64_t k1 = (k + CUBE_BLOCK_SIZE - 1) / CUBE_BLOCK_SIZE;
int64_t weight_nz_size = op_infer::CeilDiv(k, CUBE_BLOCK_SIZE) *
op_infer::CeilDiv(n, C0_SIZE_INT32) * CUBE_BLOCK_SIZE * C0_SIZE_INT32;
std::vector<int32_t> weight_nz_array(weight_nz_size, 0);
for (size_t idx = 0; idx < weight_array.size(); ++idx) {
size_t idx_k = idx / n;
size_t idx_n = idx % n;
size_t idx_k0 = idx_k % CUBE_BLOCK_SIZE;
size_t idx_k1 = idx_k / CUBE_BLOCK_SIZE;
size_t idx_n0 = idx_n % C0_SIZE_INT32;
size_t idx_n1 = idx_n / C0_SIZE_INT32;
weight_nz_array[idx_n1 * k1 * CUBE_BLOCK_SIZE * C0_SIZE_INT32 + idx_k1 * CUBE_BLOCK_SIZE * C0_SIZE_INT32 +
idx_k0 * C0_SIZE_INT32 + idx_n0] = weight_array[idx];
}
weight_array = weight_nz_array;
}
inline void int4pack_params_check(const at::Tensor &weight)
{
TORCH_CHECK(weight.is_contiguous(), "weight should be contiguous", OPS_ERROR(ErrCode::PARAM));
auto weight_dim_num = weight.dim();
TORCH_CHECK(weight_dim_num == WEIGHT_SHAPE_SIZE, "weight shape only support dim num 2, but it is ",
weight_dim_num, OPS_ERROR(ErrCode::PARAM));
auto weight_dtype = weight.dtype();
TORCH_CHECK(weight_dtype == at::kInt, "weight dtype only support int32, but it is ", weight_dtype,
OPS_ERROR(ErrCode::TYPE));
auto weight_first_dim = weight.size(weight_dim_num - 2);
auto weight_last_dim = weight.size(weight_dim_num - 1);
TORCH_CHECK(weight_first_dim > 0 && weight_last_dim > 0, "weight dim should be greater than 0",
OPS_ERROR(ErrCode::PARAM))
TORCH_CHECK(weight_last_dim % INT4_NUMS_IN_INT32 == 0, "weight last dim should be the multiple of 8, but it is ",
weight_last_dim, OPS_ERROR(ErrCode::PARAM));
}
at::Tensor npu_convert_weight_to_int4pack(const at::Tensor &weight, int64_t inner_k_tiles)
{
int4pack_params_check(weight);
auto weight_dim_num = weight.dim();
auto weight_first_dim = weight.size(weight_dim_num - 2);
auto weight_last_dim = weight.size(weight_dim_num - 1);
int64_t weight_format = at_npu::native::custom_ops::get_npu_format(weight);
at::Tensor weight_nd;
bool is_weight_nz = (weight_format == ACL_FORMAT_FRACTAL_NZ);
if (is_weight_nz) {
weight_nd = at_npu::native::custom_ops::npu_format_cast(weight, ACL_FORMAT_ND);
}
at::Tensor weight_cpu = is_weight_nz ? weight_nd.to("cpu") : weight.to("cpu");
std::vector<int32_t> weight_data(weight_cpu.data_ptr<int32_t>(), weight_cpu.data_ptr<int32_t>() + weight_cpu.numel());
std::vector<int32_t> weight_int4pack_data(weight_data.size() / INT4_NUMS_IN_INT32, 0);
std::vector<int64_t> weight_int4pack_shape = {weight_first_dim, weight_last_dim / INT4_NUMS_IN_INT32};
convert_to_int4_pack(weight_data, weight_int4pack_data);
if (is_weight_nz) {
trans_nd_to_nz(weight_int4pack_data, weight_first_dim, weight_last_dim / INT4_NUMS_IN_INT32);
weight_int4pack_shape = {op_infer::CeilDiv(weight_last_dim / INT4_NUMS_IN_INT32, C0_SIZE_INT32),
op_infer::CeilDiv(weight_first_dim, CUBE_BLOCK_SIZE), CUBE_BLOCK_SIZE, C0_SIZE_INT32};
}
c10::TensorOptions options_cpu = weight_cpu.options().dtype(at::kInt);
at::Tensor weight_int4_pack_cpu = at::from_blob(weight_int4pack_data.data(), weight_int4pack_shape,
options_cpu);
auto output_size = op_infer::array_to_small_vector(weight_int4pack_shape);
c10::TensorOptions options = weight.options().dtype(at::kInt);
at::Tensor result = npu_preparation::apply_tensor_without_format(output_size, options);
if (is_weight_nz) {
int64_t nbytes = result.numel() * result.element_size();
c10_npu::NPUStream stream = c10_npu::getCurrentNPUStream();
OPS_CHECK_ERROR(c10_npu::acl::AclrtSynchronizeStreamWithTimeout(stream));
NPU_CHECK_ERROR(aclrtMemcpy(const_cast<void*>(result.storage().unsafeGetStorageImpl()->data()), nbytes,
weight_int4_pack_cpu.storage().unsafeGetStorageImpl()->data(), nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
auto &out_desc = torch_npu::NPUBridge::GetNpuStorageImplDesc(result);
out_desc.npu_format_ = ACL_FORMAT_FRACTAL_NZ;
out_desc.origin_format_ = ACL_FORMAT_ND;
out_desc.base_sizes_ = {weight_first_dim, weight_last_dim / INT4_NUMS_IN_INT32};
out_desc.base_strides_ = {weight_last_dim / INT4_NUMS_IN_INT32, 1};
result.set_(result.storage(), 0, {weight_first_dim, weight_last_dim / INT4_NUMS_IN_INT32},
{weight_last_dim / INT4_NUMS_IN_INT32, 1});
} else {
result.copy_(weight_int4_pack_cpu);
}
return result;
}
}