* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "ascir_common.h"
#include "symbolizer/symbolic_utils.h"
#include "graph/ascendc_ir/utils/asc_tensor_utils.h"
namespace af {
namespace ascir {
bool OnlySecondInputSupportScalar(const std::vector<bool> &is_scalar_list) {
GE_ASSERT_EQ(is_scalar_list.size(), 2UL);
return !is_scalar_list[0] && is_scalar_list[1];
}
[[nodiscard]] std::pair<std::vector<ge::DataType>, std::vector<ge::DataType>>
GetConversionFromDtypeMap(const AscNode &node, const std::map<ge::DataType, ge::DataType> &dtype_conversion_map) {
std::pair<std::vector<ge::DataType>, std::vector<ge::DataType>> conversion_dtype;
AscNodeInputs node_inputs = node.inputs;
AscNodeOutputs node_outputs = node.outputs;
for (size_t i = 0; i < node_inputs().size(); i++) {
auto it = dtype_conversion_map.find(node_inputs[i].attr.dtype);
if (it != dtype_conversion_map.end()) {
conversion_dtype.first.emplace_back(it->second);
} else {
conversion_dtype.first.emplace_back(node_inputs[i].attr.dtype);
}
}
for (size_t i = 0; i < node_outputs().size(); i++) {
auto it = dtype_conversion_map.find(node_outputs[i].attr.dtype);
if (it != dtype_conversion_map.end()) {
conversion_dtype.second.emplace_back(it->second);
} else {
conversion_dtype.second.emplace_back(node_outputs[i].attr.dtype);
}
}
return conversion_dtype;
}
bool IsAllVecAxisContinuous(const AscNode &node) {
AscNodeInputs node_inputs = node.inputs;
AscNodeOutputs node_outputs = node.outputs;
for (size_t i = 0; i < node_inputs().size(); i++) {
if (node_inputs[i].attr.vectorized_axis.size() == 1) {
continue;
}
auto &attr = node_inputs[i].attr;
for (size_t j = 1; j < attr.vectorized_axis.size(); j++) {
auto it = std::find(attr.axis.begin(), attr.axis.end(), attr.vectorized_axis[j]);
GE_ASSERT_TRUE(it != attr.axis.end(), "Incorrect axis ID in node: %s input %zu vectorized_axis: %zu",
node.GetName(), i, j);
auto axis_id = static_cast<uint64_t>(std::distance(attr.axis.begin(), it));
Expression cur_axis_stride = attr.repeats[axis_id] * attr.vectorized_strides[j];
if (SymbolicUtils::StaticCheckEq(cur_axis_stride, attr.vectorized_strides[j - 1]) != TriBool::kTrue) {
return false;
}
}
}
for (size_t i = 0; i < node_outputs().size(); i++) {
if (node_outputs[i].attr.vectorized_axis.size() == 1) {
continue;
}
auto &attr = node_outputs[i].attr;
for (size_t j = 1; j < attr.vectorized_axis.size(); j++) {
auto it = std::find(attr.axis.begin(), attr.axis.end(), attr.vectorized_axis[j]);
GE_ASSERT_TRUE(it != attr.axis.end(), "Incorrect axis ID in node: %s output %zu vectorized_axis: %zu",
node.GetName(), i, j);
auto axis_id = static_cast<uint64_t>(std::distance(attr.axis.begin(), it));
Expression cur_axis_stride = attr.repeats[axis_id] * attr.vectorized_strides[j];
if (SymbolicUtils::StaticCheckEq(cur_axis_stride, attr.vectorized_strides[j - 1]) != TriBool::kTrue) {
return false;
}
}
}
return true;
}
bool IsUBScalarTensor(const AscTensor &tensor) {
auto &attr = tensor.attr;
uint64_t axis_id = UINT64_MAX;
for (size_t i = 0; i < attr.vectorized_axis.size(); i++) {
auto it = std::find(attr.axis.begin(), attr.axis.end(), attr.vectorized_axis[i]);
GE_ASSERT_TRUE(it != attr.axis.end(), "Incorrect axis ID in vectorized_axis");
axis_id = static_cast<uint64_t>(std::distance(attr.axis.begin(), it));
if (SymbolicUtils::StaticCheckEq(attr.repeats[axis_id], Symbol(1)) != TriBool::kTrue) {
return false;
}
}
return true;
}
bool IsVectorizedAxisSupportBrc(const AscNode &node, size_t input_id,
const BroadcastCapability &broadcast_capability) {
AscNodeInputs node_inputs = node.inputs;
if ((IsUBScalarTensor(node_inputs[input_id]) || broadcast_capability.support_inline_brc) &&
(std::find(broadcast_capability.scalar_inputs.begin(), broadcast_capability.scalar_inputs.end(), input_id) !=
broadcast_capability.scalar_inputs.end())) {
return true;
}
return false;
}
Status ValidateInputTensorLoopAxis(const AscNode &node, size_t input_id, size_t input_axis_id) {
AscNodeInputs node_inputs = node.inputs;
AscNodeOutputs node_outputs = node.outputs;
auto input_attr = node_inputs[input_id].attr;
auto output_attr = node_outputs[0].attr;
auto it = std::find(output_attr.axis.begin(), output_attr.axis.end(), input_attr.axis[input_axis_id]);
GE_ASSERT_TRUE(it != output_attr.axis.end(), "Node %s[%s]: input tensor %zu loop axis %zu is not in output tensor "
"axis", node.GetTypePtr(), node.GetNamePtr(), input_id, input_axis_id);
auto output_axis_id = static_cast<uint64_t>(std::distance(output_attr.axis.begin(), it));
if ((SymbolicUtils::StaticCheckEq(output_attr.repeats[output_axis_id], input_attr.repeats[input_axis_id]) ==
TriBool::kTrue) || (SymbolicUtils::StaticCheckEq(input_attr.repeats[input_axis_id], Symbol(1)) ==
TriBool::kTrue)) {
return ge::SUCCESS;
} else if (SymbolicUtils::StaticCheckEq(output_attr.repeats[output_axis_id], input_attr.repeats[input_axis_id]) ==
TriBool::kUnknown) {
GELOGW("Node %s[%s]: input tensor %zu loop axis %zu repeat %s and output tensor 0 loop axis %zu repeat %s may not "
"be equal or broadcastable(relation cannot be determined)", node.GetTypePtr(), node.GetNamePtr(), input_id,
input_axis_id, input_attr.repeats[input_axis_id].Str().get(), output_axis_id,
output_attr.repeats[output_axis_id].Str().get());
return ge::SUCCESS;
}
GELOGE(ge::FAILED, "Node %s[%s]: input tensor %zu loop axis %zu repeat %s and output tensor 0 loop axis %zu repeat "
"%s are not equal or broadcastable", node.GetTypePtr(), node.GetNamePtr(), input_id, input_axis_id,
input_attr.repeats[input_axis_id].Str().get(), output_axis_id,
output_attr.repeats[output_axis_id].Str().get());
return ge::FAILED;
}
Status ValidateInputTensorVectorizedAxis(const AscNode &node, size_t input_id, size_t input_axis_id,
const BroadcastCapability &broadcast_capability) {
AscNodeInputs node_inputs = node.inputs;
AscNodeOutputs node_outputs = node.outputs;
auto input_attr = node_inputs[input_id].attr;
auto output_attr = node_outputs[0].attr;
auto it = std::find(output_attr.axis.begin(), output_attr.axis.end(), input_attr.axis[input_axis_id]);
GE_ASSERT_TRUE(it != output_attr.axis.end(), "Node %s[%s]: input tensor %zu vectorized_axis %zu is not in output "
"tensor axis", node.GetTypePtr(), node.GetNamePtr(), input_id, input_axis_id);
auto output_axis_id = static_cast<uint64_t>(std::distance(output_attr.axis.begin(), it));
if ((SymbolicUtils::StaticCheckEq(output_attr.repeats[output_axis_id], input_attr.repeats[input_axis_id]) ==
TriBool::kTrue) || IsVectorizedAxisSupportBrc(node, input_id, broadcast_capability)) {
return ge::SUCCESS;
} else if (SymbolicUtils::StaticCheckEq(output_attr.repeats[output_axis_id], input_attr.repeats[input_axis_id]) ==
TriBool::kUnknown) {
GELOGW("Node %s[%s]: input tensor %zu vectorized_axis %zu repeat: %s and output tensor 0 vectorized_axis %zu "
"repeat: %s may not be equal or broadcastable(relation cannot be determined)", node.GetTypePtr(),
node.GetNamePtr(), input_id, input_axis_id, input_attr.repeats[input_axis_id].Str().get(), output_axis_id,
output_attr.repeats[output_axis_id].Str().get());
return ge::SUCCESS;
}
GELOGE(ge::FAILED, "Node %s[%s]: input tensor %zu vectorized_axis %zu repeat: %s and output tensor 0 vectorized_axis "
"%zu repeat: %s are not equal or broadcastable", node.GetTypePtr(), node.GetNamePtr(), input_id, input_axis_id,
input_attr.repeats[input_axis_id].Str().get(), output_axis_id,
output_attr.repeats[output_axis_id].Str().get());
return ge::FAILED;
}
Status ValidateShapeConsistencyWithSingleOutput(const AscNode &node,
const BroadcastCapability &broadcast_capability) {
AscNodeInputs node_inputs = node.inputs;
AscNodeOutputs node_outputs = node.outputs;
GE_ASSERT_TRUE(!(node_outputs().size() != 1), "Node %s[%s]: output tensor size is not equal with 1",
node.GetTypePtr(), node.GetNamePtr());
GE_ASSERT_TRUE(!node_outputs[0].attr.vectorized_axis.empty(), "Node %s[%s]: output tensor has empty vectorized_axis",
node.GetTypePtr(), node.GetNamePtr());
std::vector<Expression> output_repeats = node_outputs[0].attr.repeats;
for (size_t i = 0; i < node_inputs().size(); i++) {
auto &input = node_inputs[i];
if (input.attr.vectorized_axis.empty()) {
continue;
}
for (size_t j = 0; j < input.attr.repeats.size(); j++) {
if (std::find(input.attr.vectorized_axis.begin(), input.attr.vectorized_axis.end(), input.attr.axis[j]) !=
input.attr.vectorized_axis.end()) {
GE_ASSERT_SUCCESS(ValidateInputTensorVectorizedAxis(node, i, j, broadcast_capability), "Node %s[%s]: input "
"tensor %zu axis %zu validate vectorized_axis consistency failed", node.GetTypePtr(),
node.GetNamePtr(), i, j);
} else {
GE_ASSERT_SUCCESS(ValidateInputTensorLoopAxis(node, i, j), "Node %s[%s]: input tensor %zu "
"%zu axis %zu validate loop axis consistency failed", node.GetTypePtr(),
node.GetNamePtr(), i, j);
}
}
}
return ge::SUCCESS;
}
bool IsNodeHasScalarInput(const AscNode &node) {
AscNodeInputs node_inputs = node.inputs;
for (size_t i = 0; i < node_inputs().size(); i++) {
if (node.GetInDataNodes().at(i)->GetType() == "Scalar") {
return true;
}
}
return false;
}
bool IsNodeFirstInputScalar(const AscNode &node) {
return node.GetInDataNodes().at(0)->GetType() == "Scalar";
}
}
}