#include <string.h>
#include "graph/types.h"
#include "test_add_tik2.h"

#ifdef ACLNN_WITH_BINARY
#include <vector>
#include <tuple>
#include <map>
#include "graph/ascend_string.h"
#include "AddTik2_op_resource.h"
using OP_HOST_FUNC_HANDLE = std::vector<void *>;
using OP_RES = std::tuple<const uint8_t *, const uint8_t *>;
using OP_BINARY_RES = std::vector<OP_RES>;
using OP_RUNTIME_KB_RES = std::vector<OP_RES>;
using OP_RESOURCES = std::map<ge::AscendString,
    std::tuple<OP_HOST_FUNC_HANDLE, OP_BINARY_RES, OP_RUNTIME_KB_RES>>;
using OP_SOC_RESOURCES = std::map<ge::AscendString, std::tuple<OP_HOST_FUNC_HANDLE,
    std::map<ge::AscendString, OP_BINARY_RES>, OP_RUNTIME_KB_RES>>;
namespace op {
extern uint32_t GenOpTypeId(const char *op_name, const OP_RESOURCES &op_resources);
extern uint32_t GenOpTypeId(const char *op_name, const OP_SOC_RESOURCES &op_resources);
}
#endif

namespace {
typedef struct {
    uint32_t id;
    const char *funcName;
    bool hasReg;
} NnopbaseDfxId;
typedef struct {
    ge::DataType dtype;
    ge::Format format;
} TensorDesc;
typedef struct {
    TensorDesc *inputsDesc;
    size_t inputsNum;
    TensorDesc *outputsDesc;
    size_t outputsNum;
} SupportInfo;
typedef struct {
    SupportInfo *supportInfo;
    size_t num;
} OpSocSupportInfo;
typedef struct {
    OpSocSupportInfo *socSupportInfo;
    size_t num;
} OpSupportList;
enum SocType {
    SOC_VERSION_ASCEND910A = 1,
    SOC_VERSION_ASCEND910B = 2,
    SOC_VERSION_ASCEND910_93 = 3,
    SOC_VERSION_ASCEND950 = 4,
    SOC_VERSION_ASCEND310P = 5,
    SOC_VERSION_ASCEND310B = 6,
    SOC_VERSION_BS9SX1A = 7,
    SOC_VERSION_ASCEND610Lite = 8,
    SOC_VERSION_MC61AM21A = 10, // 9 is deprecated
    SOC_VERSION_MC62CM12A = 11,
    SOC_VERSION_BS9SX2A = 12,
    SOC_VERSION_ASCEND910_96 = 13,
    SOC_VERSION_KIRINX90 = 14,
    SOC_VERSION_KIRIN9030 = 15,
    SOC_VERSION_ASCEND350 = 16,
    SOC_VERSION_INVALID = 99
};
enum NnopbaseAttrDtype {
    kNnopbaseBool = 0U,
    kNnopbaseFloat,
    kNnopbaseInt,
    kNnopbaseString,
    kNnopbaseAttrEnd
};
uint32_t socSupportList[] = {SOC_VERSION_ASCEND310P,SOC_VERSION_ASCEND910A,SOC_VERSION_INVALID};
uint32_t socSupportListLen = 3;

static const char *socNameList[] = {"ascend310p", "ascend910", "ascendxxx"};
static const size_t socNameListLen = 3;
TensorDesc inputDesc0_0[2] =
    {{ge::DT_FLOAT16, ge::FORMAT_ND},
     {ge::DT_FLOAT16, ge::FORMAT_ND}};
TensorDesc outputDesc0_0[2] =
    {{ge::DT_FLOAT16, ge::FORMAT_ND},
     {ge::DT_FLOAT16, ge::FORMAT_ND}};
SupportInfo list0_0 = {inputDesc0_0, 2, outputDesc0_0, 2};
SupportInfo supportInfo0[1] = {list0_0};
OpSocSupportInfo socSupportInfo0= {supportInfo0, 1};

TensorDesc inputDesc1_0[2] =
    {{ge::DT_FLOAT16, ge::FORMAT_ND},
     {ge::DT_FLOAT16, ge::FORMAT_ND}};
TensorDesc inputDesc1_1[2] =
    {{ge::DT_FLOAT, ge::FORMAT_ND},
     {ge::DT_FLOAT, ge::FORMAT_NCHW}};
TensorDesc outputDesc1_0[2] =
    {{ge::DT_FLOAT16, ge::FORMAT_ND},
     {ge::DT_FLOAT16, ge::FORMAT_ND}};
TensorDesc outputDesc1_1[2] =
    {{ge::DT_FLOAT, ge::FORMAT_NCL},
     {ge::DT_FLOAT, ge::FORMAT_ND}};
SupportInfo list1_0 = {inputDesc1_0, 2, outputDesc1_0, 2};
SupportInfo list1_1 = {inputDesc1_1, 2, outputDesc1_1, 2};
SupportInfo supportInfo1[2] = {list1_0, list1_1};
OpSocSupportInfo socSupportInfo1= {supportInfo1, 2};

TensorDesc inputDesc2_0[2] =
    {{ge::DT_FLOAT16, ge::FORMAT_ND},
     {ge::DT_FLOAT16, ge::FORMAT_ND}};
TensorDesc inputDesc2_1[2] =
    {{ge::DT_FLOAT, ge::FORMAT_ND},
     {ge::DT_FLOAT, ge::FORMAT_NCHW}};
TensorDesc outputDesc2_0[2] =
    {{ge::DT_FLOAT16, ge::FORMAT_ND},
     {ge::DT_FLOAT16, ge::FORMAT_ND}};
TensorDesc outputDesc2_1[2] =
    {{ge::DT_FLOAT, ge::FORMAT_NCL},
     {ge::DT_FLOAT, ge::FORMAT_ND}};
SupportInfo list2_0 = {inputDesc2_0, 2, outputDesc2_0, 2};
SupportInfo list2_1 = {inputDesc2_1, 2, outputDesc2_1, 2};
SupportInfo supportInfo2[2] = {list2_0, list2_1};
OpSocSupportInfo socSupportInfo2= {supportInfo2, 2};

OpSocSupportInfo opSocSupportList[3] = {socSupportInfo0, socSupportInfo1, socSupportInfo2};
OpSupportList supportList = {opSocSupportList, 3};

[[maybe_unused]] uint32_t NNOPBASE_AddTik2 = 0U;
} // namespace

