/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef INC_EXTERNAL_ACL_ACL_RT_STUB_H_
#define INC_EXTERNAL_ACL_ACL_RT_STUB_H_

#include <stdint.h>
#include <stddef.h>
#include <vector>
#include <memory>
#include <mutex>
#include <set>
#include "mmpa/mmpa_api.h"
#include "acl/acl.h"
#include "acl/acl_base.h"
#include "acl/acl_dump.h"
#include "acl/acl_rt.h"
#include "common/ge_common/ge_types.h"
#include "graph/small_vector.h"
#include "graph/any_value.h"

namespace ge {
class AclRuntimeStub {
public:
  virtual ~AclRuntimeStub() = default;

  static AclRuntimeStub* GetInstance();
  void SetDeviceId(int64_t device_id) {
    device_id_ = device_id;
  }
  static void SetErrorResultApiName(const std::string &stub_api_name);
  static void SetInstance(const std::shared_ptr<AclRuntimeStub> &instance);

  static void Install(AclRuntimeStub*);
  static void UnInstall(AclRuntimeStub*);

  static void Reset() {
    instance_.reset();
  }

  virtual aclError aclrtRecordNotify(aclrtNotify notify, aclrtStream stream);
  virtual aclError aclrtBinaryGetFunctionByEntry(aclrtBinHandle binHandle,
                                                 uint64_t funcEntry,
                                                 aclrtFuncHandle *funcHandle);
  virtual aclError aclrtLaunchKernel(aclrtFuncHandle funcHandle,
                                     uint32_t blockDim,
                                     const void *argsData,
                                     size_t argsSize,
                                     aclrtStream stream);
  virtual aclError aclrtBinaryUnLoad(aclrtBinHandle binHandle);
  virtual aclError aclrtBinaryLoadFromFile(const char* binPath, aclrtBinaryLoadOptions *options,
      aclrtBinHandle *binHandle);
  virtual aclError aclrtBinaryLoadFromData(const void *data, size_t length,
      const aclrtBinaryLoadOptions *options, aclrtBinHandle *binHandle);
  virtual aclError aclrtLaunchKernelV2(aclrtFuncHandle funcHandle, uint32_t numBlocks,
      const void *argsData, size_t argsSize, aclrtLaunchKernelCfg *cfg, aclrtStream stream);
  virtual aclError aclrtRegisterCpuFunc(const aclrtBinHandle handle, const char *funcName,
      const char *kernelName, aclrtFuncHandle *funcHandle);
  virtual aclError aclrtBinaryGetFunction(const aclrtBinHandle binHandle, const char *kernelName,
      aclrtFuncHandle *funcHandle);
  virtual aclError aclrtLaunchKernelWithHostArgs(aclrtFuncHandle funcHandle, uint32_t numBlocks,
      aclrtStream stream, aclrtLaunchKernelCfg *cfg, void *hostArgs, size_t argsSize,
      aclrtPlaceHolderInfo *placeHolderArray, size_t placeHolderNum);
  virtual aclError aclrtStreamGetId(aclrtStream stream, int32_t *streamId);
  virtual aclError aclrtWaitAndResetNotify(aclrtNotify notify, aclrtStream stream, uint32_t timeout);
  virtual aclError aclrtSetDevice(int32_t deviceId);
  virtual aclError aclrtResetDevice(int32_t deviceId);
  virtual aclError aclrtCacheLastTaskExtendInfo(const char * const extendInfoPtr, const size_t infoSize);
  virtual aclError aclrtGetDevice(int32_t *deviceId);
  virtual aclError aclrtGetThreadLastTaskId(uint32_t *taskId);
  virtual aclError aclrtCreateContext(aclrtContext *context, int32_t deviceId);
  virtual aclError aclrtDestroyContext(aclrtContext context);
  virtual aclError aclrtSetCurrentContext(aclrtContext context);
  virtual aclError aclrtGetCurrentContext(aclrtContext *context);
  virtual aclError aclrtCreateEvent(aclrtEvent *event);
  virtual aclError aclrtDestroyEvent(aclrtEvent event);
  virtual aclError aclrtRecordEvent(aclrtEvent event, aclrtStream stream);
  virtual aclError aclrtQueryEventStatus(aclrtEvent event, aclrtEventRecordedStatus *status);
  virtual aclError aclrtCreateStream(aclrtStream *stream);
  virtual aclError aclrtCreateStreamWithConfig(aclrtStream *stream, uint32_t priority, uint32_t flag);
  virtual aclError aclrtDestroyStream(aclrtStream stream);
  virtual aclError aclrtStreamAbort(aclrtStream stream);
  virtual aclError aclrtSynchronizeStream(aclrtStream stream);
  virtual aclError aclrtSynchronizeStreamWithTimeout(aclrtStream stream, int32_t timeout);
  virtual aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy);
  virtual aclError aclrtMallocHost(void **hostPtr, size_t size);
  virtual aclError aclrtMallocWithCfg(void **devPtr, size_t size, aclrtMemMallocPolicy policy, aclrtMallocConfig *cfg);
  virtual aclError aclrtMallocHostWithCfg(void **hostPtr, size_t size, aclrtMallocConfig *cfg);
  virtual aclError aclrtMemset(void *devPtr, size_t maxCount, int32_t value, size_t count);
  virtual aclError aclrtFree(void *devPtr);
  virtual aclError aclrtFreeHost(void *devPtr);
  virtual aclError aclrtHostRegister(void *ptr, uint64_t size, aclrtHostRegisterType type, void **devPtr);
  virtual aclError aclrtHostUnregister(void *ptr);
  virtual aclError aclrtMemcpy(void *dst, size_t dest_max, const void *src, size_t count, aclrtMemcpyKind kind);
  virtual aclError aclrtMemcpyAsync(void *dst,
                    size_t dest_max,
                    const void *src,
                    size_t src_count,
                    aclrtMemcpyKind kind,
                    aclrtStream stream);
  virtual aclError aclrtMemcpyAsyncWithCondition(void *dst,
                          size_t destMax,
                          const void *src,
                          size_t count,
                          aclrtMemcpyKind kind,
                          aclrtStream stream);
  virtual aclError aclrtGetMemInfo(aclrtMemAttr attr, size_t *free_size, size_t *total);
  virtual const char* aclrtGetSocName();
  virtual aclError aclrtGetDeviceInfo(uint32_t deviceId, aclrtDevAttr attr, int64_t *value);
  virtual aclError aclrtGetPhyDevIdByLogicDevId(const int32_t logicDevId, int32_t *const phyDevId);
  virtual aclError aclrtMemcpyBatch(void **dsts, size_t *destMax, void **srcs, size_t *sizes, size_t numBatches,
                                    aclrtMemcpyBatchAttr *attrs, size_t *attrsIndexex, size_t numAttrs, size_t *failIndex);
  virtual aclError aclrtCheckArchCompatibility(const char *socVersion, int32_t *canCompatible);
  virtual aclError aclrtSetStreamFailureMode(aclrtStream stream, uint64_t mode);
  virtual aclError aclrtSetStreamAttribute(aclrtStream stream, aclrtStreamAttr attr, aclrtStreamAttrValue *value);
  virtual aclError aclrtActiveStream(aclrtStream activeStream, aclrtStream stream);
  virtual aclError aclrtCtxGetCurrentDefaultStream(aclrtStream *stream);
  virtual aclError aclrtDestroyLabel(aclrtLabel label);
  virtual aclError aclmdlRIDestroy(aclmdlRI modelRI);
  virtual aclError aclmdlRIUnbindStream(aclmdlRI modelRI, aclrtStream stream);
  virtual aclError aclrtDestroyLabelList(aclrtLabelList labelList);
  virtual aclError aclmdlRIBindStream(aclmdlRI modelRI, aclrtStream stream, uint32_t flag);
  virtual aclError aclmdlRIBuildEnd(aclmdlRI modelRI, void *reserve);
  virtual aclError aclrtPersistentTaskClean(aclrtStream stream);
  virtual aclError aclrtSetExceptionInfoCallback(aclrtExceptionInfoCallback callback);
  virtual uint32_t aclrtGetDeviceIdFromExceptionInfo(const aclrtExceptionInfo *info);
  virtual uint32_t aclrtGetErrorCodeFromExceptionInfo(const aclrtExceptionInfo *info);
  virtual aclError aclrtGetUserDevIdByLogicDevId(const int32_t logicDevId, int32_t *const userDevid);
  virtual aclError aclrtGetLogicDevIdByUserDevId(const int32_t userDevid, int32_t *const logicDevId);
  virtual aclError aclrtSetTsDevice(aclrtTsId tsId);
  virtual aclError aclrtGetDeviceCount(uint32_t *count);
  virtual aclError aclrtGetDeviceCapability(int32_t deviceId, aclrtDevFeatureType devFeatureType, int32_t *value);
  virtual aclError aclrtCreateEventWithFlag(aclrtEvent *event, uint32_t flag);
  virtual aclError aclrtResetEvent(aclrtEvent event, aclrtStream stream);
  virtual aclError aclrtSynchronizeEventWithTimeout(aclrtEvent event, int32_t timeout);
  virtual aclError aclrtMallocForTaskScheduler(void **devPtr, size_t size, aclrtMemMallocPolicy policy,
                                               aclrtMallocConfig *cfg);
  virtual aclError aclrtReserveMemAddress(void **virPtr, size_t size, size_t alignment, void *expectPtr,
                                          uint64_t flags);
  virtual aclError aclrtReleaseMemAddress(void *virPtr);
  virtual aclError aclrtMallocPhysical(aclrtDrvMemHandle *handle, size_t size, const aclrtPhysicalMemProp *prop,
                                       uint64_t flags);
  virtual aclError aclrtFreePhysical(aclrtDrvMemHandle handle);
  virtual aclError aclrtMapMem(void *virPtr, size_t size, size_t offset, aclrtDrvMemHandle handle, uint64_t flags);
  virtual aclError aclrtUnmapMem(void *virPtr);
  virtual aclError aclrtDestroyStreamForce(aclrtStream stream);
  virtual aclError aclrtStreamWaitEvent(aclrtStream stream, aclrtEvent event);
  virtual aclError aclrtStreamWaitEventWithTimeout(aclrtStream stream, aclrtEvent event, int32_t timeout);
  virtual aclError aclrtSetOpWaitTimeout(uint32_t timeout);
  virtual aclError aclrtSetDeviceSatMode(aclrtFloatOverflowMode mode);
  virtual aclError aclrtDeviceGetBareTgid(int32_t *pid);
  virtual aclError aclrtSetOpExecuteTimeOut(uint32_t timeout);
  virtual aclError aclrtSetOpExecuteTimeOutWithMs(uint32_t timeout);
  virtual aclError aclrtSetOpExecuteTimeOutV2(uint64_t timeout, uint64_t *actualTimeout);
  virtual aclError aclrtGetStreamAvailableNum(uint32_t *streamCount);
  virtual aclError aclrtSetStreamResLimit(aclrtStream stream, aclrtDevResLimitType type, uint32_t value);
  virtual aclError aclrtUseStreamResInCurrentThread(aclrtStream stream);
  virtual aclError aclrtUnuseStreamResInCurrentThread(aclrtStream stream);
  virtual aclError aclrtGetEventId(aclrtEvent event, uint32_t *eventId);
  virtual aclError aclrtCreateEventExWithFlag(aclrtEvent *event, uint32_t flag);
  virtual aclError aclrtGetEventAvailNum(uint32_t *eventCount);
  virtual aclError aclrtCreateLabel(aclrtLabel *label);
  virtual aclError aclrtSetLabel(aclrtLabel label, aclrtStream stream);
  virtual aclError aclrtCreateLabelList(aclrtLabel *labels, size_t num, aclrtLabelList *labelList);
  virtual aclError aclrtSwitchLabelByIndex(void *ptr, uint32_t maxValue, aclrtLabelList labelList, aclrtStream stream);
  virtual aclError aclrtSwitchStream(void *leftValue, aclrtCondition cond, void *rightValue,
                                     aclrtCompareDataType dataType, aclrtStream trueStream, aclrtStream falseStream,
                                     aclrtStream stream);
  virtual aclError aclmdlRIExecuteAsync(aclmdlRI modelRI, aclrtStream stream);
  virtual aclError aclmdlRIExecute(aclmdlRI modelRI, int32_t timeout);
  virtual aclError aclmdlRIBuildBegin(aclmdlRI *modelRI, uint32_t flag);
  virtual aclError aclmdlRIEndTask(aclmdlRI modelRI, aclrtStream stream);
  virtual aclError aclmdlRISetName(aclmdlRI modelRI, const char *name);
  virtual aclError aclmdlRIDebugJsonPrint(aclmdlRI modelRI, const char *path, uint32_t flags);
  virtual aclError aclrtCtxGetFloatOverflowAddr(void **overflowAddr);
  virtual aclError aclrtGetHardwareSyncAddr(void **addr);
  virtual aclError aclrtTaskUpdateAsync(aclrtStream taskStream, uint32_t taskId, aclrtTaskUpdateInfo *info,
                                        aclrtStream execStream);
  virtual aclError aclmdlRIAbort(aclmdlRI modelRI);
  virtual aclError aclrtProfTrace(void *userdata, int32_t length, aclrtStream stream);
  virtual aclError aclrtCreateNotify(aclrtNotify *notify, uint64_t flag);
  virtual aclError aclrtDestroyNotify(aclrtNotify notify);

