* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "reg_where_api_call.h"
#include <sstream>
#include "attr_utils.h"
#include "ascir_ops.h"
#include "common_utils.h"
#include "common/ge_common/debug/log.h"
#include "graph/ascendc_ir/utils/asc_tensor_utils.h"
#include "common/checker.h"
#include "api_call/utils/api_call_factory.h"
#include "api_call/utils/api_call_utils.h"
namespace codegen {
using namespace std;
using namespace af::ops;
using namespace af::ascir_op;
using namespace ascgen_utils;
Status WhereRegApiCall::PrepareInputsAndOutputs(const std::vector<std::reference_wrapper<const Tensor>> &inputs,
const std::vector<std::reference_wrapper<const Tensor>> &outputs,
const Tensor *&x1, const Tensor *&x2, const Tensor *&x3,
const Tensor *&y) const {
size_t x1_idx = 0;
size_t x2_idx = 1;
size_t x3_idx = 2;
x1 = &inputs[x1_idx].get();
x2 = &inputs[x2_idx].get();
x3 = &inputs[x3_idx].get();
y = &outputs[0].get();
GELOGD("x2, is_constant:%d, is_ub_scalar:%d, need_gen_get_value_of_ub_scalar:%d",
static_cast<int32_t>(x2->is_constant),
static_cast<int32_t>(x2->is_ub_scalar),
static_cast<int32_t>(x2->need_gen_get_value_of_ub_scalar));
GELOGD("x3, is_constant:%d, is_ub_scalar:%d, need_gen_get_value_of_ub_scalar:%d",
static_cast<int32_t>(x3->is_constant),
static_cast<int32_t>(x3->is_ub_scalar),
static_cast<int32_t>(x3->need_gen_get_value_of_ub_scalar));
return ge::SUCCESS;
}
Status WhereRegApiCall::GenerateLoopParams(const Tensor &x1, const Tensor &x2, const Tensor &x3, const Tensor &y,
const TPipe &tpipe, ApiLoopParams ¶m) const {
VectorizedAxisLoopMergeStatus merge_info;
std::vector<Tensor> ub_inputs;
std::vector<Tensor> ub_outputs;
ub_inputs.push_back(x1);
if (!x2.is_constant && !x2.need_gen_get_value_of_ub_scalar) {
ub_inputs.push_back(x2);
}
if (!x3.is_constant && !x3.need_gen_get_value_of_ub_scalar) {
ub_inputs.push_back(x3);
}
ub_outputs.push_back(y);
bool status = GenerateVectorizedAxisMergeStatus(ub_inputs, ub_outputs, merge_info, tpipe);
GE_ASSERT_TRUE(status, "GenerateVectorizedAxisMergeStatus failed");
SaveApiLoopAxisParams(merge_info, param);
return ge::SUCCESS;
}
Status WhereRegApiCall::GenerateNoLoopCase(const TPipe &tpipe, const std::vector<ascir::AxisId> ¤t_axis,
const Tensor &x1, const Tensor &x2, const Tensor &x3, const Tensor &y,
const std::string &x2_scalar, const std::string &x3_scalar,
std::stringstream &ss) const {
const bool x2_is_scalar_scene = x2.IsAnyScalar();
const bool x3_is_scalar_scene = x3.IsAnyScalar();
ss << this->api_name_ << "(" << y << "[" << tpipe.tiler.TensorVectorizedOffset(current_axis, y) << "], ";
ss << x1 << "[" << tpipe.tiler.TensorVectorizedOffset(current_axis, x1) << "], ";
if (x2_is_scalar_scene) {
ss << x2_scalar << ", ";
} else {
ss << x2 << "[" << tpipe.tiler.TensorVectorizedOffset(current_axis, x2) << "], ";
}
if (x3_is_scalar_scene) {
ss << x3_scalar << ", ";
} else {
ss << x3 << "[" << tpipe.tiler.TensorVectorizedOffset(current_axis, x3) << "], ";
}
ss << x1.actual_size << ");" << std::endl;
return ge::SUCCESS;
}
Status WhereRegApiCall::GenerateBothScalarCase(const TPipe &tpipe, const ApiLoopParams ¶m,
const Tensor &x1, const Tensor &y,
const std::string &scalar_local_blk_tensor_name_x2,
const std::string &scalar_local_blk_tensor_name_x3,
std::stringstream &ss) const {
stringstream ss1;
size_t output_strides_size = param.outputs_strides[0].size();
std::vector<ascir::SizeExpr> inner_output_strides(param.outputs_strides[0].begin(),
param.outputs_strides[0].begin() + output_strides_size - 1);
std::string output_inner_offset = output_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner_output_strides);
uint32_t index = 0U;
size_t input0_strides_size = param.inputs_strides[index].size();
std::vector<ascir::SizeExpr> inner0_input_strides(param.inputs_strides[index].begin(),
param.inputs_strides[index].begin() + input0_strides_size - 1);
std::string input0_inner_offset = input0_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner0_input_strides);
ss1 << this->api_name_ << "<true, true>(" << y << "[" << output_inner_offset << "], " << x1 << "["
<< input0_inner_offset << "], "
<< scalar_local_blk_tensor_name_x2 << "[0], "
<< scalar_local_blk_tensor_name_x3 << "[0], "
<< "{static_cast<uint16_t>(" << param.outer_repeats[param.outer_repeats.size() - 1]
<< "), static_cast<uint16_t>(" << tpipe.tiler.ActualSize(param.cal_count) << ")}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.output_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.input_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.output_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "});" << std::endl;
if (param.outer_repeats.size() == 1) {
ss << ss1.str();
} else {
CreateComputeNodeOuterFor(param.outer_repeats, ss1, ss, 0);
}
return ge::SUCCESS;
}
Status WhereRegApiCall::GenerateX2ScalarCase(const TPipe &tpipe, const ApiLoopParams ¶m,
const Tensor &x1, const Tensor &x3, const Tensor &y,
const std::string &scalar_local_blk_tensor_name_x2,
std::stringstream &ss) const {
stringstream ss1;
size_t output_strides_size = param.outputs_strides[0].size();
std::vector<ascir::SizeExpr> inner_output_strides(param.outputs_strides[0].begin(),
param.outputs_strides[0].begin() + output_strides_size - 1);
std::string output_inner_offset = output_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner_output_strides);
uint32_t index = 0U;
size_t input0_strides_size = param.inputs_strides[index].size();
std::vector<ascir::SizeExpr> inner0_input_strides(param.inputs_strides[index].begin(),
param.inputs_strides[index].begin() + input0_strides_size - 1);
std::string input0_inner_offset = input0_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner0_input_strides);
index++;
size_t input2_strides_size = param.inputs_strides[index].size();
std::vector<ascir::SizeExpr> inner2_input_strides(param.inputs_strides[index].begin(),
param.inputs_strides[index].begin() + input2_strides_size - 1);
std::string input2_inner_offset = input2_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner2_input_strides);
ss1 << this->api_name_ << "<true, false>(" << y << "[" << output_inner_offset << "], " << x1 << "["
<< input0_inner_offset << "], "
<< scalar_local_blk_tensor_name_x2 << "[0], "
<< x3 << "[" << input2_inner_offset << "], "
<< "{static_cast<uint16_t>(" << param.outer_repeats[param.outer_repeats.size() - 1]
<< "), static_cast<uint16_t>(" << tpipe.tiler.ActualSize(param.cal_count) << ")}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.output_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.input_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.output_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "});" << std::endl;
if (param.outer_repeats.size() == 1) {
ss << ss1.str();
} else {
CreateComputeNodeOuterFor(param.outer_repeats, ss1, ss, 0);
}
return ge::SUCCESS;
}
Status WhereRegApiCall::GenerateX3ScalarCase(const TPipe &tpipe, const ApiLoopParams ¶m,
const Tensor &x1, const Tensor &x2, const Tensor &y,
const std::string &scalar_local_blk_tensor_name_x3,
std::stringstream &ss) const {
stringstream ss1;
size_t output_strides_size = param.outputs_strides[0].size();
std::vector<ascir::SizeExpr> inner_output_strides(param.outputs_strides[0].begin(),
param.outputs_strides[0].begin() + output_strides_size - 1);
std::string output_inner_offset = output_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner_output_strides);
uint32_t index = 0U;
size_t input0_strides_size = param.inputs_strides[index].size();
std::vector<ascir::SizeExpr> inner0_input_strides(param.inputs_strides[index].begin(),
param.inputs_strides[index].begin() + input0_strides_size - 1);
std::string input0_inner_offset = input0_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner0_input_strides);
index++;
size_t input1_strides_size = param.inputs_strides[index].size();
std::vector<ascir::SizeExpr> inner1_input_strides(param.inputs_strides[index].begin(),
param.inputs_strides[index].begin() + input1_strides_size - 1);
std::string input1_inner_offset = input1_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner1_input_strides);
ss1 << this->api_name_ << "<false, true>(" << y << "[" << output_inner_offset << "], " << x1 << "["
<< input0_inner_offset << "], "
<< x2 << "[" << input1_inner_offset << "], "
<< scalar_local_blk_tensor_name_x3 << "[0], "
<< "{static_cast<uint16_t>(" << param.outer_repeats[param.outer_repeats.size() - 1]
<< "), static_cast<uint16_t>(" << tpipe.tiler.ActualSize(param.cal_count) << ")}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.output_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.input_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.output_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "});" << std::endl;
if (param.outer_repeats.size() == 1) {
ss << ss1.str();
} else {
CreateComputeNodeOuterFor(param.outer_repeats, ss1, ss, 0);
}
return ge::SUCCESS;
}
Status WhereRegApiCall::GenerateNormalCase(const TPipe &tpipe, const ApiLoopParams ¶m,
const Tensor &x1, const Tensor &x2, const Tensor &x3, const Tensor &y,
std::stringstream &ss) const {
stringstream ss1;
size_t output_strides_size = param.outputs_strides[0].size();
std::vector<ascir::SizeExpr> inner_output_strides(param.outputs_strides[0].begin(),
param.outputs_strides[0].begin() + output_strides_size - 1);
std::string output_inner_offset = output_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner_output_strides);
uint32_t index = 0U;
size_t input0_strides_size = param.inputs_strides[index].size();
std::vector<ascir::SizeExpr> inner0_input_strides(param.inputs_strides[index].begin(),
param.inputs_strides[index].begin() + input0_strides_size - 1);
std::string input0_inner_offset = input0_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner0_input_strides);
index++;
size_t input1_strides_size = param.inputs_strides[index].size();
std::vector<ascir::SizeExpr> inner1_input_strides(param.inputs_strides[index].begin(),
param.inputs_strides[index].begin() + input1_strides_size - 1);
std::string input1_inner_offset = input1_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner1_input_strides);
index++;
size_t input2_strides_size = param.inputs_strides[index].size();
std::vector<ascir::SizeExpr> inner2_input_strides(param.inputs_strides[index].begin(),
param.inputs_strides[index].begin() + input2_strides_size - 1);
std::string input2_inner_offset = input2_strides_size == 1 ? "0" : CalcInnerOffset(tpipe, inner2_input_strides);
ss1 << this->api_name_ << "<false, false>(" << y << "[" << output_inner_offset << "], " << x1 << "["
<< input0_inner_offset << "], "
<< x2 << "[" << input1_inner_offset << "], "
<< x3 << "[" << input2_inner_offset << "], "
<< "{static_cast<uint16_t>(" << param.outer_repeats[param.outer_repeats.size() - 1]
<< "), static_cast<uint16_t>(" << tpipe.tiler.ActualSize(param.cal_count) << ")}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.output_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.input_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "}, "
<< "{static_cast<uint16_t>(" << tpipe.tiler.Size(param.output_second_to_last_stride)
<< "), static_cast<uint16_t>(1)" << "});" << std::endl;
if (param.outer_repeats.size() == 1) {
ss << ss1.str();
} else {
CreateComputeNodeOuterFor(param.outer_repeats, ss1, ss, 0);
}
return ge::SUCCESS;
}
Status WhereRegApiCall::Generate(const TPipe &tpipe, const std::vector<ascir::AxisId> ¤t_axis,
const std::vector<std::reference_wrapper<const Tensor>> &inputs,
const std::vector<std::reference_wrapper<const Tensor>> &outputs, std::string &result) const {
const Tensor *x1 = nullptr;
const Tensor *x2 = nullptr;
const Tensor *x3 = nullptr;
const Tensor *y = nullptr;
GE_CHK_STATUS_RET(PrepareInputsAndOutputs(inputs, outputs, x1, x2, x3, y));
ApiLoopParams param;
GE_CHK_STATUS_RET(GenerateLoopParams(*x1, *x2, *x3, *y, tpipe, param));
stringstream ss;
const bool x2_is_scalar_scene = x2->IsAnyScalar();
const bool x3_is_scalar_scene = x3->IsAnyScalar();
std::string x2_dtype_name;
std::string x3_dtype_name;
GE_CHK_STATUS_RET(Tensor::DtypeName(x2->dtype, x2_dtype_name),
"Codegen get data type:%d failed", static_cast<int32_t>(x2->dtype));
GE_CHK_STATUS_RET(Tensor::DtypeName(x3->dtype, x3_dtype_name),
"Codegen get data type:%d failed", static_cast<int32_t>(x3->dtype));
GE_ASSERT_TRUE(x2_dtype_name == x3_dtype_name, "x2_dtype_name:%s, x3_dtype_name:%s",
x2_dtype_name.c_str(), x3_dtype_name.c_str());
std::string x2_scalar = x2->need_gen_get_value_of_ub_scalar ? ("(" + x2_dtype_name + ")" + x2->ub_scalar_name) :
x2->Str();
std::string x3_scalar = x3->need_gen_get_value_of_ub_scalar ? ("(" + x3_dtype_name + ")" + x3->ub_scalar_name) :
x3->Str();
if (param.outer_repeats.size() == 0) {
GE_CHK_STATUS_RET(GenerateNoLoopCase(tpipe, current_axis, *x1, *x2, *x3, *y, x2_scalar, x3_scalar, ss));
} else if (x2_is_scalar_scene && x3_is_scalar_scene) {
std::string scalar_local_blk_tensor_name_x2 = x2->IsConstScalar() ? "local_blk_tensor_of_" + x2->name : x2->name;
std::string scalar_local_blk_tensor_name_x3 = x3->IsConstScalar() ? "local_blk_tensor_of_" + x3->name : x3->name;
GE_CHK_STATUS_RET(GenerateBothScalarCase(tpipe, param, *x1, *y, scalar_local_blk_tensor_name_x2,
scalar_local_blk_tensor_name_x3, ss));
} else if (x2_is_scalar_scene) {
std::string scalar_local_blk_tensor_name_x2 = x2->IsConstScalar() ? "local_blk_tensor_of_" + x2->name : x2->name;
GE_CHK_STATUS_RET(GenerateX2ScalarCase(tpipe, param, *x1, *x3, *y, scalar_local_blk_tensor_name_x2, ss));
} else if (x3_is_scalar_scene) {
std::string scalar_local_blk_tensor_name_x3 = x3->IsConstScalar() ? "local_blk_tensor_of_" + x3->name : x3->name;
GE_CHK_STATUS_RET(GenerateX3ScalarCase(tpipe, param, *x1, *x2, *y, scalar_local_blk_tensor_name_x3, ss));
} else {
GE_CHK_STATUS_RET(GenerateNormalCase(tpipe, param, *x1, *x2, *x3, *y, ss));
}
result = ss.str();
return ge::SUCCESS;
}
static ApiCallRegister<WhereRegApiCall> register_where_reg_api_call("WhereRegApiCall");
}