* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ACL_UTILS_ACL_OP_MAP_H
#define ACL_UTILS_ACL_OP_MAP_H
#include <map>
#include <mutex>
#include <vector>
#include <climits>
#include "types/acl_op_inner.h"
#include "utils/attr_utils.h"
#include "utils/hash_utils.h"
#include <chrono>
#include <unordered_map>
namespace acl {
template<typename T>
class AclOpMap {
public:
aclError Insert(const AclOp &op, const T &entry, T &agingT, bool &isDeduplicate);
aclError InsertDynamic(const AclOp &op, const T &entry, T &agingT, bool &isDeduplicate);
aclError Get(const AclOp &op, T &entry, const bool needUpdateTimestamp = false);
aclError GetDynamic(const AclOp &op, T &entry, const uint64_t seq, const bool needUpdateTimestamp = false);
void SetMaxOpNum(const uint64_t maxNum)
{
maxOpNum = maxNum;
}
private:
aclError Aging(T &agingT);
void Updatetimestamp(T &entry) const;
bool Deduplicate(const AclOp &op, const aclopAttr *const attr, const T &entry,
const size_t seed, const bool isDynamic);
aclError AddMemAndAging(const T &entry, T &agingT, const size_t seed);
bool CheckValueRange(const AclOp &op, const T &entry) const;
bool HasSameValueRange(const AclOp &op, const T &entry) const;
using HashMap = std::unordered_map<size_t, std::vector<T>>;
HashMap hashMap_;
mutable std::mutex mutex_;
uint64_t cnt{0U};
uint64_t maxOpNum{DEFAULT_MAX_OPQUEUE_NUM};
};
template<typename T>
aclError AclOpMap<T>::Aging(T &agingT)
{
uint64_t timestampMin = static_cast<uint64_t>(ULONG_MAX);
typename HashMap::iterator itHashMapMin;
typename std::vector<T>::iterator itVectorMin;
bool found = false;
for (auto hashMapIter = hashMap_.begin(); hashMapIter != hashMap_.end(); ++hashMapIter) {
for (auto vecIter = hashMapIter->second.begin(); vecIter != hashMapIter->second.end(); ++vecIter) {
if ((*vecIter)->timestamp < timestampMin) {
timestampMin = (*vecIter)->timestamp;
itHashMapMin = hashMapIter;
itVectorMin = vecIter;
found = true;
}
}
}
if (!found) {
ACL_LOG_DEBUG("AclOpMap::Aging IN, cannot find minimum value");
return ACL_SUCCESS;
}
ACL_LOG_INFO("AclOpMap::Aging IN, type = %s, digest = %zu, timestamp is %lu",
(*itVectorMin)->opType.c_str(), itHashMapMin->first, timestampMin);
agingT = *itVectorMin;
(void)itHashMapMin->second.erase(itVectorMin);
--cnt;
if (itHashMapMin->second.empty()) {
ACL_LOG_INFO("AclOpMap::After delete model, hash map empty while seed is %zu. delete seed in HashMap",
itHashMapMin->first);
(void)hashMap_.erase(itHashMapMin);
}
return ACL_SUCCESS;
}
template<typename T>
void AclOpMap<T>::Updatetimestamp(T &entry) const
{
if (entry->timestamp == static_cast<uint64_t>(ULLONG_MAX)) {
return;
}
entry->timestamp = attr_utils::GetCurrentTimestamp();
}
template<typename T>
bool AclOpMap<T>::HasSameValueRange(const AclOp &op, const T &entry) const
{
if ((static_cast<size_t>(op.numInputs) != entry->inputDescArr.size()) ||
(static_cast<size_t>(op.numOutputs) != entry->outputDescArr.size())) {
return false;
}
for (size_t i = 0U; i < entry->inputDescArr.size(); ++i) {
ACL_LOG_INFO("the input [%zu] needs to check value range", i);
if (!attr_utils::IsSameValueRange(entry->inputDescArr[i].valueRange, op.inputDesc[i]->valueRange)) {
return false;
}
}
for (size_t i = 0U; i < entry->outputDescArr.size(); ++i) {
ACL_LOG_INFO("the output [%zu] needs to check value range", i);
if (!attr_utils::IsSameValueRange(entry->outputDescArr[i].valueRange, op.outputDesc[i]->valueRange)) {
return false;
}
}
return true;
}
template<typename T>
bool AclOpMap<T>::Deduplicate(const AclOp &op, const aclopAttr *const attr, const T &entry,
const size_t seed, const bool isDynamic)
{
const auto iter = hashMap_.find(seed);
if (iter != hashMap_.end()) {
for (auto vecIter = iter->second.begin(); vecIter != iter->second.end(); ++vecIter) {
if (isDynamic) {
if (!HasSameValueRange(op, *vecIter)) {
continue;
}
if (hash_utils::CheckModelAndAttrMatchDynamic(op, attr, *vecIter, entry->seq)) {
Updatetimestamp(*vecIter);
ACL_LOG_DEBUG("Find same dynamic op_desc in Hashmap, seed %zu", seed);
return true;
}
} else {
if (hash_utils::CheckModelAndAttrMatch(op, attr, *vecIter)) {
Updatetimestamp(*vecIter);
ACL_LOG_DEBUG("Find same static op_desc in Hashmap, seed %zu", seed);
return true;
}
}
}
}
return false;
}
template<typename T>
aclError AclOpMap<T>::AddMemAndAging(const T &entry, T &agingT, const size_t seed)
{
hashMap_[seed].emplace_back(entry);
ACL_LOG_INFO("AclOpMap::Insert op into HashMap success, seed = %zu", seed);
++cnt;
if ((entry->timestamp == static_cast<uint64_t>(ULLONG_MAX)) || (cnt <= maxOpNum)) {
ACL_LOG_INFO("AclOpMap::AddMemAndAging in, cnt is %lu, maxOpNum is %lu, no need aging", cnt, maxOpNum);
return ACL_SUCCESS;
}
ACL_LOG_INFO("AclOpMap::time stamp is %lu, cnt is %lu, maxOpNum is %lu, start aging",
entry->timestamp, cnt, maxOpNum);
return Aging(agingT);
}
template<typename T>
bool AclOpMap<T>::CheckValueRange(const AclOp &op, const T &entry) const
{
for (size_t i = 0U; i < entry->inputDescArr.size(); ++i) {
if ((entry->inputDescArr[i].IsHostMemTensor()) && (!entry->inputDescArr[i].valueRange.empty())) {
ACL_LOG_INFO("the input [%zu] needs to check value range", i);
if (!attr_utils::ValueRangeCheck(entry->inputDescArr[i].valueRange,
op.inputs[i], entry->inputDescArr[i].dataType)) {
ACL_LOG_DEBUG("ValueRangeCheck input is not match");
return false;
}
}
}
for (size_t i = 0U; i < entry->outputDescArr.size(); ++i) {
if ((entry->outputDescArr[i].IsHostMemTensor()) && (!entry->outputDescArr[i].valueRange.empty())) {
ACL_LOG_INFO("the output [%zu] needs to check value range", i);
if (!attr_utils::ValueRangeCheck(entry->outputDescArr[i].valueRange,
op.outputs[i], entry->outputDescArr[i].dataType)) {
ACL_LOG_DEBUG("ValueRangeCheck output is not match");
return false;
}
}
}
return true;
}
template<typename T>
aclError AclOpMap<T>::Insert(const AclOp &op, const T &entry, T &agingT, bool &isDeduplicate)
{
ACL_LOG_DEBUG("AclOpMap::Insert IN, op = %s", op.DebugString().c_str());
size_t digest = 0U;
auto opAttr = op.opAttr;
aclopAttr emptyAttr;
if (opAttr != nullptr) {
if (!attr_utils::SaveConstToAttr(op, const_cast<aclopAttr *>(opAttr))) {
ACL_LOG_ERROR("[Save][ConstData]save const data buffer to attr fail");
return ACL_ERROR_FAILURE;
}
digest = attr_utils::AttrMapToDigest(opAttr->Attrs());
} else {
if (!attr_utils::SaveConstToAttr(op, &emptyAttr)) {
ACL_LOG_ERROR("[Save][ConstData]save const data buffer to attr fail");
return ACL_ERROR_FAILURE;
}
digest = attr_utils::AttrMapToDigest(emptyAttr.Attrs());
opAttr = &emptyAttr;
}
size_t seed = 0U;
if (hash_utils::GetAclOpHash(op, opAttr, digest, seed) != ACL_SUCCESS) {
ACL_LOG_ERROR("[Check][GetAclOpHash]GetAclOpHash failed, seed = %zu, op = %s",
seed, op.DebugString().c_str());
return ACL_ERROR_FAILURE;
}
{
const std::lock_guard<std::mutex> lk(mutex_);
isDeduplicate = Deduplicate(op, opAttr, entry, seed, false);
if (!isDeduplicate) {
ACL_REQUIRES_OK(AddMemAndAging(entry, agingT, seed));
}
ACL_LOG_INFO("AclOpMap::Insert success, seed = %zu, op = %s", seed, op.DebugString().c_str());
}
return ACL_SUCCESS;
}
template<typename T>
aclError AclOpMap<T>::InsertDynamic(const AclOp &op, const T &entry, T &agingT, bool &isDeduplicate)
{
ACL_LOG_DEBUG("AclOpMap::Insert IN, op = %s", op.DebugString().c_str());
size_t digest = 0U;
auto opAttr = op.opAttr;
aclopAttr emptyAttr;
if (opAttr != nullptr) {
if (!attr_utils::SaveConstToAttr(op, const_cast<aclopAttr *>(opAttr))) {
ACL_LOG_ERROR("[Save][ConstData]save const data buffer to attr fail");
return ACL_ERROR_FAILURE;
}
digest = attr_utils::AttrMapToDigest(opAttr->Attrs());
} else {
if (!attr_utils::SaveConstToAttr(op, &emptyAttr)) {
ACL_LOG_ERROR("[Save][ConstData]save const data buffer to attr fail");
return ACL_ERROR_FAILURE;
}
digest = attr_utils::AttrMapToDigest(emptyAttr.Attrs());
opAttr = &emptyAttr;
}
size_t seed = 0U;
if (hash_utils::GetAclOpHashDynamic(op, opAttr, digest, seed, entry->seq) != ACL_SUCCESS) {
ACL_LOG_ERROR("[Check][GetAclOpHash]GetAclOpHash failed, seed = %zu, op = %s",
seed, op.DebugString().c_str());
return ACL_ERROR_FAILURE;
}
{
const std::lock_guard<std::mutex> lk(mutex_);
isDeduplicate = Deduplicate(op, opAttr, entry, seed, true);
if (!isDeduplicate) {
ACL_REQUIRES_OK(AddMemAndAging(entry, agingT, seed));
}
ACL_LOG_INFO("AclOpMap::Insert success, seed = %zu, op = %s", seed, op.DebugString().c_str());
}
return ACL_SUCCESS;
}
template<typename T>
aclError AclOpMap<T>::Get(const AclOp &op, T &entry, const bool needUpdateTimestamp)
{
auto opAttr = op.opAttr;
size_t digest = 0U;
aclopAttr emptyAttr;
if (opAttr != nullptr) {
digest = op.opAttr->GetDigest();
if (!attr_utils::SaveConstToAttr(op, const_cast<aclopAttr *>(opAttr))) {
ACL_LOG_ERROR("[Save][ConstData]save const data buffer to attr fail");
return ACL_ERROR_INVALID_PARAM;
}
} else {
if (!attr_utils::SaveConstToAttr(op, &emptyAttr)) {
ACL_LOG_ERROR("[Save][ConstData]save const data buffer to attr fail");
return ACL_ERROR_INVALID_PARAM;
}
opAttr = &emptyAttr;
}
size_t seed = 0U;
ACL_REQUIRES_OK(hash_utils::GetAclOpHash(op, opAttr, digest, seed));
const std::lock_guard<std::mutex> lk(mutex_);
const auto iter = hashMap_.find(seed);
if (iter == hashMap_.end()) {
return ACL_ERROR_OP_NOT_FOUND;
}
for (auto modelVecIter = iter->second.begin(); modelVecIter != iter->second.end(); ++ modelVecIter) {
if (CheckValueRange(op, *modelVecIter)) {
if (hash_utils::CheckModelAndAttrMatch(op, opAttr, *modelVecIter)) {
ACL_LOG_INFO("Get aclOp from aclOpMap success! seed = %zu, aclOp = %s", seed,
op.DebugString().c_str());
entry = *modelVecIter;
if (needUpdateTimestamp) {
Updatetimestamp(*modelVecIter);
}
return ACL_SUCCESS;
}
}
}
ACL_LOG_DEBUG("Get aclOp from aclOpMap failed due to CheckValueRange failed! seed = %zu, aclOp = %s",
seed, op.DebugString().c_str());
return ACL_ERROR_OP_NOT_FOUND;
}
template<typename T>
aclError AclOpMap<T>::GetDynamic(const AclOp &op, T &entry, const uint64_t seq, const bool needUpdateTimestamp)
{
auto opAttr = op.opAttr;
size_t digest = 0U;
aclopAttr emptyAttr;
if (opAttr != nullptr) {
digest = op.opAttr->GetDigest();
if (!attr_utils::SaveConstToAttr(op, const_cast<aclopAttr *>(opAttr))) {
ACL_LOG_ERROR("[Save][ConstData]save const data buffer to attr fail");
return ACL_ERROR_INVALID_PARAM;
}
} else {
if (!attr_utils::SaveConstToAttr(op, &emptyAttr)) {
ACL_LOG_ERROR("[Save][ConstData]save const data buffer to attr fail");
return ACL_ERROR_INVALID_PARAM;
}
opAttr = &emptyAttr;
}
size_t seed = 0U;
ACL_REQUIRES_OK(hash_utils::GetAclOpHashDynamic(op, opAttr, digest, seed, seq));
const std::lock_guard<std::mutex> lk(mutex_);
const auto iter = hashMap_.find(seed);
if (iter == hashMap_.end()) {
return ACL_ERROR_OP_NOT_FOUND;
}
for (auto modelVecIter = iter->second.begin(); modelVecIter != iter->second.end(); ++ modelVecIter) {
if (CheckValueRange(op, *modelVecIter)) {
if (hash_utils::CheckModelAndAttrMatchDynamic(op, opAttr, *modelVecIter, seq)) {
ACL_LOG_INFO("Get aclOp from aclOpMap success! seed = %zu, aclOp = %s", seed,
op.DebugString().c_str());
entry = *modelVecIter;
if (needUpdateTimestamp) {
Updatetimestamp(*modelVecIter);
}
return ACL_SUCCESS;
}
}
}
ACL_LOG_DEBUG("Get aclOp from aclOpMap failed due to CheckValueRange failed! seed = %zu, aclOp = %s",
seed, op.DebugString().c_str());
return ACL_ERROR_OP_NOT_FOUND;
}
}
#endif