* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/cache_policy/compile_cache_desc.h"
#include <securec.h>
#include "common/checker.h"
#include "debug/ge_util.h"
namespace ge {
constexpr int32_t kAllShape = -2;
BinaryHolder::BinaryHolder(const BinaryHolder &other) {
if ((other.GetDataPtr() != nullptr) && (other.GetDataLen() != 0UL)) {
data_len_ = other.GetDataLen();
holder_ = ComGraphMakeUnique<uint8_t[]>(data_len_);
GE_CHECK_NOTNULL_JUST_RETURN(holder_);
const auto mem_ret = memcpy_s(holder_.get(), data_len_,
ge::PtrToPtr<const uint8_t, const void>(other.GetDataPtr()), data_len_);
if (mem_ret != EOK) {
data_len_ = 0U;
holder_ = nullptr;
GELOGE(ge::GRAPH_FAILED, "[BinaryHolder] memcpy failed.");
}
}
}
BinaryHolder::BinaryHolder(const uint8_t *const data, const size_t data_len) {
if ((data != nullptr) && (data_len != 0UL)) {
data_len_ = data_len;
holder_ = ComGraphMakeUnique<uint8_t[]>(data_len_);
GE_CHECK_NOTNULL_JUST_RETURN(holder_);
const auto mem_ret = memcpy_s(holder_.get(), data_len_,
ge::PtrToPtr<const uint8_t, const void>(data), data_len_);
if (mem_ret != EOK) {
data_len_ = 0U;
holder_ = nullptr;
GELOGE(ge::GRAPH_FAILED, "[BinaryHolder] memcpy failed.");
}
}
}
BinaryHolder &BinaryHolder::operator=(const BinaryHolder &other) {
if ((other.GetDataPtr() != nullptr) && (other.GetDataLen() != 0UL)) {
data_len_ = other.GetDataLen();
holder_ = ComGraphMakeUnique<uint8_t[]>(data_len_);
if (holder_ == nullptr) {
GELOGE(ge::GRAPH_FAILED, "[BinaryHolder] make unique failed.");
return *this;
}
const auto mem_ret = memcpy_s(holder_.get(), data_len_,
ge::PtrToPtr<const uint8_t, const void>(other.GetDataPtr()), data_len_);
if (mem_ret != EOK) {
data_len_ = 0U;
holder_ = nullptr;
GELOGE(ge::GRAPH_FAILED, "[BinaryHolder] memcpy failed.");
}
}
return *this;
}
BinaryHolder::BinaryHolder(BinaryHolder &&other) {
data_len_ = other.data_len_;
holder_ = std::move(other.holder_);
other.data_len_ = 0;
}
BinaryHolder &BinaryHolder::operator=(BinaryHolder &&other) {
data_len_ = other.data_len_;
holder_ = std::move(other.holder_);
other.data_len_ = 0;
return *this;
}
std::unique_ptr<BinaryHolder> BinaryHolder::createFrom(std::unique_ptr<uint8_t[]> &&ptr, size_t length) {
auto holder = ComGraphMakeUnique<BinaryHolder>();
if ((ptr != nullptr) && (holder != nullptr) && (length != 0UL)) {
holder->data_len_ = length;
holder->holder_ = std::move(ptr);
}
return holder;
}
const uint8_t *BinaryHolder::GetDataPtr() const noexcept {
if (holder_.get() != nullptr) {
return holder_.get();
}
return nullptr;
}
const size_t &BinaryHolder::GetDataLen() const noexcept {
return data_len_;
}
bool BinaryHolder::operator!=(const BinaryHolder &second) const {
if (this->GetDataLen() != second.GetDataLen()) {
return true;
}
const auto this_data = this->GetDataPtr();
const auto second_data = second.GetDataPtr();
if (((this_data == nullptr) && (second_data != nullptr)) ||
((this_data != nullptr) && (second_data == nullptr))) {
return true;
}
if ((this_data == nullptr) && (second_data == nullptr)) {
return false;
}
if (memcmp(this_data, second_data, this->GetDataLen()) != 0) {
return true;
}
return false;
}
Format TensorInfoArgs::GetFormat() const {
return format_;
}
Format TensorInfoArgs::GetOriginFormat() const {
return origin_format_;
}
DataType TensorInfoArgs::GetDataType() const {
return data_type_;
}
void TensorInfoArgs::SetShape(const std::vector<int64_t> &shape) {
shape_.clear();
for (const auto dim : shape) {
shape_.emplace_back(dim);
}
}
void TensorInfoArgs::SetShape(const SmallVector<int64_t, kDefaultDimsNum> &shape) {
shape_.clear();
shape_ = shape;
}
void TensorInfoArgs::SetOriginShape(const std::vector<int64_t> &origin_shape) {
origin_shape_.clear();
for (const auto dim : origin_shape) {
origin_shape_.emplace_back(dim);
}
}
void TensorInfoArgs::SetOriginShape(const SmallVector<int64_t, kDefaultDimsNum> &origin_shape) {
origin_shape_.clear();
origin_shape_ = origin_shape;
}
void TensorInfoArgs::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &ranges) {
shape_range_.clear();
for (const auto &range : ranges) {
shape_range_.emplace_back(range);
}
}
bool TensorInfoArgs::IsUnknownShape() const {
return std::any_of(shape_.begin(), shape_.end(), [](const int64_t &dim) {
return (dim == UNKNOWN_DIM) || (dim == UNKNOWN_DIM_NUM);
});
}
bool TensorInfoArgs::operator!=(const TensorInfoArgs &second) const {
const bool ret = (this->format_ != second.format_) || (this->origin_format_ != second.origin_format_) ||
(this->data_type_ != second.data_type_) || (this->shape_ != second.shape_) ||
(this->origin_shape_ != second.origin_shape_) || (this->shape_range_ != second.shape_range_);
return ret;
}
bool TensorInfoArgs::IsTensorInfoMatch(const TensorInfoArgs &other) const {
const bool is_same = (this->format_ == other.format_) && (this->origin_format_ == other.origin_format_) &&
(this->data_type_ == other.data_type_);
if (!is_same) {
GELOGD("format or origin format or datatype is not matched");
return false;
}
return IsShapeInRange(other);
}
bool TensorInfoArgs::IsShapeInRange(const TensorInfoArgs &other) const {
if ((this->shape_.size() == 1U) && (this->shape_[0] == kAllShape)) {
GELOGD("current shape is -2");
return true;
}
const bool is_same_rank = (this->shape_.size() == other.shape_.size()) &&
(this->origin_shape_.size() == other.origin_shape_.size());
if (!is_same_rank) {
GELOGD("shape or origin shape is not same rank");
return false;
}
if (this->IsUnknownShape()) {
if (this->shape_.size() != this->shape_range_.size()) {
GELOGD("shape size %zu is not match shape rang size %zu", this->shape_.size(), this->shape_range_.size());
return false;
}
for (size_t i = 0U; i < this->shape_range_.size(); ++i) {
if (this->shape_range_[i].first > other.shape_[i]) {
GELOGD("shape range is not match, first is %" PRId64 ", other is %" PRId64 ", index is %zu",
this->shape_range_[i].first, other.shape_[i], i);
return false;
}
if (this->shape_range_[i].second == UNKNOWN_DIM) {
GELOGD("shape second is -1, index is %zu", i);
continue;
}
if (this->shape_range_[i].second < other.shape_[i]) {
GELOGD("shape range is not match, second is %" PRId64 ", other is %" PRId64 ", index is %zu",
this->shape_range_[i].second, other.shape_[i], i);
return false;
}
}
} else {
GELOGD("this is exact shape");
if ((this->shape_ != other.shape_) || (this->origin_shape_ != other.origin_shape_)) {
GELOGD("exact shape or origin shape is not matched");
return false;
}
}
return true;
}
size_t CompileCacheDesc::GetTensorInfoSize() {
return tensor_info_args_vec_.size();
}
TensorInfoArgs *CompileCacheDesc::MutableTensorInfo(size_t index) {
if (index >= tensor_info_args_vec_.size()) {
return nullptr;
}
return &tensor_info_args_vec_[index];
}
void CompileCacheDesc::AddBinary(const BinaryHolder &holder) {
other_desc_.emplace_back(holder);
}
void CompileCacheDesc::AddBinary(BinaryHolder &&holder) {
other_desc_.emplace_back(holder);
}
void CompileCacheDesc::SetOpType(const std::string &op_type) {
op_type_ = op_type;
return;
}
void CompileCacheDesc::AddTensorInfo(const TensorInfoArgs &tensor_info) {
tensor_info_args_vec_.emplace_back(tensor_info);
return;
}
void CompileCacheDesc::SetScopeId(const std::initializer_list<uint64_t> scope_id) {
scope_id_= scope_id;
return;
}
bool CompileCacheDesc::CheckWithoutTensorInfo(const CompileCacheDesc *first, const CompileCacheDesc *second) const {
if ((first->op_type_ != second->op_type_) ||
(first->tensor_info_args_vec_.size() != second->tensor_info_args_vec_.size())) {
GELOGD("op_type_ %s, %s is not match or size %zu %zu is not match",
first->op_type_.c_str(), second->op_type_.c_str(),
first->tensor_info_args_vec_.size(), second->tensor_info_args_vec_.size());
return false;
}
if (first->scope_id_ != second->scope_id_) {
GELOGD("scope id is not match");
return false;
}
if (first->other_desc_.size() != second->other_desc_.size()) {
GELOGD("other_desc_ size %zu, %zu is not match ", first->other_desc_.size(), second->other_desc_.size());
return false;
}
for (size_t i = 0U; i < first->other_desc_.size(); ++i) {
if (first->other_desc_[i].GetDataLen() != second->other_desc_[i].GetDataLen()) {
GELOGD("other_desc_ mem size %zu, %zu is not match ",
first->other_desc_[i].GetDataLen(), second->other_desc_[i].GetDataLen());
return false;
}
if ((first->other_desc_[i].GetDataPtr() == nullptr) || (second->other_desc_[i].GetDataPtr() == nullptr)) {
return false;
}
const auto cmp_ret = memcmp(first->other_desc_[i].GetDataPtr(),
second->other_desc_[i].GetDataPtr(), second->other_desc_[i].GetDataLen());
if (cmp_ret != 0) {
GELOGD("mem compare fail");
return false;
}
}
return true;
}
bool CompileCacheDesc::IsMatch(const CacheDescPtr &other) const {
const auto *second = dynamic_cast<const CompileCacheDesc *>(other.get());
GE_ASSERT_NOTNULL(second, "dynamic cast failed");
if (!CheckWithoutTensorInfo(this, second)) {
return false;
}
for (size_t i = 0U; i < this->tensor_info_args_vec_.size(); ++i) {
const auto &first_args = this->tensor_info_args_vec_[i];
const auto &second_args = second->tensor_info_args_vec_[i];
if (!first_args.IsTensorInfoMatch(second_args)) {
GELOGD("shape is not matched");
return false;
}
}
return true;
}
bool CompileCacheDesc::IsEqual(const CacheDescPtr &other) const {
const auto *second = dynamic_cast<const CompileCacheDesc *>(other.get());
GE_ASSERT_NOTNULL(second, "dynamic cast failed");
if (!CheckWithoutTensorInfo(this, second)) {
return false;
}
for (size_t i = 0U; i < this->tensor_info_args_vec_.size(); ++i) {
const auto &first_args = this->tensor_info_args_vec_[i];
const auto &second_args = second->tensor_info_args_vec_[i];
if (first_args != second_args) {
GELOGD("tensor info is not matched");
return false;
}
}
return true;
}
CacheHashKey CompileCacheDesc::GetCacheDescHash() const {
CacheHashKey hash_key = 0U;
for (const auto &arg : tensor_info_args_vec_) {
hash_key = HashUtils::MultiHash(hash_key, arg.GetFormat(), arg.GetOriginFormat(), arg.GetDataType());
}
hash_key = HashUtils::MultiHash(op_type_, hash_key);
return hash_key;
}
}