* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "attribute_group/attr_group_symbolic_desc.h"
#include "exe_graph/runtime/storage_shape.h"
#include "common/util/mem_utils.h"
#include "graph/optimize/symbolic/symbol_compute_context.h"
#include "graph/optimize/symbolic/infer_symbolic_shape/symbolic_infer_util.h"
#include "graph/optimize/symbolic/symbolic_kernel_factory.h"
#include "common/plugin/ge_make_unique_util.h"
#include "graph_metadef/common/ge_common/util.h"
#include "graph/optimize/symbolic/infer_symbolic_shape/symbolic_infer_util.h"
namespace ge {
namespace {
const size_t kSelectInputNum = 3U;
const size_t kSelectOutputNum = 1U;
const std::vector<ge::Expression> *GetInputSymbolicValue(const gert::InferSymbolComputeContext *context, size_t index) {
auto symbol_tensor = context->GetInputSymbolTensor(index);
return symbol_tensor == nullptr ? nullptr : symbol_tensor->GetSymbolicValue();
}
bool GetShapeValue(const std::vector<Expression> &symbol_shape, std::vector<int64_t> &shape_value) {
shape_value.reserve(symbol_shape.size());
for (const auto &expr : symbol_shape) {
int64_t value = 0;
GE_ASSERT_TRUE(expr.GetConstValue(value));
shape_value.emplace_back(value);
}
return true;
}
bool CalOutputValue(const std::vector<Expression> &condition, const std::vector<Expression> &x1,
const std::vector<Expression> &x2, std::vector<Expression> &output) {
output.reserve(condition.size());
for (size_t i = 0U; i < condition.size(); ++i) {
int32_t cond = 0;
GE_ASSERT_TRUE(condition[i].GetConstValue(cond));
if (cond > 0) {
output.emplace_back(x1[i]);
} else {
output.emplace_back(x2[i]);
}
}
return true;
}
bool CheckInputsValue(const std::vector<ge::Expression> *condition_value, const std::vector<ge::Expression> *x1_value,
const std::vector<ge::Expression> *x2_value) {
return condition_value != nullptr && !condition_value->empty() && x1_value != nullptr && !x1_value->empty() &&
x2_value != nullptr && !x2_value->empty();
}
std::vector<int64_t> AlignShape(const std::vector<int64_t> &src, const std::vector<int64_t> &dst) {
std::vector<int64_t> aligned = src;
while (aligned.size() < dst.size()) {
aligned.insert(aligned.begin(), 1);
}
return aligned;
}
std::vector<int64_t> ComputeStrides(const std::vector<int64_t> &shape) {
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
bool Broadcast(const std::vector<Expression> &src_data, const std::vector<int64_t> &src_shape, const std::vector<int64_t> &dst_shape,
std::vector<Expression> &dst_data) {
std::vector<int64_t> aligned_src = AlignShape(src_shape, dst_shape);
if (aligned_src.size() != dst_shape.size()) {
GELOGW("Cannot broadcast, after aligned shape size is not equal, aligned_size: %zu , dst_size: %zu",
aligned_src.size(), dst_shape.size());
return false;
}
for (size_t i = 0U; i < dst_shape.size(); ++i) {
if (aligned_src[i] != 1 && aligned_src[i] != dst_shape[i]) {
GELOGW("Dim: %zu cannot broadcast, src_value: %lld , dst_value: %lld", i, aligned_src[i], dst_shape[i]);
return false;
}
}
int64_t total_elements = 1;
for (auto dim : dst_shape) {
total_elements *= dim;
}
if (static_cast<int64_t>(src_data.size()) == total_elements) {
dst_data = src_data;
return true;
}
std::vector<int64_t> src_strides = ComputeStrides(aligned_src);
std::vector<int64_t> dst_strides = ComputeStrides(dst_shape);
GE_ASSERT_TRUE((aligned_src.size() == src_strides.size()) && (dst_shape.size() == dst_strides.size()));
dst_data.reserve(total_elements);
for (int64_t pos = 0; pos < total_elements; ++pos) {
int64_t src_linear = 0;
for (size_t i = 0U; i < dst_shape.size(); ++i) {
int64_t dim = aligned_src[i];
int64_t idx = (pos / dst_strides[i]) % dst_shape[i];
src_linear += (idx % dim) * src_strides[i];
}
dst_data.push_back(src_data[src_linear]);
}
return true;
}
}
static graphStatus SelectSymbolicKernelCompute(gert::InferSymbolComputeContext *context) {
GE_ASSERT_NOTNULL(context);
GELOGD("Select Symbolic Kernel in, node %s[%s].", context->GetNodeName(), context->GetNodeType());
GE_ASSERT_EQ(context->GetComputeNodeInputNum(), kSelectInputNum);
GE_ASSERT_EQ(context->GetComputeNodeOutputNum(), kSelectOutputNum);
const std::vector<ge::Expression> *condition_value = GetInputSymbolicValue(context, 0);
const std::vector<ge::Expression> *x1_value = GetInputSymbolicValue(context, 1);
const std::vector<ge::Expression> *x2_value = GetInputSymbolicValue(context, 2);
if (!CheckInputsValue(condition_value, x2_value, x2_value)) {
GELOGW("Select not support, inputs symbol value is empty.");
return UNSUPPORTED;
}
auto condition_shape = context->GetInputSymbolShape(0U);
auto x1_shape = context->GetInputSymbolShape(1U);
auto x2_shape = context->GetInputSymbolShape(2U);
GE_UNSUPPORTED_IF_NULL(condition_shape);
GE_UNSUPPORTED_IF_NULL(x1_shape);
GE_UNSUPPORTED_IF_NULL(x2_shape);
if (*x1_shape != *x2_shape) {
GELOGW("Select not support, check inputs shape failed, x1_shape: %s, x2_shape: %s.",
SymbolicInferUtil::VectorExpressionToStr(x1_shape->GetDims()).c_str(),
SymbolicInferUtil::VectorExpressionToStr(x2_shape->GetDims()).c_str());
return UNSUPPORTED;
}
std::vector<int64_t> cond_shape_value;
std::vector<int64_t> x1_shape_value;
if (!(GetShapeValue(condition_shape->GetDims(), cond_shape_value)) ||
!(GetShapeValue(x1_shape->GetDims(), x1_shape_value))) {
GELOGW("Select not support, get condition_shape value: %s or x1_shape value: %s failed.",
SymbolicInferUtil::VectorExpressionToStr(condition_shape->GetDims()).c_str(),
SymbolicInferUtil::VectorExpressionToStr(x1_shape->GetDims()).c_str());
return UNSUPPORTED;
}
std::vector<Expression> after_bc_condition;
if (!Broadcast(*condition_value, cond_shape_value, x1_shape_value, after_bc_condition)) {
GELOGW("Select not support, broadcast failed.");
return UNSUPPORTED;
}
std::vector<Expression> output_value;
if (!(CalOutputValue(after_bc_condition, *x1_value, *x2_value, output_value))) {
GELOGW("Select not support, cal output value failed.");
return UNSUPPORTED;
}
auto output_desc = context->GetOutputSymbolTensor(0U);
GE_ASSERT_NOTNULL(output_desc);
auto symbolic_value_unique = ge::MakeUnique<std::vector<ge::Expression>>(output_value);
if (symbolic_value_unique != nullptr) {
output_desc->SetSymbolicValue(std::move(symbolic_value_unique));
}
output_desc->MutableOriginSymbolShape().MutableDims() = x1_shape->GetDims();
GELOGD("%s[%s] kernel success, output ori shape: %s, output value %s", context->GetNodeName(), context->GetNodeType(),
SymbolicInferUtil::VectorExpressionToStr(output_desc->GetOriginSymbolShape().GetDims()).c_str(),
SymbolicInferUtil::VectorExpressionToStr(*output_desc->GetSymbolicValue()).c_str());
return SUCCESS;
}
REGISTER_SYMBOLIC_KERNEL(Select, SelectSymbolicKernelCompute);
}