* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include "machine/host/expr_generator.h"
#include "machine/host/backend.h"
#include <fstream>
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
#include <unistd.h>
#include <unordered_map>
namespace npu::tile_fwk {
namespace test {
class TestExprBatchGenerator : public testing::Test {
protected:
void SetUp() override
{
testDir_ = "expr_generator_temp_" + std::to_string(getpid());
mkdir(testDir_.c_str(), 0755);
}
void TearDown() override
{
std::string cmd = "rm -rf " + testDir_;
ASSERT(system(cmd.c_str()) == 0);
}
std::string testDir_;
};
bool FileExists(const std::string& path)
{
std::ifstream file(path);
return file.good();
}
TEST_F(TestExprBatchGenerator, CalculateBatches)
{
ExprBatchGenerator generator1(testDir_, 1, 1000);
ExprBatchGenerator generator2(testDir_, 2, 2500);
ExprBatchGenerator generator3(testDir_, 3, 500);
}
TEST_F(TestExprBatchGenerator, HeaderFileGeneration)
{
ExprBatchGenerator generator(testDir_, 1, 100);
std::ostringstream exprHeaderOss;
generator.HeaderFileBegin(exprHeaderOss);
generator.HeaderFileEnd(exprHeaderOss);
std::string headerPath = testDir_ + "/control_flow_expr_table.h";
ASSERT_TRUE(FileExists(headerPath));
std::ifstream headerFile(headerPath);
std::string headerContent((std::istreambuf_iterator<char>(headerFile)), std::istreambuf_iterator<char>());
ASSERT_TRUE(headerContent.find("#pragma once") != std::string::npos);
ASSERT_TRUE(headerContent.find("namespace npu::tile_fwk") != std::string::npos);
}
TEST_F(TestExprBatchGenerator, LinkScriptGeneration)
{
ExprBatchGenerator generator(testDir_, 1, 100);
std::ostringstream exprHeaderOss;
generator.HeaderFileBegin(exprHeaderOss);
std::string scriptPath = testDir_ + "/merge.link";
ASSERT_TRUE(FileExists(scriptPath));
std::ifstream scriptFile(scriptPath);
std::string scriptContent((std::istreambuf_iterator<char>(scriptFile)), std::istreambuf_iterator<char>());
ASSERT_TRUE(scriptContent.find("SECTIONS") != std::string::npos);
ASSERT_TRUE(scriptContent.find(".pypto") != std::string::npos);
}
TEST_F(TestExprBatchGenerator, CheckExprDependCoreTest)
{
std::string testTensorName = "test_tensor";
ValDependTensorMeta valDependTensorMeta;
valDependTensorMeta.tensorNameToDependCore[testTensorName] = true;
RawSymbolicScalarPtr callee = RawSymbolicSymbol::Create("RUNTIME_GetInputData");
RawSymbolicScalarPtr arg1 = RawSymbolicSymbol::Create(testTensorName);
RawSymbolicScalarPtr arg2 = RawSymbolicImmediate::Create(0);
std::vector<RawSymbolicScalarPtr> operands = {callee, arg1, arg2};
RawSymbolicScalarPtr getInputDataExpr =
std::make_shared<RawSymbolicExpression>(SymbolicOpcode::T_MOP_CALL, operands);
bool dependsCore = SymbolicExpressionTable::CheckExprDependCore(
getInputDataExpr, valDependTensorMeta.tensorNameToDependCore, valDependTensorMeta.valDependMap);
ASSERT_TRUE(dependsCore);
valDependTensorMeta.tensorNameToDependCore[testTensorName] = false;
valDependTensorMeta.valDependMap.clear();
dependsCore = SymbolicExpressionTable::CheckExprDependCore(
getInputDataExpr, valDependTensorMeta.tensorNameToDependCore, valDependTensorMeta.valDependMap);
ASSERT_FALSE(dependsCore);
}
TEST_F(TestExprBatchGenerator, BatchFileGeneration)
{
ExprBatchGenerator generator(testDir_, 1, 1500);
std::ostringstream controlFlowOss;
std::ostringstream exprHeaderOss;
std::vector<std::string> exprSrcFiles;
SymbolicExpressionTable exprTable;
OrderedSet<RawSymbolicScalarPtr> expressions;
for (int i = 0; i < 1500; ++i) {
RawSymbolicScalarPtr expr = RawSymbolicImmediate::Create(i);
expressions.Insert(expr);
}
generator.GenerateBatchFile(
&exprTable, controlFlowOss, exprHeaderOss, "test_exp.h", expressions, exprSrcFiles, 1, 1);
ASSERT_EQ(exprSrcFiles.size(), 2);
for (const auto& filePath : exprSrcFiles) {
ASSERT_TRUE(FileExists(filePath));
std::ifstream batchFile(filePath);
std::string fileContent((std::istreambuf_iterator<char>(batchFile)), std::istreambuf_iterator<char>());
ASSERT_TRUE(fileContent.find("RUNTIME_SetExpr") != std::string::npos);
}
std::string controlFlowContent = controlFlowOss.str();
ASSERT_TRUE(controlFlowContent.find("SetExprBatch_1_0") != std::string::npos);
ASSERT_TRUE(controlFlowContent.find("SetExprBatch_1_1") != std::string::npos);
std::string headerContent = exprHeaderOss.str();
ASSERT_TRUE(headerContent.find("SetExprBatch_1_0") != std::string::npos);
ASSERT_TRUE(headerContent.find("SetExprBatch_1_1") != std::string::npos);
}
namespace {
RawSymbolicScalarPtr ImmNode(int64_t v) { return RawSymbolicImmediate::Create(v); }
RawSymbolicScalarPtr SymNode(const std::string& n) { return RawSymbolicSymbol::Create(n); }
RawSymbolicScalarPtr AddNode(const RawSymbolicScalarPtr& a, const RawSymbolicScalarPtr& b)
{
return std::make_shared<RawSymbolicExpression>(
SymbolicOpcode::T_BOP_ADD, std::vector<RawSymbolicScalarPtr>{a, b});
}
RawSymbolicScalarPtr MulNode(const RawSymbolicScalarPtr& a, const RawSymbolicScalarPtr& b)
{
return std::make_shared<RawSymbolicExpression>(
SymbolicOpcode::T_BOP_MUL, std::vector<RawSymbolicScalarPtr>{a, b});
}
std::string ReadFile(const std::string& path)
{
std::ifstream f(path);
return std::string((std::istreambuf_iterator<char>(f)), std::istreambuf_iterator<char>());
}
std::string RunGen(const std::string& dir, int devRootKey, const std::vector<RawSymbolicScalarPtr>& exprs)
{
SymbolicExpressionTable table;
OrderedSet<RawSymbolicScalarPtr> set;
for (auto& e : exprs) {
set.Insert(e);
}
ExprBatchGenerator gen(dir, devRootKey, set.size());
std::ostringstream cfOss;
std::ostringstream headerOss;
std::vector<std::string> srcFiles;
gen.GenerateBatchFile(&table, cfOss, headerOss, "test_exp.h", set, srcFiles, 1, devRootKey);
return srcFiles.empty() ? std::string{} : ReadFile(srcFiles.front());
}
size_t Count(const std::string& text, const std::string& pat)
{
size_t cnt = 0;
size_t pos = 0;
while ((pos = text.find(pat, pos)) != std::string::npos) {
cnt++;
pos += pat.size();
}
return cnt;
}
}
TEST_F(TestExprBatchGenerator, FoldSingleDiff)
{
std::vector<RawSymbolicScalarPtr> exprs;
for (int64_t i = 0; i < 5; i++) {
exprs.push_back(AddNode(SymNode("a"), ImmNode(i)));
}
std::string body = RunGen(testDir_, 10, exprs);
EXPECT_EQ(Count(body, "for (int64_t sym_expr_loop_k_"), 1u);
EXPECT_NE(body.find("<= 4"), std::string::npos);
EXPECT_NE(body.find("+= 1"), std::string::npos);
EXPECT_EQ(Count(body, "RUNTIME_SetExpr"), 1u);
}
TEST_F(TestExprBatchGenerator, FoldMultiDiffLockstep)
{
std::vector<RawSymbolicScalarPtr> exprs;
for (int64_t i = 0; i < 4; i++) {
exprs.push_back(MulNode(AddNode(SymNode("a"), ImmNode(4 + i * 4)),
AddNode(SymNode("b"), ImmNode(100 + i * 100))));
}
std::string body = RunGen(testDir_, 11, exprs);
EXPECT_EQ(Count(body, "for (int64_t sym_expr_loop_k_"), 1u);
EXPECT_NE(body.find("< 4"), std::string::npos);
EXPECT_NE(body.find("k_0++"), std::string::npos);
EXPECT_EQ(Count(body, "RUNTIME_SetExpr"), 1u);
}
TEST_F(TestExprBatchGenerator, NoFoldBelowMinLenOrEqualExprs)
{
std::vector<RawSymbolicScalarPtr> short_run = {
AddNode(SymNode("a"), ImmNode(4)),
AddNode(SymNode("a"), ImmNode(8)),
};
std::string body1 = RunGen(testDir_, 12, short_run);
EXPECT_EQ(Count(body1, "for (int64_t sym_expr_loop_k_"), 0u);
EXPECT_EQ(Count(body1, "RUNTIME_SetExpr"), 2u);
std::vector<RawSymbolicScalarPtr> equal_exprs;
for (int i = 0; i < 4; i++) {
equal_exprs.push_back(AddNode(SymNode("a"), ImmNode(4)));
}
std::string body2 = RunGen(testDir_, 13, equal_exprs);
EXPECT_EQ(Count(body2, "for (int64_t sym_expr_loop_k_"), 0u);
EXPECT_EQ(Count(body2, "RUNTIME_SetExpr"), 4u);
}
TEST_F(TestExprBatchGenerator, RunBreaksAndMixed)
{
std::vector<RawSymbolicScalarPtr> exprs;
for (int64_t i = 0; i < 3; i++) {
exprs.push_back(AddNode(SymNode("a"), ImmNode(4 + i * 4)));
}
exprs.push_back(MulNode(SymNode("zzz"), ImmNode(7)));
for (int64_t i = 0; i < 3; i++) {
exprs.push_back(AddNode(SymNode("b"), ImmNode(1 + i)));
}
std::string body = RunGen(testDir_, 14, exprs);
EXPECT_EQ(Count(body, "for (int64_t sym_expr_loop_k_"), 2u);
EXPECT_EQ(Count(body, "RUNTIME_SetExpr"), 3u);
}
}
}