* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_mix_subgraph_split.cpp
* \brief Unit test for mixSubgraphSplit
* */
#include <gtest/gtest.h>
#include "symbolic_scalar_test_utils.h"
#include "interface/tensor/irbuilder.h"
#include "passes/block_graph_pass/mix_subgraph_split.h"
#include "computational_graph_builder.h"
#include "interface/tensor/irbuilder.h"
namespace npu {
namespace tile_fwk {
constexpr uint64_t programId = 100;
constexpr int MS_NUM16 = 16;
constexpr int MS_NUM3 = 3;
constexpr int MS_NUM10005 = 10005;
namespace test_utils {
void VerifyBasicChecks(Status status, Function& rootFunc);
void VerifyProgramProperties(Function& rootFunc);
void VerifyCallOpsAfterSplit(Function& rootFunc);
void VerifyScopeTypes(Function& rootFunc, int expectedCubeCount, int expectedVectorCount);
void VerifyCleanup(Function& rootFunc, Function* originalMixFunc, Operation* originalCallOp);
}
class MixSubgraphSplitTest : public ::testing::Test {
public:
static void SetUpTestCase() {}
static void TearDownTestCase() {}
void SetUp() override
{
Program::GetInstance().Reset();
config::Reset();
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
config::SetPlatformConfig(KEY_ENABLE_COST_MODEL, false);
}
void TearDown() override {}
protected:
std::shared_ptr<Function> BuildMixFunction(Function* rootFunc, std::vector<int64_t>& tensorShape)
{
auto mixFuncPtr =
std::make_shared<Function>(Program::GetInstance(), "mix_func_illegal", "mix_func_illegal", rootFunc);
mixFuncPtr->SetGraphType(GraphType::BLOCK_GRAPH);
mixFuncPtr->SetFunctionType(FunctionType::STATIC);
auto inputTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto outputTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto tensor1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto tensor2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
mixFuncPtr->inCasts_.push_back(inputTensor);
mixFuncPtr->inCasts_.push_back(tensor1);
mixFuncPtr->inCasts_.push_back(tensor2);
mixFuncPtr->outCasts_.push_back(outputTensor);
mixFuncPtr->outCasts_.push_back(tensor1);
mixFuncPtr->outCasts_.push_back(tensor2);
auto shapeImme = OpImmediate::Specified(tensorShape);
std::vector<int64_t> offsetVec = {0, 0};
auto offsetImme = OpImmediate::Specified(offsetVec);
std::vector<OpImmediate> emptyVec;
auto& copyout1 = IRBuilder().CreateTensorOpStmt(*mixFuncPtr, Opcode::OP_COPY_OUT, {inputTensor}, {tensor1});
copyout1.SetOpAttribute(
std::make_shared<CopyOpAttribute>(MemoryType::MEM_UB, offsetImme, shapeImme, shapeImme, emptyVec));
copyout1.SetOOpAtt(0, 0);
copyout1.UpdateInternalSubgraphID(1);
copyout1.SetAttr(OpAttributeKey::isCube, true);
auto& copyin3 = IRBuilder().CreateTensorOpStmt(*mixFuncPtr, Opcode::OP_COPY_IN, {tensor2}, {outputTensor});
copyin3.SetOpAttribute(
std::make_shared<CopyOpAttribute>(offsetImme, MemoryType::MEM_UB, shapeImme, shapeImme, emptyVec));
copyin3.SetIOpAtt(0, 0);
copyin3.UpdateInternalSubgraphID(1);
copyin3.SetAttr(OpAttributeKey::isCube, true);
auto& copyin2 = IRBuilder().CreateTensorOpStmt(*mixFuncPtr, Opcode::OP_COPY_IN, {tensor1}, {tensor2});
copyin2.SetOpAttribute(
std::make_shared<CopyOpAttribute>(offsetImme, MemoryType::MEM_UB, shapeImme, shapeImme, emptyVec));
copyin2.SetIOpAtt(0, 0);
copyin2.UpdateInternalSubgraphID(0);
copyin2.SetAIVCore(AIVCore::AIV0);
return mixFuncPtr;
}
};
namespace test_utils {
void VerifyBasicChecks(Status status, Function& rootFunc)
{
ASSERT_EQ(status, SUCCESS) << "MixSubgraphSplit should succeed";
auto& programs = rootFunc.programs_;
EXPECT_EQ(programs.size(), 2) << "Should have 2 programs after split (originally 1 Mix, split to 2 leaves)";
EXPECT_NE(programs.find(0), programs.end()) << "Should have program ID 0";
EXPECT_NE(programs.find(1), programs.end()) << "Should have program ID 1";
}
void VerifyProgramProperties(Function& rootFunc)
{
auto& programs = rootFunc.programs_;
for (const auto& [progId, func] : programs) {
ASSERT_NE(func, nullptr) << "Function should not be null";
std::string funcName = func->GetRawName();
EXPECT_NE(funcName.find("leaf"), std::string::npos)
<< "Function name should contain 'leaf' suffix: " << funcName;
EXPECT_EQ(func->GetFunctionType(), FunctionType::STATIC);
EXPECT_EQ(func->GetGraphType(), GraphType::BLOCK_GRAPH);
EXPECT_EQ(func->GetProgramId(), progId) << "Function's program ID should match map key";
auto leafAttr = func->GetLeafFuncAttribute();
ASSERT_NE(leafAttr, nullptr) << "LeafFuncAttribute should be set";
EXPECT_NE(leafAttr->mixId, static_cast<uint64_t>(-1)) << "mixId should be assigned";
EXPECT_NE(leafAttr->mixResourceType, MixResourceType::UNKNOWN) << "mixResourceType should be set";
}
}
void VerifyCallOpsAfterSplit(Function& rootFunc)
{
auto newCallOps = rootFunc.GetCallopList();
EXPECT_EQ(newCallOps.size(), 2) << "Should have 2 call ops after split (1 original * 2 components)";
auto& programs = rootFunc.programs_;
for (auto* newCallOp : newCallOps) {
ASSERT_NE(newCallOp, nullptr) << "CallOp should not be null";
EXPECT_FALSE(newCallOp->IsDeleted()) << "CallOp should not be deleted";
auto newCallAttr = dynamic_cast<CallOpAttribute*>(newCallOp->GetOpAttribute().get());
ASSERT_NE(newCallAttr, nullptr) << "CallOpAttribute should exist";
if (newCallAttr && newCallAttr->invokeInfo_) {
uint64_t progId = newCallAttr->invokeInfo_->GetProgramId();
EXPECT_TRUE(progId == 0 || progId == 1) << "CallOp's program ID should be 0 or 1, got: " << progId;
auto it = programs.find(progId);
EXPECT_NE(it, programs.end()) << "CallOp references non-existent program ID: " << progId;
EXPECT_NE(newCallAttr->wrapId, static_cast<uint64_t>(-1)) << "wrapId should be set";
}
}
}
void VerifyScopeTypes(Function& rootFunc, int expectedCubeCount, int expectedVectorCount)
{
auto& programs = rootFunc.programs_;
int cubeCount = 0;
int vectorCount = 0;
for (const auto& [progId, func] : programs) {
(void)progId;
auto leafAttr = func->GetLeafFuncAttribute();
if (leafAttr) {
if (leafAttr->aivCore == AIVCore::UNSPECIFIED) {
cubeCount++;
} else if (leafAttr->aivCore == AIVCore::AIV0 || leafAttr->aivCore == AIVCore::AIV1) {
vectorCount++;
}
}
}
EXPECT_EQ(cubeCount, expectedCubeCount) << "Cube component count mismatch";
EXPECT_EQ(vectorCount, expectedVectorCount) << "Vector component count mismatch";
}
void VerifyCleanup(Function& rootFunc, Function* originalMixFunc, Operation* originalCallOp)
{
auto& programs = rootFunc.programs_;
bool originalMixFuncStillExists = false;
for (const auto& [progId, func] : programs) {
(void)progId;
if (func == originalMixFunc) {
originalMixFuncStillExists = true;
break;
}
}
EXPECT_FALSE(originalMixFuncStillExists) << "Original mix function should be removed";
auto newCallOps = rootFunc.GetCallopList();
bool originalCallOpStillExists = false;
for (auto* callOpPtr : newCallOps) {
if (callOpPtr == originalCallOp) {
originalCallOpStillExists = true;
break;
}
}
EXPECT_FALSE(originalCallOpStillExists) << "Original callOp should be deleted";
std::set<uint64_t> programIdsFromCallOps;
for (auto* callOpPtr : newCallOps) {
auto newCallOpAttr = dynamic_cast<CallOpAttribute*>(callOpPtr->GetOpAttribute().get());
if (newCallOpAttr && newCallOpAttr->invokeInfo_) {
programIdsFromCallOps.insert(newCallOpAttr->invokeInfo_->GetProgramId());
}
}
for (const auto& [progId, func] : programs) {
(void)func;
EXPECT_NE(programIdsFromCallOps.find(progId), programIdsFromCallOps.end())
<< "Program ID " << progId << " should have corresponding callOp";
}
}
}
void VerifyBasicSplitResult(Status status, Function& rootFunc, Function* originalMixFunc, Operation* originalCallOp)
{
using namespace test_utils;
VerifyBasicChecks(status, rootFunc);
VerifyProgramProperties(rootFunc);
VerifyCallOpsAfterSplit(rootFunc);
VerifyScopeTypes(rootFunc, 1, 1);
VerifyCleanup(rootFunc, originalMixFunc, originalCallOp);
}
Operation& CreateCallOp(
std::shared_ptr<Function>& rootFuncPtr, const uint64_t mixProgramId, const FunctionHash& mixFuncHash)
{
std::vector<int64_t> tensorShape = {MS_NUM16, MS_NUM16};
auto callInTensor1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto callInTensor2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto callInTensor3 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto callOutTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto& callOp =
IRBuilder().CreateTensorOpStmt(*rootFuncPtr, Opcode::OP_CALL, {callInTensor1, callInTensor2, callInTensor3}, {callOutTensor});
auto callAttr = std::make_shared<CallOpAttribute>();
auto invokeInfo = std::make_shared<SubfuncInvokeInfoTy>();
invokeInfo->UpdateProgramSubgraphId(mixProgramId);
callAttr->SetCalleeHash(mixFuncHash);
callAttr->invokeInfo_ = invokeInfo;
std::vector<SymbolicScalar> linearArgs;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 9; j++) {
linearArgs.push_back(IRBuilder().CreateConstInt(static_cast<int64_t>(j + i * 10)));
}
}
callAttr->linearArgList_ = linearArgs;
callOp.SetOpAttribute(callAttr);
callOp.UpdateSubgraphID(mixProgramId);
return callOp;
}
void CreateVectorScope(
std::shared_ptr<Function>& mixFuncPtr, std::shared_ptr<LogicalTensor>& incast3,
std::shared_ptr<LogicalTensor>& outcast1)
{
std::vector<int64_t> tensorShape = {MS_NUM16, MS_NUM16};
auto operations = mixFuncPtr->Operations(false);
auto cubeTensor3 = operations[operations.size() - 1].GetOOperands()[0];
auto vectorTensor1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto vectorTensor2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto shapeImme = OpImmediate::Specified(tensorShape);
std::vector<int64_t> offsetVec = {0, 0};
auto offsetImme = OpImmediate::Specified(offsetVec);
std::vector<OpImmediate> emptyVec;
auto& vectorAdd = IRBuilder().CreateTensorOpStmt(*mixFuncPtr, Opcode::OP_ADD, {cubeTensor3, incast3}, {vectorTensor1});
vectorAdd.SetIOpAtt(1, 5);
vectorAdd.UpdateInternalSubgraphID(1);
vectorAdd.SetAIVCore(AIVCore::AIV0);
auto& vectorSqrt = IRBuilder().CreateTensorOpStmt(*mixFuncPtr, Opcode::OP_SQRT, {vectorTensor1}, {vectorTensor2});
vectorSqrt.UpdateInternalSubgraphID(1);
vectorSqrt.SetAIVCore(AIVCore::AIV0);
auto& vectorCopyOut = IRBuilder().CreateTensorOpStmt(*mixFuncPtr, Opcode::OP_COPY_OUT, {vectorTensor2}, {outcast1});
vectorCopyOut.SetOpAttribute(
std::make_shared<CopyOpAttribute>(MemoryType::MEM_UB, offsetImme, shapeImme, shapeImme, emptyVec));
vectorCopyOut.SetOOpAtt(0, 0);
vectorCopyOut.UpdateInternalSubgraphID(1);
vectorCopyOut.SetAIVCore(AIVCore::AIV0);
}
void CreateCubeScope(
std::shared_ptr<Function>& mixFuncPtr, std::shared_ptr<LogicalTensor>& incast1,
std::shared_ptr<LogicalTensor>& incast2)
{
std::vector<int64_t> tensorShape = {MS_NUM16, MS_NUM16};
auto cubeTensor1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto cubeTensor2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto cubeTensor3 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto shapeImme = OpImmediate::Specified(tensorShape);
std::vector<int64_t> offsetVec = {0, 0};
auto offsetImme = OpImmediate::Specified(offsetVec);
std::vector<OpImmediate> emptyVec;
auto& cubeCopyIn1 = IRBuilder().CreateTensorOpStmt(*mixFuncPtr, Opcode::OP_COPY_IN, {incast1}, {cubeTensor1});
cubeCopyIn1.SetOpAttribute(
std::make_shared<CopyOpAttribute>(offsetImme, MemoryType::MEM_UB, shapeImme, shapeImme, emptyVec));
cubeCopyIn1.SetIOpAtt(0, 0);
cubeCopyIn1.UpdateInternalSubgraphID(0);
cubeCopyIn1.SetAttr(OpAttributeKey::isCube, true);
auto& cubeCopyIn2 = IRBuilder().CreateTensorOpStmt(*mixFuncPtr, Opcode::OP_COPY_IN, {incast2}, {cubeTensor2});
cubeCopyIn2.SetOpAttribute(
std::make_shared<CopyOpAttribute>(offsetImme, MemoryType::MEM_UB, shapeImme, shapeImme, emptyVec));
cubeCopyIn2.SetIOpAtt(0, 0);
cubeCopyIn2.UpdateInternalSubgraphID(0);
cubeCopyIn2.SetAttr(OpAttributeKey::isCube, true);
auto& cubeMul = IRBuilder().CreateTensorOpStmt(*mixFuncPtr, Opcode::OP_A_MUL_B, {cubeTensor1, cubeTensor2}, {cubeTensor3});
cubeMul.UpdateInternalSubgraphID(0);
cubeMul.SetAttr(OpAttributeKey::isCube, true);
}
void SetupMixSubgraphStructure(std::shared_ptr<Function>& mixFuncPtr, FunctionHash& mixFuncHash)
{
std::vector<int64_t> tensorShape = {MS_NUM16, MS_NUM16};
auto incast1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto incast2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto incast3 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
auto outcast1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, tensorShape, CreateTestConstIntVector(tensorShape));
mixFuncPtr->inCasts_.push_back(incast1);
mixFuncPtr->inCasts_.push_back(incast2);
mixFuncPtr->inCasts_.push_back(incast3);
mixFuncPtr->outCasts_.push_back(outcast1);
CreateCubeScope(mixFuncPtr, incast1, incast2);
CreateVectorScope(mixFuncPtr, incast3, outcast1);
mixFuncPtr->ComputeHash();
mixFuncHash = mixFuncPtr->GetFunctionHash();
Program::GetInstance().GetFunctionCache().Insert(mixFuncHash, *mixFuncPtr);
}
std::shared_ptr<Function> CreateMixSubgraph(std::shared_ptr<Function>& rootFuncPtr, FunctionHash& mixFuncHash)
{
auto mixFuncPtr =
std::make_shared<Function>(Program::GetInstance(), "test_mix_func", "test_mix_func", rootFuncPtr.get());
mixFuncPtr->SetGraphType(GraphType::BLOCK_GRAPH);
mixFuncPtr->SetFunctionType(FunctionType::STATIC);
const uint64_t mixProgramId = 100;
rootFuncPtr->programs_[mixProgramId] = mixFuncPtr.get();
SetupMixSubgraphStructure(mixFuncPtr, mixFuncHash);
return mixFuncPtr;
}
std::shared_ptr<Function> CreateTestRootFunction()
{
auto rootFuncPtr = std::make_shared<Function>(Program::GetInstance(), "test_root", "test_root", nullptr);
rootFuncPtr->rootFunc_ = rootFuncPtr.get();
return rootFuncPtr;
}
TEST_F(MixSubgraphSplitTest, TestSingleMixSubgraphBasicSplit)
{
auto rootFuncPtr = CreateTestRootFunction();
FunctionHash mixFuncHash;
auto mixFuncPtr = CreateMixSubgraph(rootFuncPtr, mixFuncHash);
const uint64_t mixProgramId = 100;
auto& callOp = CreateCallOp(rootFuncPtr, mixProgramId, mixFuncHash);
MixSubgraphSplit splitter;
Status status = splitter.RunOnFunction(*rootFuncPtr);
VerifyBasicSplitResult(status, *rootFuncPtr, mixFuncPtr.get(), &callOp);
}
void CreateCallOpForNonMix(
std::shared_ptr<Function>& rootFuncPtr, uint64_t programIdx, FunctionHash hash, const std::vector<int64_t>& shape)
{
auto callInTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto callOutTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto& callOp = IRBuilder().CreateTensorOpStmt(*rootFuncPtr, Opcode::OP_CALL, {callInTensor}, {callOutTensor});
auto callAttr = std::make_shared<CallOpAttribute>();
auto invokeInfo = std::make_shared<SubfuncInvokeInfoTy>();
invokeInfo->UpdateProgramSubgraphId(programIdx);
callAttr->SetCalleeHash(hash);
callAttr->invokeInfo_ = invokeInfo;
std::vector<SymbolicScalar> linearArgs;
for (int argIdx = 0; argIdx < 2; argIdx++) {
for (int j = 0; j < 9; j++) {
linearArgs.push_back(IRBuilder().CreateConstInt(static_cast<int64_t>(j + argIdx * 10)));
}
}
callAttr->linearArgList_ = linearArgs;
callOp.SetOpAttribute(callAttr);
callOp.UpdateSubgraphID(programIdx);
}
void CreateNonMixFunctions(
std::shared_ptr<Function>& rootFuncPtr, std::vector<std::shared_ptr<Function>>& nonMixFunctions,
const std::vector<uint64_t>& nonMixProgramIds)
{
for (int i = 0; i < 2; i++) {
auto nonMixFunc = std::make_shared<Function>(
Program::GetInstance(), "test_non_mix_" + std::to_string(i), "test_non_mix_" + std::to_string(i),
rootFuncPtr.get());
nonMixFunc->SetGraphType(GraphType::BLOCK_GRAPH);
nonMixFunc->SetFunctionType(FunctionType::STATIC);
std::vector<int64_t> shape = {8, 8};
auto incastTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto outcastTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
nonMixFunc->inCasts_.push_back(incastTensor);
nonMixFunc->outCasts_.push_back(outcastTensor);
auto internalTensor1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto& copyInOp = IRBuilder().CreateTensorOpStmt(*nonMixFunc, Opcode::OP_COPY_IN, {incastTensor}, {internalTensor1});
copyInOp.SetIOpAtt(0, 0);
auto shapeImme = OpImmediate::Specified(shape);
std::vector<int64_t> offsetVec = {0, 0};
auto offsetImme = OpImmediate::Specified(offsetVec);
std::vector<OpImmediate> emptyVec;
auto copyInAttr =
std::make_shared<CopyOpAttribute>(offsetImme, MemoryType::MEM_UB, shapeImme, shapeImme, emptyVec);
copyInOp.SetOpAttribute(copyInAttr);
auto internalTensor2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto& expOp = IRBuilder().CreateTensorOpStmt(*nonMixFunc, Opcode::OP_EXP, {internalTensor1}, {internalTensor2});
(void)expOp;
auto& copyOutOp = IRBuilder().CreateTensorOpStmt(*nonMixFunc, Opcode::OP_COPY_OUT, {internalTensor2}, {outcastTensor});
copyOutOp.SetOOpAtt(0, 0);
auto copyOutAttr =
std::make_shared<CopyOpAttribute>(MemoryType::MEM_UB, offsetImme, shapeImme, shapeImme, emptyVec);
copyOutOp.SetOpAttribute(copyOutAttr);
nonMixFunc->ComputeHash();
FunctionHash hash = nonMixFunc->GetFunctionHash();
Program::GetInstance().GetFunctionCache().Insert(hash, *nonMixFunc);
rootFuncPtr->programs_[nonMixProgramIds[i]] = nonMixFunc.get();
nonMixFunctions.push_back(nonMixFunc);
CreateCallOpForNonMix(rootFuncPtr, nonMixProgramIds[i], hash, shape);
}
}
void CreateCallOpsForMixFunction(
std::shared_ptr<Function>& rootFuncPtr, uint64_t programIdx, FunctionHash hash, const std::vector<int64_t>& shape,
int mixIdx)
{
int callOpCount = (mixIdx % 2 == 0) ? 1 : 2;
for (int callIdx = 0; callIdx < callOpCount; callIdx++) {
auto callInTensor1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto callInTensor2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto callOutTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto& callOp = IRBuilder().CreateTensorOpStmt(*rootFuncPtr, Opcode::OP_CALL, {callInTensor1, callInTensor2}, {callOutTensor});
auto callAttr = std::make_shared<CallOpAttribute>();
auto invokeInfo = std::make_shared<SubfuncInvokeInfoTy>();
invokeInfo->UpdateProgramSubgraphId(programIdx);
callAttr->SetCalleeHash(hash);
callAttr->invokeInfo_ = invokeInfo;
std::vector<SymbolicScalar> linearArgs;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 9; j++) {
linearArgs.push_back(IRBuilder().CreateConstInt(static_cast<int64_t>(j + i * 10)));
}
}
callAttr->linearArgList_ = linearArgs;
callOp.SetOpAttribute(callAttr);
callOp.UpdateSubgraphID(programIdx);
}
}
void CreateAdditionalScopes(std::shared_ptr<Function>& mixFunc, int componentCount, const std::vector<int64_t>& shape)
{
auto operations = mixFunc->Operations(false);
auto lastTensor = operations.back().GetOOperands()[0];
for (int compIdx = 1; compIdx < componentCount; compIdx++) {
auto newTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
Opcode opcode = (compIdx % 2 == 0) ? Opcode::OP_NEG : Opcode::OP_SQRT;
if (compIdx % 2 == 0) {
auto& cubeOp = IRBuilder().CreateTensorOpStmt(*mixFunc, opcode, {lastTensor}, {newTensor});
cubeOp.UpdateInternalSubgraphID(compIdx);
cubeOp.SetAttr(OpAttributeKey::isCube, true);
} else {
auto& vectorOp = IRBuilder().CreateTensorOpStmt(*mixFunc, opcode, {lastTensor}, {newTensor});
vectorOp.UpdateInternalSubgraphID(compIdx);
vectorOp.SetAIVCore(AIVCore::AIV0);
}
lastTensor = newTensor;
}
auto& copyOut = IRBuilder().CreateTensorOpStmt(*mixFunc, Opcode::OP_COPY_OUT, {lastTensor}, {mixFunc->outCasts_[0]});
copyOut.UpdateInternalSubgraphID(componentCount - 1);
copyOut.SetOOpAtt(0, 0);
auto shapeImme = OpImmediate::Specified(shape);
std::vector<int64_t> offsetVec = {0, 0};
auto offsetImme = OpImmediate::Specified(offsetVec);
std::vector<OpImmediate> emptyVec;
auto copyOutAttr =
std::make_shared<CopyOpAttribute>(MemoryType::MEM_UB, offsetImme, shapeImme, shapeImme, emptyVec);
copyOut.SetOpAttribute(copyOutAttr);
if ((componentCount - 1) % 2 == 0) {
copyOut.SetAttr(OpAttributeKey::isCube, true);
} else {
copyOut.SetAIVCore(AIVCore::AIV0);
}
}
void CreateMixFunctions(
std::shared_ptr<Function>& rootFuncPtr, std::vector<std::shared_ptr<Function>>& mixFunctions,
const std::vector<uint64_t>& mixProgramIds, const std::vector<int>& componentCounts)
{
for (int mixIdx = 0; mixIdx < 3; mixIdx++) {
auto mixFunc = std::make_shared<Function>(
Program::GetInstance(), "test_mix_" + std::to_string(mixIdx), "test_mix_" + std::to_string(mixIdx),
rootFuncPtr.get());
mixFunc->SetGraphType(GraphType::BLOCK_GRAPH);
mixFunc->SetFunctionType(FunctionType::STATIC);
std::vector<int64_t> shape = {16, 16};
auto incast1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto incast2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto outcast = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
mixFunc->inCasts_.push_back(incast1);
mixFunc->inCasts_.push_back(incast2);
mixFunc->outCasts_.push_back(outcast);
auto cubeTensor1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto cubeTensor2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto cubeTensor3 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto cubeOutput = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto shapeImme = OpImmediate::Specified(shape);
std::vector<int64_t> offsetVec = {0, 0};
auto offsetImme = OpImmediate::Specified(offsetVec);
std::vector<OpImmediate> emptyVec;
auto& copyIn1 = IRBuilder().CreateTensorOpStmt(*mixFunc, Opcode::OP_COPY_IN, {incast1}, {cubeTensor1});
copyIn1.UpdateInternalSubgraphID(0);
copyIn1.SetIOpAtt(0, 0);
copyIn1.SetAttr(OpAttributeKey::isCube, true);
auto copyIn1Attr =
std::make_shared<CopyOpAttribute>(offsetImme, MemoryType::MEM_UB, shapeImme, shapeImme, emptyVec);
copyIn1.SetOpAttribute(copyIn1Attr);
auto& copyIn2 = IRBuilder().CreateTensorOpStmt(*mixFunc, Opcode::OP_COPY_IN, {incast2}, {cubeTensor2});
copyIn2.UpdateInternalSubgraphID(0);
copyIn2.SetIOpAtt(0, 0);
copyIn2.SetAttr(OpAttributeKey::isCube, true);
auto copyIn2Attr =
std::make_shared<CopyOpAttribute>(offsetImme, MemoryType::MEM_UB, shapeImme, shapeImme, emptyVec);
copyIn2.SetOpAttribute(copyIn2Attr);
auto& cubeMul = IRBuilder().CreateTensorOpStmt(*mixFunc, Opcode::OP_A_MUL_B, {cubeTensor1, cubeTensor2}, {cubeTensor3});
cubeMul.UpdateInternalSubgraphID(0);
cubeMul.SetAttr(OpAttributeKey::isCube, true);
auto& cubeExp = IRBuilder().CreateTensorOpStmt(*mixFunc, Opcode::OP_EXP, {cubeTensor3}, {cubeOutput});
cubeExp.UpdateInternalSubgraphID(0);
cubeExp.SetAttr(OpAttributeKey::isCube, true);
CreateAdditionalScopes(mixFunc, componentCounts[mixIdx], shape);
mixFunc->ComputeHash();
FunctionHash hash = mixFunc->GetFunctionHash();
Program::GetInstance().GetFunctionCache().Insert(hash, *mixFunc);
rootFuncPtr->programs_[mixProgramIds[mixIdx]] = mixFunc.get();
mixFunctions.push_back(mixFunc);
CreateCallOpsForMixFunction(rootFuncPtr, mixProgramIds[mixIdx], hash, shape, mixIdx);
}
}
void VerifyMultipleMixSplitResults(
std::shared_ptr<Function>& rootFuncPtr, Status status, const std::vector<std::shared_ptr<Function>>& mixFunctions,
const std::vector<std::shared_ptr<Function>>& nonMixFunctions, const std::vector<int>& componentCounts)
{
EXPECT_EQ(status, SUCCESS) << "Multiple mix subgraphs split should succeed";
size_t expectedNewProgramCount = 2;
for (int count : componentCounts) {
expectedNewProgramCount += count;
}
auto& programs = rootFuncPtr->programs_;
EXPECT_EQ(programs.size(), expectedNewProgramCount)
<< "Program count mismatch. Expected: " << expectedNewProgramCount << ", Actual: " << programs.size();
uint64_t expectedMaxId = expectedNewProgramCount - 1;
for (uint64_t i = 0; i <= expectedMaxId; i++) {
EXPECT_NE(programs.find(i), programs.end()) << "Missing continuous program ID: " << i;
}
for (size_t i = 0; i < nonMixFunctions.size(); i++) {
bool found = false;
for (const auto& [progId, func] : programs) {
if (func == nonMixFunctions[i].get()) {
EXPECT_LT(progId, 2) << "Non-mix function should have ID < 2";
found = true;
break;
}
}
EXPECT_TRUE(found) << "Non-mix function " << i << " not found after split";
}
int totalSplitFunctions = 0;
for (const auto& [progId, func] : programs) {
(void)progId;
auto leafAttr = func->GetLeafFuncAttribute();
if (leafAttr && leafAttr->mixId != -1) {
totalSplitFunctions++;
}
}
EXPECT_EQ(totalSplitFunctions, 2 + 3 + 4)
<< "Should have " << (2 + 3 + 4) << " split functions from 3 mix subgraphs";
auto newCallOps = rootFuncPtr->GetCallopList();
size_t expectedNewCallOpCount = 2 * 1 + 1 * 2 + 2 * 3 + 1 * 4;
EXPECT_EQ(newCallOps.size(), expectedNewCallOpCount)
<< "CallOp count mismatch. Expected: " << expectedNewCallOpCount << ", Actual: " << newCallOps.size();
for (const auto& mixFunc : mixFunctions) {
bool stillExists = false;
for (const auto& [progId, func] : programs) {
(void)progId;
if (func == mixFunc.get()) {
stillExists = true;
break;
}
}
EXPECT_FALSE(stillExists) << "Original mix function should be removed";
}
}
* 测试多个Mix子图拆分处理
*/
TEST_F(MixSubgraphSplitTest, TestMultipleMixSubgraphsSplit)
{
auto rootFuncPtr =
std::make_shared<Function>(Program::GetInstance(), "test_root_multi", "test_root_multi", nullptr);
rootFuncPtr->rootFunc_ = rootFuncPtr.get();
std::vector<std::shared_ptr<Function>> mixFunctions;
std::vector<std::shared_ptr<Function>> nonMixFunctions;
std::vector<uint64_t> mixProgramIds = {100, 101, 102};
std::vector<uint64_t> nonMixProgramIds = {200, 201};
CreateNonMixFunctions(rootFuncPtr, nonMixFunctions, nonMixProgramIds);
std::vector<int> componentCounts = {2, 3, 4};
CreateMixFunctions(rootFuncPtr, mixFunctions, mixProgramIds, componentCounts);
MixSubgraphSplit splitter;
Status status = splitter.RunOnFunction(*rootFuncPtr);
VerifyMultipleMixSplitResults(rootFuncPtr, status, mixFunctions, nonMixFunctions, componentCounts);
}
* 测试rootFunction无Mix子图时的处理逻辑
*/
TEST_F(MixSubgraphSplitTest, TestNoMixSubgraphScenario)
{
auto rootFuncPtr =
std::make_shared<Function>(Program::GetInstance(), "test_root_no_mix", "test_root_no_mix", nullptr);
rootFuncPtr->rootFunc_ = rootFuncPtr.get();
std::vector<std::shared_ptr<Function>> nonMixFunctions;
std::vector<uint64_t> programIds = {10, 20, 30};
for (int i = 0; i < 3; i++) {
auto func = std::make_shared<Function>(
Program::GetInstance(), "test_func_" + std::to_string(i), "test_func_" + std::to_string(i),
rootFuncPtr.get());
func->SetGraphType(GraphType::BLOCK_GRAPH);
func->SetFunctionType(FunctionType::STATIC);
std::vector<int64_t> shape = {8, 8};
auto inputTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto outputTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
func->inCasts_.push_back(inputTensor);
func->outCasts_.push_back(outputTensor);
auto& expOp = IRBuilder().CreateTensorOpStmt(*func, Opcode::OP_EXP, {inputTensor}, {outputTensor});
(void)expOp;
func->ComputeHash();
FunctionHash hash = func->GetFunctionHash();
Program::GetInstance().GetFunctionCache().Insert(hash, *func);
rootFuncPtr->programs_[programIds[i]] = func.get();
nonMixFunctions.push_back(func);
auto& callOp = IRBuilder().CreateTensorOpStmt(*rootFuncPtr, Opcode::OP_CALL, {}, {});
auto callAttr = std::make_shared<CallOpAttribute>();
auto invokeInfo = std::make_shared<SubfuncInvokeInfoTy>();
invokeInfo->UpdateProgramSubgraphId(programIds[i]);
callAttr->SetCalleeHash(hash);
callAttr->invokeInfo_ = invokeInfo;
callOp.SetOpAttribute(callAttr);
}
auto originalPrograms = rootFuncPtr->programs_;
auto originalCallOps = rootFuncPtr->GetCallopList();
size_t originalProgramCount = originalPrograms.size();
size_t originalCallOpCount = originalCallOps.size();
MixSubgraphSplit splitter;
Status status = splitter.RunOnFunction(*rootFuncPtr);
EXPECT_EQ(status, SUCCESS) << "Should succeed even with no mix subgraphs";
for (const auto& func : nonMixFunctions) {
bool isMix = splitter.IsMixSubgraph(*func);
EXPECT_FALSE(isMix) << "Non-mix function should not be identified as mix";
}
auto& newPrograms = rootFuncPtr->programs_;
EXPECT_EQ(newPrograms.size(), originalProgramCount) << "Program count should not change when no mix subgraphs";
auto newCallOps = rootFuncPtr->GetCallopList();
EXPECT_EQ(newCallOps.size(), originalCallOpCount) << "CallOp count should not change when no mix subgraphs";
}
void CreateExternalMixFunction(
std::shared_ptr<Function>& externalRootFuncPtr, std::shared_ptr<Function>& externalMixFuncPtr)
{
externalRootFuncPtr = std::make_shared<Function>(Program::GetInstance(), "external_root", "external_root", nullptr);
externalRootFuncPtr->rootFunc_ = externalRootFuncPtr.get();
const uint64_t externalMixProgramId = 999;
externalMixFuncPtr =
std::make_shared<Function>(Program::GetInstance(), "external_mix", "external_mix", externalRootFuncPtr.get());
externalMixFuncPtr->SetGraphType(GraphType::BLOCK_GRAPH);
externalMixFuncPtr->SetFunctionType(FunctionType::STATIC);
externalRootFuncPtr->programs_[externalMixProgramId] = externalMixFuncPtr.get();
std::vector<int64_t> shape = {16, 16};
auto incast1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto incast2 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto outcast1 = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
externalMixFuncPtr->inCasts_.push_back(incast1);
externalMixFuncPtr->inCasts_.push_back(incast2);
externalMixFuncPtr->outCasts_.push_back(outcast1);
for (int compIdx = 0; compIdx < 3; compIdx++) {
auto inputTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
auto outputTensor = npu::tile_fwk::IRBuilder().CreateTensorVar(DT_FP32, shape, CreateTestConstIntVector(shape));
Opcode opcode = Opcode::OP_EXP;
auto& op = IRBuilder().CreateTensorOpStmt(*externalMixFuncPtr, opcode, {inputTensor}, {outputTensor});
op.UpdateInternalSubgraphID(compIdx);
if (compIdx % 2 == 0) {
op.SetAttr(OpAttributeKey::isCube, true);
} else {
op.SetAIVCore((compIdx == 1) ? AIVCore::AIV0 : AIVCore::AIV1);
}
}
externalMixFuncPtr->ComputeHash();
FunctionHash mixFuncHash = externalMixFuncPtr->GetFunctionHash();
Program::GetInstance().GetFunctionCache().Insert(mixFuncHash, *externalMixFuncPtr);
}
void VerifyCrossFunctionResults(
MixSubgraphSplit& splitter, std::shared_ptr<Function>& rootFuncPtr, std::shared_ptr<Function>& externalRootFuncPtr,
std::shared_ptr<Function>& externalMixFuncPtr, Operation* crossCallOp, Status status)
{
EXPECT_EQ(status, FAILED) << "Fresh external mix function without preceding processing should fail";
bool isMix = splitter.IsMixSubgraph(*externalMixFuncPtr);
EXPECT_TRUE(isMix) << "External mix function should be identified as mix";
auto& programs = rootFuncPtr->programs_;
EXPECT_EQ(programs.size(), 0);
auto newCallOps = rootFuncPtr->GetCallopList();
EXPECT_EQ(newCallOps.size(), 1) << "Should only have original call op on failure";
EXPECT_EQ(externalRootFuncPtr->programs_.size(), 1) << "External root function's programs should remain unchanged";
auto it = externalRootFuncPtr->programs_.find(999);
EXPECT_NE(it, externalRootFuncPtr->programs_.end()) << "External mix function should still exist in external root";
EXPECT_EQ(it->second, externalMixFuncPtr.get()) << "External mix function pointer should be unchanged";
bool originalCallOpExists = false;
for (auto* callOp : newCallOps) {
if (callOp == crossCallOp) {
originalCallOpExists = true;
break;
}
}
EXPECT_TRUE(originalCallOpExists) << "Original cross-function callOp should still exist on failure";
}
* 测试Mix子图为跨function调用时的特殊处理
*/
TEST_F(MixSubgraphSplitTest, TestCrossFunctionMixSubgraph)
{
auto rootFuncPtr =
std::make_shared<Function>(Program::GetInstance(), "test_root_cross", "test_root_cross", nullptr);
rootFuncPtr->rootFunc_ = rootFuncPtr.get();
std::shared_ptr<Function> externalRootFuncPtr;
std::shared_ptr<Function> externalMixFuncPtr;
CreateExternalMixFunction(externalRootFuncPtr, externalMixFuncPtr);
FunctionHash mixFuncHash = externalMixFuncPtr->GetFunctionHash();
auto& crossCallOp = IRBuilder().CreateTensorOpStmt(*rootFuncPtr, Opcode::OP_CALL, {}, {});
auto crossCallAttr = std::make_shared<CallOpAttribute>();
auto invokeInfo = std::make_shared<SubfuncInvokeInfoTy>();
crossCallAttr->SetCalleeHash(mixFuncHash);
crossCallAttr->invokeInfo_ = invokeInfo;
crossCallOp.SetOpAttribute(crossCallAttr);
MixSubgraphSplit splitter;
Status status = splitter.RunOnFunction(*rootFuncPtr);
VerifyCrossFunctionResults(splitter, rootFuncPtr, externalRootFuncPtr, externalMixFuncPtr, &crossCallOp, status);
}
TEST_F(MixSubgraphSplitTest, TestDependOperand)
{
ComputationalGraphBuilder subGraph;
std::vector<std::string> tensorNames{"t1", "t2", "t3", "t4", "t5", "t6"};
std::vector<MemoryType> tensorMemTypes{MemoryType::MEM_DEVICE_DDR, MemoryType::MEM_DEVICE_DDR, MemoryType::MEM_UB,
MemoryType::MEM_UB, MemoryType::MEM_UB, MemoryType::MEM_UB};
std::vector<Opcode> opCodes{Opcode::OP_UB_ALLOC, Opcode::OP_UB_ALLOC, Opcode::OP_UB_ALLOC, Opcode::OP_UB_ALLOC,
Opcode::OP_COPY_IN, Opcode::OP_COPY_IN, Opcode::OP_ADD};
std::vector<std::vector<std::string>> ioperands{{}, {}, {}, {}, {"t1"}, {"t2"}, {"t4", "t5"}};
std::vector<std::vector<std::string>> ooperands{{"t3"}, {"t4"}, {"t5"}, {"t6"}, {"t3", "t4"}, {"t5"}, {"t6"}};
std::vector<std::string> opNames{"Alloc1", "Alloc2", "Alloc3", "Alloc4", "Copyin1", "Copyin2", "Add1"};
EXPECT_EQ(subGraph.AddTensors(DataType::DT_FP32, {128, 128}, tensorMemTypes, tensorNames, 0), true);
EXPECT_EQ(subGraph.AddOps(opCodes, ioperands, ooperands, opNames, true), true);
Function* function = subGraph.GetFunction();
Operation* copyin2 = subGraph.GetOp("Copyin2");
std::shared_ptr<LogicalTensor> tensor4 = subGraph.GetTensor("t4");
copyin2->AddDependOperand(tensor4);
Operation* add1 = subGraph.GetOp("Add1");
std::shared_ptr<LogicalTensor> tensor3 = subGraph.GetTensor("t3");
add1->AddDependOperand(tensor3);
EXPECT_EQ(copyin2->GetDependOperands().front()->GetMagic(), MS_NUM3);
EXPECT_EQ(copyin2->GetDependOperandSize(), 1);
tensor4->AddDependOp(copyin2);
tensor4->AddDependOp(copyin2);
EXPECT_EQ(tensor4->GetDependOps().size(), 1);
tensor3->AddDependOp(add1);
auto dependOp = *(tensor4->GetDependOps().begin());
EXPECT_EQ(dependOp->GetOpMagic(), MS_NUM10005);
EXPECT_EQ(tensor4->HasDependOp(copyin2), true);
function->SortOperations();
Operation* alloc2 = subGraph.GetOp("Alloc2");
auto sortedOpList = function->Operations().DuplicatedOpList();
auto alloc2Iter = std::find(sortedOpList.begin(), sortedOpList.end(), alloc2);
auto copyin2Iter = std::find(sortedOpList.begin(), sortedOpList.end(), copyin2);
EXPECT_EQ(alloc2Iter - sortedOpList.begin() < copyin2Iter - sortedOpList.begin(), true);
copyin2->EraseDependTensor(tensor4);
add1->EraseDependTensor(tensor3);
tensor4->RemoveDependOp(copyin2);
tensor3->RemoveDependOp(add1);
function->SortOperations();
auto sortedOpList2 = function->Operations().DuplicatedOpList();
auto alloc2Iter2 = std::find(sortedOpList2.begin(), sortedOpList2.end(), alloc2);
auto copyin2Iter2 = std::find(sortedOpList2.begin(), sortedOpList2.end(), copyin2);
EXPECT_EQ(alloc2Iter2 - sortedOpList2.begin() > copyin2Iter2 - sortedOpList2.begin(), true);
copyin2->AddDependOperand(tensor4);
tensor4->AddDependOp(copyin2);
copyin2->SetAsDeleted();
function->EraseOperations();
EXPECT_EQ(tensor4->GetDependOps().size(), 0);
}
}
}