* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "tiling_cache_code_gen.h"
#include "common/code_printer.h"
namespace att {
namespace cache {
void TilingCacheCodeGen::GenConstantDefs(ge::CodePrinter &code_printer, size_t input_vars_size) {
code_printer.AddLine("// ATT缓存相关常量");
code_printer.AddLine("constexpr size_t kInputShapeSize = " + std::to_string(input_vars_size) + ";");
code_printer.AddLine("constexpr size_t kOperatorCacheCapacity = 24; // 算子级缓存容量");
code_printer.AddLine("constexpr double kLoadFactorThreshold = 0.8; // 负载因子阈值");
code_printer.AddLine("");
}
std::string TilingCacheCodeGen::GenHashMapTemplate() {
std::stringstream ss;
ss << "template <size_t KEY_SIZE, size_t CAPACITY, typename VALUE_TYPE>\n";
ss << "class FixedSizeHashMap {\n";
ss << GenHashMapClassStructure();
ss << GenHashMapConstructor();
ss << GenHashMapPublicMethods();
ss << "};\n";
return ss.str();
}
std::string TilingCacheCodeGen::GenHashMapClassStructure() {
std::stringstream ss;
ss << "private:\n";
ss << " using Key = std::array<uint32_t, KEY_SIZE>;\n";
ss << " using Value = VALUE_TYPE;\n";
ss << "\n";
ss << " enum BucketState { kEmpty, kOccupied, kDeleted };\n";
ss << " struct Bucket {\n";
ss << " Key key;\n";
ss << " Value value;\n";
ss << " BucketState state;\n";
ss << " };\n";
ss << "\n";
ss << " std::array<Bucket, CAPACITY> buckets;\n";
ss << " size_t size_ = 0;\n";
ss << "\n";
ss << " // Hash - 大驼峰命名\n";
ss << GenHashFunction();
ss << "\n";
ss << " // FindIndex - 大驼峰命名\n";
ss << GenFindIndexFunction();
ss << "\n";
return ss.str();
}
std::string TilingCacheCodeGen::GenHashMapConstructor() {
std::stringstream ss;
ss << "public:\n";
ss << " FixedSizeHashMap() : size_(0) {\n";
ss << " for (size_t i = 0; i < CAPACITY; ++i) {\n";
ss << " buckets[i].state = kEmpty;\n";
ss << " }\n";
ss << " }\n";
ss << "\n";
return ss.str();
}
std::string TilingCacheCodeGen::GenFindMethod() {
std::stringstream ss;
ss << " // Find - 大驼峰命名\n";
ss << " Value* Find(const Key &key) {\n";
ss << " size_t index = FindIndex(key);\n";
ss << " if (index < CAPACITY && buckets[index].state == kOccupied) {\n";
ss << " return &buckets[index].value;\n";
ss << " }\n";
ss << " return nullptr;\n";
ss << " }\n";
ss << "\n";
ss << " const Value* Find(const Key &key) const {\n";
ss << " return const_cast<FixedSizeHashMap*>(this)->Find(key);\n";
ss << " }\n";
return ss.str();
}
std::string TilingCacheCodeGen::GenInsertMethod() {
std::stringstream ss;
ss << " // Insert - 大驼峰命名\n";
ss << " bool Insert(const Key &key, const Value &value) {\n";
ss << " if (size_ >= CAPACITY * kLoadFactorThreshold) {\n";
ss << " return false; // 80%容量阈值\n";
ss << " }\n";
ss << " size_t index = FindIndex(key);\n";
ss << " if (index >= CAPACITY) {\n";
ss << " size_t hash = Hash(key) % CAPACITY;\n";
ss << " for (size_t i = 0; i < CAPACITY; ++i) {\n";
ss << " index = (hash + i) % CAPACITY;\n";
ss << " if (buckets[index].state == kEmpty) {\n";
ss << " buckets[index].key = key;\n";
ss << " buckets[index].value = value;\n";
ss << " buckets[index].state = kOccupied;\n";
ss << " size_++;\n";
ss << " return true;\n";
ss << " }\n";
ss << " }\n";
ss << " return false;\n";
ss << " }\n";
ss << " buckets[index].value = value;\n";
ss << " return true;\n";
ss << " }\n";
return ss.str();
}
std::string TilingCacheCodeGen::GenEraseMethod() {
std::stringstream ss;
ss << " // Erase - 大驼峰命名\n";
ss << " bool Erase(const Key &key) {\n";
ss << " size_t index = FindIndex(key);\n";
ss << " if (index < CAPACITY && buckets[index].state == kOccupied) {\n";
ss << " buckets[index].state = kDeleted;\n";
ss << " size_--;\n";
ss << " return true;\n";
ss << " }\n";
ss << " return false;\n";
ss << " }\n";
return ss.str();
}
std::string TilingCacheCodeGen::GenClearAndSizeMethods() {
std::stringstream ss;
ss << " // Clear - 大驼峰命名\n";
ss << " void Clear() {\n";
ss << " for (auto& bucket : buckets) {\n";
ss << " bucket.state = kEmpty;\n";
ss << " }\n";
ss << " size_ = 0;\n";
ss << " }\n";
ss << "\n";
ss << " size_t Size() const { return size_; }\n";
ss << " bool Empty() const { return size_ == 0; }\n";
return ss.str();
}
std::string TilingCacheCodeGen::GenHashMapPublicMethods() {
std::stringstream ss;
ss << GenFindMethod();
ss << "\n";
ss << GenInsertMethod();
ss << "\n";
ss << GenEraseMethod();
ss << "\n";
ss << GenClearAndSizeMethods();
return ss.str();
}
std::string TilingCacheCodeGen::GenHashFunction() {
std::stringstream ss;
ss << " size_t Hash(const Key &key) const {\n";
ss << " size_t hash = 0;\n";
ss << " for (const auto& value : key) {\n";
ss << " constexpr uint32_t kHashPrime = 0x9e3779b9; // 黄金比例的整数表示,用于hash混合\n";
ss << " hash ^= value + kHashPrime + (hash << 6) + (hash >> 2);\n";
ss << " }\n";
ss << " return hash;\n";
ss << " }\n";
return ss.str();
}
std::string TilingCacheCodeGen::GenFindIndexFunction() {
std::stringstream ss;
ss << " size_t FindIndex(const Key &key) const {\n";
ss << " size_t hash = Hash(key) % CAPACITY;\n";
ss << " size_t start = hash;\n";
ss << " do {\n";
ss << " if (buckets[hash].state == kEmpty) {\n";
ss << " return CAPACITY;\n";
ss << " } else if (buckets[hash].state == kOccupied && buckets[hash].key == key) {\n";
ss << " return hash;\n";
ss << " }\n";
ss << " hash = (hash + 1) % CAPACITY;\n";
ss << " } while (hash != start);\n";
ss << " return CAPACITY;\n";
ss << " }\n";
return ss.str();
}
}
}