#include <ATen/NamedTensorUtils.h>
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
#include "torch_npu/csrc/aten/mirror/NPUTypeProperties.h"
#include "torch_npu/csrc/core/npu/GetCANNInfo.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_plugin {
namespace utils {
bool is_neox_style(std::string rotary_mode)
{
TORCH_CHECK(rotary_mode != "half" || rotary_mode != "interleave",
"rotary_mode only support half or interleave", OPS_ERROR(ErrCode::VALUE));
if (rotary_mode == "half") {
return true;
} else {
return false;
}
}
std::string get_reduction_str(int64_t reduction)
{
std::string reductionStr;
if (reduction == at::Reduction::None) {
reductionStr = "none";
} else if (reduction == at::Reduction::Mean) {
reductionStr = "mean";
} else {
reductionStr = "sum";
}
return reductionStr;
}
int64_t get_rotary_mode(c10::string_view mode)
{
if (mode == "half") {
return 0;
} else if (mode == "interleave") {
return 1;
} else if (mode == "quarter") {
return 2;
} else if (mode == "interleave-half") {
return 3;
}
}
int64_t make_warp_dim(int64_t dim, int64_t dim_post_expr)
{
if (dim_post_expr <= 0) {
dim_post_expr = 1;
}
if (dim < 0) {
dim += dim_post_expr;
}
return dim;
}
bool is_transpose_last_two_dims(const at::Tensor &tensor)
{
if (tensor.dim() < 2 || tensor.dim() > 3) {
return false;
}
int64_t numel = at_npu::native::NPUNativeFunctions::get_storage_size(tensor);
int64_t dim1 = tensor.dim() - 1;
int64_t dim2 = tensor.dim() - 2;
c10::SmallVector<int64_t, 5> tensor_base_size = at_npu::native::OpPreparation::get_tensor_desc_base_sizes(tensor);
if (tensor.stride(dim2) == 1 && tensor.stride(dim1) == tensor.size(dim2) &&
tensor.size(dim1) == tensor_base_size[dim2] &&
tensor.size(dim2) == tensor_base_size[dim1] &&
tensor.numel() == numel &&
tensor_base_size.size() == static_cast<uint64_t>(tensor.dim())) {
return true;
} else {
return false;
}
}
bool is_nz_format(const at::Tensor &mat2)
{
const torch_npu::NPUStorageDesc &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(mat2)->npu_desc_;
return tensor_desc.npu_format_ == ACL_FORMAT_FRACTAL_NZ;
}
bool is_two_tensor_base_format(const at::Tensor &self, const at::Tensor &mat2)
{
return at_npu::native::FormatHelper::IsOpInputBaseFormat(self) &&
at_npu::native::FormatHelper::IsOpInputBaseFormat(mat2);
}
bool is_nd_nz_format(const at::Tensor &self, const at::Tensor &mat2)
{
auto dim_tensor1 = self.dim();
auto dim_tensor2 = mat2.dim();
return dim_tensor1 == 2 && dim_tensor2 == 2 && is_nz_format(mat2) && !is_nz_format(self);
}
bool is_nd_to_nz_on_fly(const at::Tensor &self, const at::Tensor &mat2)
{
const static int64_t kInnerAxisMinLimit = 128;
const static int64_t kInnerAxisMaxLimit = 65535;
if (self.dim() < 2 || mat2.dim() < 2) {
return false;
}
int64_t self_inner_axis = self.size(self.dim() - 1);
int64_t self_outer_axis = self.size(self.dim() - 2);
int64_t mat2_inner_axis = mat2.size(mat2.dim() - 1);
int64_t mat2_outer_axis = mat2.size(mat2.dim() - 2);
if (is_transpose_last_two_dims(self)) {
self_inner_axis = self.size(self.dim() - 2);
self_outer_axis = self.size(self.dim() - 1);
}
if (is_transpose_last_two_dims(mat2)) {
mat2_inner_axis = mat2.size(mat2.dim() - 2);
mat2_outer_axis = mat2.size(mat2.dim() - 1);
}
if (self_inner_axis * self_outer_axis <= kInnerAxisMaxLimit &&
mat2_inner_axis * mat2_outer_axis <= kInnerAxisMaxLimit) {
return true;
}
return ((self_inner_axis >= kInnerAxisMinLimit && self_inner_axis <= kInnerAxisMaxLimit) ||
(self_inner_axis < kInnerAxisMinLimit && !(static_cast<uint64_t>(self_inner_axis) & 0xF))) &&
((mat2_inner_axis >= kInnerAxisMinLimit && mat2_inner_axis <= kInnerAxisMaxLimit) ||
(mat2_inner_axis < kInnerAxisMinLimit && !(static_cast<uint64_t>(mat2_inner_axis) & 0xF)));
}
bool is_scalar_one(const c10::Scalar &scalar)
{
if (scalar.isIntegral(false)) {
return scalar.toInt() == 1;
} else if (scalar.isFloatingPoint()) {
return fabs(scalar.toFloat() - 1.0) < 1e-6;
} else {
return false;
}
}
float get_scalar_float_value(const c10::Scalar &scalar)
{
float value;
if (scalar.isFloatingPoint()) {
value = scalar.toFloat();
} else {
value = static_cast<float>(scalar.toInt());
}
return value;
}
c10::SmallVector<int64_t, N> convert_array_to_vector(c10::IntArrayRef intArray)
{
c10::SmallVector<int64_t, N> intVec;
for (uint64_t i = 0; i < intArray.size(); i++) {
intVec.emplace_back(intArray[i]);
}
return intVec;
}
c10::SmallVector<int64_t, N> get_dimlist_for_tensor(const at::Tensor &self)
{
c10::SmallVector<int64_t, N> dimList = {};
for (int64_t i = 0; i < self.dim(); i++) {
dimList.emplace_back(i);
}
return dimList;
}
int64_t complete_pad(int64_t s_size, int64_t p_size, int64_t k_size, int64_t stride)
{
int64_t needpads = 0;
int64_t sizeP = s_size + p_size * 2;
int64_t leftLen = sizeP - k_size;
TORCH_CHECK(stride != 0, "CompletePad stride is zero!", OPS_ERROR(ErrCode::VALUE));
auto reminder = leftLen % stride;
if (reminder != 0) {
needpads = stride - reminder;
}
return needpads;
}
c10::optional<double> get_scale_value(c10::optional<c10::ArrayRef<double>> scales, int idx)
{
if (!scales) {
return c10::nullopt;
}
TORCH_CHECK(scales->size() > idx, "idx", idx, "is overrange scales->at(idx) ", scales->size(),
OPS_ERROR(ErrCode::VALUE));
return scales->at(idx);
}
at::ScalarType get_divide_result_type(const at::Tensor& self, const at::Tensor& other)
{
at::ScalarType high_type = at::native::result_type(self, other);
if (isIntegralType(high_type, true)) {
high_type = at::kFloat;
}
return high_type;
}
at::ScalarType get_divide_calculate_type(const at::Tensor &self, const at::Tensor &other)
{
at::ScalarType calculate_type = at_npu::native::result_type(self.scalar_type(), other.scalar_type());
if (isIntegralType(calculate_type, true) || calculate_type == at::kDouble) {
calculate_type = at::kFloat;
}
return calculate_type;
}
at::Tensor get_cast_input(const at::Tensor& self, at::ScalarType calculate_type)
{
at::Tensor self_cast = (self.dtype() == calculate_type) ? self : at_npu::native::custom_ops::npu_dtype_cast(self, calculate_type);
self_cast = at_npu::native::OpPreparation::CastBackToOriFormat(self_cast);
return self_cast;
}
NameVector compute_names_npu(std::vector<at::Tensor> tensor_list)
{
NameVector names;
bool has_names = false;
for (auto tensor : tensor_list) {
if (tensor.has_names()) {
has_names = true;
break;
}
}
if (!has_names) {
return names;
}
for (auto tensor : tensor_list) {
if (names.empty()) {
names = tensor.names();
} else {
names = NameVector(at::unify_from_right(names, tensor.names()));
}
}
return names;
}
double compute_scale(int64_t input_size, int64_t output_size, double scale)
{
if (scale > 0.0) {
return 1.0 / scale ;
} else {
return output_size != 0 ? static_cast<double>(input_size) / output_size : 0;
}
}
bool check_dtype_foreach(at::ScalarType tensorDtype, ForeachTensorDtypeSupport tensorDtypeCategory, ForeachInputType inputType,
c10::optional<at::ScalarType> scalarDtype, c10::optional<ForeachMappingType> mapping)
{
bool result = false;
result = check_foreach_tensor_dtype_spport(tensorDtype, tensorDtypeCategory);
if (!result) {
return false;
}
at::ScalarType dtype;
ForeachMappingType mappingType;
if (scalarDtype == c10::nullopt && mapping == c10::nullopt) {
return result;
} else if (scalarDtype != c10::nullopt && mapping != c10::nullopt) {
dtype = scalarDtype.value();
mappingType = mapping.value();
} else {
TORCH_CHECK(false, "Invalid scalarType Parm or ForeachMappingType Parm!", OPS_ERROR(ErrCode::PARAM));
}
switch (inputType) {
case ForeachInputType::TYPE_SCALAR:
return check_mapping_between_tensor_and_scalar(tensorDtype, dtype, mappingType);
case ForeachInputType::TYPE_SCALARLIST:
return check_mapping_between_tensor_and_scalar_list(tensorDtype, dtype, mappingType);
case ForeachInputType::TYPE_TENSOR:
return true;
default:
TORCH_CHECK(false, "Invalid inputType Parm!", OPS_ERROR(ErrCode::PARAM));
}
}
bool check_foreach_tensor_dtype_spport(at::ScalarType tensorDtype, ForeachTensorDtypeSupport tensorDtypeCategory)
{
switch (tensorDtypeCategory) {
case ForeachTensorDtypeSupport::BASE_DTYPE:
return check_foreach_tensor_dtype_spport_base(tensorDtype);
case ForeachTensorDtypeSupport::TO_INT32:
return check_foreach_tensor_dtype_spport_base(tensorDtype) || (tensorDtype == at::ScalarType::Int);
case ForeachTensorDtypeSupport::TO_INT:
return check_foreach_tensor_dtype_spport_base_and_int(tensorDtype);
default:
TORCH_CHECK(false, "Invalid ForeachTensorDtypeSupport Parm", OPS_ERROR(ErrCode::PARAM));
}
}
bool check_foreach_tensor_dtype_spport_base(at::ScalarType tensorDtype)
{
return (tensorDtype == at::ScalarType::Half || tensorDtype == at::ScalarType::Float ||
tensorDtype == at::ScalarType::BFloat16);
}
bool check_foreach_tensor_dtype_spport_base_and_int(at::ScalarType tensorDtype)
{
return (tensorDtype == at::ScalarType::Half || tensorDtype == at::ScalarType::Float ||
tensorDtype == at::ScalarType::BFloat16 || tensorDtype == at::ScalarType::Int ||
tensorDtype == at::ScalarType::Char || tensorDtype == at::ScalarType::Long);
}
bool check_foreach_scalar_dtype_spport(at::ScalarType scalarDtype)
{
return at::isIntegralType(scalarDtype) || at::isFloatingType(scalarDtype);
}
bool check_mapping_between_tensor_and_scalar_list(at::ScalarType tensorDtype, at::ScalarType scalarDtype, ForeachMappingType mapping)
{
if (!check_foreach_scalar_dtype_spport(scalarDtype)) {
return false;
}
switch (mapping) {
case ForeachMappingType::MAP_SCALARLIST_DEFAULT:
return (at::isIntegralType(scalarDtype) && at::isIntegralType(tensorDtype)) ||
(at::isFloatingType(scalarDtype) && at::isFloatingType(tensorDtype));
default:
TORCH_CHECK(false, "Invalid ForeachMappingType Parm Between Tensor And ScalarList", OPS_ERROR(ErrCode::PARAM));
}
}
bool check_mapping_between_tensor_and_scalar(at::ScalarType tensorDtype, at::ScalarType scalarDtype, ForeachMappingType mapping)
{
if (!check_foreach_scalar_dtype_spport(scalarDtype)) {
return false;
}
switch (mapping) {
case ForeachMappingType::MAP_SCALAR_DEFAULT:
return !at::isIntegralType(tensorDtype) && at::isFloatingType(scalarDtype);
case ForeachMappingType::MAP_POW_SCALAR_AND_TENSOR:
return true;
default:
TORCH_CHECK(false, "Invalid ForeachMappingType Parm Between Tensor And Scalar!", OPS_ERROR(ErrCode::PARAM));
}
}
void check_input_same_type_as_parameters(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& bias)
{
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same" + OPS_ERROR(ErrCode::TYPE));
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
"Input type (", input.toString(), ") and bias type (", bias.toString(),
") should be the same" + OPS_ERROR(ErrCode::TYPE));
}
void check_input_same_type_as_parameters(
const at::Tensor& input,
const at::Tensor& weight)
{
check_input_same_type_as_parameters(input, weight, at::Tensor());
}
bool is_gte_cann_version_810rc1()
{
const static bool is_support_inf_norm = []() -> bool {
return IsGteCANNVersion("8.1.RC1", "CANN");
}();
return is_support_inf_norm;
}
bool is_gte_cann_version_820rc1()
{
const static bool result = IsGteCANNVersion("8.2.RC1", "CANN");
return result;
}
bool is_gte_cann_version_830rc1()
{
const static bool result = IsGteCANNVersion("8.3.RC1", "CANN");
return result;
}
}
}