* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INC_GRAPH_UTILS_MEM_UTILS_H_
#define INC_GRAPH_UTILS_MEM_UTILS_H_
#include <memory>
#include <type_traits>
#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h"
namespace ge {
template<typename T>
void CheckAscTensorAttr(T &) {
static_assert(std::is_same<T, AscTensorAttr>::value, "Expected AscTensorAttr type");
}
class TQueConfig {
friend class MemUtils;
public:
template<typename... Args>
TQueConfig &BindTensors(Args &&...tensors) {
int dummy[] = { (CheckAscTensorAttr(std::forward<Args>(tensors)), 0)... };
(void)dummy;
int dummy1[] = { (tensors.que = queue_attr_, 0)... };
(void)dummy1;
int dummy2[] = { (tensors.buf.id = kIdNone, 0)... };
(void)dummy2;
int dummy3[] = { (tensors.mem.position = pos_, 0)... };
(void)dummy3;
int dummy4[] = { (tensors.mem.alloc_type = AllocType::kAllocTypeQueue, 0)... };
(void)dummy4;
return *this;
}
TQueConfig() = default;
private:
TQueConfig(const int64_t id, const ge::Position pos, const int64_t depth, const int64_t buf_num);
MemQueAttr queue_attr_{};
ge::Position pos_{ge::Position::kPositionInvalid};
};
class TBufConfig {
friend class MemUtils;
public:
template<typename... Args>
TBufConfig &BindTensors(Args &&...tensors) {
int dummy[] = { (CheckAscTensorAttr(std::forward<Args>(tensors)), 0)... };
(void)dummy;
int dummy1[] = { (tensors.buf = buf_attr_, 0)... };
(void)dummy1;
int dummy2[] = { (tensors.que.id = kIdNone, 0)... };
(void)dummy2;
int dummy3[] = { (tensors.mem.position = pos_, 0)... };
(void)dummy3;
int dummy4[] = { (tensors.mem.alloc_type = AllocType::kAllocTypeBuffer, 0)... };
(void)dummy4;
return *this;
}
TBufConfig() = default;
private:
TBufConfig(const int64_t id, const ge::Position pos);
MemBufAttr buf_attr_;
ge::Position pos_;
};
class MemUtils {
public:
static TQueConfig CreateTQueConfig(const ge::Position pos, const int64_t depth, const int64_t buf_num);
static TBufConfig CreateTBufConfig(const ge::Position pos);
template<typename... Args>
static void MergeScope(Args &&...tensors) {
int dummy[] = { (CheckAscTensorAttr(std::forward<Args>(tensors)), 0)... };
(void)dummy;
int dummy1[] = { (tensors.opt.merge_scope = scope_id_, 0)... };
(void)dummy1;
scope_id_++;
}
private:
static std::atomic<int64_t> gen_container_id_;
static std::atomic<int64_t> scope_id_;
};
}
#endif