#include <string.h>
#include "graph/types.h"
#include "aclnn_value_depend_test.h"
namespace {
typedef struct {
uint32_t id;
const char *funcName;
bool hasReg;
} NnopbaseDfxId;
typedef struct {
ge::DataType dtype;
ge::Format format;
} TensorDesc;
typedef struct {
TensorDesc *inputsDesc;
size_t inputsNum;
TensorDesc *outputsDesc;
size_t outputsNum;
} SupportInfo;
typedef struct {
SupportInfo *supportInfo;
size_t num;
} OpSocSupportInfo;
typedef struct {
OpSocSupportInfo *socSupportInfo;
size_t num;
} OpSupportList;
enum SocType {
SOC_VERSION_ASCEND910A = 1,
SOC_VERSION_ASCEND910B = 2,
SOC_VERSION_ASCEND910_93 = 3,
SOC_VERSION_ASCEND950 = 4,
SOC_VERSION_ASCEND310P = 5,
SOC_VERSION_ASCEND310B = 6,
SOC_VERSION_BS9SX1A = 7,
SOC_VERSION_ASCEND610Lite = 8,
SOC_VERSION_MC61AM21A = 10, // 9 is deprecated
SOC_VERSION_MC62CM12A = 11,
SOC_VERSION_BS9SX2A = 12,
SOC_VERSION_ASCEND910_96 = 13,
SOC_VERSION_KIRINX90 = 14,
SOC_VERSION_KIRIN9030 = 15,
SOC_VERSION_ASCEND350 = 16,
SOC_VERSION_INVALID = 99
};
enum NnopbaseAttrDtype {
kNnopbaseBool = 0U,
kNnopbaseFloat,
kNnopbaseInt,
kNnopbaseString,
kNnopbaseAttrEnd
};
uint32_t socSupportList[] = {SOC_VERSION_ASCEND910A};
uint32_t socSupportListLen = 1;
static const char *socNameList[] = {"ascend910"};
static const size_t socNameListLen = 1;
TensorDesc inputDesc0_0[4] =
{{ge::DT_FLOAT16, ge::FORMAT_ND},
{ge::DT_FLOAT, ge::FORMAT_ND},
{ge::DT_BOOL, ge::FORMAT_ND},
{ge::DT_INT32, ge::FORMAT_ND}};
TensorDesc inputDesc0_1[4] =
{{ge::DT_FLOAT, ge::FORMAT_ND},
{ge::DT_FLOAT, ge::FORMAT_ND},
{ge::DT_BOOL, ge::FORMAT_ND},
{ge::DT_UINT64, ge::FORMAT_ND}};
TensorDesc inputDesc0_2[4] =
{{ge::DT_FLOAT16, ge::FORMAT_ND},
{ge::DT_FLOAT, ge::FORMAT_ND},
{ge::DT_BOOL, ge::FORMAT_ND},
{ge::DT_INT64, ge::FORMAT_ND}};
TensorDesc inputDesc0_3[4] =
{{ge::DT_FLOAT, ge::FORMAT_ND},
{ge::DT_FLOAT, ge::FORMAT_ND},
{ge::DT_BOOL, ge::FORMAT_ND},
{ge::DT_INT64, ge::FORMAT_ND}};
TensorDesc outputDesc0_0[1] =
{{ge::DT_FLOAT16, ge::FORMAT_ND}};
TensorDesc outputDesc0_1[1] =
{{ge::DT_FLOAT, ge::FORMAT_ND}};
TensorDesc outputDesc0_2[1] =
{{ge::DT_FLOAT16, ge::FORMAT_ND}};
TensorDesc outputDesc0_3[1] =
{{ge::DT_FLOAT, ge::FORMAT_ND}};
SupportInfo list0_0 = {inputDesc0_0, 4, outputDesc0_0, 1};
SupportInfo list0_1 = {inputDesc0_1, 4, outputDesc0_1, 1};
SupportInfo list0_2 = {inputDesc0_2, 4, outputDesc0_2, 1};
SupportInfo list0_3 = {inputDesc0_3, 4, outputDesc0_3, 1};
SupportInfo supportInfo0[4] = {list0_0, list0_1, list0_2, list0_3};
OpSocSupportInfo socSupportInfo0= {supportInfo0, 4};
OpSocSupportInfo opSocSupportList[1] = {socSupportInfo0};
OpSupportList supportList = {opSocSupportList, 1};
[[maybe_unused]] uint32_t NNOPBASE_ValueDependTest = 0U;
} // namespace
extern void NnopbaseOpLogE(const aclnnStatus code, const char *const expr);
#ifdef __cplusplus
extern "C" {
#endif
extern aclnnStatus NnopbaseCreateExecutorSpace(void **space);
extern void *NnopbaseGetExecutor(void *space, const char *opType, char *inputsDesc, uint32_t inputNum,
char *outputsDesc, uint32_t outputNum, char *attrsDesc, uint32_t attrsNum);
extern aclnnStatus NnopbaseAddInput(void *executor, const aclTensor *tensor, const uint32_t index);
extern aclnnStatus NnopbaseAddIgnoreContinuesInput(void *executor,
const aclTensor *tensor, const uint32_t index);
extern aclnnStatus NnopbaseAddIntArrayInput(void *executor, const aclIntArray *array, const uint32_t index);
extern aclnnStatus NnopbaseAddBoolArrayInput(void *executor, const aclBoolArray *array, const uint32_t index);
extern aclnnStatus NnopbaseAddFloatArrayInput(void *executor, const aclFloatArray *array, const uint32_t index);
extern aclnnStatus NnopbaseAddOutput(void *executor, const aclTensor *tensor, const uint32_t index);
extern aclnnStatus NnopbaseAddDynamicInput(void *executor, const aclTensorList *tensor_list, const uint32_t index);
extern aclnnStatus __attribute__((weak)) NnopbaseAddIgnoreContiguousDynamicInput(void *executor, const aclTensorList *tensor_list, const uint32_t index);
extern aclnnStatus NnopbaseAddDynamicOutput(void *executor, const aclTensorList *tensor_list, const uint32_t index);
extern aclnnStatus NnopbaseAddAttrWithDtype(void *executor, void *attrAddr, size_t attrLen, const size_t index, const NnopbaseAttrDtype dtype);
extern aclnnStatus NnopbaseAddIntArrayAttr(void *executor, const aclIntArray* array, const size_t index);
extern aclnnStatus NnopbaseAddFloatArrayAttr(void *executor, const aclFloatArray* array, const size_t index);
extern aclnnStatus NnopbaseAddBoolArrayAttr(void *executor, const aclBoolArray* array, const size_t index);
extern aclnnStatus NnopbaseAddArrayAttrWithDtype(void *executor, void *array, const size_t len, const size_t elementSize, const size_t index, const NnopbaseAttrDtype dtype);
extern uint64_t NnopbaseMsprofSysTime();
extern const char* __attribute__((weak)) NnopbaseGetSocName();
extern aclnnStatus NnopbaseAddTilingId(void *executor, NnopbaseDfxId *tilingId);
extern void NnopbaseReportApiInfo(const uint64_t beginTime, NnopbaseDfxId &dfxId);
extern aclnnStatus NnopbaseRunForWorkspace(void *executor, uint64_t *workspaceLen);
extern aclnnStatus NnopbaseRunWithWorkspace(void *executor, aclrtStream stream, void *workspace, uint64_t workspaceSize);
extern aclnnStatus NnopbaseAddSupportList(void *executor, OpSupportList *list, uint32_t *socSupportList, size_t socSupportListLen);
extern aclnnStatus __attribute__((weak)) NnopbaseAddSocNameList(void *executor, OpSupportList *list, const char * const *socNameList, size_t socNameListLen);
extern aclnnStatus NnopbaseAddScalarInput(void *executor, const aclScalar *scalar, const uint32_t index, const int32_t srcIndex, const ge::DataType dtype);
extern aclnnStatus NnopbaseAddScalarListInput(void *executor, const aclScalarList *scalarList, const uint32_t index, const int32_t srcIndex, const ge::DataType dtype);
extern void NnopbaseAddOpTypeId(void *executor, const uint32_t opTypeId);
extern aclnnStatus __attribute__((weak)) NnopbaseAddParamName(void *executor, const uint32_t index, const char *name, const bool isInput);
extern aclnnStatus __attribute__((weak)) NnopbaseSetFormatMatchMode(void *executor, const uint32_t mode);
extern aclnnStatus NnopbaseSetRef(void *executor, const size_t inputIrIdx, const size_t outputIrIdx);
extern void __attribute__((weak)) NnopbaseSetMatchArgsFlag(void *executor);
extern bool __attribute__((weak)) NnopbaseMatchArgs(void *executor, uint64_t *workspaceLen);
extern void __attribute__((weak)) NnopbaseSetParamCheckMode(void *executor, const uint32_t mode);
#define ACLNN_SUCCESS 0
#define ACLNN_ERR_PARAM_NULLPTR 161001
#define ACLNN_ERR_PARAM_INVALID 161002
#define NNOPBASE_ASSERT_OK_RETVAL(v) \
do { \
const aclnnStatus _chk_stutus = (v); \
if (_chk_stutus != ACLNN_SUCCESS) { \
NnopbaseOpLogE(_chk_stutus, #v); \
return _chk_stutus; \
} \
} while (false)
#define NNOPBASE_ASSERT_NOTNULL_RETVAL(v) \
do { \
if ((v) == nullptr) { \
NnopbaseOpLogE(ACLNN_ERR_PARAM_NULLPTR, #v " != nullptr"); \
return ACLNN_ERR_PARAM_NULLPTR; \
} \
} while (false)
aclnnStatus aclnnValueDependTestGetWorkspaceSize(
const aclTensor *x,
const aclFloatArray *x1,
const aclBoolArray *x2,
const aclIntArray *x3,
const aclTensor *out,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
uint64_t timeStamp = NnopbaseMsprofSysTime();
static NnopbaseDfxId dfxId = {0x60000, __func__, false};
static NnopbaseDfxId tilingId = {0x60000, "aclnnValueDependTestTiling", false};
void *nnopExecutor;
static void *executorSpace = NULL;
const char *opType = "ValueDependTest";
char inputDesc[] = {1, 1, 1, 1};
char outputDesc[] = {1};
char attrDesc[] = {};
NNOPBASE_ASSERT_NOTNULL_RETVAL(x);
NNOPBASE_ASSERT_NOTNULL_RETVAL(x1);
NNOPBASE_ASSERT_NOTNULL_RETVAL(x2);
NNOPBASE_ASSERT_NOTNULL_RETVAL(x3);
NNOPBASE_ASSERT_NOTNULL_RETVAL(out);
if (!executorSpace) {
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseCreateExecutorSpace(&executorSpace));
}
nnopExecutor = NnopbaseGetExecutor(executorSpace, opType, inputDesc, sizeof(inputDesc) / sizeof(char), outputDesc,
sizeof(outputDesc) / sizeof(char), attrDesc, sizeof(attrDesc) / sizeof(char));
NNOPBASE_ASSERT_NOTNULL_RETVAL(nnopExecutor);
NNOPBASE_ASSERT_NOTNULL_RETVAL(executor);
*executor = reinterpret_cast<aclOpExecutor *>(nnopExecutor);
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddTilingId(*executor, &tilingId));
if (NnopbaseSetMatchArgsFlag != NULL) {
NnopbaseSetMatchArgsFlag(*executor);
}
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddInput(*executor, x, 0));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddFloatArrayInput(*executor, x1, 1));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddBoolArrayInput(*executor, x2, 2));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddIntArrayInput(*executor, x3, 3));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddOutput(*executor, out, 0));
if (NnopbaseMatchArgs != NULL) {
if (NnopbaseMatchArgs(*executor, workspaceSize)) {
NnopbaseReportApiInfo(timeStamp, dfxId);
return ACLNN_SUCCESS;
}
}
if (NnopbaseAddParamName != NULL) {
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 0, "x", true));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 1, "x1", true));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 2, "x2", true));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 3, "x3", true));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 0, "out", false));
}
if (NnopbaseAddSocNameList != NULL) {
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddSocNameList(*executor, &supportList, socNameList, socNameListLen));
} else {
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddSupportList(*executor, &supportList, socSupportList, socSupportListLen));
}
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseRunForWorkspace(*executor, workspaceSize));
NnopbaseReportApiInfo(timeStamp, dfxId);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnValueDependTestTensorGetWorkspaceSize(
const aclTensor *x,
const aclTensor *x1,
const aclTensor *x2,
const aclTensor *x3,
const aclTensor *out,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
uint64_t timeStamp = NnopbaseMsprofSysTime();
static NnopbaseDfxId dfxId = {0x60000, __func__, false};
static NnopbaseDfxId tilingId = {0x60000, "aclnnValueDependTestTiling", false};
void *nnopExecutor;
static void *executorSpace = NULL;
const char *opType = "ValueDependTest";
char inputDesc[] = {1, 1, 1, 1};
char outputDesc[] = {1};
char attrDesc[] = {};
NNOPBASE_ASSERT_NOTNULL_RETVAL(x);
NNOPBASE_ASSERT_NOTNULL_RETVAL(x1);
NNOPBASE_ASSERT_NOTNULL_RETVAL(x2);
NNOPBASE_ASSERT_NOTNULL_RETVAL(x3);
NNOPBASE_ASSERT_NOTNULL_RETVAL(out);
if (!executorSpace) {
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseCreateExecutorSpace(&executorSpace));
}
nnopExecutor = NnopbaseGetExecutor(executorSpace, opType, inputDesc, sizeof(inputDesc) / sizeof(char), outputDesc,
sizeof(outputDesc) / sizeof(char), attrDesc, sizeof(attrDesc) / sizeof(char));
NNOPBASE_ASSERT_NOTNULL_RETVAL(nnopExecutor);
NNOPBASE_ASSERT_NOTNULL_RETVAL(executor);
*executor = reinterpret_cast<aclOpExecutor *>(nnopExecutor);
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddTilingId(*executor, &tilingId));
if (NnopbaseSetMatchArgsFlag != NULL) {
NnopbaseSetMatchArgsFlag(*executor);
}
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddInput(*executor, x, 0));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddInput(*executor, x1, 1));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddInput(*executor, x2, 2));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddInput(*executor, x3, 3));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddOutput(*executor, out, 0));
if (NnopbaseMatchArgs != NULL) {
if (NnopbaseMatchArgs(*executor, workspaceSize)) {
NnopbaseReportApiInfo(timeStamp, dfxId);
return ACLNN_SUCCESS;
}
}
if (NnopbaseAddParamName != NULL) {
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 0, "x", true));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 1, "x1", true));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 2, "x2", true));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 3, "x3", true));
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddParamName(*executor, 0, "out", false));
}
if (NnopbaseAddSocNameList != NULL) {
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddSocNameList(*executor, &supportList, socNameList, socNameListLen));
} else {
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseAddSupportList(*executor, &supportList, socSupportList, socSupportListLen));
}
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseRunForWorkspace(*executor, workspaceSize));
NnopbaseReportApiInfo(timeStamp, dfxId);
return ACLNN_SUCCESS;
}
aclnnStatus aclnnValueDependTest(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
{
uint64_t timeStamp = NnopbaseMsprofSysTime();
static NnopbaseDfxId dfxId = {0x60000, __func__, false};
NNOPBASE_ASSERT_OK_RETVAL(NnopbaseRunWithWorkspace(executor, stream, workspace, workspaceSize));
NnopbaseReportApiInfo(timeStamp, dfxId);
return ACLNN_SUCCESS;
}
#ifdef __cplusplus
}
#endif