#include "torch_npu/csrc/core/npu/register/OptionsManager.h"
#include "torch_npu/csrc/framework/FormatHelper.h"
#include "torch_npu/csrc/core/NPUBridge.h"
#include "torch_npu/csrc/core/NPUStorageImpl.h"
#include "torch_npu/csrc/framework/InferFormat.h"
namespace at_npu {
namespace native {
aclFormat InferFormat::GuessFormatWhenContiguous(const at::Tensor &tensor)
{
auto tensor_storage_impl = torch_npu::NPUBridge::GetNpuStorageImpl(tensor);
if (tensor_storage_impl->data_ptr() == nullptr) {
return ACL_FORMAT_ND;
}
auto desc = tensor_storage_impl->npu_desc_;
if ((desc.origin_format_ == ACL_FORMAT_NCDHW)) {
if ((tensor.sizes().size() != desc.base_sizes_.size()) && (tensor.sizes().size() <= 4)) {
return ACL_FORMAT_NCHW;
}
}
return desc.origin_format_;
}
std::tuple<aclFormat, aclFormat> InferFormat::GuessFormatUnit(const c10::IntArrayRef &size, aclFormat format)
{
aclFormat baseFormat = FormatHelper::GetBaseFormat(format);
if ((baseFormat == ACL_FORMAT_NCDHW) && (size.size() > 4)) {
return std::make_tuple(ACL_FORMAT_NCDHW, format);
} else if (format == ACL_FORMAT_ND && size.size() == 4) {
return std::make_tuple(ACL_FORMAT_NCHW, ACL_FORMAT_NCHW);
} else {
if (baseFormat == ACL_FORMAT_NCDHW) {
if (size.size() == 4) {
return std::make_tuple(ACL_FORMAT_NCHW, ACL_FORMAT_NCHW);
}
}
}
return std::make_tuple(baseFormat, format);
}
aclFormat InferFormat::GuessBaseFormat(const c10::IntArrayRef &size)
{
if (size.size() == 5) {
return ACL_FORMAT_NCDHW;
} else if (size.size() == 4) {
return ACL_FORMAT_NCHW;
}
return ACL_FORMAT_ND;
}
aclFormat InferFormat::GuessStorageFormat(const c10::IntArrayRef &size, aclFormat format)
{
if (format == ACL_FORMAT_FRACTAL_NZ && size.size() < 2) {
TORCH_WARN_ONCE("Cannot create tensor with NZ format while dim < 2, "
"tensor will be created with ND format.");
return ACL_FORMAT_ND;
}
int64_t dim = static_cast<int64_t>(size.size());
aclFormat baseFormat = FormatHelper::GetBaseFormat(format);
bool isBaseFormat = (baseFormat == format);
if ((isBaseFormat) && (baseFormat == ACL_FORMAT_NCDHW)) {
if (dim == 4) {
return ACL_FORMAT_NCHW;
} else if (dim == 5) {
return ACL_FORMAT_NCDHW;
} else {
return ACL_FORMAT_ND;
}
} else if (format == ACL_FORMAT_NCHW && dim != 4) {
return ACL_FORMAT_ND;
} else if ((dim == 0) || ((dim == 1) && (size[0] == 1) && (baseFormat == ACL_FORMAT_ND))) {
return ACL_FORMAT_ND;
}
return format;
}
FormatShape InferFormat::GuessStorageSizeWhenConvertFormat(const at::Tensor &tensor)
{
auto format = FormatHelper::GetFormat(tensor);
auto size = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_.base_sizes_;
auto dtype = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_.data_type_;
if ((size.size() < 2) && format == ACL_FORMAT_ND) {
do {
size.emplace_back(1);
} while (size.size() < 2);
}
return FormatHelper::GetStorageSizes(format, size, dtype);
}
bool InferFormat::IsDefiniteTensorWhenMetaDataChanges(const at::Tensor &tensor, const c10::IntArrayRef &size)
{
auto baseformat = FormatHelper::GetBaseFormat(tensor);
if (baseformat == ACL_FORMAT_NCHW && size.size() >= 5) {
return true;
}
if (baseformat == ACL_FORMAT_NCDHW && size.size() != 5) {
return true;
}
return false;
}
}
}