#include <vector>
#include "aclnn/acl_meta.h"
#include "exe_graph/runtime/op_execute_context.h"
#include "exe_graph/runtime/tensor.h"
#include "register/op_impl_registry.h"

extern void NnopbaseOpLogE(const aclnnStatus code, const char *const expr);

#ifdef __cplusplus
extern "C" {
#endif

extern void* __attribute__((weak)) NnopbaseGetOpApiFunc(const char *funcName);
extern aclTensor* __attribute__((weak)) NnopbaseConvertTensor(const gert::Tensor* tensor);
extern aclTensorList* __attribute__((weak)) NnopbaseConvertTensorList(std::vector<const gert::Tensor*> &tenserList);
extern aclBoolArray* __attribute__((weak)) NnopbaseCovertBoolArray(const gert::Tensor* tensor);
extern aclIntArray* __attribute__((weak)) NnopbaseCovertIntArray(const gert::Tensor* tensor);
extern aclFloatArray* __attribute__((weak)) NnopbaseCovertFloatArray(const gert::Tensor* tensor);
extern aclScalar* __attribute__((weak)) NnopbaseConvertScalar(const gert::Tensor* tensor);
extern aclScalarList* __attribute__((weak)) NnopbaseConvertScalarList(const gert::Tensor* tensor);
extern aclIntArray* __attribute__((weak)) NnopbaseCovertIntArrayAttr(const gert::TypedContinuousVector<int64_t> *arr);
extern aclBoolArray* __attribute__((weak)) NnopbaseCovertBoolArrayAttr(const gert::TypedContinuousVector<bool> *arr);
extern aclFloatArray* __attribute__((weak)) NnopbaseCovertFloatArrayAttr(const gert::TypedContinuousVector<float> *arr);
extern void __attribute__((weak)) NnopbaseDestroyTensor(const aclTensor *tensor);
extern void __attribute__((weak)) NnopbaseDestroyTensorList(const aclTensorList *tensorList);
extern void __attribute__((weak)) NnopbaseDestroyScalar(const aclScalar *scalar);
extern void __attribute__((weak)) NnopbaseDestroyScalarList(const aclScalarList *scalar);
extern void __attribute__((weak)) NnopbaseDestroyIntArray(const aclIntArray *array);
extern void __attribute__((weak)) NnopbaseDestroyBoolArray(const aclBoolArray *array);
extern void __attribute__((weak)) NnopbaseDestroyFloatArray(const aclFloatArray *array);
using AclnnGetWorkspaceSizeFunc = aclnnStatus (*)(aclTensor *, const aclTensorList *, const aclFloatArray *, const aclBoolArray *, const aclIntArray *, const aclScalar *, const aclScalarList *, int64_t, double, const aclBoolArray *, const aclFloatArray *, const aclIntArray *, char *, bool, aclTensorList *, aclTensor *, uint64_t *, aclOpExecutor **);
using AclnnFunc = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);

#define ACLNN_SUCCESS  0
#define ACLNN_ERR_PARAM_NULLPTR 161001

