* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file graph_utils.cpp
* \brief
*/
#include "graph_utils.h"
#include "pass_utils.h"
#include "interface/tensor/irbuilder.h"
#include "passes/pass_utils/pass_operation_utils.h"
namespace npu {
namespace tile_fwk {
void GraphUtils::SetDynShape(Operation* newOp, const std::vector<std::vector<SymbolicScalar>>& outDynShape)
{
if (outDynShape.empty()) {
InferShapeRegistry::GetInstance().CallInferShapeFunc(newOp);
} else {
for (size_t i = 0; i < newOp->GetOOperands().size(); ++i) {
newOp->GetOOperands()[i]->UpdateDynValidShape(outDynShape[i]);
}
}
}
Operation& GraphUtils::AddDynOperation(
Function& function, const Opcode opCode, LogicalTensors iOperands, const LogicalTensors& oOperands,
const std::vector<std::vector<SymbolicScalar>>& outDynShape)
{
auto& newOp = PassOperationUtils::AddOperation(function, opCode, std::move(iOperands), oOperands, nullptr, ir::Span::Unknown(), false);
SetDynShape(&newOp, outDynShape);
return newOp;
}
Operation& GraphUtils::AddAssembleOperation(
Function& function, const AssembleOp& assemble, const std::vector<std::vector<SymbolicScalar>>& outDynShape)
{
IRBuilder builder;
auto& newOp = builder.CreateTensorOpStmt(
function, Opcode::OP_ASSEMBLE, {assemble.input}, {assemble.output});
if (assemble.originOp != nullptr) {
newOp.SetScopeInfo(assemble.originOp->GetScopeInfo());
newOp.CopyAttrFrom(*assemble.originOp, "");
}
SetAssembleAttr(newOp, assemble);
SetDynShape(&newOp, outDynShape);
return newOp;
}
Operation& GraphUtils::AddReshapeOperation(
Function& function, const LogicalTensorPtr iOperand, const LogicalTensorPtr& oOperand, const ReshapeOp& reshapeOp,
const std::vector<SymbolicScalar>& outDynShape)
{
auto& newOp = PassOperationUtils::AddOperation(function, Opcode::OP_RESHAPE, {iOperand}, {oOperand}, nullptr, ir::Span::Unknown(), false);
if (reshapeOp.originOpPtr != nullptr) {
newOp.SetScopeInfo(reshapeOp.originOpPtr->GetScopeInfo());
newOp.CopyAttrFrom(*reshapeOp.originOpPtr, "");
}
if (outDynShape.empty()) {
InferShapeRegistry::GetInstance().CallInferShapeFunc(&newOp);
std::vector<SymbolicScalar> validShape;
if (!newOp.GetAttr(OP_ATTR_PREFIX + "validShape", validShape) || validShape.empty()) {
newOp.SetAttribute(OP_ATTR_PREFIX + "validShape", oOperand->GetDynValidShape());
}
} else {
newOp.SetAttribute(OP_ATTR_PREFIX + "validShape", outDynShape);
oOperand->UpdateDynValidShape(outDynShape);
}
return newOp;
}
void GraphUtils::CopyDynStatus(const LogicalTensorPtr& dstTensor, const LogicalTensorPtr& srcTensor)
{
dstTensor->UpdateDynValidShape(srcTensor->GetDynValidShape());
}
void GraphUtils::UpdateViewAttr(Function& function, Operation& op)
{
LogicalTensorPtr input = op.GetIOperands().front();
LogicalTensorPtr output = op.GetIOperands().front();
auto viewAttribute = dynamic_cast<ViewOpAttribute*>(op.GetOpAttribute().get());
if (function.IsFromInCast(input) || function.IsFromOutCast(output)) {
if (viewAttribute->GetFromDynOffset().empty()) {
std::vector<int64_t> fromOffset = viewAttribute->GetFromOffset();
std::vector<SymbolicScalar> fromDynOffset = CommonUtils::CreateConstIntVector(fromOffset);
viewAttribute->SetFromOffset(fromOffset, fromDynOffset);
}
}
}
void GraphUtils::SetAssembleAttr(Operation& op, const AssembleOp& assemble)
{
auto assembleOpAttribute = std::make_shared<AssembleOpAttribute>(assemble.from, assemble.toOffset);
auto fromValidShape = assemble.input->GetDynValidShape();
assembleOpAttribute->SetFromDynValidShape(fromValidShape);
op.SetOpAttribute(assembleOpAttribute);
}
bool GraphUtils::IsCVMixPlatform()
{
if (Platform::Instance().GetSoc().GetNPUArch() == NPUArch::DAV_3510) {
return true;
}
return false;
}
TensorSet GraphUtils::GetTensorsByRawMagic(Function& function, int64_t rawMagic)
{
TensorSet result;
for (const auto& tensor : function.inCasts_) {
if (tensor && tensor->tensor && tensor->tensor->rawmagic == rawMagic) {
result.insert(tensor);
}
}
for (const auto& tensor : function.outCasts_) {
if (tensor && tensor->tensor && tensor->tensor->rawmagic == rawMagic) {
result.insert(tensor);
}
}
for (auto& op : function.Operations(false)) {
for (const auto& tensor : op.GetOOperands()) {
if (tensor && tensor->tensor && tensor->tensor->rawmagic == rawMagic) {
result.insert(tensor);
}
}
}
for (auto& op : function.Operations(false)) {
for (const auto& tensor : op.GetIOperands()) {
if (tensor && tensor->tensor && tensor->tensor->rawmagic == rawMagic) {
result.insert(tensor);
}
}
}
return result;
}
std::shared_ptr<RawTensor> GraphUtils::GetRawTensorByRawMagic(Function& function, int64_t rawMagic)
{
auto tensors = GetTensorsByRawMagic(function, rawMagic);
if (tensors.empty()) {
return nullptr;
}
const auto& firstTensor = *tensors.begin();
if (firstTensor == nullptr) {
return nullptr;
}
return firstTensor->tensor;
}
TensorSet GraphUtils::GetTensorsByActualRawMagic(Function& function, int64_t actualRawMagic)
{
TensorSet result;
for (const auto& tensor : function.inCasts_) {
if (tensor && tensor->tensor && tensor->tensor->actualRawmagic == actualRawMagic) {
result.insert(tensor);
}
}
for (const auto& tensor : function.outCasts_) {
if (tensor && tensor->tensor && tensor->tensor->actualRawmagic == actualRawMagic) {
result.insert(tensor);
}
}
for (auto& op : function.Operations(false)) {
for (const auto& tensor : op.GetOOperands()) {
if (tensor && tensor->tensor && tensor->tensor->actualRawmagic == actualRawMagic) {
result.insert(tensor);
}
}
}
for (auto& op : function.Operations(false)) {
for (const auto& tensor : op.GetIOperands()) {
if (tensor && tensor->tensor && tensor->tensor->actualRawmagic == actualRawMagic) {
result.insert(tensor);
}
}
}
return result;
}
std::vector<LogicalTensorPtr> GraphUtils::FindOverlappedTensors(Function& function, const LogicalTensorPtr& tensor)
{
if (tensor == nullptr || tensor->tensor == nullptr) {
return {};
}
auto candidates = GraphUtils::GetTensorsByRawMagic(function, tensor->tensor->rawmagic);
if (candidates.empty()) {
if (!function.HasParent() ||
function.IsFunctionTypeAndGraphType(FunctionType::STATIC, GraphType::EXECUTE_GRAPH)) {
return {};
}
return FindOverlappedTensors(function.Parent(), tensor);
}
std::vector<LogicalTensorPtr> result;
for (const auto& candidate : candidates) {
if (candidate == nullptr || candidate->tensor == nullptr) {
continue;
}
if (candidate->magic == tensor->magic) {
continue;
}
if (CalcOverlap(candidate, tensor) == OverlapStatus::NO_OVER_LAP) {
continue;
}
result.push_back(candidate);
}
return result;
}
LogicalTensorPtr GraphUtils::GetTensorByMagic(Function& function, int magic)
{
for (const auto& tensor : function.inCasts_) {
if (tensor && tensor->GetMagic() == magic) {
return tensor;
}
}
for (const auto& tensor : function.outCasts_) {
if (tensor && tensor->GetMagic() == magic) {
return tensor;
}
}
for (auto& op : function.Operations(false)) {
for (const auto& tensor : op.GetOOperands()) {
if (tensor && tensor->GetMagic() == magic) {
return tensor;
}
}
}
for (auto& op : function.Operations(false)) {
for (const auto& tensor : op.GetIOperands()) {
if (tensor && tensor->GetMagic() == magic) {
return tensor;
}
}
}
return nullptr;
}
TensorSet GraphUtils::GetAllTensors(Function& function)
{
TensorSet result;
for (const auto& tensor : function.inCasts_) {
if (tensor) {
result.insert(tensor);
}
}
for (const auto& tensor : function.outCasts_) {
if (tensor) {
result.insert(tensor);
}
}
for (auto& op : function.Operations(false)) {
for (const auto& tensor : op.GetOOperands()) {
if (tensor) {
result.insert(tensor);
}
}
}
for (auto& op : function.Operations(false)) {
for (const auto& tensor : op.GetIOperands()) {
if (tensor) {
result.insert(tensor);
}
}
}
return result;
}
}
}