extern void NnopbaseOpLogE(const aclnnStatus code, const char *const expr);

#ifdef __cplusplus
extern "C" {
#endif

extern aclnnStatus NnopbaseCreateExecutorSpace(void **space);
extern void *NnopbaseGetExecutor(void *space, const char *opType, char *inputsDesc, uint32_t inputNum,
                                 char *outputsDesc, uint32_t outputNum, char *attrsDesc, uint32_t attrsNum);
extern aclnnStatus NnopbaseAddInput(void *executor, const aclTensor *tensor, const uint32_t index);
extern aclnnStatus NnopbaseAddIgnoreContinuesInput(void *executor,
                                                   const aclTensor *tensor, const uint32_t index);
extern aclnnStatus NnopbaseAddIntArrayInput(void *executor, const aclIntArray *array, const uint32_t index);
extern aclnnStatus NnopbaseAddBoolArrayInput(void *executor, const aclBoolArray *array, const uint32_t index);
extern aclnnStatus NnopbaseAddFloatArrayInput(void *executor, const aclFloatArray *array, const uint32_t index);
extern aclnnStatus NnopbaseAddOutput(void *executor, const aclTensor *tensor, const uint32_t index);
extern aclnnStatus NnopbaseAddDynamicInput(void *executor, const aclTensorList *tensor_list, const uint32_t index);
extern aclnnStatus __attribute__((weak)) NnopbaseAddIgnoreContiguousDynamicInput(void *executor, const aclTensorList *tensor_list, const uint32_t index);
extern aclnnStatus NnopbaseAddDynamicOutput(void *executor, const aclTensorList *tensor_list, const uint32_t index);
extern aclnnStatus NnopbaseAddAttrWithDtype(void *executor, void *attrAddr, size_t attrLen, const size_t index, const NnopbaseAttrDtype dtype);
extern aclnnStatus NnopbaseAddIntArrayAttr(void *executor, const aclIntArray* array, const size_t index);
extern aclnnStatus NnopbaseAddFloatArrayAttr(void *executor, const aclFloatArray* array, const size_t index);
extern aclnnStatus NnopbaseAddBoolArrayAttr(void *executor, const aclBoolArray* array, const size_t index);
extern aclnnStatus NnopbaseAddArrayAttrWithDtype(void *executor, void *array, const size_t len, const size_t elementSize, const size_t index, const NnopbaseAttrDtype dtype);
extern uint64_t NnopbaseMsprofSysTime();
extern const char* __attribute__((weak)) NnopbaseGetSocName();
extern aclnnStatus NnopbaseAddTilingId(void *executor, NnopbaseDfxId *tilingId);
extern void NnopbaseReportApiInfo(const uint64_t beginTime, NnopbaseDfxId &dfxId);
extern aclnnStatus NnopbaseRunForWorkspace(void *executor, uint64_t *workspaceLen);
extern aclnnStatus NnopbaseRunWithWorkspace(void *executor, aclrtStream stream, void *workspace, uint64_t workspaceSize);
extern aclnnStatus NnopbaseAddSupportList(void *executor, OpSupportList *list, uint32_t *socSupportList, size_t socSupportListLen);
extern aclnnStatus __attribute__((weak)) NnopbaseAddSocNameList(void *executor, OpSupportList *list, const char * const *socNameList, size_t socNameListLen);
extern aclnnStatus NnopbaseAddScalarInput(void *executor, const aclScalar *scalar, const uint32_t index, const int32_t srcIndex, const ge::DataType dtype);
extern aclnnStatus NnopbaseAddScalarListInput(void *executor, const aclScalarList *scalarList, const uint32_t index, const int32_t srcIndex, const ge::DataType dtype);
extern void NnopbaseAddOpTypeId(void *executor, const uint32_t opTypeId);
extern aclnnStatus __attribute__((weak)) NnopbaseAddParamName(void *executor, const uint32_t index, const char *name, const bool isInput);
extern aclnnStatus __attribute__((weak)) NnopbaseSetFormatMatchMode(void *executor, const uint32_t mode);
extern aclnnStatus NnopbaseSetRef(void *executor, const size_t inputIrIdx, const size_t outputIrIdx);
extern void __attribute__((weak)) NnopbaseSetMatchArgsFlag(void *executor);
extern bool __attribute__((weak)) NnopbaseMatchArgs(void *executor, uint64_t *workspaceLen);
extern void __attribute__((weak)) NnopbaseSetParamCheckMode(void *executor, const uint32_t mode);

#define ACLNN_SUCCESS  0
#define ACLNN_ERR_PARAM_NULLPTR 161001
#define ACLNN_ERR_PARAM_INVALID 161002

#define NNOPBASE_ASSERT_OK_RETVAL(v)                                    \
    do {                                                                \
        const aclnnStatus _chk_stutus = (v);                            \
        if (_chk_stutus != ACLNN_SUCCESS) {                             \
            NnopbaseOpLogE(_chk_stutus, #v);                            \
            return _chk_stutus;                                         \
        }                                                               \
    } while (false)

#define NNOPBASE_ASSERT_NOTNULL_RETVAL(v)                               \
    do {                                                                \
        if ((v) == nullptr) {                                           \
            NnopbaseOpLogE(ACLNN_ERR_PARAM_NULLPTR, #v " != nullptr");  \
            return ACLNN_ERR_PARAM_NULLPTR;                             \
        }                                                               \
    } while (false)

aclnnStatus testAddTik2GetWorkspaceSize(
    const aclTensor *x1Optional,
    const aclTensorList *x2,
    int64_t bias0,
    double bias1,
    const aclBoolArray *bias2Optional,
    const aclFloatArray *bias3Optional,
    const aclIntArray *bias4Optional,
    char *bias5Optional,
    bool bias6,
    int64_t bias00,
    double bias11,
    const aclBoolArray *bias22,
    const aclFloatArray *bias33,
    const aclIntArray *bias44,
    char *bias55,
    bool bias66,
    const aclTensor *yOut,
    const aclTensorList *y2Out,
    uint64_t *workspaceSize,
    aclOpExecutor **executor)
{
    uint64_t timeStamp = NnopbaseMsprofSysTime();
#ifdef ACLNN_WITH_BINARY
    static uint32_t AddTik2OpTypeId = op::GenOpTypeId("AddTik2", AddTik2_RESOURCES);
#endif
    static NnopbaseDfxId dfxId = {0x60000, __func__, false};
    static NnopbaseDfxId tilingId = {0x60000, "testAddTik2Tiling", false};
    void *nnopExecutor;
    static void *executorSpace = NULL;
    const char *opType = "AddTik2";
    char inputDesc[] = {0, 2};
    char outputDesc[] = {1, 2};
    char attrDesc[] = {0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1};

    NNOPBASE_ASSERT_NOTNULL_RETVAL(x2);
    NNOPBASE_ASSERT_NOTNULL_RETVAL(yOut);
    NNOPBASE_ASSERT_NOTNULL_RETVAL(y2Out);

    if (!executorSpace) {
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseCreateExecutorSpace(&executorSpace));
    }
    nnopExecutor = NnopbaseGetExecutor(executorSpace, opType, inputDesc, sizeof(inputDesc) / sizeof(char), outputDesc,
                                       sizeof(outputDesc) / sizeof(char), attrDesc, sizeof(attrDesc) / sizeof(char));
    NNOPBASE_ASSERT_NOTNULL_RETVAL(nnopExecutor);
    NNOPBASE_ASSERT_NOTNULL_RETVAL(executor);
    *executor = reinterpret_cast<aclOpExecutor *>(nnopExecutor);
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddTilingId(*executor, &tilingId));
    if (NnopbaseSetMatchArgsFlag != NULL) {
        NnopbaseSetMatchArgsFlag(*executor);
    }
#ifdef ACLNN_WITH_BINARY
    NnopbaseAddOpTypeId(*executor, AddTik2OpTypeId);
#endif
    if (NnopbaseSetFormatMatchMode != NULL) {
        NnopbaseSetFormatMatchMode(*executor, 1);
    }
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddInput(*executor, x1Optional, 0));
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddDynamicInput(*executor, x2, 1));
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddAttrWithDtype(*executor, static_cast<void*>(&bias0), sizeof(int64_t), 0, kNnopbaseInt));
    float tmp1 = static_cast<float>(bias1);
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddAttrWithDtype(*executor, static_cast<void*>(&tmp1), sizeof(float), 1, kNnopbaseFloat));
    if (bias2Optional) {
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddBoolArrayAttr(*executor, bias2Optional, 2));
    } else {
        static bool bias2OptionalDef[] = {true, false};
        static size_t bias2OptionalLen = 2;
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddArrayAttrWithDtype(*executor, static_cast<void*>(bias2OptionalDef), bias2OptionalLen, sizeof(bool), 2, kNnopbaseBool));
    }
    if (bias3Optional) {
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddFloatArrayAttr(*executor, bias3Optional, 3));
    } else {
        static float bias3OptionalDef[] = {0.1, 0.2};
        static size_t bias3OptionalLen = 2;
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddArrayAttrWithDtype(*executor, static_cast<void*>(bias3OptionalDef), bias3OptionalLen, sizeof(float), 3, kNnopbaseFloat));
    }
    if (bias4Optional) {
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddIntArrayAttr(*executor, bias4Optional, 4));
    } else {
        static int64_t bias4OptionalDef[] = {1, 2};
        static size_t bias4OptionalLen = 2;
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddArrayAttrWithDtype(*executor, static_cast<void*>(bias4OptionalDef), bias4OptionalLen, sizeof(int64_t), 4, kNnopbaseInt));
    }
    if (bias5Optional) {
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddAttrWithDtype(*executor, static_cast<void*>(bias5Optional), strlen(bias5Optional) + 1, 5, kNnopbaseString));
    } else {
        static char *bias5OptionalDef = "ssss";
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddAttrWithDtype(*executor, static_cast<void*>(bias5OptionalDef), strlen(bias5OptionalDef) + 1, 5, kNnopbaseString));
    }
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddAttrWithDtype(*executor, static_cast<void*>(&bias6), sizeof(bool), 6, kNnopbaseBool));
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddAttrWithDtype(*executor, static_cast<void*>(&bias00), sizeof(int64_t), 7, kNnopbaseInt));
    float tmp8 = static_cast<float>(bias11);
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddAttrWithDtype(*executor, static_cast<void*>(&tmp8), sizeof(float), 8, kNnopbaseFloat));
    NNOPBASE_ASSERT_NOTNULL_RETVAL(bias22);
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddBoolArrayAttr(*executor, bias22, 9));
    NNOPBASE_ASSERT_NOTNULL_RETVAL(bias33);
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddFloatArrayAttr(*executor, bias33, 10));
    NNOPBASE_ASSERT_NOTNULL_RETVAL(bias44);
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddIntArrayAttr(*executor, bias44, 11));
    NNOPBASE_ASSERT_NOTNULL_RETVAL(bias55);
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddAttrWithDtype(*executor, static_cast<void*>(bias55), strlen(bias55) + 1, 12, kNnopbaseString));
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddAttrWithDtype(*executor, static_cast<void*>(&bias66), sizeof(bool), 13, kNnopbaseBool));
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddOutput(*executor, yOut, 0));
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddDynamicOutput(*executor, y2Out, 1));
    if (NnopbaseMatchArgs != NULL) {
        if (NnopbaseMatchArgs(*executor, workspaceSize)) {
            NnopbaseReportApiInfo(timeStamp, dfxId);
            return ACLNN_SUCCESS;
        }
    }
    if (NnopbaseAddParamName != NULL) {
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 0, "x1Optional", true));
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 1, "x2", true));
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 0, "yOut", false));
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 1, "y2Out", false));
    }
    if (NnopbaseAddSocNameList != NULL) {
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddSocNameList(*executor, &supportList, socNameList, socNameListLen));
    } else {
        NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddSupportList(*executor, &supportList, socSupportList, socSupportListLen));
    }
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseRunForWorkspace(*executor, workspaceSize));
    NnopbaseReportApiInfo(timeStamp, dfxId);
    return ACLNN_SUCCESS;
}

aclnnStatus testAddTik2(
    void *workspace,
    uint64_t workspaceSize,
    aclOpExecutor *executor,
    aclrtStream stream)
{
    uint64_t timeStamp = NnopbaseMsprofSysTime();
    static NnopbaseDfxId dfxId = {0x60000, __func__, false};
    NNOPBASE_ASSERT_OK_RETVAL(NnopbaseRunWithWorkspace(executor, stream, workspace, workspaceSize));
    NnopbaseReportApiInfo(timeStamp, dfxId);
    return ACLNN_SUCCESS;
}

#ifdef __cplusplus
}
#endif