* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file tensormap.cpp
* \brief
*/
#include "interface/tensor/tensormap.h"
#include <sstream>
#include <unordered_map>
#include <vector>
#include <memory>
#include <string>
#include <algorithm>
#include <set>
#include <numeric>
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "interface/program/program.h"
#include "interface/function/function.h"
#include "interface/tensor/raw_tensor.h"
#include "interface/tensor/logical_tensor.h"
namespace npu::tile_fwk {
std::vector<int> tensor2shape(const std::shared_ptr<LogicalTensor>& pTensor)
{
std::vector<int> tensorShape;
for (size_t dim = 0; dim < pTensor->offset.size(); dim++) {
int pStart = pTensor->offset[dim];
int pEnd = pStart + pTensor->shape[dim];
tensorShape.push_back(pStart);
tensorShape.push_back(pEnd);
}
return tensorShape;
}
bool Overlap(const std::shared_ptr<LogicalTensor>& t0, const std::shared_ptr<LogicalTensor>& t1)
{
return CalcOverlap(t0, t1) != OverlapStatus::NO_OVER_LAP;
}
bool TensorPtrComparator::operator()(
const std::shared_ptr<LogicalTensor>& lhs, const std::shared_ptr<LogicalTensor>& rhs) const
{
return lhs->magic > rhs->magic;
}
void TensorMap::Insert(std::shared_ptr<LogicalTensor> tobject, bool checkOverlap)
{
FE_ASSERT(tobject != nullptr) << "tobject is nullptr.";
std::vector<int> tensorShape = tensor2shape(tobject);
if (checkOverlap) {
auto match = Find(tobject);
if (!match.empty()) {
FE_LOGI("Tensor %d is full covered in function %s", tobject->magic, belongTo.GetRawName().c_str());
return;
}
}
FE_ASSERT(tobject->tensor != nullptr) << "tobject->tensor is nullptr.";
int rawmagic = tobject->tensor->rawmagic;
auto& tensorList = tensorMap_[rawmagic];
if (inverseMap_.count(tobject->magic) > 0) {
auto& rt = inverseMap_[tobject->magic];
tensorList.erase(rt);
std::vector<int> rtensorShape = tensor2shape(rt);
int rrawmagic = rt->tensor->rawmagic;
overlapChecker_[rrawmagic].Erase(rtensorShape, rt);
}
if (checkOverlap) {
if (belongTo.expandFunctionAccelerate) {
auto match = Find(tobject);
for (auto it = match.begin(); it != match.end(); it++) {
if ((*it)->shape == tobject->shape && (*it)->offset == tobject->offset) {
tensorList.erase(*it);
std::vector<int> tShape = tensor2shape(*it);
overlapChecker_[(*it)->tensor->rawmagic].Erase(tShape, *it);
}
}
} else {
for (auto it = tensorList.begin(); it != tensorList.end();) {
if ((*it)->shape == tobject->shape && (*it)->offset == tobject->offset) {
it = tensorList.erase(it);
} else {
it++;
}
}
}
}
tensorList.insert(tobject);
overlapChecker_[rawmagic].Insert(tensorShape, tobject);
inverseMap_[tobject->magic] = tobject;
}
void TensorMap::Erase(const std::shared_ptr<LogicalTensor>& ttensor)
{
if (tensorMap_.count(ttensor->tensor->rawmagic) != 0) {
tensorMap_[ttensor->tensor->rawmagic].erase(ttensor);
}
if (inverseMap_.count(ttensor->magic) != 0) {
inverseMap_.erase(ttensor->magic);
}
std::vector<int> tensorShape = tensor2shape(ttensor);
int rawmagic = ttensor->tensor->rawmagic;
overlapChecker_[rawmagic].Erase(tensorShape, ttensor);
}
void TensorMap::EraseRawMagic(int rawmagic)
{
tensorMap_.erase(rawmagic);
overlapChecker_.erase(rawmagic);
}
std::shared_ptr<LogicalTensor> TensorMap::GetTensorByMagic(int magic) const
{
auto it = inverseMap_.find(magic);
if (it != inverseMap_.end()) {
return it->second;
}
return nullptr;
}
std::shared_ptr<RawTensor> TensorMap::GetRawTensorByRawMagic(int rawMagic) const
{
auto iter = tensorMap_.find(rawMagic);
if (iter == tensorMap_.end()) {
return nullptr;
}
if (iter->second.empty()) {
return nullptr;
}
auto firstTensor = *(iter->second.begin());
return firstTensor->tensor;
}
void CalcShapeAndOffsetOfGroup(
const std::vector<std::shared_ptr<LogicalTensor>>& tensors, std::vector<int64_t>& resultOffset,
std::vector<int64_t>& resultShape)
{
resultOffset = tensors.front()->offset;
std::vector<int64_t> maximumOffset;
for (size_t i = 0; i < tensors.front()->offset.size(); i++) {
maximumOffset.emplace_back(tensors.front()->offset[i] + tensors.front()->shape[i]);
}
for (const auto& incast : tensors) {
for (size_t i = 0; i < incast->offset.size(); ++i) {
resultOffset[i] = std::min(resultOffset[i], incast->offset[i]);
maximumOffset[i] = std::max(maximumOffset[i], incast->offset[i] + incast->shape[i]);
}
}
for (size_t i = 0; i < resultOffset.size(); ++i) {
resultShape.emplace_back(maximumOffset[i] - resultOffset[i]);
}
};
void TensorMap::Reset()
{
tensorMap_.clear();
inverseMap_.clear();
overlapChecker_.clear();
}
int CalcOverlapSize(const std::shared_ptr<LogicalTensor>& pTensor, const std::shared_ptr<LogicalTensor>& qTensor)
{
std::vector<int> overlapEdge;
for (size_t dim = 0; dim < pTensor->offset.size(); dim++) {
int pStart = pTensor->offset[dim];
int pEnd = pStart + pTensor->shape[dim];
int qStart = qTensor->offset[dim];
int qEnd = qStart + qTensor->shape[dim];
std::vector<int> range = {pStart, pEnd, qStart, qEnd};
std::sort(range.begin(), range.end());
overlapEdge.push_back(range[2] - range[1]);
}
return std::accumulate(overlapEdge.begin(), overlapEdge.end(), 1, std::multiplies<>());
}
OverlapStatus CalcOverlapByOffsetShape(
const std::vector<int64_t>& pOffset, const std::vector<int64_t>& pShape, const std::vector<int64_t>& qOffset,
const std::vector<int64_t>& qShape) noexcept
{
if (pOffset.size() != qOffset.size() || pShape.size() != qShape.size() || pOffset.size() != pShape.size()) {
return OverlapStatus::NO_OVER_LAP;
}
bool perfectlyMatch = true;
bool pCoverQ = true;
bool qCoverP = true;
for (size_t dim = 0; dim < pOffset.size(); dim++) {
const int64_t pStart = pOffset[dim];
const int64_t pEnd = pStart + pShape[dim] - 1;
const int64_t qStart = qOffset[dim];
const int64_t qEnd = qStart + qShape[dim] - 1;
if (pEnd < qStart || qEnd < pStart) {
return OverlapStatus::NO_OVER_LAP;
}
pCoverQ &= (pStart <= qStart && qEnd <= pEnd);
qCoverP &= (qStart <= pStart && pEnd <= qEnd);
perfectlyMatch &= (pStart == qStart && pEnd == qEnd);
}
if (perfectlyMatch) {
return OverlapStatus::PERFECTLY_MATCH;
} else if (pCoverQ) {
return OverlapStatus::COVERED;
} else if (qCoverP) {
return OverlapStatus::BE_COVERED;
} else {
return OverlapStatus::PARTIAL_OVERLAP;
}
}
OverlapStatus CalcOverlap(
const std::shared_ptr<LogicalTensor>& pTensor, const std::shared_ptr<LogicalTensor>& qTensor, bool loose)
{
if (!pTensor || !qTensor)
return OverlapStatus::NO_OVER_LAP;
if (!loose && pTensor->tensor->rawmagic != qTensor->tensor->rawmagic) {
return OverlapStatus::NO_OVER_LAP;
}
return CalcOverlapByOffsetShape(pTensor->offset, pTensor->shape, qTensor->offset, qTensor->shape);
}
OverlapStatus CalcOverlap(
const std::shared_ptr<LogicalTensor>& pTensor, const std::vector<std::shared_ptr<LogicalTensor>>& pGroup,
bool loose)
{
OverlapStatus status = OverlapStatus::NO_OVER_LAP;
if (pGroup.empty()) {
return status;
} else if (pGroup.size() == 1) {
status = CalcOverlap(pTensor, pGroup.front(), loose);
} else {
int overlapSize = 0;
bool coveredAll = true;
for (auto& other : pGroup) {
auto subStatus = CalcOverlap(pTensor, other, loose);
coveredAll &= subStatus == OverlapStatus::COVERED;
switch (subStatus) {
case OverlapStatus::NO_OVER_LAP:
return OverlapStatus::NO_OVER_LAP;
case OverlapStatus::PERFECTLY_MATCH:
return OverlapStatus::BE_COVERED_BY_ALL;
case OverlapStatus::COVERED: {
overlapSize += std::accumulate(other->shape.begin(), other->shape.end(), 1, std::multiplies<>());
break;
}
case OverlapStatus::PARTIAL_OVERLAP: {
overlapSize += CalcOverlapSize(pTensor, other);
break;
}
default:
break;
}
}
auto pSize = std::accumulate(pTensor->shape.begin(), pTensor->shape.end(), 1, std::multiplies<>());
if (coveredAll) {
if (pSize == overlapSize) {
status = OverlapStatus::PERFECTLY_MATCH_WITH_ALL;
} else if (pSize > overlapSize) {
status = OverlapStatus::COVERED_ALL;
}
} else {
if (pSize == overlapSize) {
status = OverlapStatus::BE_COVERED_BY_ALL;
} else {
status = OverlapStatus::PARTIAL_OVERLAP_WITH_ALL;
}
}
}
return status;
}
std::vector<std::shared_ptr<LogicalTensor>> TensorMap::Find(std::shared_ptr<LogicalTensor> ttensor)
{
std::vector<std::shared_ptr<LogicalTensor>> result;
auto it = tensorMap_.find(ttensor->tensor->rawmagic);
if (it == tensorMap_.end()) {
if (!belongTo.HasParent() ||
belongTo.IsFunctionTypeAndGraphType(FunctionType::STATIC, GraphType::EXECUTE_GRAPH)) {
return {};
}
return belongTo.Parent().GetTensorMap().Find(ttensor);
}
if (belongTo.expandFunctionAccelerate) {
std::vector<int> tensorShape = tensor2shape(ttensor);
int rawmagic = ttensor->tensor->rawmagic;
result = overlapChecker_[rawmagic].Find(tensorShape);
} else {
for (const auto& tensorPtr : it->second) {
bool overlap = Overlap(tensorPtr, ttensor);
if (!overlap) {
continue;
}
result.push_back(tensorPtr);
}
}
return result;
}
void TensorMap::ValidCheck() const {}
}