* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_logical.cpp
* \brief
*/
#include "tikicpulib.h"
#include "kernel_operator.h"
#include <kernel_tpipe.h>
#include <algorithm>
#include <cstdlib>
#include "gtest/gtest.h"
using namespace AscendC;
#include "utils.h"
#include "logical.h"
template<class T>
void GmToUb(LocalTensor<T>& local, T* gm, int size) {
for (int i = 0; i < size; i++) {
local.SetValue(i, gm[i]);
}
}
template<class T>
void UbToGm(T* gm, LocalTensor<T>& local, int size) {
for (int i = 0; i < size; i++) {
gm[i] = local.GetValue(i);
}
}
template<typename T>
void InitParams(LocalTensor<T> &l_x1, LocalTensor<T> &l_x2, LocalTensor<uint8_t> &l_y,
LocalTensor<uint8_t> &l_tmp, uint32_t size) {
TPipe pipe;
TBuf<TPosition::VECCALC> x1_buf, x2_buf, y_buf, tmp_buf;
pipe.InitBuffer(x1_buf, sizeof(T) * size);
pipe.InitBuffer(x2_buf, sizeof(T) * size);
pipe.InitBuffer(y_buf, sizeof(uint8_t) * size);
constexpr int tmp_size = 8 * 1024;
pipe.InitBuffer(tmp_buf, tmp_size);
l_x1 = x1_buf.Get<T>();
l_x2 = x2_buf.Get<T>();
l_y = y_buf.Get<uint8_t>();
l_tmp = tmp_buf.Get<uint8_t>();
}
template<typename T>
void TestLogicalCommon (const std::string &logical, uint32_t size) {
ASSERT_TRUE(logical == "or" || logical == "and");
auto *x1 = static_cast<T*>(GmAlloc(sizeof(T) * size));
auto *x2 = static_cast<T*>(GmAlloc(sizeof(T) * size));
auto *y = static_cast<uint8_t*>(GmAlloc(sizeof(T) * size));
uint8_t expect[size];
for (int i = 0; i < size; i++) {
T tmp1 = std::rand() % 2;
x1[i] = tmp1;
T tmp2 = std::rand() % 2;
x2[i] = tmp2;
if (logical == "or") {
expect[i] = std::max(tmp1, tmp2);
} else {
expect[i] = tmp1 * tmp2;
}
}
auto kernel = [](uint32_t size, T *x1, T *x2, uint8_t *y, const std::string &logical) {
LocalTensor<T> l_x1;
LocalTensor<T> l_x2;
LocalTensor<uint8_t> l_y;
LocalTensor<uint8_t> l_tmp;
InitParams(l_x1, l_x2, l_y, l_tmp, size);
GmToUb(l_x1, x1, size);
GmToUb(l_x2, x2, size);
GmToUb(l_y, y, size);
if (logical == "or") {
LogicalOr(l_y, l_x1, l_x2, l_tmp, size);
} else {
LogicalAnd(l_y, l_x1, l_x2, l_tmp, size);
}
UbToGm(y, l_y, size);
};
AscendC::SetKernelMode(KernelMode::AIV_MODE);
ICPU_RUN_KF(kernel, 1, size, x1, x2, y, logical);
int diff_count = 0;
for (int i = 0; i < size; i++) {
if (y[i] != expect[i]) {
diff_count++;
}
}
EXPECT_EQ(diff_count, 0);
GmFree(x1);
GmFree(x2);
GmFree(y);
}
template<typename T>
void InitParamsScalarExtend(LocalTensor<T> &l_x1, LocalTensor<uint8_t> &l_y,
LocalTensor<uint8_t> &l_tmp, uint32_t size) {
TPipe pipe;
TBuf<TPosition::VECCALC> x1_buf, y_buf, tmp_buf;
pipe.InitBuffer(x1_buf, sizeof(T) * size);
pipe.InitBuffer(y_buf, sizeof(uint8_t) * size);
constexpr int tmp_size = 8 * 1024;
pipe.InitBuffer(tmp_buf, tmp_size);
l_x1 = x1_buf.Get<T>();
l_y = y_buf.Get<uint8_t>();
l_tmp = tmp_buf.Get<uint8_t>();
}
template<typename T>
void TestLogicalCommonScalarExtend (const std::string &logical, uint32_t size) {
ASSERT_TRUE(logical == "orScalarExtend" || logical == "andScalarExtend");
auto *x1 = static_cast<T*>(GmAlloc(sizeof(T) * size));
auto *y = static_cast<uint8_t*>(GmAlloc(sizeof(T) * size));
uint8_t expect[size];
T tmp2 = std::rand() % 2;
auto x2 = tmp2;
for (int i = 0; i < size; i++) {
T tmp1 = std::rand() % 2;
x1[i] = tmp1;
if (logical == "orScalarExtend") {
expect[i] = std::max(tmp1, tmp2);
} else {
expect[i] = tmp1 * tmp2;
}
}
auto kernel = [](uint32_t size, T *x1, T x2, uint8_t *y, const std::string &logical) {
LocalTensor<T> l_x1;
LocalTensor<uint8_t> l_y;
LocalTensor<uint8_t> l_tmp;
InitParamsScalarExtend(l_x1, l_y, l_tmp, size);
GmToUb(l_x1, x1, size);
GmToUb(l_y, y, size);
T l_x2 = x2;
if (logical == "orScalarExtend") {
LogicalOrScalarExtend(l_y, l_x1, l_x2, l_tmp, size);
} else {
LogicalAndScalarExtend(l_y, l_x1, l_x2, l_tmp, size);
}
UbToGm(y, l_y, size);
};
AscendC::SetKernelMode(KernelMode::AIV_MODE);
ICPU_RUN_KF(kernel, 1, size, x1, x2, y, logical);
int diff_count = 0;
for (int i = 0; i < size; i++) {
if (y[i] != expect[i]) {
diff_count++;
}
}
EXPECT_EQ(diff_count, 0);
GmFree(x1);
GmFree(y);
}
template<typename T>
void InitParamsUnalign(LocalTensor<T> &l_x1, LocalTensor<T> &l_x2, LocalTensor<uint8_t> &l_y,
LocalTensor<uint8_t> &l_tmp, uint32_t size) {
TPipe pipe;
TBuf<TPosition::VECCALC> x1_buf, x2_buf, y_buf, tmp_buf;
pipe.InitBuffer(x1_buf, sizeof(T) * size);
pipe.InitBuffer(x2_buf, sizeof(T) * size);
pipe.InitBuffer(y_buf, sizeof(uint8_t) * size);
constexpr int tmp_size = 16928;
pipe.InitBuffer(tmp_buf, tmp_size);
l_x1 = x1_buf.Get<T>();
l_x2 = x2_buf.Get<T>();
l_y = y_buf.Get<uint8_t>();
l_tmp = tmp_buf.Get<uint8_t>();
}
template<typename T>
void TestLogicalCommonUnalign (const std::string &logical, uint32_t size) {
ASSERT_TRUE(logical == "or" || logical == "and");
auto *x1 = static_cast<T*>(GmAlloc(sizeof(T) * size));
auto *x2 = static_cast<T*>(GmAlloc(sizeof(T) * size));
auto *y = static_cast<uint8_t*>(GmAlloc(sizeof(T) * size));
uint8_t expect[size];
for (int i = 0; i < size; i++) {
T tmp1 = std::rand() % 2;
x1[i] = tmp1;
T tmp2 = std::rand() % 2;
x2[i] = tmp2;
if (logical == "or") {
expect[i] = std::max(tmp1, tmp2);
} else {
expect[i] = tmp1 * tmp2;
}
}
auto kernel = [](uint32_t size, T *x1, T *x2, uint8_t *y, const std::string &logical) {
LocalTensor<T> l_x1;
LocalTensor<T> l_x2;
LocalTensor<uint8_t> l_y;
LocalTensor<uint8_t> l_tmp;
InitParamsUnalign(l_x1, l_x2, l_y, l_tmp, size);
GmToUb(l_x1, x1, size);
GmToUb(l_x2, x2, size);
GmToUb(l_y, y, size);
if (logical == "or") {
LogicalOr(l_y, l_x1, l_x2, l_tmp, size);
} else {
LogicalAnd(l_y, l_x1, l_x2, l_tmp, size);
}
UbToGm(y, l_y, size);
};
AscendC::SetKernelMode(KernelMode::AIV_MODE);
ICPU_RUN_KF(kernel, 1, size, x1, x2, y, logical);
int diff_count = 0;
for (int i = 0; i < size; i++) {
if (y[i] != expect[i]) {
diff_count++;
}
}
EXPECT_EQ(diff_count, 0);
GmFree(x1);
GmFree(x2);
GmFree(y);
}
TEST(TestApiLogical, Test_1_blk) {
uint32_t size = ONE_BLK_HALF_NUM;
TestLogicalCommon<float>("or", size);
TestLogicalCommon<half>("or", size);
TestLogicalCommon<uint8_t>("or", size);
TestLogicalCommon<int8_t>("or", size);
TestLogicalCommon<int16_t>("or", size);
TestLogicalCommon<int32_t>("or", size);
TestLogicalCommon<int64_t>("or", size);
TestLogicalCommon<float>("and", size);
TestLogicalCommon<half>("and", size);
TestLogicalCommon<uint8_t>("and", size);
TestLogicalCommon<int8_t>("and", size);
TestLogicalCommon<int16_t>("and", size);
TestLogicalCommon<int32_t>("and", size);
TestLogicalCommon<int64_t>("and", size);
}
TEST(TestApiLogical, Test_1_repeat) {
uint32_t size = ONE_REPEAT_HALF_SIZE;
TestLogicalCommon<float>("or", size);
TestLogicalCommon<half>("or", size);
TestLogicalCommon<uint8_t>("or", size);
TestLogicalCommon<int8_t>("or", size);
TestLogicalCommon<int16_t>("or", size);
TestLogicalCommon<int32_t>("or", size);
TestLogicalCommon<int64_t>("or", size);
TestLogicalCommon<float>("and", size);
TestLogicalCommon<half>("and", size);
TestLogicalCommon<uint8_t>("and", size);
TestLogicalCommon<int8_t>("and", size);
TestLogicalCommon<int16_t>("and", size);
TestLogicalCommon<int32_t>("and", size);
TestLogicalCommon<int64_t>("and", size);
}
TEST(TestApiLogical, Test_max_repeat) {
uint32_t size = MAX_REPEAT_TIMES * ONE_REPEAT_HALF_SIZE;
TestLogicalCommon<half>("or", size);
TestLogicalCommon<uint8_t>("or", size);
TestLogicalCommon<int8_t>("or", size);
TestLogicalCommon<int16_t>("or", size);
TestLogicalCommon<half>("and", size);
TestLogicalCommon<uint8_t>("and", size);
TestLogicalCommon<int8_t>("and", size);
TestLogicalCommon<int16_t>("and", size);
}
TEST(TestApiLogical, Test_less_1_blk) {
uint32_t size = ONE_BLK_HALF_NUM - 1;
TestLogicalCommon<float>("or", size);
TestLogicalCommon<half>("or", size);
TestLogicalCommon<uint8_t>("or", size);
TestLogicalCommon<int8_t>("or", size);
TestLogicalCommon<int16_t>("or", size);
TestLogicalCommon<int32_t>("or", size);
TestLogicalCommon<int64_t>("or", size);
TestLogicalCommon<float>("and", size);
TestLogicalCommon<half>("and", size);
TestLogicalCommon<uint8_t>("and", size);
TestLogicalCommon<int8_t>("and", size);
TestLogicalCommon<int16_t>("and", size);
TestLogicalCommon<int32_t>("and", size);
TestLogicalCommon<int64_t>("and", size);
}
TEST(TestApiLogical, Test_less_1_repeat) {
uint32_t size = ONE_REPEAT_HALF_SIZE - 1;
TestLogicalCommon<float>("or", size);
TestLogicalCommon<half>("or", size);
TestLogicalCommon<uint8_t>("or", size);
TestLogicalCommon<int8_t>("or", size);
TestLogicalCommon<int16_t>("or", size);
TestLogicalCommon<int32_t>("or", size);
TestLogicalCommon<int64_t>("or", size);
TestLogicalCommon<float>("and", size);
TestLogicalCommon<half>("and", size);
TestLogicalCommon<uint8_t>("and", size);
TestLogicalCommon<int8_t>("and", size);
TestLogicalCommon<int16_t>("and", size);
TestLogicalCommon<int32_t>("and", size);
TestLogicalCommon<int64_t>("and", size);
}
TEST(TestApiLogical, Test_less_max_repeat) {
uint32_t size = (MAX_REPEAT_TIMES - 1) * ONE_REPEAT_HALF_SIZE;
TestLogicalCommon<half>("or", size);
TestLogicalCommon<uint8_t>("or", size);
TestLogicalCommon<int8_t>("or", size);
TestLogicalCommon<int16_t>("or", size);
TestLogicalCommon<half>("and", size);
TestLogicalCommon<uint8_t>("and", size);
TestLogicalCommon<int8_t>("and", size);
TestLogicalCommon<int16_t>("and", size);
}
TEST(TestApiLogical, Test_mix) {
uint32_t size = MAX_REPEAT_TIMES * ONE_REPEAT_HALF_SIZE + ONE_REPEAT_HALF_SIZE + ONE_BLK_HALF_NUM;
TestLogicalCommon<half>("or", size);
TestLogicalCommon<uint8_t>("or", size);
TestLogicalCommon<int8_t>("or", size);
TestLogicalCommon<int16_t>("or", size);
TestLogicalCommon<half>("and", size);
TestLogicalCommon<uint8_t>("and", size);
TestLogicalCommon<int8_t>("and", size);
TestLogicalCommon<int16_t>("and", size);
}
TEST(TestApiLogical, Test_1_blk_scalar_extend) {
uint32_t size = ONE_BLK_HALF_NUM;
TestLogicalCommonScalarExtend<float>("orScalarExtend", size);
TestLogicalCommonScalarExtend<half>("orScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int32_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int64_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<float>("andScalarExtend", size);
TestLogicalCommonScalarExtend<half>("andScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int32_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int64_t>("andScalarExtend", size);
}
TEST(TestApiLogical, Test_1_repeat_scalar_extend) {
uint32_t size = ONE_REPEAT_HALF_SIZE;
TestLogicalCommonScalarExtend<float>("orScalarExtend", size);
TestLogicalCommonScalarExtend<half>("orScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int32_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int64_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<float>("andScalarExtend", size);
TestLogicalCommonScalarExtend<half>("andScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int32_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int64_t>("andScalarExtend", size);
}
TEST(TestApiLogical, Test_max_repeat_scalar_extend) {
uint32_t size = MAX_REPEAT_TIMES * ONE_REPEAT_HALF_SIZE;
TestLogicalCommonScalarExtend<half>("orScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<half>("andScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("andScalarExtend", size);
}
TEST(TestApiLogical, Test_less_1_blk_scalar_extend) {
uint32_t size = ONE_BLK_HALF_NUM - 1;
TestLogicalCommonScalarExtend<float>("orScalarExtend", size);
TestLogicalCommonScalarExtend<half>("orScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int32_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int64_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<float>("andScalarExtend", size);
TestLogicalCommonScalarExtend<half>("andScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int32_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int64_t>("andScalarExtend", size);
}
TEST(TestApiLogical, Test_less_1_repeat_scalar_extend) {
uint32_t size = ONE_REPEAT_HALF_SIZE - 1;
TestLogicalCommonScalarExtend<float>("orScalarExtend", size);
TestLogicalCommonScalarExtend<half>("orScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int32_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int64_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<float>("andScalarExtend", size);
TestLogicalCommonScalarExtend<half>("andScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int32_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int64_t>("andScalarExtend", size);
}
TEST(TestApiLogical, Test_less_max_repeat_scalar_extend) {
uint32_t size = (MAX_REPEAT_TIMES - 1) * ONE_REPEAT_HALF_SIZE;
TestLogicalCommonScalarExtend<half>("orScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<half>("andScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("andScalarExtend", size);
}
TEST(TestApiLogical, Test_mix_scalar_extend) {
uint32_t size = MAX_REPEAT_TIMES * ONE_REPEAT_HALF_SIZE + ONE_REPEAT_HALF_SIZE + ONE_BLK_HALF_NUM;
TestLogicalCommonScalarExtend<half>("orScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("orScalarExtend", size);
TestLogicalCommonScalarExtend<half>("andScalarExtend", size);
TestLogicalCommonScalarExtend<uint8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int8_t>("andScalarExtend", size);
TestLogicalCommonScalarExtend<int16_t>("andScalarExtend", size);
}
TEST(TestApiLogical, Test_1_blk_unalign) {
uint32_t size = 6160;
TestLogicalCommonUnalign<float>("or", size);
TestLogicalCommonUnalign<half>("or", size);
TestLogicalCommonUnalign<uint8_t>("or", size);
TestLogicalCommonUnalign<int8_t>("or", size);
TestLogicalCommonUnalign<int16_t>("or", size);
TestLogicalCommonUnalign<int32_t>("or", size);
TestLogicalCommonUnalign<float>("and", size);
TestLogicalCommonUnalign<half>("and", size);
TestLogicalCommonUnalign<uint8_t>("and", size);
TestLogicalCommonUnalign<int8_t>("and", size);
TestLogicalCommonUnalign<int16_t>("and", size);
TestLogicalCommonUnalign<int32_t>("and", size);
}