* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef BASE_COMMON_OM2_CODEGEN_AST_AST_CONTEXT_H
#define BASE_COMMON_OM2_CODEGEN_AST_AST_CONTEXT_H
#include <vector>
#include <cstdint>
#include "securec.h"
#include "common/checker.h"
namespace ge {
class AstNodePool {
public:
AstNodePool() = default;
~AstNodePool();
AstNodePool(const AstNodePool &) = delete;
AstNodePool &operator=(const AstNodePool &) = delete;
uint8_t *Allocate(const size_t mem_size);
size_t GetMemoryUsage() const {
return total_mem_usage_;
}
private:
size_t CreateNewBlock(const size_t min_size);
static size_t AlignTo(const size_t size, const size_t alignment) {
return (size + alignment - 1) & ~(alignment - 1);
}
private:
struct Block {
uint8_t *data;
size_t length;
size_t offset;
};
std::vector<Block> blocks_;
size_t total_mem_usage_ = 0UL;
};
class StringRef {
public:
StringRef() : data_(nullptr), length_(0UL) {}
explicit StringRef(const char_t *str) : data_(str), length_(str ? std::strlen(str) : 0UL) {}
StringRef(const char_t *str, const size_t len) : data_(str), length_(len) {}
const char_t *Data() const {
return data_;
}
size_t Length() const {
return length_;
}
bool Empty() const {
return length_ == 0;
}
bool operator==(const StringRef &other) const {
if (length_ != other.length_) {
return false;
}
if (length_ == 0) {
return true;
}
if (data_ == nullptr || other.data_ == nullptr) {
return false;
}
return std::memcmp(data_, other.data_, length_) == 0;
}
private:
const char_t *data_;
size_t length_;
};
template <typename T>
class ArrayRef {
public:
using iterator = const T *;
ArrayRef() : data_(nullptr), length_(0) {}
ArrayRef(const T *start, const size_t len) : data_(start), length_(len) {}
const T *Data() const {
return data_;
}
size_t Size() const {
return length_;
}
const T &operator[](size_t index) const {
if (index >= length_) {
throw std::out_of_range("Index out of range");
}
return data_[index];
}
bool Empty() const {
return length_ == 0;
}
protected:
const T *data_;
size_t length_;
};
template <typename T>
class MutableArrayRef : public ArrayRef<T> {
public:
MutableArrayRef(T *mutable_data, size_t length) : ArrayRef<T>(mutable_data, length) {}
T &operator[](size_t index) const {
if (index >= this->length_) {
throw std::out_of_range("Index out of range");
}
return const_cast<T &>(this->data_[index]);
}
};
class AstContext {
public:
AstContext() = default;
~AstContext() = default;
AstContext(const AstContext &) = delete;
AstContext &operator=(const AstContext &) = delete;
AstContext(AstContext &&) = delete;
AstContext &operator=(AstContext &&) = delete;
void *Allocate(const size_t size) {
return node_pool_.Allocate(size);
}
size_t GetMemoryUsage() const {
return node_pool_.GetMemoryUsage();
}
template <typename T>
MutableArrayRef<T> AllocateMutableArray(const size_t count) {
static_assert(std::is_pointer<T>::value || std::is_same<std::decay_t<T>, StringRef>::value,
"ArrayRef element type must be a pointer or StringRef.");
static_assert(std::is_trivially_destructible<T>::value, "Array elements T must be trivially destructible");
if (count == 0) {
return MutableArrayRef<T>(nullptr, 0);
}
void *mem = Allocate(count * sizeof(T));
if (mem == nullptr) {
return MutableArrayRef<T>(nullptr, 0);
}
T *data = static_cast<T *>(mem);
for (size_t i = 0; i < count; ++i) {
(void)new (&data[i]) T();
}
return MutableArrayRef<T>(data, count);
}
StringRef CopyString(const char_t *s) {
if (!s) {
return {};
}
const size_t len = std::strlen(s);
if (len == 0UL) {
return StringRef("", 0UL);
}
const auto dest = static_cast<char_t *>(Allocate(len + 1));
if (dest == nullptr) {
return {};
}
GE_ASSERT_EOK(memcpy_s(dest, len, s, len));
dest[len] = '\0';
return {dest, len};
}
private:
AstNodePool node_pool_;
};
}
#endif