/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_
#define GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_

#include <iostream>
#include <unordered_set>
#include <map>

#include "graph/node.h"
#include "graph/manager/block_memory.h"
#include "graph/manager/graph_mem_allocator.h"
#include "runtime/mem.h"
#include "acl/acl_rt.h"

namespace ge {
constexpr size_t kRoundBlockSize = 512U;         // all block sizes are rounded to at least 512 bytes
constexpr size_t kBinSizeUnit8 = 8U;
constexpr size_t kBinSizeUnit32 = 32U;
constexpr size_t kBinSizeUnit128 = 128U;
constexpr size_t kBinSizeUnit256 = 256U;
constexpr size_t kBinSizeUnit512 = 512U;

constexpr float64_t kSplitThreshold = 0.5;         // split when malloc size <= small block size * kSpliThreshold
constexpr size_t kKByteSize = 1024U;
constexpr size_t kMByteSize = 1048576U;   // 1024 * 1024
constexpr size_t kGByteSize = 1073741824U;   // 1024 * 1024 * 1024

constexpr uint32_t kNumBins = 7U;
enum class GeLogLevel : int32_t {
  kDebug = 0,
  kInfo = 1,
  kWarn = 2,
  kError = 3,
  kNull = 4,
  kEvent = 10
};

class CachingAllocator {
 public:
  explicit CachingAllocator(const rtMemType_t memory_type);

  CachingAllocator(const CachingAllocator &) = delete;

  CachingAllocator &operator=(const CachingAllocator &) & = delete;

  virtual ~CachingAllocator() = default;

  /// @ingroup ge_graph
  /// @brief caching allocator init
  /// @param [in] device id
  /// @return Status of init
  Status Initialize(const uint32_t device_id = 0U);

  /// @ingroup ge_graph
  /// @brief memory allocator finalize, release cached memory
  /// @return void
  void Finalize();

  /// @ingroup ge_graph
  /// @brief malloc memory
  /// @param [in] size memory size
  /// @param [in] try to reuse the same memory
  /// @param [in] device id
  /// @return  memory address
  uint8_t *Malloc(size_t size, uint8_t *const org_ptr = nullptr, const uint32_t device_id = 0U);

  /// @ingroup ge_graph
  /// @brief free memory
  /// @param [in] memory_ptr memory address ptr
  /// @param [in] device_id device id
  /// @return Status result of function
  Status Free(uint8_t *const memory_addr, const uint32_t device_id = 0U);

  /// @ingroup ge_graph
  /// @brief try to free memory when no memory is referenced
  /// @return void
  void TryFreeBlocks();

  /// @ingroup ge_graph
  /// @brief try to free memory after stream synchronize
  /// @return void
  Status FreeBlocksAfterSynchronize(aclrtStream const stream);

  /// @ingroup ge_graph
  /// @brief Set whether the allocator is binding to a stream
  /// @return void
  void SetBindStream(const bool bind_stream);
 private:

  /// @ingroup ge_graph
  /// @brief extend cache by size
  /// @param [in] memory size
  /// @param [in] device id
  /// @return Status result of function
  Status TryExtendCache(const size_t size, const uint32_t device_id);

  /// @ingroup ge_graph
  /// @brief find free block by size
  /// @param [in] memory size
  /// @param [in] device_id device id
  /// @return block ptr
  Block *FindFreeBlock(const size_t size, uint8_t *const org_ptr, const uint32_t device_id);

  /// @ingroup ge_graph
  /// @brief get the right bin based on size
  /// @param [in] original malloc size
  /// @return block bin
  BlockBin *GetBlockBin(const size_t size) const;

  /// @ingroup ge_graph
  /// @brief add memory to right bin based on size
  /// @param [in] memory ptr
  /// @param [in] memory size
  /// @param [in] device_id device id
  /// @return Status result of function
  Status AddToBlockBin(uint8_t *const ptr, const size_t size, const uint32_t device_id);

  /// @ingroup ge_graph
  /// @brief free block to right bin
  /// @param [in] block ptr
  /// @return void
  void FreeBlock(Block *const block) const;

  /// @ingroup ge_graph
  /// @brief free all cached blocks to right bin and release the memory when memory is not enough
  /// @return free cached memory size
  size_t FreeCachedBlocks();

  /// @ingroup ge_graph
  /// @brief free allocated and cached blocks and release the memory when process exit
  /// @return void
  void FreeBlocks();

  /// @ingroup ge_graph
  /// @brief free block bins when process exit
  /// @return void
  void FreeBlockBins();

  /// @ingroup ge_graph
  /// @brief If a split block is freed, try merging with the original block
  /// @param [inout] dest block ptr
  /// @param [in] src block ptr
  /// @param [out] block bin
  /// @return void
  void MergeBlocks(Block *const dst, Block *const src, BlockBin &bin) const;

  /// @ingroup ge_graph
  /// @brief If the allocated memory size is too much smaller than the memory block, try to split the memory block
  /// @param [in] original block ptr
  /// @param [in] allocated memory size
  /// @param [in] block bin
  /// @param [in] device id
  /// @return splited block ptr
  Block *SplitBlock(Block &block, const size_t size, BlockBin &bin, const uint32_t device_id) const;

  /// @ingroup ge_graph
  /// @brief print the memory info in pool
  /// @param [in] log level
  /// @return void
  void PrintStatics(const GeLogLevel ge_log_level = GeLogLevel::kInfo);

 private:
  rtMemType_t memory_type_;
  bool bind_stream_ = false;

  // device memory allocator
  MemoryAllocator *memory_allocator_ = nullptr;

  // lock around all operations
  mutable std::recursive_mutex mutex_;

  // allocated blocks by memory pointer
  std::unordered_map<uint8_t *, Block *> allocated_blocks_;

  // block bins by different block size
  BlockBin *free_block_bins_[kNumBins] = {};

  // malloced memorys from device
  std::map<size_t, size_t> malloced_memory_;

  // user call Malloc total counts
  std::atomic<size_t> called_malloc_counts_{0U};

  // user call Free total counts
  std::atomic<size_t> called_free_counts_{0U};
};
}  // namespace ge
#endif  // GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_