* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file operation_common.h
* \brief
*/
#ifndef INTERFACE_MAIN_OPERATION_COMMON_H
#define INTERFACE_MAIN_OPERATION_COMMON_H
#include "interface/utils/common.h"
#include "tilefwk/symbolic_scalar.h"
#include "tilefwk/tensor.h"
#include "interface/tensor/logical_tensor.h"
#include "interface/operation/opcode.h"
#include "tilefwk/error_code.h"
#include "tilefwk/platform.h"
#include "ir/span.h"
namespace npu::tile_fwk {
#define CALL(n, ...) Tensor##n(__VA_ARGS__)
#define RETURN_CALL(n, ...) return Tensor##n(__VA_ARGS__)
#define CHECK_OP(cond) \
(cond) ? 0 : \
npu::tile_fwk::Error(__func__, __FILE__, __LINE__) = \
npu::tile_fwk::ErrorMessage() << "CHECK FAILED: " #cond << "\n" \
<< "location: " << pypto::ir::Span::Current().ToString() << "\n"
constexpr int32_t NUM_VALUE_0 = 0;
constexpr int32_t NUM_VALUE_1 = 1;
constexpr int32_t NUM_VALUE_2 = 2;
constexpr int32_t NUM_VALUE_3 = 3;
constexpr int32_t NUM_VALUE_4 = 4;
constexpr int32_t NUM_VALUE_5 = 5;
constexpr int32_t NUM_VALUE_6 = 6;
constexpr int32_t NUM_VALUE_8 = 8;
constexpr int32_t NUM_VALUE_10 = 10;
constexpr int32_t NUM_VALUE_16 = 16;
constexpr int32_t NUM_VALUE_31 = 31;
constexpr int32_t NUM_VALUE_32 = 32;
constexpr int32_t NUM_VALUE_64 = 64;
constexpr int32_t NUM_VALUE_1024 = 1024;
constexpr double NUM_VALUE_0_5 = 0.5;
constexpr double NUM_VALUE_EPS = 1e-9;
constexpr size_t MIN_TENSOR_DIM = 1;
constexpr size_t MAX_TENSOR_DIM = 4;
struct TileInfo {
std::vector<int64_t> shape;
std::vector<int64_t> offset;
std::vector<SymbolicScalar> validShape;
TileInfo(size_t shapeSize, size_t offsetSize) : shape(shapeSize), offset(offsetSize), validShape(shapeSize) {}
TileInfo(std::vector<int64_t> aShape, std::vector<int64_t> aOffset, std::vector<SymbolicScalar> aValidShape = {})
: shape(std::move(aShape)), offset(std::move(aOffset)), validShape(aValidShape)
{}
};
struct Input {
const Tensor tensor;
TileInfo tileInfo;
};
void CheckTensorDynamicShape(const LogicalTensors iOperands, const Opcode opCode);
std::vector<int> GetBroadCastShape(LogicalTensorPtr& operand1, LogicalTensorPtr& operand2);
std::vector<int> GetBroadcastAxes(const Shape& shape1, const Shape& shape2);
void CheckAxisRange(const Tensor& tensor, int& axis);
void CheckTensorDimRange(const LogicalTensorPtr& tensor, size_t minDim, size_t maxDim, const std::string& opName);
void CheckDstShapeDimRange(const std::vector<int64_t>& shape, size_t minDim, size_t maxDim, const std::string& opName);
void CheckTensorsDimConsistency(const std::vector<LogicalTensorPtr>& tensors, const std::string& opName);
void CheckTensorShapeSize(const LogicalTensorPtr& tensor, const std::string& opName);
void CheckDstShapeSize(const std::vector<int64_t>& shape, const std::string& opName);
void CheckTensorsShapeConsistencyOrBroadcast(const std::vector<LogicalTensorPtr>& tensors, const std::string& opName);
void CheckTensorDataType(
const LogicalTensorPtr& tensor, const std::unordered_set<DataType>& supportedTypes, const std::string& opName);
void CheckTensorDataType(DataType dtype, const std::unordered_set<DataType>& supportedTypes, const std::string& opName);
void CheckSupportedNPUArch(const std::vector<NPUArch>& supportedArches, const std::string& opName);
void CheckTensorsDataTypeConsistency(
const LogicalTensorPtr& tensor1, const LogicalTensorPtr& tensor2, const std::string& opName);
void CheckTensorsDataTypeConsistency(const LogicalTensorPtr& tensor, const Element& element, const std::string& opName);
void CheckTensorsDataTypeConsistency(const std::vector<LogicalTensorPtr>& tensors, const std::string& opName);
void CheckTensorsFormatConsistency(
const LogicalTensorPtr& tensor1, const LogicalTensorPtr& tensor2, const std::string& opName);
void CheckTensorsFormatConsistency(const std::vector<LogicalTensorPtr>& tensors, const std::string& opName);
void CheckBinaryInputTensors(
const LogicalTensorPtr& tensor1, const LogicalTensorPtr& tensor2, const std::string& opName);
const std::unordered_set<DataType>& GetSupportedDataTypesByArch(
const std::unordered_set<DataType>& a2a3Types, const std::unordered_set<DataType>& a5Types);
using TiledFuncType = std::function<void(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)>;
class TiledFuncRegistry {
private:
TiledFuncRegistry() = default;
~TiledFuncRegistry() = default;
public:
static TiledFuncRegistry& GetInstance()
{
static TiledFuncRegistry instance;
return instance;
}
void RegisterTiledFunc(const Opcode opcode, TiledFuncType func) { tiledFuncs_[opcode] = func; }
TiledFuncType GetTiledFunc(const Opcode opcode)
{
auto it = tiledFuncs_.find(opcode);
if (it == tiledFuncs_.end()) {
return nullptr;
}
return tiledFuncs_[opcode];
}
private:
std::unordered_map<Opcode, TiledFuncType> tiledFuncs_;
};
#define REGISTER_OPERATION_TILED_FUNC(OpCoreStr, OpType, FuncName) \
class OpCoreStr##TiledRegister { \
public: \
OpCoreStr##TiledRegister() { TiledFuncRegistry::GetInstance().RegisterTiledFunc(OpType, FuncName); } \
}; \
static OpCoreStr##TiledRegister OpCoreStr##_tiled_register
enum class AIVCore;
class OpSyncQueue {
public:
OpSyncQueue() {}
OpSyncQueue(
PipeType pipeId, PipeType trigPipeId, CoreType coreType, CoreType tirgCoreType, int evid, AIVCore setAivCore,
AIVCore waitAivCore)
: pipeId_(pipeId),
trigPipeId_(trigPipeId),
coreType_(coreType),
trigCoreType_(tirgCoreType),
eventId_(evid),
setAivCore_(setAivCore),
waitAivCore_(waitAivCore)
{}
OpSyncQueue(int bufid, const std::vector<int>& offset, CoreType coreType, CoreType tirgCoreType)
: coreType_(coreType), trigCoreType_(tirgCoreType), gMBufId(bufid), offset_(offset)
{}
PipeType pipeId_{PIPE_S};
PipeType trigPipeId_{PIPE_S};
CoreType coreType_{CoreType::AIV};
CoreType trigCoreType_{CoreType::AIV};
int eventId_{0};
AIVCore setAivCore_{-1};
AIVCore waitAivCore_{-1};
int gMBufId{0};
std::vector<int> offset_;
Json ToJson() const
{
Json j;
j["pipe_id"] = pipeId_;
j["trig_pipe"] = trigPipeId_;
j["core_type"] = static_cast<int>(coreType_);
j["tri_core_type"] = static_cast<int>(trigCoreType_);
j["event_id"] = eventId_;
j["set_aiv_core"] = static_cast<int>(setAivCore_);
j["wait_aiv_core"] = static_cast<int>(waitAivCore_);
j["gm_buf_id"] = gMBufId;
j["offset"] = offset_;
return j;
}
void FromJson(const Json& j)
{
pipeId_ = static_cast<PipeType>(j["pipe_id"].get<int>());
trigPipeId_ = static_cast<PipeType>(j["trig_pipe"].get<int>());
coreType_ = static_cast<CoreType>(j["core_type"].get<int>());
trigCoreType_ = static_cast<CoreType>(j["tri_core_type"].get<int>());
eventId_ = j["event_id"].get<int>();
setAivCore_ = static_cast<AIVCore>(j["set_aiv_core"].get<int>());
waitAivCore_ = static_cast<AIVCore>(j["wait_aiv_core"].get<int>());
gMBufId = j["gm_buf_id"].get<int>();
offset_ = j["offset"].get<std::vector<int>>();
}
std::string Dump() const
{
std::ostringstream oss;
oss << GetPipeTypeDict().Find(pipeId_) << "," << GetPipeTypeDict().Find(trigPipeId_) << "," << eventId_;
return oss.str();
}
};
}
#endif