* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "group_level_cache_gen.h"
#include "common/code_printer.h"
namespace att {
namespace cache {
ge::Status GroupLevelCacheGen::GenFixedSizeHashMapDef(ge::CodePrinter &code_printer) {
std::string hashmap_code = GenHashMapTemplate();
code_printer.AddLine(hashmap_code);
return ge::SUCCESS;
}
ge::Status GroupLevelCacheGen::GenGroupCacheTypes(ge::CodePrinter &code_printer,
size_t cache_capacity) {
code_printer.AddLine("using GroupLevelCache = FixedSizeHashMap<kInputShapeSize, " +
std::to_string(cache_capacity) + ", TilingDataCopy>;");
code_printer.AddLine("");
return ge::SUCCESS;
}
ge::Status GroupLevelCacheGen::GenGroupCacheFunctions(ge::CodePrinter &code_printer,
const std::string &tiling_data_type_name) {
std::string cache_decl = R"(
// 第二级:Group间缓存(通过参数传递)
static inline bool FindGroupCache(const std::array<uint32_t, kInputShapeSize> &key,
)" + tiling_data_type_name + R"(& tiling_data,
GroupLevelCache &group_level_cache) {
auto *result = group_level_cache.Find(key);
if (result != nullptr) {
OP_LOGI(OP_NAME, "[Group Cache] HIT!key[%s]", [&key]()->std::string {
std::string out;
for (auto axis : key) {
out.append(std::to_string(axis));
}
return out;
}.operator()().c_str());
GetScheduleGroupTilingData(*result, tiling_data);
return true;
}
OP_LOGI(OP_NAME, "[Group Cache] MISS! key=[%s]", [&key]()->std::string {
std::string out;
for (auto axis : key) {
out.append(std::to_string(axis));
}
return out;
}.operator()().c_str());
return false;
}
// 保存到Group级缓存
static inline bool SaveGroupCache(const std::array<uint32_t, kInputShapeSize>& key,
const TilingDataCopy& data,
GroupLevelCache &group_level_cache) {
bool success = group_level_cache.Insert(key, data);
OP_LOGI(OP_NAME, "[Group Cache] SAVE %s: key=[%s], tiling_key=%u\n",
success ? "SUCCESS" : "FAILED", [&key]()->std::string {
std::string out;
for (auto axis : key) {
out.append(std::to_string(axis));
}
return out;
}.operator()().c_str(), data.tiling_key);
return success;
}
)";
code_printer.AddLine(cache_decl);
return ge::SUCCESS;
}
}
}