* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ATC_OPCOMPILER_INC_TENSOR_ENGINE_TBE_OP_INFO_H_
#define ATC_OPCOMPILER_INC_TENSOR_ENGINE_TBE_OP_INFO_H_
#include <map>
#include <vector>
#include <memory>
#include "graph/node.h"
#include "tensor_engine/fusion_types.h"
#include "tensor_engine/fusion_constants.h"
#include "tensor_engine/tbe_attr_value.h"
#include "tensor_engine/tbe_op_param.h"
namespace te {
class TbeOpInfo : public std::enable_shared_from_this<TbeOpInfo> {
public:
TbeOpInfo(const std::string &name,
const std::string &moduleName,
const std::string &opType,
const std::string &engineType)
: op_name_(name),
op_module_name_(moduleName),
op_type_(opType),
engine_type_(engineType),
core_type_(engineType),
op_L1_space_(-1),
op_unknown_shape_(false),
op_use_int64_(false),
is_dynamic_impl_(false),
build_type_(ACCURATELY_BUILD),
dynamicRankType_(DynamicRankType::NOT_SUPPORT),
rangeLimitType_(RangeLimitType::UNLIMITED),
isSingleOpScene_(false),
isOriSingleOpScene_(false),
needPreCompile_(false),
opJitCompile_(JitCompileType::DEFAULT),
vectorCoreType_(VectorCoreType::DEFAULT),
ubFusionSpaceSize_(-1)
{
}
~TbeOpInfo()
{
}
void GetName(std::string& name) const
{
name = op_name_;
}
const std::string& GetName() const
{
return op_name_;
}
void GetRealName(std::string& realName) const
{
realName = op_real_name_;
}
const std::string& GetRealName() const
{
return op_real_name_;
}
void SetRealName(const std::string& realName)
{
op_real_name_ = realName;
}
void GetOpType(std::string &optype) const
{
optype = op_type_;
}
const std::string& GetOpType() const
{
return op_type_;
}
void GetOpFileName(std::string& fileName) const
{
fileName = op_file_name_;
}
const std::string& GetOpFileName() const
{
return op_file_name_;
}
void SetOpFileName(const std::string &fileName)
{
op_file_name_ = fileName;
}
void GetOpFuncName(std::string& funcName) const
{
funcName = op_func_name_;
}
const std::string& GetOpFuncName() const
{
return op_func_name_;
}
void SetOpFuncName(const std::string &funcName)
{
op_func_name_ = funcName;
}
const std::string& GetEngineType() const
{
return engine_type_;
}
void SetOpCoreType(const std::string &coreType)
{
core_type_ = coreType;
}
void GetOpCoreType(std::string &coreType) const
{
coreType = core_type_;
}
const std::string& GetOpCoreType() const
{
return core_type_;
}
void GetCheckSupportedFunc(std::string& checkName) const
{
checkName = op_check_supported_name_;
}
const std::string& GetCheckSupportedFunc() const
{
return op_check_supported_name_;
}
void SetCheckSupportedFunc(const std::string &checkName)
{
op_check_supported_name_ = checkName;
}
* @brief: normalize function name
* @param [in] funcName op function name
* @return [out] bool succ or fail for normalizing
*/
void NormalizeFuncName(std::string& funcName) const;
bool GetModuleName(std::string& moduleName) const;
void GetFuncName(std::string& funcName) const;
void GetOpImplSwitch(std::string& opImplSwitch) const
{
opImplSwitch = opImplSwitch_;
}
const std::string& GetOpImplSwitch() const
{
return opImplSwitch_;
}
void SetOpImplSwitch(const std::string& opImplSwitch)
{
opImplSwitch_ = opImplSwitch;
}
void SetOpsPathNamePrefix(const std::string &opsPathNamePrefix) {
opsPathNamePrefix_ = opsPathNamePrefix;
}
const std::string& GetOpsPathNamePrefix() const
{
return opsPathNamePrefix_;
}
JitCompileType GetOpJitCompile() const
{
return opJitCompile_;
}
void SetOpJitCompile(const JitCompileType &jitCompile)
{
opJitCompile_ = jitCompile;
}
void AddInput(const TbeOpParam& param)
{
inputs_.push_back(param);
}
void GetInputs(std::vector<TbeOpParam>& inputs) const
{
inputs = inputs_;
}
const std::vector<TbeOpParam>& GetInputs() const
{
return inputs_;
}
std::vector<TbeOpParam>& MutableInputs()
{
return inputs_;
}
void SetInputs(const std::vector<TbeOpParam> &inputs)
{
inputs_.assign(inputs.begin(), inputs.end());
}
void AddOutput(const TbeOpParam& param)
{
outputs_.push_back(param);
}
void GetOutputs(std::vector<TbeOpParam> &outputs) const
{
outputs = outputs_;
}
const std::vector<TbeOpParam>& GetOutputs() const
{
return outputs_;
}
std::vector<TbeOpParam>& MutableOutputs()
{
return outputs_;
}
void SetOutputs(const std::vector<TbeOpParam> &outputs)
{
outputs_.assign(outputs.begin(), outputs.end());
}
void AddAttrValue(const TbeAttrValue& value)
{
attr_values_.push_back(value);
}
void GetAttrValues(std::vector<TbeAttrValue>& attrValues) const
{
attrValues = attr_values_;
}
const std::vector<TbeAttrValue>& GetAttrValues() const
{
return attr_values_;
}
void SetPrivateAttrs(std::vector<TbeAttrValue>& privateAttrValues)
{
private_attr_values_ = privateAttrValues;
}
const std::vector<TbeAttrValue>& GetPrivateAttrValues() const
{
return private_attr_values_;
}
void GetPattern(std::string& pattern) const
{
pattern = pattern_;
}
const std::string& GetPattern() const
{
return pattern_;
}
void SetPattern(const std::string &pattern)
{
pattern_ = pattern;
}
const std::string& GetOpImplMode() const
{
return op_impl_mode_;
}
void SetOpImplMode(const std::string& opImplMode)
{
op_impl_mode_ = opImplMode;
}
void ClearOpImplMode()
{
op_impl_mode_.clear();
}
void GetExtraParams(std::string& extraParams) const
{
extraParams = extra_params_;
}
const std::string& GetExtraParams() const
{
return extra_params_;
}
void SetExtraParams(const std::string& extraParams)
{
extra_params_ = extraParams;
}
const std::string& GetHashedExtraParams() const
{
return hashed_extra_params_;
}
void SetHashedExtraParams(const std::string& hasedExtraParams)
{
hashed_extra_params_ = hasedExtraParams;
}
bool IsDynamicImpl() const
{
return is_dynamic_impl_;
}
void SetDynamicImpl(const bool& is_dynamic_impl)
{
is_dynamic_impl_ = is_dynamic_impl;
}
void GetKernelName(std::string& kernelName) const
{
kernelName = kernel_name_;
}
const std::string& GetKernelName() const
{
return kernel_name_;
}
void SetKernelName(const std::string &kernelName)
{
kernel_name_ = kernelName;
}
void GetL1Space(int64_t& l1Space) const
{
l1Space = op_L1_space_;
}
int64_t GetL1Space() const
{
return op_L1_space_;
}
void SetL1Space(const int64_t l1Space)
{
op_L1_space_ = l1Space;
}
void SetIsUnknownShape(const bool unknown)
{
op_unknown_shape_ = unknown;
}
bool GetIsUnknownShape() const
{
return op_unknown_shape_;
}
void SetFlagUseInt64(const bool opUseInt64)
{
op_use_int64_ = opUseInt64;
}
bool GetFlagUseInt64() const
{
return op_use_int64_;
}
void GetOptions(std::map<std::string, std::string>& options) const
{
options = options_;
}
const std::map<std::string, std::string>& GetOptions() const
{
return options_;
}
void SetOptions(const std::map<std::string, std::string> &options)
{
options_ = options;
}
bool GetOption(const std::string& key, std::string &value) const
{
auto optionsFind = options_.find(key);
if (optionsFind != options_.end()) {
value = optionsFind->second;
return true;
}
return false;
}
void SetOption(const std::string& key, const std::string &value)
{
options_[key] = value;
}
void SetBuildType(BUILD_TYPE build_type)
{
build_type_ = build_type;
}
BUILD_TYPE GetBuildType() const
{
return build_type_;
}
void SetRangeLimitType(const RangeLimitType &rangeLimitType)
{
rangeLimitType_ = rangeLimitType;
}
RangeLimitType GetRangeLimitType() const
{
return rangeLimitType_;
}
void SetNode(const ge::NodePtr &nodePtr)
{
nodePtr_ = nodePtr;
}
void GetNode(ge::NodePtr &nodePtr) const
{
nodePtr = nodePtr_;
}
const ge::NodePtr& GetNode() const
{
return nodePtr_;
}
bool IsSupportDynamicRank() const
{
return dynamicRankType_ != DynamicRankType::NOT_SUPPORT;
}
DynamicRankType GetDynamicRankType() const
{
return dynamicRankType_;
}
void SetDynamicRankType(const DynamicRankType &dynamicRankType)
{
dynamicRankType_ = dynamicRankType;
}
void SetOpStorePattern(const std::string &opStorePattern)
{
opStorePattern_ = opStorePattern;
}
void GetOpStorePattern(std::string &opStorePattern) const
{
opStorePattern = opStorePattern_;
}
const std::string& GetOpStorePattern() const
{
return opStorePattern_;
}
void SetPrebuildPattern(const std::string &prebuildPattern)
{
prebuildPattern_ = prebuildPattern;
}
void GetPrebuildPattern(std::string &prebuildPattern) const
{
prebuildPattern = prebuildPattern_;
}
const std::string& GetPrebuildPattern() const
{
return prebuildPattern_;
}
const std::string& GetOpDebugConfig() const
{
return opDebugConfig_;
}
void SetOpDebugConfig(const std::string &opDebugConfig)
{
opDebugConfig_ = opDebugConfig;
}
void SetSingleOpSceneFlag(const bool &isSingleOpScene)
{
isSingleOpScene_ = isSingleOpScene;
}
void GetSingleOpSceneFlag(bool &isSingleOpScene) const
{
isSingleOpScene = isSingleOpScene_;
}
bool IsSingleOpScene() const
{
return isSingleOpScene_;
}
void SetOriSingleOpSceneFlag(const bool &isOriSingleOpScene)
{
isOriSingleOpScene_ = isOriSingleOpScene;
}
void GetOriSingleOpSceneFlag(bool &isOriSingleOpScene) const
{
isOriSingleOpScene = isOriSingleOpScene_;
}
bool IsOriSingleOpScene() const
{
return isOriSingleOpScene_;
}
bool IsNeedPreCompile() const
{
return needPreCompile_;
}
void SetNeedPreCompile(const bool needPreCompile)
{
needPreCompile_ = needPreCompile;
}
VectorCoreType GetVectorCoreType() const
{
return vectorCoreType_;
}
void SetVectorCoreType(const VectorCoreType vectorCoreType)
{
vectorCoreType_ = vectorCoreType;
}
void SetCustAicNum(const std::string &custAicNum)
{
cust_aic_num_ = custAicNum;
}
const std::string& GetCustAicNum() const
{
return cust_aic_num_;
}
void SetCustAivNum(const std::string &custAivNum)
{
cust_aiv_num_ = custAivNum;
}
const std::string& GetCustAivNum() const
{
return cust_aiv_num_;
}
int64_t GetUBSpaceSize() const
{
return ubFusionSpaceSize_;
}
void SetUBSpaceSize(const int64_t size)
{
ubFusionSpaceSize_ = size;
}
bool operator==(TbeOpInfo& rObject);
private:
std::string op_name_;
std::string op_real_name_;
std::string op_module_name_;
std::string op_type_;
std::string engine_type_;
std::string op_file_name_;
std::string op_func_name_;
std::string op_check_supported_name_;
std::string core_type_;
std::string opImplSwitch_;
std::string opsPathNamePrefix_;
std::string kernel_name_;
std::string pattern_;
std::string op_impl_mode_;
std::string extra_params_;
std::string hashed_extra_params_;
int64_t op_L1_space_;
bool op_unknown_shape_;
bool op_use_int64_;
bool is_dynamic_impl_;
BUILD_TYPE build_type_;
std::vector<TbeOpParam> inputs_;
std::vector<TbeOpParam> outputs_;
std::vector<TbeAttrValue> attr_values_;
std::vector<TbeAttrValue> private_attr_values_;
std::map<std::string, std::string> options_;
DynamicRankType dynamicRankType_;
RangeLimitType rangeLimitType_;
std::string opStorePattern_;
std::string prebuildPattern_;
std::string opDebugConfig_;
bool isSingleOpScene_;
bool isOriSingleOpScene_;
bool needPreCompile_;
JitCompileType opJitCompile_;
VectorCoreType vectorCoreType_;
ge::NodePtr nodePtr_{nullptr};
int64_t ubFusionSpaceSize_;
std::string cust_aic_num_;
std::string cust_aiv_num_;
};
using TbeOpInfoPtr = std::shared_ptr<TbeOpInfo>;
using ConstTbeOpInfoPtr = std::shared_ptr<const TbeOpInfo>;
}
#endif