#include "torch_npu/csrc/core/npu/npu_log.h"
#include "torch_npu/csrc/core/NPUBridge.h"
#include "torch_npu/csrc/core/npu/DeviceUtils.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
#include "torch_npu/csrc/framework/FormatHelper.h"
namespace at_npu {
namespace native {
namespace {
constexpr int BLOCKSIZE = 16;
constexpr int BLOCKBYTES = 32;
constexpr int C0_16 = 16;
constexpr int C0_32 = 32;
constexpr int C0_2 = 2;
constexpr int C0_4 = 4;
FormatShape InferShapeLessTo4(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShape4To5(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShape5To4(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeNDToNZ(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeNDToNZC016(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeNDToNZC032(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeNDToNZC02(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeNDToNZC04(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeNDToZ(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeofNCHW(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeofND(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeNCHWToND(c10::IntArrayRef storage_dims, c10::IntArrayRef base_dims, size_t itemsize);
FormatShape InferShapeNCDHWToND(c10::IntArrayRef storage_dims, c10::IntArrayRef base_dims, size_t itemsize);
FormatShape InferShapeNDToNCHW(c10::IntArrayRef storage_dims, c10::IntArrayRef base_dims, size_t itemsize);
FormatShape InferShapeNDToNCDHW(c10::IntArrayRef storage_dims, c10::IntArrayRef base_dims, size_t itemsize);
FormatShape InferShapeOfNDHWC(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeOfNCDHW(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeOfNDC1HWC0(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeOfFZ3D(c10::IntArrayRef dims, size_t itemsize);
FormatShape InferShapeofNHWC(c10::IntArrayRef dims, size_t itemsize);
}
std::unordered_map<aclFormat, FormatHelper::FormatInfo> FormatHelper::InitializeInfo()
{
return {
{ACL_FORMAT_NC1HWC0, (FormatInfo){ACL_FORMAT_NC1HWC0, ACL_FORMAT_NCHW, InferShape4To5, "NC1HWC0", true}},
{ACL_FORMAT_ND, (FormatInfo){ACL_FORMAT_ND, ACL_FORMAT_ND, InferShapeofND, "ND", false}},
{ACL_FORMAT_NCHW, (FormatInfo){ACL_FORMAT_NCHW, ACL_FORMAT_NCHW, InferShapeofNCHW, "NCHW", false}},
{ACL_FORMAT_NHWC, (FormatInfo){ACL_FORMAT_NHWC, ACL_FORMAT_NHWC, InferShapeofNHWC, "NHWC", false}},
{ACL_FORMAT_FRACTAL_NZ,
(FormatInfo){ACL_FORMAT_FRACTAL_NZ, ACL_FORMAT_ND, InferShapeNDToNZ, "FRACTAL_NZ", true}},
{ACL_FORMAT_FRACTAL_Z, (FormatInfo){ACL_FORMAT_FRACTAL_Z, ACL_FORMAT_NCHW, InferShapeNDToZ, "FRACTAL_Z", true}},
{ACL_FORMAT_NDHWC, (FormatInfo){ACL_FORMAT_NDHWC, ACL_FORMAT_NCDHW, InferShapeOfNDHWC, "NDHWC", false}},
{ACL_FORMAT_NCDHW, (FormatInfo){ACL_FORMAT_NCDHW, ACL_FORMAT_NCDHW, InferShapeOfNCDHW, "NCDHW", false}},
{ACL_FORMAT_NDC1HWC0,
(FormatInfo){ACL_FORMAT_NDC1HWC0, ACL_FORMAT_NCDHW, InferShapeOfNDC1HWC0, "NDC1HWC0", true}},
{ACL_FRACTAL_Z_3D, (FormatInfo){ACL_FRACTAL_Z_3D, ACL_FORMAT_NCDHW, InferShapeOfFZ3D, "FRACTAL_Z_3D", true}},
{ACL_FORMAT_FRACTAL_NZ_C0_16,
(FormatInfo){ACL_FORMAT_FRACTAL_NZ_C0_16, ACL_FORMAT_ND, InferShapeNDToNZC016, "FRACTAL_NZ_C0_16", true}},
{ACL_FORMAT_FRACTAL_NZ_C0_32,
(FormatInfo){ACL_FORMAT_FRACTAL_NZ_C0_32, ACL_FORMAT_ND, InferShapeNDToNZC032, "FRACTAL_NZ_C0_32", true}},
{ACL_FORMAT_FRACTAL_NZ_C0_2,
(FormatInfo){ACL_FORMAT_FRACTAL_NZ_C0_2, ACL_FORMAT_ND, InferShapeNDToNZC02, "FRACTAL_NZ_C0_2", true}},
{ACL_FORMAT_FRACTAL_NZ_C0_4,
(FormatInfo){ACL_FORMAT_FRACTAL_NZ_C0_4, ACL_FORMAT_ND, InferShapeNDToNZC04, "FRACTAL_NZ_C0_4", true}},
{ACL_FORMAT_FRACTAL_NZ_C0_8,
(FormatInfo){ACL_FORMAT_FRACTAL_NZ_C0_8, ACL_FORMAT_ND, nullptr, "FRACTAL_NZ_C0_8", true}},
};
};
std::unordered_map<aclFormat, FormatHelper::FormatInfo> FormatHelper::info = FormatHelper::InitializeInfo();
bool FormatHelper::IsPadded(const at::Tensor *tensor)
{
auto format = torch_npu::NPUBridge::GetNpuStorageImplDesc(*tensor).npu_format_;
return IsPadded(format);
}
bool FormatHelper::IsPadded(aclFormat format)
{
auto itr = info.find(format);
if (itr != info.end()) {
return itr->second.isPadded;
}
AT_ERROR("unknown format type:", format);
return true;
}
char *FormatHelper::GetFormatName(aclFormat format)
{
const auto &itr = info.find(format);
if (itr == info.end()) {
AT_ERROR("unknown format type:", format);
return nullptr;
}
return itr->second.formatName;
}
char *FormatHelper::GetFormatName(const at::Tensor &tensor)
{
auto format = torch_npu::NPUBridge::GetNpuStorageImplDesc(tensor).npu_format_;
return GetFormatName(format);
}
aclFormat FormatHelper::GetBaseFormat(const at::Tensor &tensor)
{
auto format = GetFormat(tensor);
return GetBaseFormat(format);
}
aclFormat FormatHelper::GetBaseFormat(aclFormat format)
{
const auto &itr = info.find(format);
if (itr == info.end()) {
AT_ERROR("unknown format type:", format);
return ACL_FORMAT_ND;
}
return itr->second.baseFormat;
}
aclFormat FormatHelper::GetFormat(const at::Tensor &tensor)
{
return torch_npu::NPUBridge::GetNpuStorageImplDesc(tensor).npu_format_;
}
bool FormatHelper::IsBaseFormatType(aclFormat format)
{
return GetBaseFormat(format) == format;
}
bool FormatHelper::IsBaseFormatType(const at::Tensor &tensor)
{
auto format = torch_npu::NPUBridge::GetNpuStorageImplDesc(tensor).npu_format_;
return IsBaseFormatType(format);
}
FormatShape FormatHelper::GetStorageSizes(const torch_npu::NPUStorageDesc &desc)
{
auto ori_size = desc.base_sizes_;
auto format = desc.npu_format_;
auto dtype = desc.data_type_;
return GetStorageSizes(format, ori_size, dtype);
}
bool FormatHelper::IsOpInputBaseFormat(const at::Tensor &tensor)
{
if (!torch_npu::utils::is_npu(tensor) || (typeid(*tensor.storage().unsafeGetStorageImpl()) != typeid(torch_npu::NPUStorageImpl))) {
return true;
}
const auto format = torch_npu::NPUBridge::GetNpuStorageImplDesc(tensor).npu_format_;
return (format == ACL_FORMAT_ND) || (format == ACL_FORMAT_NCHW) || (format == ACL_FORMAT_NHWC) ||
(format == ACL_FORMAT_NCDHW);
}
bool FormatHelper::IsOpInputBaseFormat(const c10::optional<at::Tensor> &tensor)
{
if (!tensor.has_value()) {
return true;
}
return IsOpInputBaseFormat(tensor.value());
}
bool FormatHelper::IsOpInputBaseFormat(const c10::List<c10::optional<at::Tensor>> &tensors)
{
const auto &iter =
std::find_if(tensors.begin(), tensors.end(), [](const auto &tensor) { return !IsOpInputBaseFormat(tensor); });
return iter == tensors.end();
}
bool FormatHelper::IsOpInputBaseFormat(const c10::optional<at::TensorList> &tensors)
{
if (!tensors.has_value()) {
return true;
}
const auto &iter =
std::find_if(tensors.value().begin(), tensors.value().end(), [](const auto &tensor) { return !IsOpInputBaseFormat(tensor); });
return iter == tensors.value().end();
}
bool FormatHelper::IsOpInputBaseFormat(const at::TensorList &tensors)
{
const auto &iter =
std::find_if(tensors.begin(), tensors.end(), [](const auto &tensor) { return !IsOpInputBaseFormat(tensor); });
return iter == tensors.end();
}
bool FormatHelper::IsOpInputBaseFormat(const at::ITensorListRef &tensors)
{
auto materialized = tensors.materialize();
const auto &iter = std::find_if(materialized.begin(), materialized.end(), [](const auto &tensor) {
return !IsOpInputBaseFormat(tensor.get());
});
return iter == materialized.end();
}
namespace {
FormatShape InferShapeLessTo4(c10::IntArrayRef dims, size_t itemsize)
{
FormatShape res;
res.resize(4);
AT_ASSERT(dims.size() <= 4, "input dim > 4 when InferShapeLessTo4", OPS_ERROR(ErrCode::PARAM));
switch (dims.size()) {
case 0:
res[0] = 1;
res[1] = 1;
res[2] = 1;
res[3] = 1;
break;
case 1:
res[0] = 1;
res[1] = dims[0];
res[2] = 1;
res[3] = 1;
break;
case 2:
res[0] = 1;
res[1] = dims[0];
res[2] = dims[1];
res[3] = 1;
break;
case 3:
res[0] = 1;
res[1] = dims[0];
res[2] = dims[1];
res[3] = dims[2];
break;
case 4:
res[0] = dims[0];
res[1] = dims[1];
res[2] = dims[2];
res[3] = dims[3];
break;
default:
AT_ERROR("dims of NCHW shape should not be greater than 4, which is ", dims.size());
}
return res;
}
FormatShape InferShapeofNHWC(c10::IntArrayRef dims, size_t itemsize)
{
AT_ASSERT(dims.size() == 4, "input dim should be equal to 4 when InferShapeofNHWC", OPS_ERROR(ErrCode::PARAM));
return FormatShape(dims.begin(), dims.end());
}
FormatShape InferShape4To5(c10::IntArrayRef dims, size_t itemsize)
{
FormatShape res;
res.resize(5);
if (dims.size() < 4) {
ASCEND_LOGD("infershape4to5 but input dim < 4");
return InferShape4To5(InferShapeLessTo4(dims, itemsize), itemsize);
} else if (dims.size() > 4) {
ASCEND_LOGE("infershape4to5 but input dim > 4");
}
res[0] = dims[0];
res[1] = (dims[1] + 15) / 16;
res[2] = dims[2];
res[3] = dims[3];
res[4] = BLOCKSIZE;
return res;
}
FormatShape InferShape5To4(c10::IntArrayRef dims, size_t itemsize)
{
FormatShape res;
res.emplace_back(dims[0]);
res.emplace_back(((dims[1] + 15) / 16) * 16);
res.emplace_back(dims[2]);
res.emplace_back(dims[3]);
return res;
}
FormatShape InferShapeNDToNZ(c10::IntArrayRef dims, size_t itemsize)
{
FormatShape res;
FormatShape dim;
for (size_t i = 0; i < dims.size(); i++) {
dim.emplace_back(dims[i]);
}
if (dim.size() == 0) {
dim.emplace_back(1);
}
if (dim.size() == 1) {
dim.emplace_back(1);
}
size_t i = 0;
for (; i < dim.size() - 2; i++) {
res.emplace_back(dim[i]);
}
AT_ASSERT(itemsize != 0, "dtype itemsize should not be 0", OPS_ERROR(ErrCode::PARAM));
auto itemsize_ = (itemsize > 2) ? 2 : itemsize;
auto lastSize = BLOCKBYTES / itemsize_;
res.emplace_back((dim[i + 1] + lastSize - 1) / lastSize);
res.emplace_back((dim[i] + BLOCKSIZE - 1) / BLOCKSIZE);
res.emplace_back(BLOCKSIZE);
res.emplace_back(lastSize);
return res;
}
template <int C0Value>
FormatShape InferShapeNDToNZWithC0(c10::IntArrayRef dims, size_t itemsize)
{
FormatShape res;
FormatShape dim;
for (size_t i = 0; i < dims.size(); i++) {
dim.emplace_back(dims[i]);
}
if (dim.size() == 0) {
dim.emplace_back(1);
}
if (dim.size() == 1) {
dim.emplace_back(1);
}
size_t i = 0;
for (; i < dim.size() - 2; i++) {
res.emplace_back(dim[i]);
}
AT_ASSERT(itemsize != 0, "dtype itemsize should not be 0", OPS_ERROR(ErrCode::PARAM));
auto lastSize = C0Value;
res.emplace_back((dim[i + 1] + lastSize - 1) / lastSize);
res.emplace_back((dim[i] + BLOCKSIZE - 1) / BLOCKSIZE);
res.emplace_back(BLOCKSIZE);
res.emplace_back(lastSize);
return res;
}
FormatShape InferShapeNDToNZC016(c10::IntArrayRef dims, size_t itemsize) {
return InferShapeNDToNZWithC0<C0_16>(dims, itemsize);
}
FormatShape InferShapeNDToNZC032(c10::IntArrayRef dims, size_t itemsize) {
return InferShapeNDToNZWithC0<C0_32>(dims, itemsize);
}
FormatShape InferShapeNDToNZC02(c10::IntArrayRef dims, size_t itemsize) {
return InferShapeNDToNZWithC0<C0_2>(dims, itemsize);
}
FormatShape InferShapeNDToNZC04(c10::IntArrayRef dims, size_t itemsize) {
return InferShapeNDToNZWithC0<C0_4>(dims, itemsize);
}
FormatShape InferShapeNDToZ(c10::IntArrayRef dims, size_t itemsize)
{
FormatShape res;
if (dims.size() < 4) {
return InferShapeNDToZ(InferShapeLessTo4(dims, itemsize), itemsize);
}
res.emplace_back((dims[1] + 15) / BLOCKSIZE * dims[2] * dims[3]);
res.emplace_back((dims[0] + 15) / BLOCKSIZE);
res.emplace_back(BLOCKSIZE);
res.emplace_back(BLOCKSIZE);
return res;
}
FormatShape InferShapeNCHWToND(c10::IntArrayRef storage_dims, c10::IntArrayRef base_dims, size_t itemsize)
{
FormatShape res;
res.resize(4);
auto cur_storage_dims = storage_dims;
if (storage_dims.size() != 4) {
cur_storage_dims = InferShapeLessTo4(storage_dims, itemsize);
}
AT_ASSERT(
cur_storage_dims.size() == 4,
"input dim num not equal 4 when InferShapeNCHWToND",
OPS_ERROR(ErrCode::PARAM));
if (base_dims.size() == 0) {
FormatShape temp_dims;
temp_dims.emplace_back(1);
return InferShapeLessTo4(temp_dims, itemsize);
}
switch (base_dims.size()) {
case 1:
res.resize(1);
res[0] = cur_storage_dims[1];
AT_ASSERT(
cur_storage_dims[0] == 1,
"reshape type RESHAPE_TYPE_C erase dim N must be 1",
OPS_ERROR(ErrCode::PARAM));
AT_ASSERT(
cur_storage_dims[2] == 1,
"reshape type RESHAPE_TYPE_C erase dim H must be 1",
OPS_ERROR(ErrCode::PARAM));
AT_ASSERT(
cur_storage_dims[3] == 1,
"reshape type RESHAPE_TYPE_C erase dim W must be 1",
OPS_ERROR(ErrCode::PARAM));
break;
case 2:
res.resize(2);
res[0] = cur_storage_dims[1];
res[1] = cur_storage_dims[2];
AT_ASSERT(
cur_storage_dims[0] == 1,
"reshape type RESHAPE_TYPE_CH erase dim N must be 1",
OPS_ERROR(ErrCode::PARAM));
AT_ASSERT(
cur_storage_dims[3] == 1,
"reshape type RESHAPE_TYPE_CH erase dim W must be 1",
OPS_ERROR(ErrCode::PARAM));
break;
case 3:
res.resize(3);
res[0] = cur_storage_dims[1];
res[1] = cur_storage_dims[2];
res[2] = cur_storage_dims[3];
AT_ASSERT(
cur_storage_dims[0] == 1,
"reshape type RESHAPE_TYPE_CHW erase dim N must be 1",
OPS_ERROR(ErrCode::PARAM));
break;
case 4:
res = cur_storage_dims;
return res;
default:
AT_ERROR("unknown reshape type:");
}
return res;
}
FormatShape InferShapeNDToNCHW(c10::IntArrayRef storage_dims, c10::IntArrayRef base_dims, size_t itemsize)
{
AT_ASSERT(storage_dims.size() <= 4, "input storage dim not less than 4", OPS_ERROR(ErrCode::PARAM));
AT_ASSERT(base_dims.size() <= 4, "input storage dim not less than 4", OPS_ERROR(ErrCode::PARAM));
return InferShapeLessTo4(base_dims, itemsize);
}
FormatShape InferShapeNDToNCDHW(c10::IntArrayRef storage_dims, c10::IntArrayRef base_dims, size_t itemsize)
{
AT_ASSERT(
storage_dims.size() == 5,
"ND [",
storage_dims,
"] failed to convert to NCDHW",
OPS_ERROR(ErrCode::PARAM));
FormatShape res;
res.resize(5);
res = storage_dims;
return res;
}
FormatShape InferShapeNCDHWToND(c10::IntArrayRef storage_dims, c10::IntArrayRef base_dims, size_t itemsize)
{
FormatShape res;
res.resize(5);
res = storage_dims;
AT_ASSERT(res.size() == 5, "input dim num not equal 5 when InferShapeNCDHWToND", OPS_ERROR(ErrCode::PARAM));
return res;
}
FormatShape InferShapeOfNDHWC(c10::IntArrayRef dims, size_t itemsize)
{
if (dims.size() < 5) {
AT_ERROR("dim (", dims, ") cannot convert to NDHWC");
}
FormatShape res;
res.resize(5);
res[0] = dims[0];
res[1] = dims[2];
res[2] = dims[3];
res[3] = dims[4];
res[4] = dims[1];
return res;
}
FormatShape InferShapeOfNCDHW(c10::IntArrayRef dims, size_t itemsize)
{
if (dims.size() < 5) {
AT_ERROR("dim (", dims, ") cannot convert to NCDHW");
}
FormatShape res;
res.resize(5);
res[0] = dims[0];
res[1] = dims[1];
res[2] = dims[2];
res[3] = dims[3];
res[4] = dims[4];
return res;
}
FormatShape InferShapeOfNDC1HWC0(c10::IntArrayRef dims, size_t itemsize)
{
if (dims.size() < 5) {
AT_ERROR("dim (", dims, ") cannot convert to NDC1HWC0");
}
FormatShape res;
res.resize(6);
res[0] = dims[0];
res[1] = dims[2];
res[2] = (dims[1] + BLOCKSIZE - 1) / BLOCKSIZE;
res[3] = dims[3];
res[4] = dims[4];
res[5] = BLOCKSIZE;
return res;
}
FormatShape InferShapeOfFZ3D(c10::IntArrayRef dims, size_t itemsize)
{
if (dims.size() < 5) {
AT_ERROR("dim (", dims, ") cannot convert to FZ_3D");
}
int64_t d1 = dims[2];
int64_t d2 = (dims[1] + BLOCKSIZE - 1) / BLOCKSIZE;
int64_t d3 = dims[3];
int64_t d4 = dims[4];
int64_t d5 = (dims[0] + BLOCKSIZE - 1) / BLOCKSIZE;
int64_t d6 = BLOCKSIZE;
int64_t d7 = BLOCKSIZE;
FormatShape res;
res.resize(4);
res[0] = d1 * d2 * d3 * d4;
res[1] = d5;
res[2] = d6;
res[3] = d7;
return res;
}
FormatShape InferShapeofNCHW(c10::IntArrayRef dims, size_t itemsize)
{
if (dims.size() < 5) {
return InferShapeLessTo4(dims, itemsize);
} else {
return InferShapeofND(dims, itemsize);
}
}
FormatShape InferShapeofND(c10::IntArrayRef dims, size_t itemsize)
{
FormatShape res;
res.resize(dims.size());
for (size_t j = 0; j < dims.size(); j++) {
res[j] = dims[j];
}
return res;
}
}
at::Tensor &FormatHelper::unsafe_format_cast(at::Tensor &self, int64_t self_format, int64_t result_format)
{
torch_npu::NPUStorageDesc &self_desc = torch_npu::NPUBridge::GetNpuStorageImpl(self)->npu_desc_;
if (self_format == ACL_FORMAT_ND && result_format == ACL_FORMAT_NC1HWC0) {
auto itemsize = self_desc.data_type_.itemsize();
self_desc.storage_sizes_ = InferShape4To5(self.sizes(), itemsize);
self_desc.npu_format_ = ACL_FORMAT_NC1HWC0;
} else if (self_format == ACL_FORMAT_NC1HWC0 && result_format == ACL_FORMAT_ND) {
self_desc.storage_sizes_ = self_desc.base_sizes_;
self_desc.npu_format_ = ACL_FORMAT_ND;
}
return self;
}
}
}