* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file merge_view_assemble_utils.cpp
* \brief utils of view and assemble operation merging
*/
#include "merge_view_assemble_utils.h"
#include "interface/tensor/irbuilder.h"
#include "interface/operation/attribute.h"
#include "passes/pass_utils/dead_operation_eliminate.h"
#include "passes/pass_utils/infer_shape_utils.h"
#include "passes/pass_log/pass_log.h"
#define MODULE_NAME "MergeViewAssembleUtils"
namespace npu::tile_fwk {
Status MergeViewAssembleUtils::MergeViewAssemble(Function& function)
{
MergeViewAssembleUtils MergeViewAssembleUtils;
Status status = MergeViewAssembleUtils.Process(function);
return status;
}
Status MergeViewAssembleUtils::Process(Function& function)
{
Status status = Initialize();
if (status != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Function, "MergeViewAssembleUtils initialization failed.");
return status;
}
status = ProcessOperations(function);
if (status != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Function, "Processing operations failed.");
return status;
}
status = CleanUp(function);
if (status != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Function, "Cleanup phase failed.");
return status;
}
return SUCCESS;
}
Status MergeViewAssembleUtils::Initialize()
{
visitedOp_.clear();
viewOpToAppend_.clear();
assembleOpToAppend_.clear();
return SUCCESS;
}
Status MergeViewAssembleUtils::ProcessOperations(Function& function)
{
for (auto& op : function.Operations()) {
if (visitedOp_.count(op.GetOpMagic()) != 0) {
continue;
}
Status processStatus = SUCCESS;
std::vector<Operation*> chain;
if (op.GetOpcode() == Opcode::OP_VIEW) {
processStatus = MergeViewChain(function, op, chain);
} else if (op.GetOpcode() == Opcode::OP_ASSEMBLE) {
processStatus = MergeAssembleChain(function, op, chain);
}
if (processStatus != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Function, "ProcessOperations failed.");
return processStatus;
}
}
Status status = AppendMergedViewOperations(function);
if (status != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Function, "AppendMergedViewOperations phase failed.");
return status;
}
status = AppendMergedAssembleOperations(function);
if (status != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Function, "AppendMergedAssembleOperations phase failed.");
return FAILED;
}
return status;
}
Status MergeViewAssembleUtils::AppendMergedViewOperations(Function& function)
{
for (auto& viewOp : viewOpToAppend_) {
auto attr =
std::make_shared<ViewOpAttribute>(viewOp.offset, viewOp.toType, viewOp.dynOffset, viewOp.dynValidShape);
if (!attr) {
APASS_LOG_ERROR_F(Elements::Function, "Failed to create ViewOpAttribute.");
return FAILED;
}
auto& mergedViewOp =
irBuilder_.CreateTensorOpStmt(function, Opcode::OP_VIEW, {viewOp.input}, {viewOp.output}, viewOp.span);
mergedViewOp.SetScopeInfo(viewOp.scopeInfo);
mergedViewOp.SetOpAttribute(attr);
if (viewOp.hasCopyInMode) {
mergedViewOp.SetAttr("op_attr_copy_in_mode", viewOp.copyInModeValue);
}
viewOp.output->UpdateDynValidShape(viewOp.dynValidShape);
}
return SUCCESS;
}
Status MergeViewAssembleUtils::AppendMergedAssembleOperations(Function& function)
{
for (const auto& assembleOp : assembleOpToAppend_) {
auto attr = std::make_shared<AssembleOpAttribute>(assembleOp.offset, assembleOp.dynOffset);
if (!attr) {
return FAILED;
}
auto& mergedAssembleOp = irBuilder_.CreateTensorOpStmt(
function, Opcode::OP_ASSEMBLE, {assembleOp.input}, {assembleOp.output}, assembleOp.span);
mergedAssembleOp.SetScopeInfo(assembleOp.scopeInfo);
mergedAssembleOp.SetOpAttribute(attr);
}
return SUCCESS;
}
Status MergeViewAssembleUtils::CleanUp(Function& function)
{
Status status = EraseRedundantAssemble(function);
if (status != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Function, "EraseRedundantAssemble failed.");
return status;
}
DeadOperationEliminator eliminator;
eliminator.EliminateOperation(function, false, false);
function.SortOperations(SortOperationsMode::LIGHTWEIGHT);
return SUCCESS;
}
ir::Span MergeViewAssembleUtils::GetFirstSpan(const std::vector<Operation*>& chain)
{
ir::Span firstSpan;
for (auto* op : chain) {
auto loc = op->GetSpan();
if (!loc.IsUnknown()) {
firstSpan = loc;
break;
}
}
return firstSpan;
}
Operation::ScopeInfo MergeViewAssembleUtils::GetChainScopeInfo(const std::vector<Operation*>& chain)
{
for (auto* op : chain) {
if (op->GetScopeId() != -1) {
return op->GetScopeInfo();
}
}
return Operation::ScopeInfo();
}
Status MergeViewAssembleUtils::MergeViewChain(
Function& function, Operation& operation, std::vector<Operation*>& chain, int effectiveScopeId)
{
auto viewOpAttribute = std::dynamic_pointer_cast<ViewOpAttribute>(operation.GetOpAttribute());
InitOperationChain(operation, chain);
int newScopeId = operation.GetScopeId();
if (effectiveScopeId == -1 && newScopeId != -1) {
effectiveScopeId = newScopeId;
}
auto consumers = function.FindConsumers(operation);
bool chainEnd = true;
Status status = ProcessConsumerChain(function, consumers, chain, chainEnd, effectiveScopeId);
if (status != SUCCESS) {
return status;
}
if (chainEnd && chain.size() > 1) {
return ProcessChainEnd(function, chain);
}
return SUCCESS;
}
void MergeViewAssembleUtils::InitOperationChain(Operation& operation, std::vector<Operation*>& chain)
{
visitedOp_.insert(operation.opmagic);
chain.emplace_back(&operation);
}
Status MergeViewAssembleUtils::ProcessConsumerChain(
Function& function, const std::set<Operation*, LogicalTensor::CompareOp>& consumers, std::vector<Operation*>& chain,
bool& chainEnd, int effectiveScopeId)
{
if (consumers.empty()) {
return SUCCESS;
}
Operation* currentOp = chain.back();
auto currentViewAttr = std::dynamic_pointer_cast<ViewOpAttribute>(currentOp->GetOpAttribute());
if (!currentViewAttr) {
APASS_LOG_ERROR_F(Elements::Function, "Failed to get current view attribute.");
return FAILED;
}
MemoryType currentMemType = currentViewAttr->GetTo();
for (auto& op : consumers) {
if (!op) {
return FAILED;
}
if (op->GetOpcode() == Opcode::OP_VIEW) {
auto viewOpAttribute = std::dynamic_pointer_cast<ViewOpAttribute>(op->GetOpAttribute());
if (viewOpAttribute == nullptr) {
APASS_LOG_ERROR_F(Elements::Function, "View operation has null viewOpAttribute.");
return FAILED;
}
auto memoryTo = viewOpAttribute->GetTo();
bool canMerge = false;
if (currentMemType == MemoryType::MEM_UNKNOWN || currentMemType == memoryTo) {
canMerge = true;
}
if (canMerge) {
int consumerScopeId = op->GetScopeId();
if (effectiveScopeId != -1 && consumerScopeId != -1 && effectiveScopeId != consumerScopeId) {
chainEnd = true;
continue;
}
chainEnd = false;
Status status = MergeViewChain(function, *op, chain, effectiveScopeId);
if (status != SUCCESS) {
return status;
}
chain.pop_back();
} else {
chainEnd = true;
}
}
}
return SUCCESS;
}
Status MergeViewAssembleUtils::ProcessChainEnd(Function& function, std::vector<Operation*>& chain)
{
(void)function;
Operation* startOp = chain.front();
Operation* endOp = chain.back();
if (startOp->iOperand.empty()) {
APASS_LOG_ERROR_F(Elements::Function, "First operation in chain has no input operands.");
return FAILED;
}
if (endOp->oOperand.empty()) {
APASS_LOG_ERROR_F(Elements::Function, "Last operation in chain has no output operands.");
return FAILED;
}
auto& startTensor = startOp->iOperand.front();
auto& endTensor = endOp->oOperand.front();
if (!startTensor) {
APASS_LOG_ERROR_F(Elements::Function, "Null input tensor found for first operation in chain.");
return FAILED;
}
if (!endTensor) {
APASS_LOG_ERROR_F(Elements::Function, "Null output tensor found for last operation in chain.");
return FAILED;
}
std::vector<int64_t> newOffset;
std::vector<SymbolicScalar> newDynOffset;
std::vector<SymbolicScalar> newDynValidShape;
Status status = CalculateMergedOffsets(chain, newOffset, newDynOffset, newDynValidShape);
if (status != SUCCESS) {
return status;
}
ir::Span firstSpan = GetFirstSpan(chain);
Operation::ScopeInfo chainScopeInfo = GetChainScopeInfo(chain);
RecordMergedViewOperation(
endOp, startTensor, endTensor, newOffset, newDynOffset, newDynValidShape, firstSpan, chainScopeInfo);
endOp->oOperand.clear();
return SUCCESS;
}
Status MergeViewAssembleUtils::CalculateMergedOffsets(
const std::vector<Operation*>& chain, std::vector<int64_t>& newOffset, std::vector<SymbolicScalar>& newDynOffset,
std::vector<SymbolicScalar>& newDynValidShape)
{
for (size_t i = 0; i < chain.size(); ++i) {
const auto& view = chain[i];
if (!view) {
APASS_LOG_ERROR_F(Elements::Function, "Null view operation in chain.");
return FAILED;
}
auto viewOpAttribute = std::dynamic_pointer_cast<ViewOpAttribute>(view->GetOpAttribute());
if (!viewOpAttribute) {
APASS_LOG_ERROR_F(Elements::Function, "Failed to get ViewOpAttribute.");
return FAILED;
}
if (i == 0) {
newOffset = viewOpAttribute->GetFromOffset();
newDynOffset = viewOpAttribute->GetFromDynOffset();
if (!viewOpAttribute->GetToDynValidShape().empty()) {
newDynValidShape = viewOpAttribute->GetToDynValidShape();
}
continue;
}
auto ret = TensorOffset::Add(
newOffset, newDynOffset, viewOpAttribute->GetFromOffset(), viewOpAttribute->GetFromDynOffset());
if (!ret.first.empty()) {
newOffset = ret.first;
newDynOffset = ret.second;
}
if (!viewOpAttribute->GetToDynValidShape().empty()) {
newDynValidShape = viewOpAttribute->GetToDynValidShape();
continue;
}
newDynValidShape = GetViewValidShape(
newDynValidShape, viewOpAttribute->GetFromOffset(), viewOpAttribute->GetFromDynOffset(),
view->GetOOperands()[0]->GetShape());
}
return SUCCESS;
}
void MergeViewAssembleUtils::RecordMergedViewOperation(
Operation* lastViewOp, const std::shared_ptr<LogicalTensor>& startTensor,
const std::shared_ptr<LogicalTensor>& endTensor, const std::vector<int64_t>& newOffset,
const std::vector<SymbolicScalar>& newDynOffset, const std::vector<SymbolicScalar>& newDynValidShape,
const ir::Span& span, const Operation::ScopeInfo& scopeInfo)
{
auto lastViewAttr = std::dynamic_pointer_cast<ViewOpAttribute>(lastViewOp->GetOpAttribute());
if (!lastViewAttr) {
return;
}
int64_t copyInModeValue = 0;
bool hasCopyInMode = lastViewOp->GetAttr<int64_t>("op_attr_copy_in_mode", copyInModeValue);
endTensor->GetProducers().clear();
viewOpToAppend_.emplace_back(ViewOp{
startTensor, endTensor, newOffset, newDynOffset, newDynValidShape, lastViewAttr->GetTo(), hasCopyInMode,
std::move(copyInModeValue), span, scopeInfo});
}
Status MergeViewAssembleUtils::MergeAssembleChain(
Function& function, Operation& operation, std::vector<Operation*>& chain, int effectiveScopeId)
{
InitAssembleChain(operation, chain);
int newScopeId = operation.GetScopeId();
if (effectiveScopeId == -1 && newScopeId != -1) {
effectiveScopeId = newScopeId;
}
bool chainEnd = false;
bool hasAssembleConsumer = false;
if (assembleWithoutAssembleConsumer_.count(operation.opmagic) == 0) {
auto consumers = function.FindConsumers(operation);
chainEnd = consumers.empty();
Status status =
ProcessAssembleConsumers(function, consumers, chain, chainEnd, hasAssembleConsumer, effectiveScopeId);
if (status != SUCCESS) {
return status;
}
if (!hasAssembleConsumer) {
assembleWithoutAssembleConsumer_.insert(operation.opmagic);
}
} else {
chainEnd = true;
}
if (chainEnd && chain.size() > 1) {
Status status = ProcessAssembleChainEnd(function, chain, operation);
if (status != SUCCESS) {
return status;
}
}
chain.pop_back();
return SUCCESS;
}
void MergeViewAssembleUtils::InitAssembleChain(Operation& operation, std::vector<Operation*>& chain)
{
visitedOp_.insert(operation.opmagic);
chain.emplace_back(&operation);
}
Status MergeViewAssembleUtils::ProcessAssembleConsumers(
Function& function, const std::set<Operation*, LogicalTensor::CompareOp>& consumers, std::vector<Operation*>& chain,
bool& chainEnd, bool& hasAssembleConsumer, int effectiveScopeId)
{
if (consumers.empty()) {
return SUCCESS;
}
for (auto& op : consumers) {
if (!op) {
APASS_LOG_ERROR_F(Elements::Function, "Null consumer operation found.");
return FAILED;
}
if (op->GetOpcode() == Opcode::OP_ASSEMBLE) {
int consumerScopeId = op->GetScopeId();
if (effectiveScopeId != -1 && consumerScopeId != -1 && effectiveScopeId != consumerScopeId) {
chainEnd = true;
continue;
}
hasAssembleConsumer = true;
Status status = MergeAssembleChain(function, *op, chain, effectiveScopeId);
if (status != SUCCESS) {
APASS_LOG_ERROR_F(Elements::Function, "Run MergeAssembleChain failed.");
return status;
}
continue;
}
chainEnd = true;
}
return SUCCESS;
}
Status MergeViewAssembleUtils::ProcessAssembleChainEnd(
Function& function, std::vector<Operation*>& chain, Operation& operation)
{
(void)function;
if (chain.front()->iOperand.empty() || chain.back()->oOperand.empty()) {
APASS_LOG_ERROR_F(Elements::Function, "Invalid chain operations.");
return FAILED;
}
auto& startTensor = chain.front()->iOperand.front();
auto& endTensor = chain.back()->oOperand.front();
if (!startTensor || !endTensor) {
APASS_LOG_ERROR_F(Elements::Function, "Null tensor found in chain.");
return FAILED;
}
auto [newOffset, newDynOffset] = CalculateAssembleOffsets(chain, startTensor->offset.size());
ir::Span firstSpan = GetFirstSpan(chain);
Operation::ScopeInfo chainScopeInfo = GetChainScopeInfo(chain);
RecordAssembleOperation(startTensor, endTensor, newOffset, newDynOffset, firstSpan, chainScopeInfo);
operation.SetAsDeleted();
return SUCCESS;
}
std::pair<std::vector<int64_t>, std::vector<SymbolicScalar>> MergeViewAssembleUtils::CalculateAssembleOffsets(
const std::vector<Operation*>& chain, size_t offsetSize)
{
std::vector<int64_t> newOffset(offsetSize, 0);
std::vector<SymbolicScalar> newDynOffset;
for (size_t i = 0; i < chain.size(); ++i) {
const auto& assemble = chain[i];
if (!assemble) {
return {};
}
auto assembleOpAttribute = std::dynamic_pointer_cast<AssembleOpAttribute>(assemble->GetOpAttribute());
if (!assembleOpAttribute) {
return {};
}
if (i == 0) {
newOffset = assembleOpAttribute->GetToOffset();
newDynOffset = assembleOpAttribute->GetToDynOffset();
continue;
}
auto ret = TensorOffset::Add(
newOffset, newDynOffset, assembleOpAttribute->GetToOffset(), assembleOpAttribute->GetToDynOffset());
if (!ret.first.empty()) {
newOffset = ret.first;
newDynOffset = ret.second;
}
}
return {newOffset, newDynOffset};
}
void MergeViewAssembleUtils::RecordAssembleOperation(
const std::shared_ptr<LogicalTensor>& input, const std::shared_ptr<LogicalTensor>& output,
const std::vector<int64_t>& offset, const std::vector<SymbolicScalar>& dynOffset, const ir::Span& span,
const Operation::ScopeInfo& scopeInfo)
{
assembleOpToAppend_.emplace_back(AssembleOp{input, output, offset, dynOffset, span, scopeInfo});
}
Status MergeViewAssembleUtils::EraseRedundantAssemble(Function& function) const
{
std::unordered_set<Operation*> redundantAssembles;
for (auto& op : function.Operations(false)) {
if (op.GetOpcode() != Opcode::OP_ASSEMBLE) {
continue;
}
if (op.iOperand.empty()) {
APASS_LOG_ERROR_F(Elements::Function, "Assemble operation with no input operands.");
return FAILED;
}
if (op.iOperand.front()->GetProducers().empty()) {
redundantAssembles.emplace(&op);
}
}
for (auto& ele : redundantAssembles) {
if (!ele) {
continue;
}
ele->SetAsDeleted();
}
function.EraseOperations(true, false);
return SUCCESS;
}
}