* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file main_block.cpp
* \brief
*/
#include "main_block.h"
#include "codegen/codegen.h"
#include "tilefwk/platform.h"
#include "tilefwk/pypto_fwk_log.h"
namespace npu::tile_fwk {
MainBlockCondBulider::MainBlockCondBulider() = default;
void MainBlockCondBulider::AddUniqueCondition(const SymbolicScalar& newCond)
{
SymbolicScalar cond = SymbolicScalar(false);
std::string condStr = newCond.Dump();
if ((mainBlockStrSet_.find(condStr) != mainBlockStrSet_.end()) ||
(mainBlockStrSet_.find(cond.Dump()) != mainBlockStrSet_.end())) {
return;
}
mainBlockStrSet_.insert(condStr);
mainBlockCondGroup_.push_back(newCond);
}
bool MainBlockCondBulider::CheckShapeEquality(const Shape& shape, const std::vector<SymbolicScalar>& dynShape)
{
SymbolicScalar cond = SymbolicScalar(false);
if (shape.size() != dynShape.size()) {
AddUniqueCondition(cond);
return false;
}
for (uint32_t i = 0; i < shape.size(); i++) {
if (shape[i] == -1) {
continue;
}
cond = (shape[i] == dynShape[i]);
AddUniqueCondition(cond);
if (cond.IsImmediate() && (cond == 0)) {
return false;
}
}
return true;
}
bool MainBlockCondBulider::GetValidShapeFromCoa(
const std::vector<SymbolicScalar>& argList, Shape& shape, std::vector<SymbolicScalar>& dynValidShape)
{
if (argList.empty() || (argList.size() <= COA_INDEX_TYPE_COUNT)) {
MACHINE_LOGW("argList is invalid!");
return false;
}
int dim = (argList.size() - 1 + COA_INDEX_TYPE_COUNT - 1) / COA_INDEX_TYPE_COUNT;
int validShapeDim = argList.size() - 1 - dim * (COA_INDEX_TYPE_COUNT - 1);
int coaIndex = COA_INDEX_DIM_BASE;
shape.reserve(dim);
dynValidShape.reserve(validShapeDim);
for (int i = 0; i < dim; i++) {
shape.push_back(argList[coaIndex + dim + i]);
}
coaIndex += dim * (COA_INDEX_TYPE_COUNT - 1);
for (int i = 0; i < validShapeDim; i++) {
dynValidShape.push_back(argList[coaIndex + i]);
}
return true;
}
void MainBlockCondBulider::CollectCallopMainBlockConds(Function* func)
{
bool enableVF = Platform::Instance().GetSoc().GetNPUArch() == NPUArch::DAV_3510;
enableVF = enableVF && config::GetPassGlobalConfig(KEY_ENABLE_VF, false);
if (config::GetRuntimeOption<int64_t>(CFG_VALID_SHAPE_OPTIMIZE) != 1 && !enableVF) {
AddUniqueCondition(SymbolicScalar(false));
return;
}
auto checkOperand = [&](auto& op, auto& shape, auto& validshape, const char* tag) -> bool {
auto cond = CheckShapeEquality(shape, validshape);
if (!cond) {
MACHINE_LOGW(
"get mainBlock flag false, op code %s, %s shape is %s, validShape is %s", op.GetOpcodeStr().c_str(),
tag, IntVecToStr(shape).c_str(), IntVecToStr(validshape).c_str());
}
return cond;
};
for (auto& op : func->Operations()) {
for (auto& iop : op.GetIOperands()) {
if (!checkOperand(op, iop->shape, iop->GetDynValidShape(), "iop")) {
return;
}
}
for (auto& oop : op.GetOOperands()) {
if (!checkOperand(op, oop->shape, oop->GetDynValidShape(), "oop")) {
return;
}
}
}
}
bool MainBlockCondBulider::CheckReshapeCopy(Function* func)
{
for (auto& op : func->Operations()) {
if (op.GetOpcode() == Opcode::OP_RESHAPE_COPY_OUT || op.GetOpcode() == Opcode::OP_RESHAPE_COPY_IN) {
return true;
}
}
return false;
}
void MainBlockCondBulider::CollectCoaMainBlockConds(const std::vector<std::vector<SymbolicScalar>>& argList, Function* func)
{
bool enableVF = Platform::Instance().GetSoc().GetNPUArch() == NPUArch::DAV_3510;
enableVF = enableVF && config::GetPassGlobalConfig(KEY_ENABLE_VF, false);
if ((config::GetRuntimeOption<int64_t>(CFG_VALID_SHAPE_OPTIMIZE) != 1 && !enableVF) || CheckReshapeCopy(func)) {
AddUniqueCondition(SymbolicScalar(false));
return;
}
for (const auto& iter : argList) {
Shape shape;
std::vector<SymbolicScalar> dynValidShape;
if (!GetValidShapeFromCoa(iter, shape, dynValidShape)) {
AddUniqueCondition(SymbolicScalar(false));
return;
}
auto cond = CheckShapeEquality(shape, dynValidShape);
if (!cond) {
MACHINE_LOGW(
"get mainBlock flag false, coa shape is %s, validShape is %s", IntVecToStr(shape).c_str(),
IntVecToStr(dynValidShape).c_str());
return;
}
}
}
SymbolicScalar MainBlockCondBulider::BuildMainBlockExpression()
{
SymbolicScalar runtimeSelect("RUNTIME_Select");
SymbolicScalar runtimeAnd("RUNTIME_And");
SymbolicScalar cond = false;
if (mainBlockCondGroup_.empty()) {
return runtimeSelect(cond, 1, 0);
}
if (mainBlockCondGroup_.size() > MAX_RUNTIME_AND_NESTING_DEPTH) {
MACHINE_LOGW("runtimeAnd nesting depth too large (%zu), disable mainblock", mainBlockCondGroup_.size());
return runtimeSelect(false, 1, 0);
}
cond = true;
for (const auto& iter : mainBlockCondGroup_) {
std::string exprStr = iter.Dump();
if (exprStr.find("RUNTIME_GetTensorDataInt32") != std::string::npos) {
MACHINE_LOGW("AICPU does not support RUNTIME_GetTensorDataInt32");
return runtimeSelect(false, 1, 0);
}
cond = runtimeAnd(cond, iter);
}
cond = runtimeSelect(cond, 1, 0);
return cond;
}
void MainBlockCondBulider::Gencode(Function* function)
{
bool enableVF = Platform::Instance().GetSoc().GetNPUArch() == NPUArch::DAV_3510;
enableVF = enableVF && config::GetPassGlobalConfig(KEY_ENABLE_VF, false);
if (config::GetRuntimeOption<int64_t>(CFG_VALID_SHAPE_OPTIMIZE) == 1 || enableVF) {
bool isDynamicAligned = function->paramConfigs_.dynamicAlignedOps;
npu::tile_fwk::CodeGenCtx codeGenCtxMainBlock("", config::GetEmitPath("kernel_aicore"), true, isDynamicAligned);
npu::tile_fwk::CodeGen codeGenMainBlock(codeGenCtxMainBlock);
codeGenMainBlock.GenCode(*function, {});
}
}
const std::vector<SymbolicScalar>& MainBlockCondBulider::GetCondGroup() const { return mainBlockCondGroup_; }
const std::unordered_set<std::string>& MainBlockCondBulider::GetCondStrSet() const { return mainBlockStrSet_; }
}