* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "reduce_api_call_base.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "ascir_node_param/ascir_node_param.h"
#include "attr_utils.h"
#include "ascir_ops.h"
#include "common/checker.h"
#include "common/ge_common/debug/log.h"
#include "common_utils.h"
#include "graph/ascendc_ir/utils/asc_tensor_utils.h"
#include "api_call/utils/api_call_factory.h"
#include "api_call/utils/api_call_utils.h"
namespace reduce_base {
using namespace codegen;
using namespace af::ops;
using namespace af::ascir_op;
using namespace ascgen_utils;
namespace {
struct ReduceShadowCheckResult {
std::vector<std::string> mismatch_fields;
std::vector<std::string> mismatch_details;
};
struct ReduceLegacyBehavior {
codegen::ReducePattern pattern{codegen::ReducePattern::kUnknown};
bool need_multi_reduce{false};
bool has_reuse{false};
bool is_reuse_source{false};
};
struct ReduceShadowBuildContext {
af::AscNodePtr node;
std::string api_name;
const TPipe *tpipe{nullptr};
const Tensor *input{nullptr};
const Tensor *output{nullptr};
ascir::AxisId axis_id{-1};
};
struct ReduceShadowCheckContext {
std::string api_name;
ReduceShadowBuildContext build;
ReduceLegacyBehavior legacy;
};
void LogMissingAscirNodeParams(const ReduceShadowCheckContext &ctx) {
const auto &node = ctx.build.node;
if (node == nullptr) {
GELOGW("[ASCIR_PARAM_TRACE] Codegen missing api[%s], node is null.", ctx.api_name.c_str());
return;
}
const auto op_desc = node->GetOpDesc();
const auto owner_graph = node->GetOwnerComputeGraph();
const auto node_name = node->GetName();
const auto node_type = node->GetType();
const auto graph_name = owner_graph == nullptr ? std::string("<null>") : owner_graph->GetName();
GELOGW("[ASCIR_PARAM_TRACE] Codegen missing api[%s], node[%s], type[%s], graph[%s], node_ptr[%p], "
"op_desc_ptr[%p], owner_graph_ptr[%p].",
ctx.api_name.c_str(), node_name.c_str(), node_type.c_str(), graph_name.c_str(),
static_cast<void *>(node.get()), static_cast<void *>(op_desc.get()), static_cast<void *>(owner_graph.get()));
}
std::string BoolToString(bool value) { return value ? "true" : "false"; }
std::string ExprToString(const ge::Expression &expr) {
return expr.IsValid() ? std::string(expr.Str().get()) : std::string("<invalid>");
}
std::string ReducePatternToString(codegen::ReducePattern pattern) {
switch (pattern) {
case codegen::ReducePattern::kUnknown:
return "Unknown";
case codegen::ReducePattern::kAR:
return "AR";
case codegen::ReducePattern::kRA:
return "RA";
default:
return "Invalid(" + std::to_string(static_cast<int32_t>(pattern)) + ")";
}
}
std::string ReduceMergeModeToString(codegen::ReduceMergeMode mode) {
switch (mode) {
case codegen::ReduceMergeMode::kUnknown:
return "Unknown";
case codegen::ReduceMergeMode::kNone:
return "None";
case codegen::ReduceMergeMode::kCopy:
return "Copy";
case codegen::ReduceMergeMode::kMergeByElementwise:
return "MergeByElementwise";
default:
return "Invalid(" + std::to_string(static_cast<int32_t>(mode)) + ")";
}
}
void AddMismatchDetail(const std::string &field, const std::string &lhs_name, const std::string &lhs_value,
const std::string &rhs_name, const std::string &rhs_value,
ReduceShadowCheckResult &result) {
std::stringstream ss;
ss << field << "{" << lhs_name << "=" << lhs_value << "," << rhs_name << "=" << rhs_value << "}";
result.mismatch_details.emplace_back(ss.str());
}
void AddMismatch(bool matched, const std::string &field, const std::string &parser_value,
const std::string &shadow_value, ReduceShadowCheckResult &result) {
if (matched) {
return;
}
result.mismatch_fields.emplace_back(field);
AddMismatchDetail(field, "parser", parser_value, "shadow", shadow_value, result);
}
bool FindAxisIndex(const std::vector<ascir::AxisId> &axis_ids, ascir::AxisId axis_id, size_t &axis_index) {
const auto iter = std::find(axis_ids.begin(), axis_ids.end(), axis_id);
if (iter == axis_ids.end()) {
return false;
}
axis_index = static_cast<size_t>(std::distance(axis_ids.begin(), iter));
return true;
}
std::vector<ge::Expression> GetVectorizedRepeats(const Tensor &tensor) {
std::vector<ge::Expression> repeats;
repeats.reserve(tensor.vectorized_axis.size());
for (const auto axis_id : tensor.vectorized_axis) {
size_t axis_index = 0U;
const auto repeat = FindAxisIndex(tensor.axis, axis_id, axis_index) ? tensor.axis_size[axis_index] : ge::Symbol(1U);
repeats.emplace_back(repeat);
}
return repeats;
}
std::vector<ge::Expression> GetOutputDims(const Tensor &tensor) {
return tensor.axis_size.empty() ? GetVectorizedRepeats(tensor) : tensor.axis_size;
}
codegen::ReducePattern GetCodegenReducePattern(const Tensor &output) {
if (output.vectorized_strides.empty()) {
return codegen::ReducePattern::kUnknown;
}
return output.vectorized_strides.back() == af::sym::kSymbolZero ? codegen::ReducePattern::kAR
: codegen::ReducePattern::kRA;
}
ge::Status FillReduceReuseSource(const af::AscNodePtr &node, codegen::ReduceReuseInfo &reuse) {
GE_ASSERT_NOTNULL(node);
auto node_in_anchor = node->GetInDataAnchor(0);
GE_ASSERT_NOTNULL(node_in_anchor);
auto peer_out_anchor = node_in_anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(peer_out_anchor);
const auto &in_node = std::dynamic_pointer_cast<af::AscNode>(peer_out_anchor->GetOwnerNode());
GE_ASSERT_NOTNULL(in_node);
reuse.valid = true;
reuse.is_reuse_source = in_node->GetOutAllNodes().size() == 1UL;
return ge::SUCCESS;
}
ge::Status BuildReduceInput(const ReduceShadowBuildContext &ctx, codegen::ReduceSpecificParamBuildInput &input) {
GE_ASSERT_NOTNULL(ctx.tpipe);
GE_ASSERT_NOTNULL(ctx.input);
GE_ASSERT_NOTNULL(ctx.output);
input.node_name = ctx.node == nullptr ? ctx.api_name : ctx.node->GetName();
input.reduce_type = ctx.api_name;
input.input_repeats = GetVectorizedRepeats(*ctx.input);
input.input_strides = ctx.input->vectorized_strides;
input.output_dims = GetOutputDims(*ctx.output);
input.output_strides = ctx.output->vectorized_strides;
const auto dtype_size = ge::GetSizeByDataType(ctx.input->dtype);
GE_ASSERT_TRUE(dtype_size > 0, "Invalid reduce input dtype size, dtype[%d].", static_cast<int32_t>(ctx.input->dtype));
input.dtype_size = static_cast<uint32_t>(dtype_size);
input.pattern = GetCodegenReducePattern(*ctx.output);
input.need_multi_reduce = IsNeedMultiReduce(ctx.tpipe->tiler, *ctx.input, *ctx.output, ctx.axis_id);
input.merge_times = input.need_multi_reduce ? ctx.tpipe->tiler.GetAxis(ctx.axis_id).size_expr : ge::Symbol(1U);
if (ctx.node != nullptr) {
GE_ASSERT_SUCCESS(FillReduceReuseSource(ctx.node, input.reuse));
}
return ge::SUCCESS;
}
std::string JoinMismatchFields(const std::vector<std::string> &fields) {
std::stringstream ss;
for (size_t i = 0U; i < fields.size(); ++i) {
if (i != 0U) {
ss << ",";
}
ss << fields[i];
}
return ss.str();
}
std::string JoinMismatchDetails(const std::vector<std::string> &details) {
std::stringstream ss;
for (size_t i = 0U; i < details.size(); ++i) {
if (i != 0U) {
ss << ";";
}
ss << details[i];
}
return ss.str();
}
ge::Status BuildCodegenShadowReduceSpecificParams(const ReduceShadowBuildContext &ctx,
codegen::ReduceSpecificParams ¶ms) {
codegen::ReduceSpecificParamBuildInput input;
GE_ASSERT_SUCCESS(BuildReduceInput(ctx, input));
GE_ASSERT_SUCCESS(codegen::BuildReduceSpecificParams(input, params));
return ge::SUCCESS;
}
ReduceShadowCheckContext BuildReduceShadowCheckContext(const ReduceCodegenShadowCheckInput &input) {
ReduceShadowCheckContext ctx;
ctx.api_name = input.api_name;
ctx.build.node = input.node;
ctx.build.api_name = input.api_name;
ctx.build.tpipe = input.tpipe;
ctx.build.input = input.input;
ctx.build.output = input.output;
ctx.build.axis_id = input.axis_id;
if (input.output != nullptr) {
ctx.legacy.pattern = GetCodegenReducePattern(*input.output);
}
if (input.tpipe != nullptr && input.input != nullptr && input.output != nullptr) {
ctx.legacy.need_multi_reduce = IsNeedMultiReduce(input.tpipe->tiler, *input.input, *input.output, input.axis_id);
}
ctx.legacy.has_reuse = input.has_reuse;
ctx.legacy.is_reuse_source = input.is_reuse_source;
return ctx;
}
void CompareReduceSpecificParams(const codegen::ReduceSpecificParams &parser_params,
const codegen::ReduceSpecificParams &shadow_params, ReduceShadowCheckResult &result) {
AddMismatch(parser_params.valid == shadow_params.valid, "valid", BoolToString(parser_params.valid),
BoolToString(shadow_params.valid), result);
AddMismatch(parser_params.reduce_type == shadow_params.reduce_type, "reduce_type", parser_params.reduce_type,
shadow_params.reduce_type, result);
AddMismatch(parser_params.pattern == shadow_params.pattern, "pattern", ReducePatternToString(parser_params.pattern),
ReducePatternToString(shadow_params.pattern), result);
AddMismatch(parser_params.merge_mode == shadow_params.merge_mode, "merge_mode",
ReduceMergeModeToString(parser_params.merge_mode), ReduceMergeModeToString(shadow_params.merge_mode),
result);
AddMismatch(ExpressEq(parser_params.merge_size, shadow_params.merge_size), "merge_size",
ExprToString(parser_params.merge_size), ExprToString(shadow_params.merge_size), result);
AddMismatch(ExpressEq(parser_params.merge_times, shadow_params.merge_times), "merge_times",
ExprToString(parser_params.merge_times), ExprToString(shadow_params.merge_times), result);
AddMismatch(parser_params.reuse.valid == shadow_params.reuse.valid, "reuse.valid",
BoolToString(parser_params.reuse.valid), BoolToString(shadow_params.reuse.valid), result);
AddMismatch(parser_params.reuse.is_reuse_source == shadow_params.reuse.is_reuse_source, "reuse.source",
BoolToString(parser_params.reuse.is_reuse_source), BoolToString(shadow_params.reuse.is_reuse_source),
result);
AddMismatch(parser_params.merged_dims.valid == shadow_params.merged_dims.valid, "merged_dims.valid",
BoolToString(parser_params.merged_dims.valid), BoolToString(shadow_params.merged_dims.valid), result);
if (parser_params.merged_dims.valid && shadow_params.merged_dims.valid) {
AddMismatch(ExpressEq(parser_params.merged_dims.first, shadow_params.merged_dims.first), "merged_dims.first",
ExprToString(parser_params.merged_dims.first), ExprToString(shadow_params.merged_dims.first), result);
AddMismatch(ExpressEq(parser_params.merged_dims.last, shadow_params.merged_dims.last), "merged_dims.last",
ExprToString(parser_params.merged_dims.last), ExprToString(shadow_params.merged_dims.last), result);
}
}
void CompareReduceLegacyBehavior(const codegen::ReduceSpecificParams &parser_params, const ReduceLegacyBehavior &legacy,
ReduceShadowCheckResult &result) {
if (parser_params.pattern != legacy.pattern) {
result.mismatch_fields.emplace_back("legacy.pattern");
AddMismatchDetail("legacy.pattern", "parser", ReducePatternToString(parser_params.pattern), "legacy",
ReducePatternToString(legacy.pattern), result);
}
const bool need_multi_reduce = parser_params.merge_mode != codegen::ReduceMergeMode::kNone;
if (need_multi_reduce != legacy.need_multi_reduce) {
result.mismatch_fields.emplace_back("legacy.merge_mode");
AddMismatchDetail(
"legacy.merge_mode", "parser",
BoolToString(need_multi_reduce) + ",parser.merge_mode=" + ReduceMergeModeToString(parser_params.merge_mode),
"legacy", BoolToString(legacy.need_multi_reduce), result);
}
if (legacy.has_reuse) {
if (!parser_params.reuse.valid) {
result.mismatch_fields.emplace_back("legacy.reuse.valid");
AddMismatchDetail("legacy.reuse.valid", "parser", BoolToString(parser_params.reuse.valid), "legacy", "true",
result);
}
if (parser_params.reuse.is_reuse_source != legacy.is_reuse_source) {
result.mismatch_fields.emplace_back("legacy.reuse.source");
AddMismatchDetail("legacy.reuse.source", "parser", BoolToString(parser_params.reuse.is_reuse_source), "legacy",
BoolToString(legacy.is_reuse_source), result);
}
}
}
}
static void ReplaceSS(std::string& str, const std::string& oldSubStr, const std::string& newSubStr)
{
size_t pos = 0;
while ((pos = str.find(oldSubStr, pos)) != std::string::npos) {
str.replace(pos, oldSubStr.length(), newSubStr);
pos += newSubStr.length();
}
return;
}
static void ReplaceSSWithSwappingFirstAndLast(std::string firstAndFirst_actual, std::string lastAndLast_actual, const bool &isAllAxisReduce, std::stringstream &ss)
{
if (isAllAxisReduce) {
ReplaceSS(firstAndFirst_actual, "first", "last");
ReplaceSS(lastAndLast_actual, "last", "first");
}
ss << firstAndFirst_actual << ";\n" << lastAndLast_actual << ";\n";
return;
}
size_t GetAxesNumExceptZeroTail(const Tensor &src, const Tensor &dst)
{
size_t num_axes = src.vectorized_axis.size();
for (; num_axes > 0; num_axes--) {
if (src.vectorized_strides[num_axes - 1] != 0 || dst.vectorized_strides[num_axes - 1] != 0) {
break;
}
}
return num_axes;
}
返回AR或者RA的fist_size和last_size;
为了避免由于存在src[i].axis_size=1(此时strides=0)导致误判为R轴,所以在遍历过程中过滤掉了src[i].axis_size为1的情况;
用last_not_1_axis_size_index来记录上一次的axis_size != 1的位置。
*/
void ReduceMergedSizeCodeGen(const TPipe &tpipe, std::stringstream &ss, const Tensor &src, const Tensor &dst,
bool is_tail) {
std::stringstream first;
std::stringstream first_actual;
std::stringstream last;
std::stringstream last_actual;
first << "{\nuint32_t " << (is_tail ? "first_tail" : "first") << " = 1";
first_actual << "uint32_t " << (is_tail ? "first_tail_actual" : "first_actual") << " = 1";
last << "uint32_t " << (is_tail ? "last_tail" : "last") << " = 1";
last_actual << "uint32_t " << (is_tail ? "last_tail_actual" : "last_actual") << " = 1";
std::string dtype_name;
Tensor::DtypeName(dst.dtype, dtype_name);
bool is_first = true;
const size_t num_axes = GetAxesNumExceptZeroTail(src, dst);
ascir::SizeExpr lastNonZeroStride = Zero;
size_t last_not_1_axis_size_index = 0xFFFFFFFF;
bool isAllAxisReduce = true;
for (size_t i = 0; i < num_axes; ++i) {
isAllAxisReduce = isAllAxisReduce && (dst.vectorized_strides[i] == 0);
const auto axis = tpipe.tiler.GetAxis(src.vectorized_axis[i]);
const auto axis_size = tpipe.tiler.AxisSize(src.vectorized_axis[i]);
if (i == num_axes - 1U) {
if (is_first && !isAllAxisReduce) {
last << " * " << KernelUtils::SizeAlign() << "(" << axis_size << ", 32/sizeof(" << dtype_name << "))";
last_actual << " * " << KernelUtils::SizeAlign() << "(" << axis.actual_size << ", 32/sizeof(" << dtype_name << "))";
} else if (is_first && isAllAxisReduce) {
first << " * " << KernelUtils::SizeAlign() << "(" << axis_size << ", 32/sizeof(" << dtype_name << "))";
first_actual << " * " << KernelUtils::SizeAlign() << "(" << axis.actual_size << ", 32/sizeof(" << dtype_name << "))";
} else {
last << " * " << tpipe.tiler.Size(lastNonZeroStride);
last_actual << " * " << tpipe.tiler.Size(lastNonZeroStride);
}
break;
}
if (src.vectorized_strides[i] == 0 && dst.vectorized_strides[i] == 0) {
continue;
}
if (is_first && last_not_1_axis_size_index != 0xFFFFFFFF) {
is_first = !((dst.vectorized_strides[i] == 0 && dst.vectorized_strides[last_not_1_axis_size_index] != 0) ||
(dst.vectorized_strides[i] != 0 && dst.vectorized_strides[last_not_1_axis_size_index] == 0));
}
if (!is_first) {
if (src.vectorized_strides[i] != Zero) {
lastNonZeroStride = src.vectorized_strides[i];
}
last << " * " << axis_size;
last_actual << " * " << axis.actual_size;
} else {
first << " * " << axis_size;
first_actual << " * " << axis.actual_size;
last_not_1_axis_size_index = i;
}
}
ReplaceSSWithSwappingFirstAndLast(first.str() + ";\n" + first_actual.str(), last.str() + ";\n" + last_actual.str(), isAllAxisReduce, ss);
}
bool IsNeedMultiReduce(const Tiler &tiler, const Tensor &input, const Tensor &output, ascir::AxisId axis_id) {
int64_t total_count = 0;
int64_t valid_count = 0;
std::function<void(ascir::AxisId)> recursive_functor = [&tiler, &input, &output, &total_count,
&valid_count, &recursive_functor](ascir::AxisId id) {
Axis axis = tiler.GetAxis(id);
auto pos = std::find(output.axis.begin(), output.axis.end(), id);
if (pos != output.axis.end()) {
size_t diff = pos - output.axis.begin();
total_count++;
valid_count = output.axis_strides[diff] == Zero && input.axis_strides[diff] != Zero ? valid_count + 1 : valid_count;
return;
}
for (size_t i = 0; i < axis.from.size(); i++) {
auto from_axis = tiler.GetAxis(axis.from[i]);
bool need_recursive = from_axis.type != Axis::Type::kAxisTypeOriginal;
auto pos = std::find(output.axis.begin(), output.axis.end(), axis.from[i]);
if (pos != output.axis.end()) {
size_t diff = pos - output.axis.begin();
total_count++;
valid_count = output.axis_strides[diff] == Zero && input.axis_strides[diff] != Zero ? valid_count + 1 : valid_count;
return;
}
if (need_recursive) {
for (size_t j = 0; j < from_axis.from.size(); j++) {
recursive_functor(from_axis.from[j]);
}
}
}
};
recursive_functor(axis_id);
return total_count == valid_count;
}
void ReduceMeanCodeGen(std::string &dtype_name, const TPipe &tpipe, const Tensor &src, const Tensor &dst,
std::stringstream &ss) {
std::set<ascir::AxisId> r_from_axis;
for (size_t i = 0; i < dst.axis_strides.size(); i++) {
if ((src.axis_strides[i] != 0 || src.axis_size[i] != 1) && dst.axis_strides[i] == 0) {
auto axis_id = dst.axis[i];
std::function<void(int)> collect_original_axes = [&tpipe, &r_from_axis, &collect_original_axes](int current_axis_id) {
auto axis = tpipe.tiler.GetAxis(current_axis_id);
if (axis.type == ascir::Axis::Type::kAxisTypeOriginal) {
r_from_axis.insert(current_axis_id);
return;
}
for (auto from_axis_id : axis.from) {
collect_original_axes(from_axis_id);
}
};
collect_original_axes(axis_id);
}
}
ss << "const float dimr_recip = 1.0f / (";
uint32_t count = 0;
for (auto axis_id : r_from_axis) {
if (count++ == 0) {
ss << tpipe.tiler.AxisSize(axis_id);
} else {
ss << " * " << tpipe.tiler.AxisSize(axis_id);
}
}
ss << ");" << std::endl;
ss << "Muls(" << dst << ", " << dst << ", " << "dimr_recip, " << KernelUtils::SizeAlign() << "(" << "reduce_dim_a" << ", 32 / sizeof(" << dtype_name << ")));" << std::endl;
return;
}
void GetIsArAndPattern(const Tensor &y, bool &isAr, std::string &reduce_pattern)
{
isAr = (y.vectorized_strides.back() == 0);
std::unordered_map<bool, std::string> reduce_pattern_map = {{true, "AscendC::Pattern::Reduce::AR"},
{false, "AscendC::Pattern::Reduce::RA"}};
reduce_pattern = reduce_pattern_map[isAr];
return;
}
void CheckReduceSpecificParamsForCodegen(const ReduceCodegenShadowCheckInput &input) {
const auto ctx = BuildReduceShadowCheckContext(input);
const auto params = ascir_param::GetAscirNodeParams(ctx.build.node);
if (params == nullptr) {
LogMissingAscirNodeParams(ctx);
return;
}
if (params->status != ascir_param::ParamBuildStatus::kBuilt) {
GELOGW("[ASCIR_PARAM] AscirNodeParams status is not built, api[%s].", ctx.api_name.c_str());
return;
}
const auto *parser_params = ascir_param::GetSpecificParams<ascir_param::ReduceNodeParams>(*params);
if (parser_params == nullptr) {
GELOGW("[ASCIR_PARAM] Reduce parser params missing, api[%s].", ctx.api_name.c_str());
return;
}
codegen::ReduceSpecificParams shadow_params;
const auto shadow_status = BuildCodegenShadowReduceSpecificParams(ctx.build, shadow_params);
if (shadow_status != ge::SUCCESS) {
GELOGW("[ASCIR_PARAM] Build codegen reduce shadow params failed, api[%s].", ctx.api_name.c_str());
return;
}
ReduceShadowCheckResult result;
AddMismatch(params->api_name == ctx.api_name, "api_name", params->api_name, ctx.api_name, result);
CompareReduceSpecificParams(parser_params->canonical_params, shadow_params, result);
CompareReduceLegacyBehavior(parser_params->canonical_params, ctx.legacy, result);
if (!result.mismatch_fields.empty()) {
GELOGW("[ASCIR_PARAM] Reduce param shadow check mismatch, api[%s], fields[%s], details[%s].", ctx.api_name.c_str(),
JoinMismatchFields(result.mismatch_fields).c_str(), JoinMismatchDetails(result.mismatch_details).c_str());
}
}
bool IsTilerLastReduceAxis(const Tensor &tensor) {
int count = 0;
for (auto stride : tensor.vectorized_strides) {
if (stride == 0) {
count++;
}
}
return count == 1;
}
void ReduceInitCodeGen(const Tensor &x, const Tensor &y, const int &type_value, std::stringstream &ss, const TPipe &tpipe, const std::string &dtype_name)
{
if (x.isAr) {
std::string is_last_axis_str = IsTilerLastReduceAxis(y) ? "true" : "false";
ss << "ReduceInit<" << dtype_name << ", " << type_value << ", " << is_last_axis_str << ">("
<< x << ", " << "first_actual" << ", last" << ", last_actual, " << tpipe.tiler.GetAxis(x.vectorized_axis.back()).actual_size
<< ");" << std::endl;
ss << "AscendC::PipeBarrier<PIPE_V>();" << std::endl;
}
return;
}
void ReduceDimACodeGen(const Tensor &x, const std::string &apiName, std::stringstream &ss)
{
if (apiName == "Mean") {
if (x.isAr) {
ss << "reduce_dim_a = first_actual;" << std::endl;
} else {
ss << "reduce_dim_a = last_actual;" << std::endl;
}
}
return;
}
void GenLastTwoRAxisSizeProductCode(const Tensor &x, const Tensor &y,
const TPipe &tpipe, std::stringstream &ss) {
std::vector<std::pair<ascir::AxisId, size_t>> r_axes;
for (size_t i = 0; i < x.axis.size(); ++i) {
bool is_r_axis = (y.axis_strides[i] == Zero && x.axis_strides[i] != Zero);
if (is_r_axis) {
r_axes.push_back({x.axis[i], i});
}
}
if (r_axes.size() >= 2) {
ascir::AxisId last_r_axis = r_axes[r_axes.size() - 1].first;
ascir::AxisId second_last_r_axis = r_axes[r_axes.size() - 2].first;
ss << "// 最后两个R轴大小的乘积,作为每个核处理的R轴块大小" << std::endl;
ss << "int64_t r_axis_block_size = "
<< tpipe.tiler.AxisSize(last_r_axis)
<< " * "
<< tpipe.tiler.AxisSize(second_last_r_axis)
<< ";" << std::endl;
} else if (r_axes.size() == 1) {
ss << "// 只有一个R轴,使用其大小作为块大小" << std::endl;
ss << "int64_t r_axis_block_size = " << tpipe.tiler.AxisSize(r_axes[0].first) << ";" << std::endl;
} else {
ss << "// 没有R轴,使用默认值" << std::endl;
ss << "int64_t r_axis_block_size = 1;" << std::endl;
}
}
Status GetDtypeNameForReduce(const std::string &api_name, const Tensor &x, const Tensor &y, std::string &dtype_name) {
if (api_name == "ArgMax" || api_name == "ArgMaxMultiRPhase1" || api_name == "ArgMaxMultiRPhase2") {
GE_CHK_STATUS_RET(Tensor::DtypeName(x.dtype, dtype_name), "Codegen get data type:%d failed", static_cast<int32_t>(x.dtype));
} else {
GE_CHK_STATUS_RET(Tensor::DtypeName(y.dtype, dtype_name), "Codegen get data type:%d failed", static_cast<int32_t>(y.dtype));
}
return ge::SUCCESS;
}
void GenAccumulatedOffsetDeclForArgMax(const std::string &api_name, const Tensor &x, const Tensor &y,
const TPipe &tpipe, std::stringstream &ss) {
if (api_name == "ArgMax") {
ss << "static int64_t accumulated_offset = 0;" << std::endl;
} else if (api_name == "ArgMaxMultiRPhase1") {
GenLastTwoRAxisSizeProductCode(x, y, tpipe, ss);
ss << "static int64_t accumulated_offset = 0;" << std::endl;
}
}
}