#include <ATen/record_function.h>
#include "third_party/acl/inc/acl/acl_base.h"
#include "third_party/acl/inc/acl/acl_rt.h"
#include "torch_npu/csrc/aten/mirror/NPUMemoryOverlap.h"
#include "torch_npu/csrc/core/NPUBridge.h"
#include "torch_npu/csrc/core/NPUStorageImpl.h"
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
#include "torch_npu/csrc/core/npu/NPUFunctions.h"
#include "torch_npu/csrc/core/npu/interface/AclInterface.h"
#include "torch_npu/csrc/core/npu/interface/AsyncTaskQueueInterface.h"
#include "torch_npu/csrc/core/npu/register/OptionRegister.h"
#include "torch_npu/csrc/core/npu/register/OptionsManager.h"
#include "torch_npu/csrc/framework/InferFormat.h"
#include "torch_npu/csrc/framework/contiguous/ReshapeOpt.h"
#include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h"
#include "torch_npu/csrc/framework/interface/EnvVariables.h"
#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h"
#include "torch_npu/csrc/framework/utils/ForceJitCompileList.h"
#include "torch_npu/csrc/framework/utils/NpuUtils.h"
namespace {
constexpr float EPSILON = 1e-6;
static const string CUBE_MATH_TYPE = "CUBE_MATH_TYPE";
#define ENUM_PAIR_FUNC(_1, n) static_assert(static_cast<int64_t>(at::ScalarType::n) >= 0, #n " is negative");
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(ENUM_PAIR_FUNC)
#undef ENUM_PAIR_FUNC
#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \
_(at::ScalarType::Byte, ACL_UINT8) \
_(at::ScalarType::Char, ACL_INT8) \
_(at::ScalarType::Short, ACL_INT16) \
_(at::ScalarType::Int, ACL_INT32) \
_(at::ScalarType::Long, ACL_INT64) \
_(at::ScalarType::Half, ACL_FLOAT16) \
_(at::ScalarType::Float, ACL_FLOAT) \
_(at::ScalarType::Double, ACL_DOUBLE) \
_(at::ScalarType::ComplexHalf, ACL_COMPLEX32) \
_(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \
_(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \
_(at::ScalarType::Bool, ACL_BOOL) \
_(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \
_(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \
_(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \
_(at::ScalarType::BFloat16, ACL_BF16) \
_(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \
_(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \
_(at::ScalarType::Bits1x8, ACL_DT_UNDEFINED) \
_(at::ScalarType::Bits2x4, ACL_DT_UNDEFINED) \
_(at::ScalarType::Bits4x2, ACL_DT_UNDEFINED) \
_(at::ScalarType::Bits8, ACL_DT_UNDEFINED) \
_(at::ScalarType::Bits16, ACL_DT_UNDEFINED) \
_(at::ScalarType::Float8_e5m2, ACL_FLOAT8_E5M2) \
_(at::ScalarType::Float8_e4m3fn, ACL_FLOAT8_E4M3FN) \
_(at::ScalarType::Float8_e5m2fnuz, ACL_DT_UNDEFINED) \
_(at::ScalarType::Float8_e4m3fnuz, ACL_DT_UNDEFINED) \
_(at::ScalarType::UInt16, ACL_UINT16) \
_(at::ScalarType::UInt32, ACL_UINT32) \
_(at::ScalarType::UInt64, ACL_UINT64) \
_(at::ScalarType::UInt1, ACL_DT_UNDEFINED) \
_(at::ScalarType::UInt2, ACL_DT_UNDEFINED) \
_(at::ScalarType::UInt3, ACL_DT_UNDEFINED) \
_(at::ScalarType::UInt4, ACL_DT_UNDEFINED) \
_(at::ScalarType::UInt5, ACL_DT_UNDEFINED) \
_(at::ScalarType::UInt6, ACL_DT_UNDEFINED) \
_(at::ScalarType::UInt7, ACL_DT_UNDEFINED) \
_(at::ScalarType::Int1, ACL_DT_UNDEFINED) \
_(at::ScalarType::Int2, ACL_DT_UNDEFINED) \
_(at::ScalarType::Int3, ACL_DT_UNDEFINED) \
_(at::ScalarType::Int4, ACL_DT_UNDEFINED) \
_(at::ScalarType::Int5, ACL_DT_UNDEFINED) \
_(at::ScalarType::Int6, ACL_DT_UNDEFINED) \
_(at::ScalarType::Int7, ACL_DT_UNDEFINED) \
_(at::ScalarType::Float8_e8m0fnu, ACL_FLOAT8_E8M0) \
_(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \
_(at::ScalarType::NumOptions, ACL_DT_UNDEFINED)
constexpr aclDataType kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(at::ScalarType::NumOptions) + 1] = {
#define DEFINE_ENUM(_1, n) n,
AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM)
#undef DEFINE_ENUM
};
#define ENUM_PAIR_FUNC(at_dtype, acl_dtype) \
static_assert(kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(at_dtype)] == (acl_dtype), \
#at_dtype " and " #acl_dtype " is not match any more, please check " \
"AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR and modify it");
AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(ENUM_PAIR_FUNC)
#undef DEFINE_ENUM
static std::map<const std::string, const aclDataType> STRING_SCALAR_TYPE_TO_ACL_TYPE_MAP = {
{"uint16", ACL_UINT16}, {"uint8", ACL_UINT8}, {"uint64", ACL_UINT64}, {"string", ACL_STRING}};
static std::unordered_map<const aclDataType, const at::ScalarType>
ACL_TYPE_TO_SCALAR_TYPE_MAP = {{ACL_DT_UNDEFINED, at::ScalarType::Undefined},
{ACL_FLOAT, at::ScalarType::Float},
{ACL_FLOAT16, at::ScalarType::Half},
{ACL_INT8, at::ScalarType::Char},
{ACL_INT32, at::ScalarType::Int},
{ACL_UINT8, at::ScalarType::Byte},
{ACL_INT16, at::ScalarType::Short},
{ACL_UINT16, at::ScalarType::UInt16},
{ACL_UINT32, at::ScalarType::UInt32},
{ACL_INT64, at::ScalarType::Long},
{ACL_UINT64, at::ScalarType::UInt64},
{ACL_DOUBLE, at::ScalarType::Double},
{ACL_BOOL, at::ScalarType::Bool},
{ACL_STRING, at::ScalarType::Undefined},
{ACL_COMPLEX64, at::ScalarType::ComplexFloat},
{ACL_COMPLEX128, at::ScalarType::ComplexDouble},
{ACL_BF16, at::ScalarType::BFloat16},
{ACL_INT4, at::ScalarType::Undefined},
{ACL_UINT1, at::ScalarType::Undefined},
{ACL_COMPLEX32, at::ScalarType::ComplexHalf},
{ACL_HIFLOAT8, at::ScalarType::Byte},
{ACL_FLOAT8_E5M2, at::ScalarType::Float8_e5m2},
{ACL_FLOAT8_E4M3FN, at::ScalarType::Float8_e4m3fn},
{ACL_FLOAT8_E8M0, at::ScalarType::Float8_e8m0fnu},
{ACL_FLOAT6_E3M2, at::ScalarType::Byte},
{ACL_FLOAT6_E2M3, at::ScalarType::Byte},
{ACL_FLOAT4_E2M1, at::ScalarType::Byte},
{ACL_FLOAT4_E1M2, at::ScalarType::Byte}};
aclError AclrtMemcpyAsyncParamCheck(
void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind, aclrtStream stream)
{
auto ret = aclrtMemcpyAsync(dst, destMax, src, count, kind, stream);
return ret;
}
aclError AclrtMemcpyParamCheck(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind)
{
auto ret = aclrtMemcpy(dst, destMax, src, count, kind);
return ret;
}
}
namespace at_npu {
namespace native {
aclDataType CalcuOpUtil::ConvertToAclDataType(const at::ScalarType &data_type)
{
int64_t dtype_index = static_cast<int64_t>(data_type);
TORCH_CHECK(dtype_index >= 0 && dtype_index < static_cast<int64_t>(at::ScalarType::NumOptions) + 1,
"data_type enum value (",
dtype_index,
") is out of range: [0, ",
static_cast<int64_t>(at::ScalarType::NumOptions),
"]",
OPS_ERROR(ErrCode::VALUE))
auto acl_dtype = kATenScalarTypeToAclDataTypeTable[dtype_index];
TORCH_CHECK(acl_dtype != ACL_DT_UNDEFINED,
std::string(c10::toString(data_type)) + " has not been supported",
OPS_ERROR(ErrCode::NOT_SUPPORT))
return acl_dtype;
}
aclDataType CalcuOpUtil::ConvertToAclDataType(const at::ScalarType &data_type, const std::string &realDataType)
{
int64_t dtype_index = static_cast<int64_t>(data_type);
TORCH_CHECK(dtype_index >= 0 && dtype_index < static_cast<int64_t>(at::ScalarType::NumOptions) + 1,
"data_type enum value (",
dtype_index,
") is out of range: [0, ",
static_cast<int64_t>(at::ScalarType::NumOptions),
"]",
OPS_ERROR(ErrCode::VALUE))
auto acl_dtype = kATenScalarTypeToAclDataTypeTable[dtype_index];
TORCH_CHECK(acl_dtype != ACL_DT_UNDEFINED,
std::string(c10::toString(data_type)) + " has not been supported",
OPS_ERROR(ErrCode::NOT_SUPPORT))
if (!realDataType.empty()) {
return STRING_SCALAR_TYPE_TO_ACL_TYPE_MAP[realDataType];
}
return acl_dtype;
}
c10::Scalar CalcuOpUtil::ConvertTensorToScalar(const at::Tensor &tensor)
{
c10::Scalar expScalar;
const at::Tensor *aclInput = &tensor;
if (aclInput->scalar_type() == at::ScalarType::Double) {
double value = *(double *)aclInput->data_ptr();
c10::Scalar scalar(value);
expScalar = scalar;
} else if (aclInput->scalar_type() == at::ScalarType::Long) {
int64_t value = *(int64_t *)aclInput->data_ptr();
c10::Scalar scalar(value);
expScalar = scalar;
} else if (aclInput->scalar_type() == at::ScalarType::Float) {
float value = *(float *)aclInput->data_ptr();
c10::Scalar scalar(value);
expScalar = scalar;
} else if (aclInput->scalar_type() == at::ScalarType::Int) {
int value = *(int *)aclInput->data_ptr();
c10::Scalar scalar(value);
expScalar = scalar;
} else if (aclInput->scalar_type() == at::ScalarType::Half) {
c10::Half value = *(c10::Half *)aclInput->data_ptr();
c10::Scalar scalar(value);
expScalar = scalar;
} else {
ASCEND_LOGE("unsupport scalar type! ");
NPU_CHECK_ERROR(ACL_ERROR_UNSUPPORTED_DATA_TYPE);
}
return expScalar;
}
at::Tensor CalcuOpUtil::CopyScalarToDevice(const c10::Scalar &cpu_scalar, at::ScalarType scalar_data_type)
{
return CalcuOpUtil::CopyTensorHostToDevice(scalar_to_tensor(cpu_scalar).to(scalar_data_type));
}
at::Tensor CalcuOpUtil::CopyTensorHostToDevice(const at::Tensor &cpu_tensor)
{
at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory();
int deviceIndex = 0;
NPU_CHECK_ERROR(c10_npu::GetDevice(&deviceIndex));
return cpuPinMemTensor.to(
c10::Device(c10::DeviceType::PrivateUse1, deviceIndex), cpuPinMemTensor.scalar_type(), true, true);
}
NPUStatus CalcuOpUtil::AclrtMemcpyAsync(const std::pair<at::Tensor, int64_t> &dst,
size_t dst_size,
const std::pair<at::Tensor, int64_t> &src,
size_t src_size,
aclrtMemcpyKind kind)
{
void *dst_ptr = reinterpret_cast<uint8_t *>(dst.first.data_ptr()) + dst.second * dst.first.itemsize();
void *src_ptr = reinterpret_cast<uint8_t *>(src.first.data_ptr()) + src.second * src.first.itemsize();
NPU_CHECK_ERROR(
c10_npu::queue::LaunchAsyncCopyTask(dst_ptr, dst_size, const_cast<void *>(src_ptr), src_size, kind));
return NPU_STATUS_SUCCESS;
}
aclError CalcuOpUtil::AclrtMemcpyWithModeSwitch(const StorageAndOffsetMemSizePair &dst,
size_t dstMax,
const StorageAndOffsetMemSizePair &src,
size_t count,
aclrtMemcpyKind kind)
{
void *dst_ptr = static_cast<void *>(static_cast<uint8_t *>(const_cast<void *>(dst.first->data())) + dst.second);
void *src_ptr = static_cast<void *>(static_cast<uint8_t *>(const_cast<void *>(src.first->data())) + src.second);
return AclrtMemcpyParamCheck(dst_ptr, dstMax, const_cast<void *>(src_ptr), count, kind);
}
aclError CalcuOpUtil::AclrtMemcpyWithModeSwitch(
const StorageAndOffsetMemSizePair &dst, size_t dstMax, const void *src, size_t count, aclrtMemcpyKind kind)
{
void *dst_ptr = static_cast<void *>(static_cast<uint8_t *>(const_cast<void *>(dst.first->data())) + dst.second);
return AclrtMemcpyParamCheck(dst_ptr, dstMax, src, count, kind);
}
aclError CalcuOpUtil::AclrtMemcpyWithModeSwitch(
void *dst, size_t dstMax, const StorageAndOffsetMemSizePair &src, size_t count, aclrtMemcpyKind kind)
{
void *src_ptr = static_cast<void *>(static_cast<uint8_t *>(const_cast<void *>(src.first->data())) + src.second);
return AclrtMemcpyParamCheck(dst, dstMax, const_cast<void *>(src_ptr), count, kind);
}
aclError CalcuOpUtil::LaunchAsyncCopyTaskWithModeSwitch(
const at::Tensor &dst, size_t dstMax, const at::Tensor &src, size_t count, aclrtMemcpyKind kind)
{
aclError ret = c10_npu::queue::LaunchAsyncCopyTask(dst.data_ptr(), dstMax, src.data_ptr(), count, kind);
return ret;
}
aclError CalcuOpUtil::LaunchAsyncCopyTaskWithModeSwitch(
const c10::StorageImpl &dst, size_t dstMax, void *src, size_t count, aclrtMemcpyKind kind)
{
aclError ret = c10_npu::queue::LaunchAsyncCopyTask(const_cast<void *>(dst.data()), dstMax, src, count, kind);
return ret;
}
int64_t CalcuOpUtil::GetTensorNpuFormat(const at::Tensor &tensor)
{
TORCH_CHECK(tensor.device().type() == c10::DeviceType::PrivateUse1,
"Expected all tensors to be on the same device. "
"Expected NPU tensor, please check whether the input tensor "
"device is correct.",
OPS_ERROR(ErrCode::TYPE));
if (NpuUtils::check_match(&tensor) || NpuUtils::check_5d_5d_match(tensor)) {
const torch_npu::NPUStorageDesc &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_;
return tensor_desc.npu_format_;
} else if (tensor.data_ptr() == nullptr) {
return ACL_FORMAT_ND;
} else {
return InferFormat::GuessFormatWhenContiguous(tensor);
}
}
void CalcuOpUtil::CheckMemoryOverLaps(c10::ArrayRef<at::Tensor> inputs, c10::ArrayRef<at::Tensor> outputs)
{
for (const auto i : c10::irange(outputs.size())) {
if (!outputs[i].defined()) {
continue;
}
assert_no_internal_overlap(outputs[i]);
for (const auto j : c10::irange(inputs.size())) {
assert_no_partial_overlap(outputs[i], inputs[j]);
}
}
}
bool CalcuOpUtil::IsScalarWrappedToTensor(const at::Tensor &tensor)
{
return tensor.unsafeGetTensorImpl()->is_wrapped_number() && (!torch_npu::utils::is_npu(tensor));
}
float CalcuOpUtil::GetScalarFloatValue(const c10::Scalar &scalar)
{
float value;
if (scalar.isFloatingPoint()) {
value = scalar.toFloat();
} else {
value = static_cast<float>(scalar.toInt());
}
return value;
}
c10::SmallVector<int64_t, SHAPE_SIZE> CalcuOpUtil::ConvertIntArrayRefToSmallVector(c10::IntArrayRef intArray)
{
c10::SmallVector<int64_t, SHAPE_SIZE> intVec;
for (const auto i : c10::irange(intArray.size())) {
intVec.emplace_back(intArray[i]);
}
return intVec;
}
using aclCubeMathType = enum : int8_t {
KEEP_DTYPE = 0,
ALLOW_FP32_DOWN_PRECISION = 1,
USE_FP16 = 2,
USE_HF32 = 3,
FORCE_GRP_ACC_FOR_FP32 = 4,
USE_FP32_ADD = 4,
};
static std::unordered_map<uint8_t, aclCubeMathType> ACL_CUBE_MATH_TYPE_MAP = {
{0b00, KEEP_DTYPE}, {0b01, USE_FP16}, {0b10, USE_HF32}, {0b11, ALLOW_FP32_DOWN_PRECISION}};
static std::unordered_map<uint8_t, aclCubeMathType> ACL_CUBE_MATH_TYPE_MAP_PASSTHROUGH = {
{0b00, KEEP_DTYPE},
{0b01, ALLOW_FP32_DOWN_PRECISION},
{0b10, USE_FP16},
{0b11, USE_HF32},
{0b100, USE_FP32_ADD}
};
int8_t CalcuOpUtil::GetCubeMathType()
{
auto option_key = c10_npu::option::GetOption(CUBE_MATH_TYPE);
if (option_key.has_value() && !option_key.value().empty()) {
uint8_t cube_math_type = static_cast<uint8_t>(std::stoi(option_key.value().c_str()));
auto iter = ACL_CUBE_MATH_TYPE_MAP_PASSTHROUGH.find(cube_math_type);
if (iter != ACL_CUBE_MATH_TYPE_MAP_PASSTHROUGH.end()) {
return iter->second;
}
}
return -1;
}
int8_t CalcuOpUtil::GetCubeMathType(bool allowHf32)
{
bool allowFp32ToFp16 = native::env::IsAllowFP32ToFP16();
uint8_t CubeMathTypeCode = (static_cast<uint8_t>(allowHf32) << 1) + static_cast<uint8_t>(allowFp32ToFp16);
auto iter = ACL_CUBE_MATH_TYPE_MAP.find(CubeMathTypeCode);
if (iter == ACL_CUBE_MATH_TYPE_MAP.end()) {
return ALLOW_FP32_DOWN_PRECISION;
}
return iter->second;
}
at::ScalarType CalcuOpUtil::ConvertToScalarType(const aclDataType data_type)
{
auto iter = ACL_TYPE_TO_SCALAR_TYPE_MAP.find(data_type);
if (iter == ACL_TYPE_TO_SCALAR_TYPE_MAP.end()) {
TORCH_CHECK(false,
std::string("aclDataType:") + std::to_string(data_type) + " has not been supported",
OPS_ERROR(ErrCode::NOT_SUPPORT))
}
return iter->second;
}
}
}