* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef __CODEGEN_API_PARAM_H__
#define __CODEGEN_API_PARAM_H__
#include <cstddef>
#include <cstdint>
#include <variant>
#include <string>
#include <vector>
#include <map>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include "ge_common/ge_api_types.h"
#include "ascir.h"
namespace codegen {
struct MergeAxesInfo {
std::vector<std::string> repeats;
std::vector<std::string> gm_strides;
std::vector<std::string> ub_strides;
};
struct DataCopyBaseParams {
std::string block_count;
std::string block_len;
std::string src_stride;
std::string dst_stride;
};
struct DataCopyLoopModeParams {
std::vector<std::string> loop_sizes;
std::vector<std::string> loop_src_strides;
std::vector<std::string> loop_dst_strides;
};
struct DmaSpecificParams {
MergeAxesInfo merge_axes_info;
DataCopyBaseParams data_copy_params;
DataCopyLoopModeParams loop_mode_params;
};
enum class ReducePattern {
kUnknown,
kAR,
kRA,
};
enum class ReduceMergeMode {
kUnknown,
kNone,
kCopy,
kMergeByElementwise,
};
struct ReduceMergedDims {
bool valid{false};
ge::Expression first{ge::Symbol(1U)};
ge::Expression last{ge::Symbol(1U)};
};
struct ReduceReuseInfo {
bool valid{false};
bool is_reuse_source{false};
};
struct ReduceSpecificParams {
bool valid{false};
std::string reduce_type;
ReducePattern pattern{ReducePattern::kUnknown};
ReduceMergedDims merged_dims;
ReduceMergeMode merge_mode{ReduceMergeMode::kUnknown};
ge::Expression merge_size{ge::Symbol(0U)};
ge::Expression merge_times{ge::Symbol(1U)};
ReduceReuseInfo reuse;
};
struct ReduceSpecificParamBuildInput {
std::string node_name;
std::string reduce_type;
std::vector<ge::Expression> input_repeats;
std::vector<ge::Expression> input_strides;
std::vector<ge::Expression> output_dims;
std::vector<ge::Expression> output_strides;
uint32_t dtype_size{0U};
ReducePattern pattern{ReducePattern::kUnknown};
bool need_multi_reduce{false};
ge::Expression merge_times{ge::Symbol(1U)};
ReduceReuseInfo reuse;
};
struct BroadcastSpecificParams {
std::string broadcast_type;
};
struct TransposeSpecificParams {
std::vector<std::string> output_dims;
std::vector<std::string> input_strides;
std::vector<std::string> output_strides;
};
struct ReduceMergedAxisPlan {
bool valid{false};
bool is_all_axis_reduce{false};
bool align_last_axis{false};
bool use_last_non_zero_stride{false};
bool use_zero_stride{false};
size_t aligned_axis_index{0U};
size_t last_non_zero_stride_index{0U};
std::vector<size_t> first_axis_indices;
std::vector<size_t> last_axis_indices;
};
struct ReduceMergedShape {
bool valid{false};
ge::Expression first{ge::Symbol(1U)};
ge::Expression last{ge::Symbol(1U)};
};
enum class ReduceMergedAxisAction {
kSkip,
kFirstAxis,
kLastAxis,
kAlignFirst,
kAlignLast,
kLastStride,
};
struct ReduceMergedAxisState {
bool is_first{true};
bool is_all_axis_reduce{true};
bool use_zero_stride{false};
size_t last_not_one_axis_index{static_cast<size_t>(-1)};
size_t last_non_zero_stride_index{static_cast<size_t>(-1)};
};
bool IsReduceMergedZeroTail(bool src_stride_zero, bool dst_stride_zero);
bool IsReduceMergedSameSide(bool current_dst_stride_zero, bool last_dst_stride_zero);
ReduceMergedAxisAction UpdateReduceMergedAxisState(bool src_stride_zero, bool dst_stride_zero,
bool last_not_one_dst_stride_zero, bool is_last_axis,
size_t axis_index, ReduceMergedAxisState &state);
ReduceMergedAxisPlan BuildReduceMergedAxisPlan(const std::vector<bool> &src_stride_zero,
const std::vector<bool> &dst_stride_zero);
ReduceMergedShape BuildReduceMergedShape(const std::vector<ge::Expression> &src_repeats,
const std::vector<ge::Expression> &src_strides,
const std::vector<ge::Expression> &dst_strides,
uint32_t dtype_size);
ge::Status BuildReduceSpecificParams(const ReduceSpecificParamBuildInput &input, ReduceSpecificParams ¶m);
struct CodegenApiParam;
using CodegenApiParamPtr = std::shared_ptr<CodegenApiParam>;
struct CodegenApiParam {
static ge::Status Register(af::AscNodePtr node, CodegenApiParamPtr api_param);
static CodegenApiParamPtr GetNodeApiParam(af::AscNodePtr node);
std::string api_name;
std::vector<std::string> template_params;
struct TensorParam {
TensorParam() = default;
TensorParam(const std::string& name, bool is_tensor, const std::string& offset)
: name(name), is_tensor(is_tensor), offset(offset) {}
std::string name;
bool is_tensor = true;
std::string offset;
};
std::vector<std::string> outer_loop_axes;
std::vector<std::string> api_pre_process;
std::vector<std::string> api_post_process;
std::vector<TensorParam> input_params;
std::vector<TensorParam> output_params;
std::string tmp_buf_name;
std::string cal_count;
using AnySpecificParams = std::variant<
std::monostate,
DmaSpecificParams,
ReduceSpecificParams,
BroadcastSpecificParams,
TransposeSpecificParams
>;
AnySpecificParams specific_params;
};
}
#endif