* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file operation.h
* \brief
*/
#pragma once
#include <vector>
#include <string>
#include <sstream>
#include <map>
#include <unordered_set>
#include <variant>
#include <nlohmann/json.hpp>
#include "interface/inner/any.h"
#include "interface/inner/pre_def.h"
#include "tilefwk/tilefwk_op.h"
#include "interface/configs/config_manager.h"
#include "tilefwk/data_type.h"
#include "tilefwk/tile_shape.h"
#include "interface/utils/common.h"
#include "opcode.h"
#include "attribute.h"
#include "attr_holder.h"
#include "interface/tensor/logical_tensor.h"
#include "operation_common.h"
#include "ir/stmt.h"
using Json = nlohmann::json;
using namespace pypto;
namespace npu::tile_fwk {
constexpr size_t NON_GROUP = -1;
constexpr int32_t TILE_STR_PREFIX_LEN = 5;
#define AICPU_CALL_NUM_COPYOUT_RESOLVE 1
#define AICPU_CALL_NUM_BIT 16
#define AICPU_CALL_ARG_BIT 16
#define AICPU_CALL_TASK_BIT 32
class OpAttributeKey {
public:
static const std::string aicpuCall;
static const std::string scalar;
static const std::string vectorScalar;
static const std::string dynScalar;
static const std::string isGlobalInput;
static const std::string seqNo;
static const std::string color;
static const std::string isCube;
static const std::string blockPadding;
static const std::string broadcastLastAxis;
static const std::string tilePadding;
static const std::string reshapePadding;
static const std::string shapePadded;
static const std::string needAlloc;
static const std::string dontTouch;
static const std::string tag;
static const std::string distTilingInfo;
static const std::string sameInOut;
static const std::string expandDims;
static const std::string inputCombineAxis;
static const std::string outputCombineAxis;
static const std::string inputCombineAxisDone;
static const std::string outputCombineAxisDone;
static const std::string inplaceIdx;
static const std::string inplaceInfo;
static const std::string cacheMode;
static const std::string panzBlockSize;
static const std::string requiresBoundaryCopy;
static const std::string excludeBufferReuse;
static const std::string bindTensor;
static const std::string startOffset;
static const std::string distOpAttr;
static const std::string isDistCopyOut;
static const std::string subBlockIdx;
static const std::string accumulate;
static const std::string indicesSize;
static const std::string brcbIdx;
static const std::string brcOperand;
static const std::string topkAlgo;
static const std::string quantFlag;
static const std::string loopGroup;
static const std::string loopAxes;
static const std::string loopGroupStart;
static const std::string loopGroupEnd;
static const std::string dynloopGroup;
static const std::string dynloopAxes;
static const std::string dynloopGroupStart;
static const std::string dynloopGroupEnd;
static const std::string lastUse;
static const std::string isUpper;
static const std::string blockSize;
static const std::string transMode;
static const std::string workspaceBaseOffset;
static const std::string copyInMode;
static const std::string copyOutMode;
static const std::string localCopyLocalMode;
static const std::string copyIsNZ;
static const std::string scaleValue;
static const std::string rowPad;
static const std::string ownerRank;
static const std::string maxTileNum;
static const std::string perRankDataShape;
static const std::string perRankTileNum;
static const std::string totalTileNum;
static const std::string precisionType;
static const std::string perm;
static const std::string mxQuantMode;
static const std::string mxQuantAxis;
static const std::string mxQuantPerformanceMode;
static const std::string gmTensorParamIdxInCall;
static const std::string staticValidShape;
static const std::string l1ReuseHashOrder;
static const std::string l1ReuseSubgraphCount;
static const std::string cubeMergeHashOrder;
static const std::string cubeMergeSubgraphCount;
static const std::string vecMergeHashOrder;
static const std::string vecMergeSubgraphCount;
static const std::string atomicAdd;
static const std::string splitMN;
static const std::string reduceCopyPreSubgraphId;
static const std::string postK;
static const std::string postM;
static const std::string postN;
static const std::string filterH;
static const std::string filterW;
static const std::string strideH;
static const std::string strideW;
static const std::string dilationH;
static const std::string dilationW;
static const std::string paddingLeft;
static const std::string paddingRight;
static const std::string paddingTop;
static const std::string paddingBottom;
static const std::string padValue;
static const std::string repeatStride;
static const std::string repeatTime;
static const std::string wStride;
static const std::string srcGmConvValidShape;
static const std::string l0cValidMN;
static const std::string rmwMode;
static const std::string transDataOffset;
};
const std::string CONV_GROUPS_ATTR = "op_attr_groups";
const std::string FAKE_TRANS_IN_FORMAT_ATTR = "op_attr_fake_trans_in_format";
const std::string FAKE_TRANS_OUT_FORMAT_ATTR = "op_attr_fake_trans_out_format";
class ConvOpAttributeKey {
public:
static const std::string cin;
static const std::string cout;
static const std::string paddingLeft;
static const std::string paddingTop;
static const std::string paddingRight;
static const std::string paddingBottom;
static const std::string strideh;
static const std::string stridew;
static const std::string hposX;
static const std::string hsteP;
static const std::string wposX;
static const std::string wstep;
static const std::string hoffsetY;
static const std::string woffsetY;
static const std::string reluType;
static const std::string reluAlpha;
static const std::string clearFlag;
static const std::string hasAccFlag;
static const std::string hasEltFlag;
static const std::string hasBiasFlag;
static const std::string eltBrcbFlag;
static const std::string fmapSrcNum;
static const std::string eltMode;
static const std::string fmapC0;
};
class FixpOpAttributeKey {
public:
static const std::string hStart;
static const std::string hEnd;
static const std::string quantPreScalar;
static const std::string quantPostScalar;
static const std::string antiqScalar;
static const std::string hasQuantPreVector;
static const std::string hasQuantPostVector;
static const std::string hasAntiqVector;
static const std::string fbAddrSpace;
};
class PoolOpAttributeKey {
public:
static const std::string poolh;
static const std::string poolw;
};
class TensorAttributeKey {
public:
static const std::string tensorAddr;
};
enum class FbBufferSpace { QUANT_PRE = 0, RELU_PRE, RELU_POST, QUANT_POST, ANTIQ_ELT, ANTIQ_MTE2 };
enum class AIVCore {
UNSPECIFIED = -1,
AIV0 = 0,
AIV1 = 1
};
inline const BiMap<AIVCore>& GetAIVCoreDict()
{
static BiMap<AIVCore> dict{{
{AIVCore::UNSPECIFIED, "AIC"},
{AIVCore::AIV0, "AIV0"},
{AIVCore::AIV1, "AIV1"},
}};
return dict;
}
struct OperandAttribute {
OperandAttribute() : offset(-1), attr(0) {}
OperandAttribute(int off_in, int attr_in = 0) : offset(off_in), attr(attr_in) {}
bool isAtomic() const { return attr & 0x1; }
int offset;
int attr;
};
class Function;
class Operation : public ir::TensorOpStmt, public std::enable_shared_from_this<Operation>, public AttrHolder {
public:
struct MixSubgraphFields {
int internalSubgraphID{NOT_IN_SUBGRAPH};
AIVCore aivCore{AIVCore::UNSPECIFIED};
};
struct ScopeInfo {
int scopeId{-1};
bool allowParallelMerge{false};
bool allowCrossScopeMerge{false};
int cvFuseId{-1};
ScopeInfo() = default;
explicit ScopeInfo(int id) : scopeId(id) {}
static ScopeInfo FromConfig(const std::vector<int64_t>& config)
{
ScopeInfo info;
info.scopeId = static_cast<int>(config[0]);
info.allowParallelMerge = static_cast<bool>(config[1]);
info.allowCrossScopeMerge = static_cast<bool>(config[2]);
return info;
}
void SetCvFuseId(int id) { cvFuseId = id; }
};
friend class Function;
LogicalTensors iOperand;
LogicalTensors oOperand;
LogicalTensors dependOperand;
int opmagic;
int programFuncMagic_;
int outcastRefcount{0};
int cycles{0};
int cycleStart{0};
int cycleEnd{0};
int cubeDepId{-1};
std::vector<int> inParamLocation_;
std::vector<int> outParamLocation_;
OpSyncQueue syncQueue_;
QueueType queueType;
Operation(Function& cur, Opcode opcode, LogicalTensors iOperands, LogicalTensors oOperands, int opMagic = -1);
Operation(Function& cur, Opcode opcode) : Operation(cur, opcode, {}, {}) {}
Operation(
Function& cur, const std::string& op, const LogicalTensors& input, const LogicalTensors& output)
: Operation(cur, FindOpcode(op), input, output)
{
if (op.substr(0, TILE_STR_PREFIX_LEN) == "TILE_") {
isTileOp_ = true;
}
};
Operation(const Operation& other) = delete;
Operation(Operation&& other) = delete;
Operation& operator=(const Operation& other) = delete;
Operation& operator=(Operation&& other) = delete;
Function* BelongTo() const { return function_; }
const QueueType& GetQueueType() const { return queueType; }
const OpSyncQueue& GetSyncQueue() const { return syncQueue_; }
const TileShape& GetTileShape() const { return tileShape_; }
void UpdateTileShape(const TileShape newTileShape) { tileShape_ = newTileShape; }
TileShape& GetTileShapeForSetting() { return tileShape_; }
[[nodiscard]] std::string GetStringAttribute(const std::string& key) const;
void SetAttribute(const std::string& key, const std::string& value);
[[nodiscard]] bool GetBoolAttribute(const std::string& key) const;
void SetAttribute(const std::string& key, bool value);
[[nodiscard]] int64_t GetIntAttribute(const std::string& key) const;
void SetAttribute(const std::string& key, int64_t value);
void SetAttribute(const std::string& key, int value) { SetAttribute(key, static_cast<int64_t>(value)); }
[[nodiscard]] Element GetElementAttribute(const std::string& key) const;
std::vector<Element> GetVectorElementAttribute(const std::string& key) const;
void SetAttribute(const std::string& key, Element value);
template <typename T = int64_t>
std::vector<T> GetVectorIntAttribute(const std::string& key) const
{
static_assert(std::is_integral_v<T>);
std::vector<int64_t> val;
GetAttr(key, val);
if constexpr (std::is_same_v<T, int64_t>) {
return val;
}
std::vector<T> ret;
for (auto& x : val) {
ret.emplace_back(static_cast<T>(x));
}
return ret;
}
template <typename T = int64_t>
void SetAttribute(const std::string& key, const std::vector<T>& value)
{
static_assert(std::is_integral_v<T>);
if constexpr (std::is_same_v<T, int64_t>) {
SetAttr(key, value);
} else {
std::vector<int64_t> nvalue;
for (auto& x : value) {
nvalue.emplace_back(static_cast<int64_t>(x));
}
SetAttr(key, nvalue);
}
}
[[nodiscard]] CastMode GetCastModeAttribute(const std::string& key) const;
void SetAttribute(const std::string& key, CastMode value);
[[nodiscard]] SymbolicScalar GetSymbolicScalarAttribute(const std::string& key) const;
void SetAttribute(const std::string& key, const SymbolicScalar& value);
[[nodiscard]] std::vector<SymbolicScalar> GetVectorSymbolicScalarAttribute(const std::string& key) const;
void SetAttribute(const std::string& key, const std::vector<SymbolicScalar>& value);
void SetAttribute(const std::string& key, const std::vector<Element>& value);
[[nodiscard]] bool HasAttribute(const std::string& key) const { return HasAttr(key); }
[[nodiscard]] std::map<std::string, npu::tile_fwk::Any> GetAllAttribute() const;
Json DumpJson(bool dumpTensor = true) const;
static std::shared_ptr<Operation> LoadJson(
Function& cur, const std::unordered_map<int, std::shared_ptr<LogicalTensor>>& tensorDict, const Json& opDump);
private:
void DumpOperandsJson(Json& opDump, bool dumpTensor) const;
void DumpCalleeHashJson(Json& opDump) const;
void DumpLocationJson(Json& opDump) const;
void DumpParamLocationJson(Json& opDump) const;
void DumpCallOpInfoJson(Json& opDump) const;
void DumpTileInfoJson(Json& opDump) const;
void DumpAttributesJson(Json& opDump) const;
static void LoadOperandsFromJson(
const Json& opDump, const std::unordered_map<int, std::shared_ptr<LogicalTensor>>& tensorDict,
std::vector<std::shared_ptr<LogicalTensor>>& ioperands, std::vector<std::shared_ptr<LogicalTensor>>& ooperands);
void LoadLocationFromJson(const Json& opDump);
void LoadBasicInfoFromJson(const Json& opDump);
void LoadTileInfoFromJson(const Json& opDump);
void LoadOpAttributeFromJson(const Json& opDump, Opcode opcode);
void LoadExtraInfoFromJson(const Json& opDump);
public:
[[nodiscard]] std::string DumpSSA(const std::string& prefix = "") const;
[[nodiscard]] std::string Dump() const;
[[nodiscard]] int GetOpMagic() const { return opmagic; }
[[nodiscard]] const LogicalTensors& GetIOperands() const { return iOperand; }
LogicalTensors& GetIOperands() { return iOperand; }
[[nodiscard]] const LogicalTensors& GetOOperands() const { return oOperand; }
LogicalTensors& GetOOperands() { return oOperand; }
[[nodiscard]] const LogicalTensors& GetDependOperands() const { return dependOperand; }
LogicalTensors& GetDependOperands() { return dependOperand; }
void AddDependOperand(LogicalTensorPtr dependoperand);
size_t GetInputOperandSize() const { return iOperand.size(); }
size_t GetOutputOperandSize() const { return oOperand.size(); }
size_t GetDependOperandSize() const { return dependOperand.size(); }
LogicalTensorPtr GetInputOperand(const size_t index) const;
LogicalTensorPtr GetOutputOperand(const size_t index) const;
int GetIOperandIndex(const LogicalTensorPtr& ioperand) const;
int GetOOperandIndex(const LogicalTensorPtr& ooperand) const;
void ReplaceInputOperand(const LogicalTensorPtr& originInput, const LogicalTensorPtr& newInput);
void ReplaceOutputOperand(const LogicalTensorPtr& originOutput, const LogicalTensorPtr& newOutput);
void UpdateInputOperand(const size_t index, const std::shared_ptr<LogicalTensor>& newInput);
void UpdateOutputOperand(const size_t index, const std::shared_ptr<LogicalTensor>& newOutput);
std::unordered_set<Operation*> ConsumerOps() const;
std::unordered_set<Operation*> ProducerOps() const;
class OperationComparator {
public:
bool operator()(const Operation* lhs, const Operation* rhs) const
{
return lhs->GetOpMagic() < rhs->GetOpMagic();
}
};
std::set<Operation*, OperationComparator> ConsumerOpsOrdered() const;
std::set<Operation*, OperationComparator> ProducerOpsOrdered() const;
std::unordered_set<Operation*> ConsumerOpsByToken() const;
std::unordered_set<Operation*> ProducerOpsByToken() const;
[[nodiscard]] const std::unordered_set<Operation*>& GetInCtrlOperations() const { return inputCtrlOps; }
[[nodiscard]] const std::unordered_set<Operation*>& GetOutCtrlOperations() const { return outputCtrlOps; }
void ClearInCtrlOperations() { inputCtrlOps.clear(); }
void ClearOutCtrlOperations() { outputCtrlOps.clear(); }
ScopeInfo scopeInfo_;
void SetScopeId(int scopeId) { scopeInfo_.scopeId = scopeId; };
void SetScopeInfo(const ScopeInfo& info) { scopeInfo_ = info; };
const ScopeInfo& GetScopeInfo() const { return scopeInfo_; };
int GetScopeId() const { return scopeInfo_.scopeId; };
bool GetAllowParallelMerge() const { return scopeInfo_.allowParallelMerge; };
bool GetAllowCrossScopeMerge() const { return scopeInfo_.allowCrossScopeMerge; };
int GetCvFuseId() const { return scopeInfo_.cvFuseId; };
void AddInCtrlOperation(Operation& operation);
void RemoveInCtrlOperation(Operation& operation);
void AddOutCtrlOperation(Operation& operation);
void RemoveOutCtrlOperation(Operation& operation);
Operation& CloneOperation(
Function& func, const LogicalTensors& iOperandList, const LogicalTensors& oOperandList) const;
[[nodiscard]] std::string GetOpcodeStr(bool appendTile = false) const;
[[nodiscard]] CoreType GetCoreType() const { return coreType_; }
void SetCoreType(CoreType ct) { coreType_ = ct; }
[[nodiscard]] std::string GetCoreTypeStr() const;
unsigned long ComputeHash();
unsigned long ComputeHashOrderless() const;
[[nodiscard]] bool IsCall() const;
[[nodiscard]] bool IsNOP() const;
bool IsIsolatedOp() const;
bool OnlyHasCtrlEdgeToOp(Operation& op) const;
const std::shared_ptr<OpAttribute>& GetOpAttribute() const { return opAttribute_; }
std::shared_ptr<OpAttribute>& GetOpAttribute() { return opAttribute_; }
void SetOpAttribute(const std::shared_ptr<OpAttribute>& attr)
{
opAttribute_ = attr;
static std::unordered_set<Opcode> copyOpAttrOpTypes{
Opcode::OP_L1_COPY_IN,
Opcode::OP_L1_COPY_OUT,
Opcode::OP_COPY_IN,
Opcode::OP_L0C_TO_L1,
Opcode::OP_L0C_COPY_UB,
Opcode::OP_L0C_COPY_UB_DUAL_DST,
Opcode::OP_L1_TO_BT,
Opcode::OP_L1_TO_FIX_QUANT_PRE,
Opcode::OP_L1_TO_L0A,
Opcode::OP_L1_TO_L0B,
Opcode::OP_L1_TO_L0_AT,
Opcode::OP_L1_TO_L0_BT,
Opcode::OP_UB_COPY_L1,
Opcode::OP_COPY_OUT,
Opcode::OP_RESHAPE_COPY_IN,
Opcode::OP_RESHAPE_COPY_OUT,
Opcode::OP_INDEX_OUTCAST,
Opcode::OP_INDEX_PUT,
Opcode::OP_INDEX_ADD,
Opcode::OP_TRANSPOSE_MOVEIN,
Opcode::OP_TRANSPOSE_MOVEOUT,
Opcode::OP_NCHW2NC1HWC0,
Opcode::OP_NCHW2Fractal_Z,
Opcode::OP_NC1HWC02NCHW,
Opcode::OP_NCDHW2NDC1HWC0,
Opcode::OP_NCDHW2FRACTAL_Z_3D,
Opcode::OP_NDC1HWC02NCDHW,
Opcode::OP_FFN_SCHED,
Opcode::OP_FFN_BATCHING,
Opcode::OP_FFN_COMBINEINFO,
Opcode::OP_FFN_VALIDCNT,
Opcode::OP_SHMEM_PUT,
Opcode::OP_SHMEM_STORE,
Opcode::OP_SHMEM_SIGNAL,
Opcode::OP_SHMEM_GET,
Opcode::OP_SHMEM_LOAD,
Opcode::OP_SHMEM_SET,
Opcode::OP_MOE_DISTRIBUTED_COMBINE_SEND,
Opcode::OP_MOE_DISTRIBUTED_COMBINE_RECEIVE,
Opcode::OP_GATHER_IN_UB,
Opcode::OP_COPY_TO_LOCAL_EXPERT,
Opcode::OP_L1_COPY_IN_A_SCALE,
Opcode::OP_L1_COPY_IN_B_SCALE,
Opcode::OP_L1_TO_L0A_SCALE,
Opcode::OP_L1_TO_L0B_SCALE,
Opcode::OP_L1_COPY_IN_CONV,
Opcode::OP_L0C_COPY_OUT,
Opcode::OP_L0C_COPY_OUT_CONV,
Opcode::OP_L1_RESHAPE_COPY_IN,
Opcode::OP_L0C_RESHAPE_COPY_OUT};
if (copyOpAttrOpTypes.count(opcode_) > 0) {
FE_ASSERT(std::dynamic_pointer_cast<CopyOpAttribute>(opAttribute_) != nullptr);
return;
}
switch (opcode_) {
case Opcode::OP_VIEW: {
FE_ASSERT(std::dynamic_pointer_cast<ViewOpAttribute>(opAttribute_) != nullptr);
break;
}
case Opcode::OP_ASSEMBLE: {
FE_ASSERT(std::dynamic_pointer_cast<AssembleOpAttribute>(opAttribute_) != nullptr);
break;
}
case Opcode::OP_ASSEMBLE_SSA:
case Opcode::OP_ATOMIC_RMW:
FE_ASSERT(
std::dynamic_pointer_cast<AssembleOpAttribute>(opAttribute_) != nullptr ||
std::dynamic_pointer_cast<CopyOpAttribute>(opAttribute_) != nullptr);
break;
case Opcode::OP_BLOCK_CALL:
case Opcode::OP_CALL: {
FE_ASSERT(std::dynamic_pointer_cast<CallOpAttribute>(opAttribute_) != nullptr);
break;
}
case Opcode::OP_CONVERT: {
FE_ASSERT(std::dynamic_pointer_cast<ConvertOpAttribute>(opAttribute_) != nullptr);
break;
}
default:
FE_ASSERT(opAttribute_ == nullptr);
}
}
void SetAssembleOpAttribute(
const std::vector<int64_t>& toOffset, const std::vector<SymbolicScalar>& toDynOffset = {})
{
FE_ASSERT(
opcode_ == Opcode::OP_ASSEMBLE || opcode_ == Opcode::OP_ASSEMBLE_SSA || opcode_ == Opcode::OP_ATOMIC_RMW);
SetOpAttribute(std::make_shared<AssembleOpAttribute>(toOffset, toDynOffset));
}
void ReplaceIOperand(size_t index, std::shared_ptr<LogicalTensor> newTensor);
void ReplaceOOperand(size_t index, std::shared_ptr<LogicalTensor> newTensor);
std::string GetCalleeMagicName() const
{
FE_ASSERT(IsCall());
return std::static_pointer_cast<CallOpAttribute>(opAttribute_)->GetCalleeMagicName();
}
const std::string& GetCalleeBracketName() const
{
return std::static_pointer_cast<CallOpAttribute>(opAttribute_)->GetCalleeBracketName();
}
const FunctionHash& GetCalleeHash() const
{
FE_ASSERT(IsCall() || opcode_ == Opcode::OP_BLOCK_CALL);
auto callop = std::dynamic_pointer_cast<CallOpAttribute>(opAttribute_);
return callop->GetCalleeHash();
}
void EraseInput(const std::shared_ptr<LogicalTensor>& input);
void EraseDependTensor(const std::shared_ptr<LogicalTensor>& dependTensor);
void ReplaceInput(const std::shared_ptr<LogicalTensor>& newInput, const std::shared_ptr<LogicalTensor>& oldInput);
void ReplaceOutput(
const std::shared_ptr<LogicalTensor>& newOutput, const std::shared_ptr<LogicalTensor>& oldOutput);
Opcode GetOpcode() const { return opcode_; }
void SetOpCode(Opcode opcode)
{
opcode_ = opcode;
TensorOpStmt::opcode_ = OpcodeManager::Inst().GetOpcodeStr(opcode);
}
int GetLatency() const { return latency_; }
void UpdateLatency(int latency) { latency_ = latency; }
int GetRemainingTime() const { return remainingTime_; }
void UpdateRemainingTime(int remainingTime) { remainingTime_ = remainingTime; }
int GetSubgraphID() const { return subgraphID_; }
int GetInternalSubgraphID() const
{
return mixSubgraphFields_ ? mixSubgraphFields_->internalSubgraphID : NOT_IN_SUBGRAPH;
}
void UpdateSubgraphID(int subgraphID) { subgraphID_ = subgraphID; }
void SetHashOrderInfo(
const std::string& hashOrderKey, const std::string& countKey, const std::string& hashOrder,
size_t subgraphCount)
{
SetAttr(hashOrderKey, hashOrder);
SetAttr(countKey, static_cast<int64_t>(subgraphCount));
}
std::pair<std::string, size_t> GetHashOrderInfo(const std::string& hashOrderKey, const std::string& countKey) const
{
std::string hashOrder;
int64_t count = 0;
GetAttr(hashOrderKey, hashOrder);
GetAttr(countKey, count);
return {hashOrder, static_cast<size_t>(count)};
}
void UpdateInternalSubgraphID(int internalSubgraphID)
{
ensureMixSubgraphFields();
mixSubgraphFields_->internalSubgraphID = internalSubgraphID;
}
auto GroupID() const { return groupID_; }
void SetGroupID(size_t groupID) const { groupID_ = groupID; }
void SetSemanticLabel(std::shared_ptr<SemanticLabel> label) { semanticLabel_ = label; }
const std::string& GetSemanticLabelStr() const
{
static std::string empty = "";
return semanticLabel_ ? semanticLabel_->label : empty;
}
std::shared_ptr<SemanticLabel> GetSemanticLabel() const { return semanticLabel_; }
void SetAsDeleted() { isDeleted_ = true; }
void SetAsNotDeleted() { isDeleted_ = false; }
[[nodiscard]] bool IsDeleted() const { return isDeleted_; }
[[nodiscard]] AIVCore GetAIVCore() const
{
return mixSubgraphFields_ ? mixSubgraphFields_->aivCore : AIVCore::UNSPECIFIED;
}
[[nodiscard]] std::string GetAIVCoreStr() const
{
if (!mixSubgraphFields_) {
return "UNKNOWN";
}
return GetAIVCoreDict().Find(mixSubgraphFields_->aivCore);
}
void SetAIVCore(AIVCore aivCore)
{
ensureMixSubgraphFields();
mixSubgraphFields_->aivCore = aivCore;
}
void SetSubFuncInvokeInfo(const SubfuncInvokeInfoTy& invokeInfo);
SubfuncInvokeInfoTy& GetSubFuncInvokeInfo()
{
auto callAttr = std::dynamic_pointer_cast<CallOpAttribute>(opAttribute_);
FE_ASSERT(callAttr != nullptr);
return *(callAttr->invokeInfo_);
}
int GetProgramId();
bool IsNeedStackGM() const;
int GetIOpAttrOffset(int pos) const { return iOpAttr_.empty() ? -1 : iOpAttr_[pos].offset; }
int GetOOpAttrOffset(int pos) const { return oOpAttr_.empty() ? -1 : oOpAttr_[pos].offset; }
OperandAttribute GetIOpAttr(int pos) const { return iOpAttr_.empty() ? OperandAttribute() : iOpAttr_[pos]; }
OperandAttribute GetOOpAttr(int pos) const { return oOpAttr_.empty() ? OperandAttribute() : oOpAttr_[pos]; }
void SetIOpAtt(int pos, int offset)
{
if (iOpAttr_.empty()) {
iOpAttr_.resize(iOperand.size());
}
iOpAttr_[pos].offset = offset;
}
void SetOOpAtt(int pos, int offset, int attr = 0)
{
if (oOpAttr_.empty()) {
oOpAttr_.resize(oOperand.size());
}
if (pos < (int)oOpAttr_.size()) {
oOpAttr_[pos] = {offset, attr};
}
}
const std::vector<OperandAttribute>& GetIOpAttr() const { return iOpAttr_; }
const std::vector<OperandAttribute>& GetOOpAttr() const { return oOpAttr_; }
void SetOperandAttr(const std::vector<OperandAttribute>& iAttr, const std::vector<OperandAttribute>& oAttr)
{
iOpAttr_ = iAttr;
oOpAttr_ = oAttr;
}
std::vector<std::reference_wrapper<SymbolicScalar>> GetDynamicAttributeList();
const ir::Span& GetSpan() const { return span_; }
void SetSpan(ir::Span span) { span_ = span; }
const std::vector<std::string>& GetCommentList() const { return commentList_; }
std::vector<std::string>& GetCommentList() { return commentList_; }
private:
void InitCoreTypeAndTileShape(Opcode opcode);
void InitTensorGraphMetadata();
void InitLatency(Opcode opcode);
Opcode opcode_{Opcode::OP_UNKNOWN};
int subgraphID_{NOT_IN_SUBGRAPH};
bool isTileOp_{false};
TileShape tileShape_;
std::shared_ptr<OpAttribute> opAttribute_;
unsigned long operationHash_{0};
int latency_{1};
std::vector<OperandAttribute> iOpAttr_;
std::vector<OperandAttribute> oOpAttr_;
int remainingTime_{INVALID_TIME};
CoreType coreType_{CoreType::MIX};
std::unordered_set<Operation*> inputCtrlOps;
std::unordered_set<Operation*> outputCtrlOps;
mutable size_t groupID_{NON_GROUP};
bool isDeleted_{false};
ir::Span span_;
std::shared_ptr<SemanticLabel> semanticLabel_;
Function* function_;
std::vector<std::string> commentList_;
std::unique_ptr<MixSubgraphFields> mixSubgraphFields_;
void ensureMixSubgraphFields()
{
if (!mixSubgraphFields_) {
mixSubgraphFields_ = std::make_unique<MixSubgraphFields>();
}
}
};
using OperationPtr = std::shared_ptr<Operation>;
struct OperationCmp {
bool operator()(const Operation* lhs, const Operation* rhs) const;
};
}