#include <sys/types.h>
#include <iostream>
#include <torch/extension.h>

#include "third_party/acl/inc/acl/acl_base.h"
#include "third_party/acl/inc/acl/acl_rt.h"
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"

extern "C" {
using c10_npu::NPUCachingAllocator::DeviceStats;
static bool useflag = false;

void* my_malloc(ssize_t size, int device, aclrtStream stream)
{
    void *ptr;
    aclrtMallocAlign32(&ptr, size, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST);
    std::cout<<"alloc ptr = "<<ptr<<", size = "<<size<<std::endl;
    useflag = true;
    return ptr;
}

void my_free(void* ptr, ssize_t size, int device, aclrtStream stream)
{
    std::cout<<"free ptr = "<<ptr<<std::endl;
    aclrtFree(ptr);
}

bool check_custom_allocator_used()
{
    return useflag;
}

DeviceStats my_get_device_stats(int device)
{
    DeviceStats stats;
    return stats;
}

void my_reset_peak_status(int device)
{
    std::cout<<"resetPeakStatus success!"<<std::endl;
}
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("my_malloc", &my_malloc, "");
    m.def("my_free", &my_free, "");
    m.def("check_custom_allocator_used", &check_custom_allocator_used, "");
}