#include <ATen/NamedTensorUtils.h>
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
#include "torch_npu/csrc/aten/mirror/NPUTypeProperties.h"
#include "torch_npu/csrc/core/npu/GetCANNInfo.h"
#include "torch_npu/csrc/custom_dtype/Init.h"
#include "torch_npu/csrc/core/npu/NpuVariables.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_plugin {
namespace utils {
static const uint64_t GROUP_MAX = 65535UL;
static const size_t GROUP_DIM = 3;
static const size_t OFFSET_32_BITS = 32;
static const size_t OFFSET_16_BITS = 16;
std::string get_reduction_str(int64_t reduction)
{
std::string reductionStr;
if (reduction == at::Reduction::None) {
reductionStr = "none";
} else if (reduction == at::Reduction::Mean) {
reductionStr = "mean";
} else {
reductionStr = "sum";
}
return reductionStr;
}
std::string get_vector_str(const std::vector<int64_t> &vec)
{
std::string shape_str = "[";
size_t vec_num = vec.size();
for (size_t i = 0; i < vec_num; i++) {
shape_str += std::to_string(vec[i]);
if (i != vec_num - 1) {
shape_str += ", ";
}
}
shape_str += "]";
return shape_str;
}
int64_t get_rotary_mode(c10::string_view mode)
{
if (mode == "half") {
return 0;
} else if (mode == "interleave") {
return 1;
} else if (mode == "quarter") {
return 2;
} else if (mode == "interleave-half") {
return 3;
}
}
int64_t make_warp_dim(int64_t dim, int64_t dim_post_expr)
{
if (dim_post_expr <= 0) {
dim_post_expr = 1;
}
if (dim < 0) {
dim += dim_post_expr;
}
return dim;
}
bool is_transpose_last_two_dims(const at::Tensor &tensor)
{
if (tensor.dim() < 2 || tensor.dim() > 3) {
return false;
}
int64_t numel = at_npu::native::NPUNativeFunctions::get_storage_size(tensor);
int64_t dim1 = tensor.dim() - 1;
int64_t dim2 = tensor.dim() - 2;
c10::SmallVector<int64_t, 5> tensor_base_size = at_npu::native::OpPreparation::get_tensor_desc_base_sizes(tensor);
if (tensor.stride(dim2) == 1 && tensor.stride(dim1) == tensor.size(dim2) &&
tensor.size(dim1) == tensor_base_size[dim2] &&
tensor.size(dim2) == tensor_base_size[dim1] &&
tensor.numel() == numel &&
tensor_base_size.size() == static_cast<uint64_t>(tensor.dim())) {
return true;
} else {
return false;
}
}
bool is_nz_format(const at::Tensor &mat2)
{
const torch_npu::NPUStorageDesc &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(mat2)->npu_desc_;
return tensor_desc.npu_format_ == ACL_FORMAT_FRACTAL_NZ;
}
bool is_two_tensor_base_format(const at::Tensor &self, const at::Tensor &mat2)
{
return at_npu::native::FormatHelper::IsOpInputBaseFormat(self) &&
at_npu::native::FormatHelper::IsOpInputBaseFormat(mat2);
}
bool is_nd_nz_format(const at::Tensor &self, const at::Tensor &mat2)
{
constexpr int k2D = 2;
constexpr int k3D = 3;
const auto d1 = self.dim();
const auto d2 = mat2.dim();
if (!is_nz_format(mat2) || is_nz_format(self)) {
return false;
}
if (d1 == k2D && d2 == k2D) {
return true;
}
if (d1 == k3D && d2 == k3D && c10_npu::GetSocVersion() > c10_npu::SocVersion::Ascend910_9362) {
return true;
}
return false;
}
bool is_nd_to_nz_on_fly(const at::Tensor &self, const at::Tensor &mat2)
{
const static int64_t kInnerAxisMinLimit = 128;
const static int64_t kInnerAxisMaxLimit = 65535;
if (self.dim() < 2 || mat2.dim() < 2) {
return false;
}
int64_t self_inner_axis = self.size(self.dim() - 1);
int64_t self_outer_axis = self.size(self.dim() - 2);
int64_t mat2_inner_axis = mat2.size(mat2.dim() - 1);
int64_t mat2_outer_axis = mat2.size(mat2.dim() - 2);
if (is_transpose_last_two_dims(self)) {
self_inner_axis = self.size(self.dim() - 2);
self_outer_axis = self.size(self.dim() - 1);
}
if (is_transpose_last_two_dims(mat2)) {
mat2_inner_axis = mat2.size(mat2.dim() - 2);
mat2_outer_axis = mat2.size(mat2.dim() - 1);
}
if (self_inner_axis * self_outer_axis <= kInnerAxisMaxLimit &&
mat2_inner_axis * mat2_outer_axis <= kInnerAxisMaxLimit) {
return true;
}
return ((self_inner_axis >= kInnerAxisMinLimit && self_inner_axis <= kInnerAxisMaxLimit) ||
(self_inner_axis < kInnerAxisMinLimit && !(static_cast<uint64_t>(self_inner_axis) & 0xF))) &&
((mat2_inner_axis >= kInnerAxisMinLimit && mat2_inner_axis <= kInnerAxisMaxLimit) ||
(mat2_inner_axis < kInnerAxisMinLimit && !(static_cast<uint64_t>(mat2_inner_axis) & 0xF)));
}
bool is_scalar_one(const c10::Scalar &scalar)
{
if (scalar.isIntegral(false)) {
return scalar.toInt() == 1;
} else if (scalar.isFloatingPoint()) {
return fabs(scalar.toFloat() - 1.0) < 1e-6;
} else {
return false;
}
}
float get_scalar_float_value(const c10::Scalar &scalar)
{
float value;
if (scalar.isFloatingPoint()) {
value = scalar.toFloat();
} else {
value = static_cast<float>(scalar.toInt());
}
return value;
}
c10::SmallVector<int64_t, N> convert_array_to_vector(c10::IntArrayRef intArray)
{
c10::SmallVector<int64_t, N> intVec;
for (uint64_t i = 0; i < intArray.size(); i++) {
intVec.emplace_back(intArray[i]);
}
return intVec;
}
c10::SmallVector<int64_t, N> get_dimlist_for_tensor(const at::Tensor &self)
{
c10::SmallVector<int64_t, N> dimList = {};
for (int64_t i = 0; i < self.dim(); i++) {
dimList.emplace_back(i);
}
return dimList;
}
int64_t complete_pad(int64_t s_size, int64_t p_size, int64_t k_size, int64_t stride)
{
int64_t needpads = 0;
int64_t sizeP = s_size + p_size * 2;
int64_t leftLen = sizeP - k_size;
TORCH_CHECK(stride != 0, "CompletePad stride is zero!", OPS_ERROR(ErrCode::VALUE));
auto reminder = leftLen % stride;
if (reminder != 0) {
needpads = stride - reminder;
}
return needpads;
}
c10::optional<double> get_scale_value(c10::optional<c10::ArrayRef<double>> scales, int idx)
{
if (!scales) {
return c10::nullopt;
}
TORCH_CHECK(scales->size() > idx, "idx", idx, "is overrange scales->at(idx) ", scales->size(),
OPS_ERROR(ErrCode::VALUE));
return scales->at(idx);
}
at::ScalarType get_divide_result_type(const at::Tensor& self, const at::Tensor& other)
{
at::ScalarType high_type = at::native::result_type(self, other);
if (isIntegralType(high_type, true)) {
high_type = at::kFloat;
}
return high_type;
}
at::ScalarType get_divide_calculate_type(const at::Tensor &self, const at::Tensor &other)
{
at::ScalarType calculate_type = at_npu::native::result_type(self.scalar_type(), other.scalar_type());
if (isIntegralType(calculate_type, true) || calculate_type == at::kDouble) {
calculate_type = at::kFloat;
}
return calculate_type;
}
at::Tensor get_cast_input(const at::Tensor& self, at::ScalarType calculate_type)
{
at::Tensor self_cast = (self.dtype() == calculate_type) ? self : at_npu::native::custom_ops::_npu_dtype_cast(self, calculate_type);
self_cast = at_npu::native::OpPreparation::CastBackToOriFormat(static_cast<const at::Tensor&>(self_cast));
return self_cast;
}
NameVector compute_names_npu(std::vector<at::Tensor> tensor_list)
{
NameVector names;
bool has_names = false;
for (auto tensor : tensor_list) {
if (tensor.has_names()) {
has_names = true;
break;
}
}
if (!has_names) {
return names;
}
for (auto tensor : tensor_list) {
if (names.empty()) {
names = tensor.names();
} else {
names = NameVector(at::unify_from_right(names, tensor.names()));
}
}
return names;
}
double compute_scale(int64_t input_size, int64_t output_size, double scale)
{
if (scale > 0.0) {
return 1.0 / scale ;
} else {
return output_size != 0 ? static_cast<double>(input_size) / output_size : 0;
}
}
bool check_dtype_foreach(at::ScalarType tensorDtype, ForeachTensorDtypeSupport tensorDtypeCategory, ForeachInputType inputType,
c10::optional<at::ScalarType> scalarDtype, c10::optional<ForeachMappingType> mapping)
{
bool result = false;
result = check_foreach_tensor_dtype_spport(tensorDtype, tensorDtypeCategory);
if (!result) {
return false;
}
at::ScalarType dtype;
ForeachMappingType mappingType;
if (scalarDtype == c10::nullopt && mapping == c10::nullopt) {
return result;
} else if (scalarDtype != c10::nullopt && mapping != c10::nullopt) {
dtype = scalarDtype.value();
mappingType = mapping.value();
} else {
TORCH_CHECK(false, "Invalid scalarType Parm or ForeachMappingType Parm!", OPS_ERROR(ErrCode::PARAM));
}
switch (inputType) {
case ForeachInputType::TYPE_SCALAR:
return check_mapping_between_tensor_and_scalar(tensorDtype, dtype, mappingType);
case ForeachInputType::TYPE_SCALARLIST:
return check_mapping_between_tensor_and_scalar_list(tensorDtype, dtype, mappingType);
case ForeachInputType::TYPE_TENSOR:
return true;
default:
TORCH_CHECK(false, "Invalid inputType Parm!", OPS_ERROR(ErrCode::PARAM));
}
}
bool check_foreach_tensor_dtype_spport(at::ScalarType tensorDtype, ForeachTensorDtypeSupport tensorDtypeCategory)
{
switch (tensorDtypeCategory) {
case ForeachTensorDtypeSupport::BASE_DTYPE:
return check_foreach_tensor_dtype_spport_base(tensorDtype);
case ForeachTensorDtypeSupport::TO_INT32:
return check_foreach_tensor_dtype_spport_base(tensorDtype) || (tensorDtype == at::ScalarType::Int);
case ForeachTensorDtypeSupport::TO_INT:
return check_foreach_tensor_dtype_spport_base_and_int(tensorDtype);
default:
TORCH_CHECK(false, "Invalid ForeachTensorDtypeSupport Parm", OPS_ERROR(ErrCode::PARAM));
}
}
bool check_foreach_tensor_dtype_spport_base(at::ScalarType tensorDtype)
{
return (tensorDtype == at::ScalarType::Half || tensorDtype == at::ScalarType::Float ||
tensorDtype == at::ScalarType::BFloat16);
}
bool check_foreach_tensor_dtype_spport_base_and_int(at::ScalarType tensorDtype)
{
return (tensorDtype == at::ScalarType::Half || tensorDtype == at::ScalarType::Float ||
tensorDtype == at::ScalarType::BFloat16 || tensorDtype == at::ScalarType::Int ||
tensorDtype == at::ScalarType::Char || tensorDtype == at::ScalarType::Long);
}
bool check_foreach_scalar_dtype_spport(at::ScalarType scalarDtype)
{
return at::isIntegralType(scalarDtype) || at::isFloatingType(scalarDtype);
}
bool check_mapping_between_tensor_and_scalar_list(at::ScalarType tensorDtype, at::ScalarType scalarDtype, ForeachMappingType mapping)
{
if (!check_foreach_scalar_dtype_spport(scalarDtype)) {
return false;
}
switch (mapping) {
case ForeachMappingType::MAP_SCALARLIST_DEFAULT:
return (at::isIntegralType(scalarDtype) && at::isIntegralType(tensorDtype)) ||
(at::isFloatingType(scalarDtype) && at::isFloatingType(tensorDtype));
default:
TORCH_CHECK(false, "Invalid ForeachMappingType Parm Between Tensor And ScalarList", OPS_ERROR(ErrCode::PARAM));
}
}
bool check_mapping_between_tensor_and_scalar(at::ScalarType tensorDtype, at::ScalarType scalarDtype, ForeachMappingType mapping)
{
if (!check_foreach_scalar_dtype_spport(scalarDtype)) {
return false;
}
switch (mapping) {
case ForeachMappingType::MAP_SCALAR_DEFAULT:
return !at::isIntegralType(tensorDtype) && at::isFloatingType(scalarDtype);
case ForeachMappingType::MAP_POW_SCALAR_AND_TENSOR:
return true;
default:
TORCH_CHECK(false, "Invalid ForeachMappingType Parm Between Tensor And Scalar!", OPS_ERROR(ErrCode::PARAM));
}
}
void check_input_same_type_as_parameters(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& bias)
{
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same" + OPS_ERROR(ErrCode::TYPE));
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
"Input type (", input.toString(), ") and bias type (", bias.toString(),
") should be the same" + OPS_ERROR(ErrCode::TYPE));
}
void check_input_same_type_as_parameters(
const at::Tensor& input,
const at::Tensor& weight)
{
check_input_same_type_as_parameters(input, weight, at::Tensor());
}
bool is_gte_cann_version_810rc1()
{
const static bool is_support_inf_norm = []() -> bool {
return IsGteCANNVersion("8.1.RC1", "CANN");
}();
return is_support_inf_norm;
}
bool is_gte_cann_version_820rc1()
{
const static bool result = IsGteCANNVersion("8.2.RC1", "CANN");
return result;
}
bool is_gte_cann_version_830rc1()
{
const static bool result = IsGteCANNVersion("8.3.RC1", "CANN");
return result;
}
bool is_gte_cann_version_850()
{
const static bool result = IsGteCANNVersion("8.5.0", "CANN");
return result;
}
bool is_gte_cann_version_851()
{
const static bool result = IsGteCANNVersion("8.5.1", "CANN");
return result;
}
bool is_gte_cann_version_850alpha003()
{
const static bool result = IsGteCANNVersion("8.5.0.alpha003", "CANN");
return result;
}
bool is_gte_cann_version_900()
{
const static bool result = IsGteCANNVersion("9.0.0", "CANN");
return result;
}
const std::string DTypeToString(int64_t input_type)
{
return c10_npu::IsCustomDType(input_type) ?
c10_npu::CustomDataTypeToString(input_type) : c10::toString(static_cast<at::ScalarType>(input_type));
}
aclDataType get_dynamic_scales_dtype(const at::Tensor &x, const c10::optional<at::Tensor> &scales, int64_t quant_mode)
{
aclDataType dynamic_scale_dtype = ACL_FLOAT;
if (quant_mode == QuantMode::QUANT_MODE_NO_QUANT) {
dynamic_scale_dtype = (x.scalar_type() == at::kBFloat16 || x.scalar_type() == at::kHalf) ? ACL_FLOAT :
(scales.has_value() ? at_npu::native::OpPreparation::convert_to_acl_data_type(scales.value().scalar_type()) : ACL_FLOAT);
} else if (quant_mode == QuantMode::QUANT_MODE_MX) {
dynamic_scale_dtype = c10_npu::GetAclDataType(c10_npu::DType::FLOAT8_E8M0);
}
return dynamic_scale_dtype;
}
std::vector<int64_t> get_dynamic_shape(const c10::optional<at::Tensor> &scales, int64_t quant_mode, int64_t a, int64_t h)
{
const static int PER_GROUP_SIZE = 128;
const static int MX_QUANT_SIZE = 32;
const static int DIM_TWO = 2;
std::vector<int64_t> shape{a};
if (quant_mode == QuantMode::QUANT_MODE_NO_QUANT && scales.has_value()) {
TORCH_CHECK(scales.value().dim() >= DIM_TWO, "Expected scales to be at least 2D.", OPS_ERROR(ErrCode::PARAM));
shape = {a * scales.value().sizes()[1]};
} else if (quant_mode == QuantMode::QUANT_MODE_PERTOKEN) {
shape = {a};
} else if (quant_mode == QuantMode::QUANT_MODE_PERGROUP) {
shape = {a, (h + PER_GROUP_SIZE - 1) / PER_GROUP_SIZE};
} else if (quant_mode == QuantMode::QUANT_MODE_MX) {
shape = {a, ((h + MX_QUANT_SIZE - 1) / MX_QUANT_SIZE + 1) / 2 * 2};
}
return shape;
}
int64_t check_and_get_group_size(at::IntArrayRef group_size_list)
{
int64_t groups = 0;
if (group_size_list.empty()) {
return groups;
}
size_t group_dim = group_size_list.size();
TORCH_CHECK(group_dim == GROUP_DIM, "group_sizes only support input with three elements, but got ", group_dim,
OPS_ERROR(ErrCode::PARAM));
int64_t group_m = static_cast<int64_t>(group_size_list[0]);
int64_t group_n = static_cast<int64_t>(group_size_list[1]);
int64_t group_k = static_cast<int64_t>(group_size_list[2]);
bool invalid_group_param = ((group_m <= GROUP_MAX && group_m >= 0) && (group_n <= GROUP_MAX && group_n >= 0) &&
(group_k <= GROUP_MAX && group_k >= 0));
TORCH_CHECK(invalid_group_param, "group param value must conform to range [0, 65535]", OPS_ERROR(ErrCode::VALUE));
groups = static_cast<int64_t>((static_cast<uint64_t>(group_m) << OFFSET_32_BITS) +
(static_cast<uint64_t>(group_n) << OFFSET_16_BITS) + static_cast<uint64_t>(group_k));
return groups;
}
int8_t get_cube_math_type_with_passthrough()
{
int8_t cube_math_type = at_npu::native::OpPreparation::get_cube_math_type(at_npu::native::env::IsAllowMatmulHF32());
int8_t cube_math_type_passthrough = at_npu::native::OpPreparation::get_cube_math_type();
if (cube_math_type_passthrough >= 0) {
cube_math_type = cube_math_type_passthrough;
}
return cube_math_type;
}
}
}