* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file schedule_ooo_check.cpp
* \brief
*/
#include "schedule_ooo_checker.h"
#include "passes/block_graph_pass/schedule_ooo/schedule_ooo.h"
#include "interface/function/function.h"
#include "interface/tensor/logical_tensor.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "interface/program/program.h"
#include "passes/pass_utils/parallel_tool.h"
#include "passes/pass_utils/subgraph_utils.h"
#include "passes/pass_log/pass_log.h"
#include "passes/pass_utils/graph_utils.h"
#include "tilefwk/error_code.h"
#ifndef MODULE_NAME
#define MODULE_NAME "OoOSchedule"
#endif
namespace npu {
namespace tile_fwk {
bool OoOScheduleChecker::PreCheckTensorInfo(const LogicalTensorPtr tensor)
{
if (tensor->GetMemoryTypeOriginal() != tensor->GetMemoryTypeToBe()) {
APASS_LOG_ERROR_C(TensorErr::TENSOR_INVALID_MEMORY_TYPE, Elements::Tensor, "Tensor[%d] memorytypeOriginal is not equal to memorytypeTobe, OoOSchedule Precheck failed!", tensor->GetMagic());
return false;
}
if (SubgraphUtils::IsBoundary(tensor)) {
return true;
}
if (tensor->memoryrange.memId == -1) {
APASS_LOG_ERROR_C(TensorErr::TENSOR_MEMORY_ALLOCATION, Elements::Tensor, "Tensor[%d] memId does not exist, OoOSchedule Precheck failed!", tensor->GetMagic());
return false;
}
return true;
}
bool OoOScheduleChecker::PreCheckOpInfo(const Operation* op)
{
for (auto inTensor : op->GetIOperands()) {
if (!PreCheckTensorInfo(inTensor)) {
return false;
}
}
for (auto outTensor : op->GetOOperands()) {
if (!PreCheckTensorInfo(outTensor)) {
return false;
}
}
if (op->GetOpcode() == Opcode::OP_CALL) {
APASS_LOG_ERROR_C(OperationErr::OP_SPECIAL_CONSTRAINT, Elements::Operation, "Block graph has call op, OoOSchedule Precheck failed!");
return false;
}
if (op->GetOpcode() != Opcode::OP_ASSEMBLE && op->GetOpcode() != Opcode::OP_RESHAPE &&
op->GetOpcode() != Opcode::OP_VIEW && op->GetOpcode() != Opcode::OP_NOP) {
return true;
}
if (op->GetLatency() != 1) {
APASS_LOG_WARN_F(
Elements::Operation, "%s[%d] Op latency is not 1, OoOSchedule Precheck warning!",
op->GetOpcodeStr().c_str(), op->GetOpMagic());
}
if (op->GetOOperands()[0]->GetMemoryTypeOriginal() < MemoryType::MEM_DEVICE_DDR) {
int memId = op->GetOOperands()[0]->memoryrange.memId;
for (auto inTensor : op->GetIOperands()) {
if (inTensor->memoryrange.memId != memId &&
inTensor->GetMemoryTypeOriginal() == MemoryType::MEM_UB) {
APASS_LOG_ERROR_C(TensorErr::TENSOR_INVALID_MEMORY_TYPE, Elements::Operation, "%s[%d] input output tensors memId does not match, OoOSchedule Precheck failed!", op->GetOpcodeStr().c_str(), op->GetOpMagic());
return false;
}
}
}
return true;
}
Status OoOScheduleChecker::DoPreCheck(Function& function)
{
APASS_LOG_INFO_F(Elements::Operation, "Start OoOSchedule Precheck.");
int programIdx = 0;
int programSize = function.rootFunc_->programs_.size();
tensorListBeforePass_.resize(programSize);
for (auto& program : function.rootFunc_->programs_) {
APASS_LOG_INFO_F(Elements::Operation, "Subgraph[%zu] OoOSchedule Precheck begin.", program.first);
auto opList = program.second->Operations().DuplicatedOpList();
if (opList.empty()) {
APASS_LOG_INFO_F(Elements::Operation, "Operation List is empty!");
APASS_LOG_INFO_F(Elements::Operation, "Subgraph[%zu] OoOSchedule Precheck end.", program.first);
continue;
}
for (auto& op : opList) {
if (op == nullptr) {
APASS_LOG_ERROR_C(OperationErr::OP_NULL_POINTER, Elements::Operation, "Operation is nullptr, OoOSchedule Precheck failed!");
return FAILED;
}
}
if ((opList.size() == 1) && (opList.front()->GetOpcode() == Opcode::OP_ASSEMBLE || opList.front()->GetOpcode() == Opcode::OP_RESHAPE ||
opList.front()->GetOpcode() == Opcode::OP_VIEW) && opList.front()->GetOOperands()[0]->GetMemoryTypeOriginal() < MemoryType::MEM_DEVICE_DDR) {
APASS_LOG_ERROR_C(OperationErr::OP_NULL_POINTER, Elements::Operation, "Single Op: localBuffer does not have alloc, OoOSchedule Precheck failed!");
return FAILED;
}
std::unordered_set<LogicalTensorPtr> tensorList;
for (auto& op : opList) {
if (!PreCheckOpInfo(op)) {
APASS_LOG_ERROR_C(OperationErr::OP_SPECIAL_CONSTRAINT, Elements::Operation, "PreCheckOpInfo failed; Please check the PreCheckOpInfo method.");
return FAILED;
}
APASS_LOG_INFO_F(
Elements::Operation, "Before OoOSchedule op: %s, %d.", op->GetOpcodeStr().c_str(),
op->GetOpMagic());
auto ioperands = op->GetIOperands();
auto ooperands = op->GetOOperands();
std::copy(ioperands.begin(), ioperands.end(), std::inserter(tensorList, tensorList.end()));
std::copy(ooperands.begin(), ooperands.end(), std::inserter(tensorList, tensorList.end()));
}
tensorListBeforePass_[programIdx] = tensorList;
programIdx++;
APASS_LOG_INFO_F(Elements::Operation, "Subgraph[%zu] OoOSchedule Precheck end.", program.first);
}
APASS_LOG_INFO_F(Elements::Operation, "OoOSchedule Precheck completed successfully!");
return SUCCESS;
}
bool OoOScheduleChecker::PostCheckOpMagic(std::set<int> opSet, const Operation* op, const int programIdx)
{
if (!opSet.insert(op->GetOpMagic()).second) {
APASS_LOG_ERROR_F(
Elements::Operation, "Program %d: %d opmagic is not unique, OoOSchedule Postcheck failed!", programIdx,
op->GetOpMagic());
return false;
}
return true;
}
bool OoOScheduleChecker::PostCheckNewOpConnection(
const std::vector<Operation*> opListBeforePass, const std::vector<int> opMagicListBeforePass, const Operation* op,
const int programIdx)
{
auto it = std::find(opMagicListBeforePass.begin(), opMagicListBeforePass.end(), op->GetOpMagic());
if (it == opMagicListBeforePass.end()) {
return true;
}
int index = std::distance(opMagicListBeforePass.begin(), it);
auto& opBefore = opListBeforePass[index];
auto inTensorsBefore = opBefore->GetIOperands();
std::vector<std::set<Operation*, LogicalTensor::CompareOp>> opBeforeIncast;
for (auto& inTensorBefore : inTensorsBefore) {
opBeforeIncast.emplace_back(inTensorBefore->GetProducers());
}
auto inTensorsAfter = op->GetIOperands();
int shape = inTensorsAfter.size();
for (int i = 0; i < shape; i++) {
auto opAfterIncast = inTensorsAfter[i]->GetProducers();
std::set<Operation*, LogicalTensor::CompareOp> beforeHasAfterNot;
std::set<Operation*, LogicalTensor::CompareOp> beforeNotAfterHas;
std::set_difference(
opBeforeIncast[i].begin(), opBeforeIncast[i].end(), opAfterIncast.begin(), opAfterIncast.end(),
std::inserter(beforeHasAfterNot, beforeHasAfterNot.begin()));
std::set_difference(
opAfterIncast.begin(), opAfterIncast.end(), opBeforeIncast[i].begin(), opBeforeIncast[i].end(),
std::inserter(beforeNotAfterHas, beforeNotAfterHas.begin()));
if (beforeHasAfterNot.empty() && beforeNotAfterHas.empty()) {
continue;
}
std::vector<Operation*> copyins;
for (auto& opNew : beforeNotAfterHas) {
if (opNew->GetOpcode() != Opcode::OP_COPY_IN) {
APASS_LOG_ERROR_F(
Elements::Operation,
"Program %d: %d op's successors include unexpected op %s, OoOSchedule Postcheck failed!",
programIdx, op->GetOpMagic(), opNew->GetOpcodeStr().c_str());
return false;
}
copyins.emplace_back(opNew);
}
std::vector<Operation*> copyouts;
for (auto& copyin : copyins) {
auto opPtr = *(copyin->GetIOperands()[0]->GetProducers().begin());
if (opPtr->GetOpcode() != Opcode::OP_COPY_OUT) {
APASS_LOG_ERROR_F(
Elements::Operation,
"Program %d: %d op's successors include unexpected op %s, OoOSchedule Postcheck failed!",
programIdx, copyin->GetOpMagic(), opPtr->GetOpcodeStr().c_str());
return false;
}
}
std::set<Operation*, LogicalTensor::CompareOp> mainres;
for (auto& copyout : copyouts) {
mainres.insert(*(copyout->GetIOperands()[0]->GetProducers()).begin());
}
if (mainres != beforeHasAfterNot) {
std::set<Operation*, LogicalTensor::CompareOp> difference;
std::set_difference(
beforeHasAfterNot.begin(), beforeHasAfterNot.end(), mainres.begin(), mainres.end(),
std::inserter(difference, difference.begin()));
for (auto& dif : difference) {
APASS_LOG_ERROR_F(
Elements::Operation,
"Program %d: %d op is not found after OoOSchedule, OoOSchedule Postcheck failed!", programIdx,
dif->GetOpMagic());
return false;
}
}
}
return true;
}
bool OoOScheduleChecker::PostCheckSpecialOp(const Operation* op)
{
if (op->GetOpcode() == Opcode::OP_ASSEMBLE || op->GetOpcode() == Opcode::OP_RESHAPE ||
op->GetOpcode() == Opcode::OP_VIEW) {
if (op->GetOOperands()[0]->GetMemoryTypeOriginal() < MemoryType::MEM_DEVICE_DDR) {
bool needAlloc = false;
if (op->GetOOperands()[0]->GetAttr(OpAttributeKey::needAlloc, needAlloc) && needAlloc) {
APASS_LOG_ERROR_F(
Elements::Operation,
"ASSEMBLE/RESHAPE/VIEW op output tensor has alloc attribute, OoOSchedule Postcheck failed!");
return false;
}
}
}
return true;
}
bool OoOScheduleChecker::PostCheckTensorMagic(
std::set<int> tensorSet, const LogicalTensorPtr tensor, const int programIdx)
{
if (!tensorSet.insert(tensor->GetMagic()).second) {
APASS_LOG_ERROR_F(
Elements::Operation, "Program %d: %d tensormagic is not unique, OoOSchedule Postcheck failed!", programIdx,
tensor->GetMagic());
return false;
}
return true;
}
bool OoOScheduleChecker::PostCheckLocalTensor(const LogicalTensorPtr tensor, const int programIdx)
{
MemoryType memType = tensor->GetMemoryTypeOriginal();
if (memType == MemoryType::MEM_UB || memType == MemoryType::MEM_L1 || memType == MemoryType::MEM_L0A ||
memType == MemoryType::MEM_L0B || memType == MemoryType::MEM_L0C) {
int memoryRange = tensor->memoryrange.end - tensor->memoryrange.start;
if (memoryRange == 0) {
APASS_LOG_ERROR_F(
Elements::Operation, "Program %d: %d tensor memory range is 0, OoOSchedule Postcheck failed!",
programIdx, tensor->GetMagic());
return false;
}
int tensorshape = 1;
for (auto num : tensor->GetShape()) {
tensorshape *= num;
}
int tensorsize = tensorshape * BytesOf(tensor->Datatype());
if (memoryRange < tensorsize) {
APASS_LOG_ERROR_F(
Elements::Operation, "Program %d: %d tensor memory range < tensor size, OoOSchedule Postcheck failed!",
programIdx, tensor->GetMagic());
return false;
}
}
return true;
}
bool OoOScheduleChecker::PostCheckGlobalTensor(const LogicalTensorPtr tensor, const int programIdx)
{
MemoryType memType = tensor->GetMemoryTypeOriginal();
if (memType >= MemoryType::MEM_DEVICE_DDR && !(SubgraphUtils::IsBoundary(tensor))) {
if (tensor->memoryrange.memId == -1) {
APASS_LOG_ERROR_F(
Elements::Operation, "Program %d: %d global tensor memid is -1, OoOSchedule Postcheck failed!",
programIdx, tensor->GetMagic());
return false;
}
}
return true;
}
bool OoOScheduleChecker::PostCheckDynValidShape(const LogicalTensorPtr tensor, const int programIdx)
{
if (tensor->dynValidShape_.empty()) {
APASS_LOG_ERROR_F(
Elements::Operation, "Program %d: %d Dyn validshape is empty, OoOSchedule Postcheck failed!", programIdx,
tensor->GetMagic());
return false;
}
return true;
}
bool OoOScheduleChecker::PostCheckNewTensor(std::pair<const int, Function*> program, const int programIdx)
{
std::vector<LogicalTensorPtr> newTensors;
std::unordered_set<int> tensorMagicBeforePass;
std::unordered_set<int> tensorMagicAfterPass;
for (auto& tensor : tensorListBeforePass_[programIdx]) {
tensorMagicBeforePass.insert(tensor->GetMagic());
}
for (auto& tensor : tensorListAfterPass_[programIdx]) {
tensorMagicAfterPass.insert(tensor->GetMagic());
}
for (auto& tensor : tensorListAfterPass_[programIdx]) {
int magic = tensor->GetMagic();
if (tensorMagicBeforePass.find(magic) == tensorMagicBeforePass.end()) {
newTensors.emplace_back(tensor);
}
}
for (auto& newtensor : newTensors) {
auto matchTensors = GraphUtils::FindOverlappedTensors(*(program.second), newtensor);
bool existFlag = GraphUtils::GetTensorByMagic(*(program.second), newtensor->GetMagic()) == newtensor;
if (!existFlag && matchTensors.size() != 0) {
existFlag = true;
}
if (existFlag == false) {
APASS_LOG_ERROR_F(
Elements::Operation,
"Program %d: %d new tensor does not exist in tensormap, OoOSchedule Postcheck failed!", programIdx,
newtensor->GetMagic());
return false;
}
if ((newtensor->shape.size() == 0) && (SubgraphUtils::IsBoundary(newtensor))) {
APASS_LOG_ERROR_F(
Elements::Operation, "Program %d: %d new tensor shape is null, OoOSchedule Postcheck failed!",
programIdx, newtensor->GetMagic());
return false;
}
}
return true;
}
Status OoOScheduleChecker::PostCheckTensor(
const LogicalTensorPtr& tensor, const std::set<int>& tensorSet, int programIdx)
{
if (!PostCheckTensorMagic(tensorSet, tensor, programIdx)) {
APASS_LOG_ERROR_F(
Elements::Tensor, "PostCheckTensorMagic failed; Please check the PostCheckTensorMagic method.");
return FAILED;
};
if (!PostCheckLocalTensor(tensor, programIdx)) {
APASS_LOG_ERROR_F(
Elements::Tensor, "PostCheckLocalTensor failed; Please check the PostCheckLocalTensor method.");
return FAILED;
};
if (!PostCheckGlobalTensor(tensor, programIdx)) {
APASS_LOG_ERROR_F(
Elements::Tensor, "PostCheckGlobalTensor failed; Please check the PostCheckGlobalTensor method.");
return FAILED;
};
if (!PostCheckDynValidShape(tensor, programIdx)) {
APASS_LOG_ERROR_F(
Elements::Tensor, "PostCheckDynValidShape failed; Please check the PostCheckDynValidShape method.");
return FAILED;
};
return SUCCESS;
}
Status OoOScheduleChecker::PostCheckSubGraph(const std::pair<uint64_t, Function*>& program, int programIdx)
{
auto opList = program.second->Operations().DuplicatedOpList();
if (opList.empty()) {
APASS_LOG_INFO_F(
Elements::Operation, "Operation List is empty! \nSubgraph[%zu] OoOSchedule Precheck end.", program.first);
return SUCCESS;
}
for (auto& op : opList) {
if (op == nullptr) {
APASS_LOG_ERROR_F(Elements::Operation, "Operation is nullptr, OoOSchedule Postcheck failed!");
return FAILED;
}
}
std::unordered_set<LogicalTensorPtr> tensorList;
std::set<int> opSet;
auto opListBeforePass = oriFunctions_[programIdx]->Operations().DuplicatedOpList();
std::vector<int> opMagicListBeforePass;
for (auto& op : opListBeforePass) {
opMagicListBeforePass.emplace_back(op->GetOpMagic());
}
for (auto& op : opList) {
if (!PostCheckOpMagic(opSet, op, programIdx)) {
APASS_LOG_ERROR_F(
Elements::Operation, "PostCheckOpMagic failed; Please check the PostCheckOpMagic method.");
return FAILED;
};
if (!PostCheckNewOpConnection(opListBeforePass, opMagicListBeforePass, op, programIdx)) {
APASS_LOG_ERROR_F(
Elements::Operation,
"PostCheckNewOpConnection failed; Please check the PostCheckNewOpConnection method.");
return FAILED;
};
if (!PostCheckSpecialOp(op)) {
APASS_LOG_ERROR_F(
Elements::Operation, "PostCheckSpecialOp failed; Please check the PostCheckSpecialOp method.");
return FAILED;
};
auto ioperands = op->GetIOperands();
auto ooperands = op->GetOOperands();
std::copy(ioperands.begin(), ioperands.end(), std::inserter(tensorList, tensorList.end()));
std::copy(ooperands.begin(), ooperands.end(), std::inserter(tensorList, tensorList.end()));
}
tensorListAfterPass_[programIdx] = tensorList;
std::set<int> tensorSet;
for (auto& tensor : tensorList) {
if (PostCheckTensor(tensor, tensorSet, programIdx) != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Tensor, "PostCheckTensor failed; Please check the PostCheckTensor method.");
return FAILED;
}
}
if (!PostCheckNewTensor(program, programIdx)) {
APASS_LOG_ERROR_F(Elements::Tensor, "PostCheckNewTensor failed; Please check the PostCheckNewTensor method.");
return FAILED;
};
return SUCCESS;
}
Status OoOScheduleChecker::DoPostCheck(Function& function)
{
APASS_LOG_INFO_F(Elements::Operation, "Start OoOSchedule Postcheck.");
int programSize = function.rootFunc_->programs_.size();
tensorListAfterPass_.resize(programSize);
int programIdx = 0;
for (auto& program : function.rootFunc_->programs_) {
APASS_LOG_INFO_F(Elements::Operation, "Subgraph[%zu] OoOSchedule Postcheck begin.", program.first);
if (PostCheckSubGraph(program, programIdx) != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Operation, "Subgraph[%d] OoOSchedule Postcheck failed!", programIdx);
return FAILED;
}
programIdx++;
APASS_LOG_INFO_F(Elements::Operation, "Subgraph[%zu] OoOSchedule Postcheck end.", program.first);
}
APASS_LOG_INFO_F(Elements::Operation, "OoOSchedule Postcheck completed successfully!");
return SUCCESS;
}
void OoOScheduleChecker::SetOriFunctions(const std::vector<Function*>& oriFunctions) { oriFunctions_ = oriFunctions; }
}
}