  virtual aclError aclrtValueWait(void* devAddr, uint64_t value, uint32_t flag, aclrtStream stream);

  virtual aclError aclrtValueWrite(void* devAddr, uint64_t value, uint32_t flag, aclrtStream stream);

 private:
  static std::mutex mutex_;
  static std::shared_ptr<AclRuntimeStub> instance_;
  static thread_local AclRuntimeStub *fake_instance_;
  size_t reserve_mem_size_ = 200UL * 1024UL * 1024UL;
  std::mutex mtx_;
  int64_t device_id_{0L};
  std::vector<aclrtStream> model_bind_streams_;
  std::vector<aclrtStream> model_unbind_streams_;
  size_t input_mem_copy_batch_count_{0UL};
  int32_t cur_device_id = 0;
  int32_t batch_memcpy_device_id = 0;
};

class AclApiStub {
public:
  virtual ~AclApiStub() = default;

  static AclApiStub* GetInstance();

  static void SetInstance(const std::shared_ptr<AclApiStub> &instance) {
    instance_ = instance;
  }

  static void Install(AclApiStub*);
  static void UnInstall(AclApiStub*);

  static void Reset() {
    instance_.reset();
  }

  virtual aclError aclInit(const char *configPath);
  virtual aclError aclFinalize();
  virtual aclDataBuffer *aclCreateDataBuffer(void *data, size_t size);
  virtual aclTensorDesc *aclCreateTensorDesc(aclDataType dataType,
                                    int numDims,
                                    const int64_t *dims,
                                    aclFormat format);
  virtual void *aclGetDataBufferAddr(const aclDataBuffer *dataBuffer);
  virtual size_t aclGetDataBufferSizeV2(const aclDataBuffer *dataBuffer);
  virtual aclError aclGetTensorDescDimV2(const aclTensorDesc *desc, size_t index, int64_t *dimSize);
  virtual aclFormat aclGetTensorDescFormat(const aclTensorDesc *desc);
  virtual size_t aclGetTensorDescNumDims(const aclTensorDesc *desc);
  virtual aclDataType aclGetTensorDescType(const aclTensorDesc *desc);
  virtual aclError aclmdlAddDatasetBuffer(aclmdlDataset *dataset, aclDataBuffer *dataBuffer);
  virtual aclmdlConfigHandle *aclmdlCreateConfigHandle();
  virtual aclmdlDataset *aclmdlCreateDataset();
  virtual aclmdlDesc *aclmdlCreateDesc();
  virtual aclError aclmdlDestroyDataset(const aclmdlDataset *dataset);
  virtual aclError aclmdlDestroyDesc(aclmdlDesc *modelDesc);
  virtual aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output);
  virtual aclDataBuffer *aclmdlGetDatasetBuffer(const aclmdlDataset *dataset, size_t index);
  virtual aclTensorDesc *aclmdlGetDatasetTensorDesc(const aclmdlDataset *dataset, size_t index);
  virtual aclError aclmdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId);
  virtual aclError aclmdlLoadFromMem(const void *model,  size_t modelSize, uint32_t *modelId);
  virtual aclError aclmdlLoadWithConfig(const aclmdlConfigHandle *handle, uint32_t *modelId);
  virtual aclError aclmdlSetConfigOpt(aclmdlConfigHandle *handle, aclmdlConfigAttr attr,
                                    const void *attrValue, size_t valueSize);
  virtual aclError aclmdlSetDatasetTensorDesc(aclmdlDataset *dataset,
                                      aclTensorDesc *tensorDesc,
                                      size_t index);
  virtual aclError aclmdlSetExternalWeightAddress(aclmdlConfigHandle *handle, const char *weightFileName,
                                          void *devPtr, size_t size);
  virtual aclError aclmdlUnload(uint32_t modelId);
  virtual aclError aclmdlDestroyConfigHandle(aclmdlConfigHandle *handle);
  virtual void aclDestroyTensorDesc(const aclTensorDesc *desc);
  virtual aclError aclDestroyDataBuffer(const aclDataBuffer *dataBuffer);

private:
  static std::mutex mutex_;
  static std::shared_ptr<AclApiStub> instance_;
  static thread_local AclApiStub *fake_instance_;
  std::mutex mtx_;
};
}

#ifdef __cplusplus
extern "C" {
#endif

extern std::string g_acl_stub_mock;
extern std::string g_acl_stub_mock_v2;
extern std::string g_acl_stub_debug_json_last_file_path;

#ifdef __cplusplus
}
#endif

#endif // INC_EXTERNAL_ACL_ACL_RT_STUB_H_