* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "defalut_reg_func.h"
namespace af {
namespace ascir {
Expression GetCompareNormalNoLoopTmpSize(const AscNode &node, const std::vector<uint32_t> &axis_ids) {
auto node_inputs = node.inputs;
uint8_t typeSize = node_inputs[0].attr.dtype == ge::DataType::DT_FLOAT16 ? 2 : 4;
EXPECT_SYMBOL_NE(Symbol(typeSize), Symbol(0));
Expression compare_out_size = Symbol(256 / 8 / typeSize) * node_inputs[0].attr.repeats[axis_ids[0]];
Expression select_out_size = Symbol(128) *
((node_inputs[0].attr.repeats[axis_ids[0]] * Symbol(2) + Symbol(typeSize - 1)) / Symbol(typeSize));
return compare_out_size + select_out_size;
}
std::vector<std::unique_ptr<TmpBufDesc>> GetCompareNormalTmpSize(const AscNode &node, const std::string &mode) {
auto node_inputs = node.inputs;
std::vector<uint32_t> axis_ids = {UINT32_MAX, UINT32_MAX};
uint32_t vec_last_axis_pos_in_axis = std::find(node_inputs[0].attr.axis.begin(), node_inputs[0].attr.axis.end(),
node_inputs[0].attr.vectorized_axis.back()) - node_inputs[0].attr.axis.begin();
std::vector<std::unique_ptr<TmpBufDesc>> tmp_buf_descs;
for (auto vec_axis : node_inputs[0].attr.vectorized_axis) {
auto pos = std::find(node_inputs[0].attr.axis.begin(), node_inputs[0].attr.axis.end(), vec_axis);
GE_ASSERT_TRUE(pos != node_inputs[0].attr.axis.end(), "Incorrect axis ID in vectorized_axis");
uint32_t axis_id = static_cast<uint32_t>(pos - node_inputs[0].attr.axis.begin());
if (axis_id < axis_ids[1]) {
if (axis_id < axis_ids[0]) {
axis_ids[1] = axis_ids[0];
axis_ids[0] = axis_id;
} else {
axis_ids[1] = axis_id;
}
}
}
if (axis_ids[1] != UINT32_MAX) {
if (node_inputs[0].attr.dtype == ge::DataType::DT_FLOAT16 ||
node_inputs[0].attr.dtype == ge::DataType::DT_FLOAT ||
(node_inputs[0].attr.dtype == ge::DataType::DT_INT32 && (mode == "EQ" || mode == "NE"))) {
GELOGD("[GetInputSize] axis id is: %u, %u", axis_ids[0], axis_ids[1]);
GELOGD("[GetInputSize] inputs[0].repeat is: %s, %s", node_inputs[0].attr.repeats[axis_ids[0]].Str().get(),
node_inputs[0].attr.repeats[axis_ids[1]].Str().get());
GELOGD("[GetInputSize] inputs[0].vectorized_strides is: %s, %s",
node_inputs[0].attr.vectorized_strides[0].Str().get(),
node_inputs[0].attr.vectorized_strides[1].Str().get());
Expression input_size = Symbol(256) * node_inputs[0].attr.repeats[axis_ids[1]] *
node_inputs[0].attr.vectorized_strides[1];
Expression total_size = input_size + Symbol(32) * node_inputs[0].attr.repeats[axis_ids[1]] *
node_inputs[0].attr.vectorized_strides[1];
return GetTmpBuffer(sym::Max(total_size, GetCompareNormalNoLoopTmpSize(node, axis_ids)));
} else if (node_inputs[0].attr.dtype == ge::DataType::DT_INT32) {
GELOGD("[GetInputSize] axis id is: %u", axis_ids[1]);
GELOGD("[GetInputSize] inputs[1].repeat is: %s", node_inputs[0].attr.repeats[axis_ids[0]].Str().get());
GELOGD("[GetInputSize] inputs[1].vectorized_strides is: %s",
node_inputs[0].attr.vectorized_strides[1].Str().get());
Expression input_size = node_inputs[0].attr.repeats[axis_ids[0]] *
node_inputs[0].attr.vectorized_strides[1];
Expression total_size = input_size * Symbol(256);
return GetTmpBuffer(total_size);
} else if (node_inputs[0].attr.dtype == ge::DataType::DT_INT64) {
if (mode == "EQ" || mode == "NE") {
GELOGD("[GetInputSize] axis id is: %u", axis_ids[1]);
GELOGD("[GetInputSize] inputs[1].repeat is: %s", node_inputs[0].attr.repeats[axis_ids[0]].Str().get());
GELOGD("[GetInputSize] inputs[1].vectorized_strides is: %s",
node_inputs[0].attr.vectorized_strides[1].Str().get());
Expression input_size = node_inputs[0].attr.repeats[axis_ids[0]] *
node_inputs[0].attr.vectorized_strides[1];
Expression total_size = input_size * Symbol(256) * Symbol(2) + Symbol(256);
return GetTmpBuffer(total_size);
} else {
Expression last_axis_size = sym::Align(node_inputs[0].attr.repeats[vec_last_axis_pos_in_axis], 32);
GELOGD("[GetInputSize] last axis size is: %s", last_axis_size.Str().get());
GELOGD("[GetInputSize] axis id is: %u", axis_ids[1]);
GELOGD("[GetInputSize] node_inputs[0].attr.repeats[axis_ids[0]] is: %s",
node_inputs[0].attr.repeats[axis_ids[0]].Str().get());
GELOGD("[GetInputSize] node_inputs[0].attr.vectorized_strides[1] is: %s",
node_inputs[0].attr.vectorized_strides[1].Str().get());
Expression input_size = node_inputs[0].attr.repeats[axis_ids[0]] *
node_inputs[0].attr.vectorized_strides[1];
GELOGD("[GetInputSize] input_size is: %s", input_size.Str().get());
Expression total_size = input_size * last_axis_size * Symbol(8) * Symbol(5);
TmpBufDesc desc = {total_size, -1};
tmp_buf_descs.emplace_back(std::make_unique<TmpBufDesc>(desc));
return tmp_buf_descs;
}
} else {
return CalcDefaultTmpSize(node);
}
} else {
GELOGD("[GetInputSize] axis id is: %u", axis_ids[0]);
GELOGD("[GetInputSize] inputs[0].repeat is: %s", node_inputs[0].attr.repeats[axis_ids[0]].Str().get());
GELOGD("[GetInputSize] inputs[0].vectorized_strides is: %s",
node_inputs[0].attr.vectorized_strides[0].Str().get());
Expression input_size = node_inputs[0].attr.repeats[axis_ids[0]] * node_inputs[0].attr.vectorized_strides[0];
Expression total_size = Symbol(ge::GetSizeByDataType(node_inputs[0].attr.dtype)) * input_size;
if (node_inputs[0].attr.dtype == ge::DataType::DT_INT64) {
total_size = total_size * Symbol(5);
}
return GetTmpBuffer(total_size);
}
}
std::vector<std::unique_ptr<TmpBufDesc>> CalcGeTmpSize(const AscNode &node) {
return GetCompareNormalTmpSize(node, "GE");
}
std::vector<std::unique_ptr<TmpBufDesc>> CalcEqTmpSize(const AscNode &node) {
return GetCompareNormalTmpSize(node, "EQ");
}
std::vector<std::unique_ptr<TmpBufDesc>> CalcNeTmpSize(const AscNode &node) {
return GetCompareNormalTmpSize(node, "NE");
}
std::vector<std::unique_ptr<TmpBufDesc>> CalcGtTmpSize(const AscNode &node) {
return GetCompareNormalTmpSize(node, "GT");
}
std::vector<std::unique_ptr<TmpBufDesc>> CalcLeTmpSize(const AscNode &node) {
return GetCompareNormalTmpSize(node, "LE");
}
std::vector<std::unique_ptr<TmpBufDesc>> CalcLtTmpSize(const AscNode &node) {
return GetCompareNormalTmpSize(node, "LT");
}
}
}