#ifndef __PULGIN_NATIVE_UTILS_FORMAT_HELPER__
#define __PULGIN_NATIVE_UTILS_FORMAT_HELPER__
#include <ATen/ATen.h>
#include <unordered_map>
#include "torch_npu/csrc/framework/utils/NPUDefinition.h"
#include "torch_npu/csrc/core/NPUBridge.h"
namespace at_npu {
namespace native {
using baseFormatConverter =
std::function<FormatShape(c10::IntArrayRef storage_dims, c10::IntArrayRef base_dims, size_t itemsize)>;
class FormatHelper {
public:
static bool IsPadded(const at::Tensor *tensor);
static char *GetFormatName(const at::Tensor &tensor);
static char *GetFormatName(aclFormat format);
static aclFormat GetBaseFormat(const at::Tensor &tensor);
static aclFormat GetBaseFormat(aclFormat format);
static aclFormat GetFormat(const at::Tensor &tensor);
static bool IsBaseFormatType(aclFormat format);
static bool IsBaseFormatType(const at::Tensor &tensor);
template <typename sizeType>
static FormatShape GetStorageSizes(aclFormat format, sizeType ori_size, caffe2::TypeMeta dtype);
static FormatShape GetStorageSizes(const torch_npu::NPUStorageDesc &desc);
static at::Tensor& unsafe_format_cast(at::Tensor& self, int64_t self_format, int64_t result_format);
static bool IsOpInputBaseFormat(const at::Tensor &tensor);
static bool IsOpInputBaseFormat(const c10::optional<at::Tensor> &tensor);
static bool IsOpInputBaseFormat(const c10::optional<at::TensorList> &tensors);
static bool IsOpInputBaseFormat(const c10::List<c10::optional<at::Tensor>> &tensors);
static bool IsOpInputBaseFormat(const at::TensorList &tensors);
static bool IsOpInputBaseFormat(const at::ITensorListRef &tensors);
private:
static bool IsPadded(aclFormat format);
private:
using shapeInfer = std::function<FormatShape(c10::IntArrayRef dims, size_t itemsize)>;
typedef struct FormatInfo_ {
aclFormat format = ACL_FORMAT_ND;
aclFormat baseFormat = ACL_FORMAT_ND;
shapeInfer func = nullptr;
char formatName[30] = {0};
bool isPadded = false;
} FormatInfo;
static std::unordered_map<aclFormat, FormatInfo> info;
static std::unordered_map<aclFormat, FormatInfo> InitializeInfo();
};
template <typename sizeType>
FormatShape FormatHelper::GetStorageSizes(aclFormat format, sizeType ori_size, caffe2::TypeMeta dtype)
{
auto itr = info.find(format);
if (itr != info.end()) {
if (itr->second.func) {
return itr->second.func(ori_size, dtype.itemsize());
}
}
AT_ERROR("unsupport InferShape with format ", GetFormatName(format), "with shape", ori_size);
return {};
}
}
}
#endif