#define FALLBACK_ASSERT_OK_RETVAL(v)                                    \
    do {                                                                \
        const aclnnStatus _chk_stutus = (v);                            \
        if (_chk_stutus != ACLNN_SUCCESS) {                             \
            NnopbaseOpLogE(_chk_stutus, #v);                            \
            return ge::GRAPH_FAILED;                                     \
        }                                                               \
    } while (false)

#define FALLBACK_ASSERT_NOTNULL_RETVAL(v)                               \
    do {                                                                \
        if ((v) == nullptr) {                                           \
            NnopbaseOpLogE(ACLNN_ERR_PARAM_NULLPTR, #v " != nullptr");  \
            return ge::GRAPH_FAILED;                                    \
        }                                                               \
    } while (false)

namespace fallback {
ge::graphStatus FallBackTestHostExecuteFunc(gert::OpExecuteContext* host_api_ctx) {
    FALLBACK_ASSERT_NOTNULL_RETVAL(host_api_ctx);
    auto x1 = host_api_ctx->GetOptionalInputTensor(0);
    size_t index_1 = 0U;
    std::vector<const gert::Tensor*> x2;
    do {
        auto val = host_api_ctx->GetDynamicInputTensor(1, index_1);
        if (val == nullptr) {break;}
        x2.push_back(val);
        index_1++;
    } while (true);

    auto x3 = host_api_ctx->GetRequiredInputTensor(2);
    FALLBACK_ASSERT_NOTNULL_RETVAL(x3);
    auto x4 = host_api_ctx->GetRequiredInputTensor(3);
    FALLBACK_ASSERT_NOTNULL_RETVAL(x4);
    auto x5 = host_api_ctx->GetRequiredInputTensor(4);
    FALLBACK_ASSERT_NOTNULL_RETVAL(x5);
    auto x6 = host_api_ctx->GetOptionalInputTensor(5);
    auto x7 = host_api_ctx->GetOptionalInputTensor(6);

    size_t outIndex_1 = 0U;
    std::vector<const gert::Tensor*> y1;
    do {
        auto var = host_api_ctx->GetDynamicOutputTensor(1, outIndex_1);
        if (val == nullptr) {break;}
        y1.push_back(val);
        outIndex_1++;
    } while (true);

    auto y2 = host_api_ctx->GetRequiredOutputTensor(2);
    FALLBACK_ASSERT_NOTNULL_RETVAL(y2);

    auto attrs = host_api_ctx->GetAttrs();
    FALLBACK_ASSERT_NOTNULL_RETVAL(attrs);
    size_t attrIndex = 0U;
    const int64_t *bias0 = attrs->GetAttrPointer<int64_t>(attrIndex++);
    const float *bias1 = attrs->GetAttrPointer<float>(attrIndex++);
    FALLBACK_ASSERT_NOTNULL_RETVAL(bias1);
    const gert::TypedContinuousVector<bool> *bias2 = attrs->GetAttrPointer<gert::TypedContinuousVector<bool>>(attrIndex++);
    const gert::TypedContinuousVector<float> *bias3 = attrs->GetAttrPointer<gert::TypedContinuousVector<float>>(attrIndex++);
    const gert::TypedContinuousVector<int64_t> *bias4 = attrs->GetAttrPointer<gert::TypedContinuousVector<int64_t>>(attrIndex++);
    FALLBACK_ASSERT_NOTNULL_RETVAL(bias4);
    const char *bias5 = attrs->GetAttrPointer<char>(attrIndex++);
    const bool *bias6 = attrs->GetAttrPointer<bool>(attrIndex++);

    if (NnopbaseGetOpApiFunc == NULL) {return ge::GRAPH_FAILED;}
    static AclnnGetWorkspaceSizeFunc aclnnFallBackTestGetWorkspaceSize = (AclnnGetWorkspaceSizeFunc)NnopbaseGetOpApiFunc("aclnnFallBackTestGetWorkspaceSize");
    FALLBACK_ASSERT_NOTNULL_RETVAL(aclnnFallBackTestGetWorkspaceSize);
    static AclnnFunc aclnnFallBackTest= (AclnnFunc)NnopbaseGetOpApiFunc("aclnnFallBackTest");
    FALLBACK_ASSERT_NOTNULL_RETVAL(aclnnFallBackTest);
    aclTensor *x1_tensor = NnopbaseConvertTensor(x1);
    aclTensorList *x2_tensorList = NnopbaseConvertTensorList(x2);
    aclFloatArray *x3_tensor = NnopbaseCovertFloatArray(x3);
    aclBoolArray *x4_tensor = NnopbaseCovertBoolArray(x4);
    aclIntArray *x5_tensor = NnopbaseCovertIntArray(x5);
    aclScalar *x6_scalar = NnopbaseConvertScalar(x6);
    aclScalarList *x7_scalarList = NnopbaseConvertScalarList(x7);
    aclTensorList *y1_tensorList = NnopbaseConvertTensorList(y1);
    aclTensor *y2_tensor = NnopbaseConvertTensor(y2);
    aclBoolArray *bias2_attr = NnopbaseCovertBoolArrayAttr(bias2);
    aclFloatArray *bias3_attr = NnopbaseCovertFloatArrayAttr(bias3);
    aclIntArray *bias4_attr = NnopbaseCovertIntArrayAttr(bias4);

    uint64_t workspaceSize = 0;
    aclOpExecutor *executor = nullptr;
    auto ret = aclnnFallBackTestGetWorkspaceSize(x1_tensor, x2_tensorList, x3_tensor, x4_tensor, x5_tensor, x6_scalar, x7_scalarList, *bias0, *bias1, bias2_attr, bias3_attr, bias4_attr, const_cast<char *>(bias5), *bias6, y1_tensorList, y2_tensor, &workspaceSize, &executor);
    FALLBACK_ASSERT_OK_RETVAL(ret);
    void *workspace = nullptr;
    if (workspaceSize > 0) {
        workspace = host_api_ctx->MallocWorkspace(workspaceSize);
        FALLBACK_ASSERT_NOTNULL_RETVAL(workspace);
    }
    auto stream = host_api_ctx->GetStream();
    ret = aclnnFallBackTest(workspace, workspaceSize, executor, stream);
    FALLBACK_ASSERT_OK_RETVAL(ret);
    NnopbaseDestroyTensor(x1_tensor);
    NnopbaseDestroyTensorList(x2_tensorList);
    NnopbaseDestroyFloatArray(x3_tensor);
    NnopbaseDestroyBoolArray(x4_tensor);
    NnopbaseDestroyIntArray(x5_tensor);
    NnopbaseDestroyScalar(x6_scalar);
    NnopbaseDestroyScalarList(x7_scalarList);
    NnopbaseDestroyTensorList(y1_tensorList);
    NnopbaseDestroyTensor(y2_tensor);
    NnopbaseCovertBoolArrayAttr(bias2_attr);
    NnopbaseDestroyFloatArray(bias3_attr);
    NnopbaseDestroyIntArray(bias4_attr);
    host_api_ctx->FreeWorkspace();
    return ge::GRAPH_SUCCESS;
}

IMPL_OP(FallBackTest).OpExecuteFunc(FallBackTestHostExecuteFunc).HostInputs({2, 3, 4, 5, 6});
} // namespace fallback

#ifdef __cplusplus
}
#endif