Copyright (c) 2026 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#include "test_common.h"
#include "acl/acl.h"
#include <gtest/gtest.h>
using namespace std;
using namespace PtoTestCommon;
class MGATHERTest : public testing::Test {
protected:
void SetUp() override
{}
void TearDown() override
{}
};
static std::string GetGoldenDir()
{
const testing::TestInfo *testInfo = testing::UnitTest::GetInstance()->current_test_info();
return std::string("../") + testInfo->test_suite_name() + "." + testInfo->name();
}
template <typename T, typename TIdx, typename Launcher>
void run_mgather_test(size_t tableCount, size_t idxCount, size_t outCount, Launcher launcher)
{
size_t tableByteSize = tableCount * sizeof(T);
size_t idxByteSize = idxCount * sizeof(TIdx);
size_t outByteSize = outCount * sizeof(T);
aclInit(nullptr);
aclrtSetDevice(0);
aclrtStream stream;
aclrtCreateStream(&stream);
T *tableHost, *outHost;
TIdx *idxHost;
T *tableDevice, *outDevice;
TIdx *idxDevice;
aclrtMallocHost((void **)(&tableHost), tableByteSize);
aclrtMallocHost((void **)(&idxHost), idxByteSize);
aclrtMallocHost((void **)(&outHost), outByteSize);
aclrtMalloc((void **)&tableDevice, tableByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void **)&idxDevice, idxByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void **)&outDevice, outByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile(GetGoldenDir() + "/table.bin", tableByteSize, tableHost, tableByteSize);
ReadFile(GetGoldenDir() + "/indices.bin", idxByteSize, idxHost, idxByteSize);
aclrtMemcpy(tableDevice, tableByteSize, tableHost, tableByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemcpy(idxDevice, idxByteSize, idxHost, idxByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemset(outDevice, outByteSize, 0, outByteSize);
launcher(outDevice, tableDevice, idxDevice, stream);
aclrtSynchronizeStream(stream);
aclrtMemcpy(outHost, outByteSize, outDevice, outByteSize, ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile(GetGoldenDir() + "/output.bin", outHost, outByteSize);
aclrtFree(tableDevice);
aclrtFree(idxDevice);
aclrtFree(outDevice);
aclrtFreeHost(tableHost);
aclrtFreeHost(idxHost);
aclrtFreeHost(outHost);
aclrtDestroyStream(stream);
aclrtResetDevice(0);
aclFinalize();
std::vector<T> golden(outCount);
std::vector<T> devFinal(outCount);
ReadFile(GetGoldenDir() + "/golden.bin", outByteSize, golden.data(), outByteSize);
ReadFile(GetGoldenDir() + "/output.bin", outByteSize, devFinal.data(), outByteSize);
bool ret = ResultCmp<T>(golden, devFinal, 0.0f);
EXPECT_TRUE(ret);
}
#define DECLARE_LAUNCH(NAME, THOST, TIDX) void Launch_##NAME(THOST *out, THOST *table, TIDX *indices, void *stream);
DECLARE_LAUNCH(row_float_8x32_64rows, float, int32_t)
DECLARE_LAUNCH(row_half_16x64_64rows, aclFloat16, int32_t)
DECLARE_LAUNCH(row_bfloat16_16x16_64rows, uint16_t, int32_t)
DECLARE_LAUNCH(row_int32_8x16_32rows, int32_t, int32_t)
DECLARE_LAUNCH(row_uint32_8x16_32rows, uint32_t, int32_t)
DECLARE_LAUNCH(row_int16_8x16_32rows, int16_t, int32_t)
DECLARE_LAUNCH(row_uint16_8x16_32rows, uint16_t, int32_t)
DECLARE_LAUNCH(row_int8_8x32_32rows, int8_t, int32_t)
DECLARE_LAUNCH(row_uint8_8x32_32rows, uint8_t, int32_t)
DECLARE_LAUNCH(row_float_clamp_8x32_8rows, float, int32_t)
DECLARE_LAUNCH(row_int32_wrap_8x16_8rows, int32_t, int32_t)
DECLARE_LAUNCH(row_half_zero_8x32_8rows, aclFloat16, int32_t)
DECLARE_LAUNCH(row_int32_unaligned_3x8_8rows, int32_t, int32_t)
DECLARE_LAUNCH(row_float_partial_4x16_in_8x16, float, int32_t)
DECLARE_LAUNCH(row_half_partial_5x32_in_8x32, aclFloat16, int32_t)
DECLARE_LAUNCH(row_uint8_unaligned_3x32_32rows, uint8_t, int32_t)
DECLARE_LAUNCH(row_int16_partial_3x16_in_4x16, int16_t, int32_t)
DECLARE_LAUNCH(elem_float_64_128size, float, int32_t)
DECLARE_LAUNCH(elem_half_64_128size, aclFloat16, int32_t)
DECLARE_LAUNCH(elem_bfloat16_64_128size, uint16_t, int32_t)
DECLARE_LAUNCH(elem_int32_32_64size, int32_t, int32_t)
DECLARE_LAUNCH(elem_uint32_32_64size, uint32_t, int32_t)
DECLARE_LAUNCH(elem_int16_32_64size, int16_t, int32_t)
DECLARE_LAUNCH(elem_uint16_32_64size, uint16_t, int32_t)
DECLARE_LAUNCH(elem_int8_64_128size, int8_t, int32_t)
DECLARE_LAUNCH(elem_uint8_64_128size, uint8_t, int32_t)
DECLARE_LAUNCH(elem_float_clamp_32_16size, float, int32_t)
DECLARE_LAUNCH(elem_int32_wrap_32_16size, int32_t, int32_t)
DECLARE_LAUNCH(elem_half_zero_32_16size, aclFloat16, int32_t)
DECLARE_LAUNCH(elem2d_float_8x32_256size, float, int32_t)
DECLARE_LAUNCH(elem2d_int32_8x16_256size, int32_t, int32_t)
DECLARE_LAUNCH(elem2d_half_4x32_256size, aclFloat16, int32_t)
DECLARE_LAUNCH(elem2d_bfloat16_4x32_256size, uint16_t, int32_t)
DECLARE_LAUNCH(elem2d_uint8_4x64_256size, uint8_t, int32_t)
DECLARE_LAUNCH(elem2d_int8_4x64_256size, int8_t, int32_t)
DECLARE_LAUNCH(elem2d_int16_4x32_256size, int16_t, int32_t)
DECLARE_LAUNCH(elem2d_uint16_4x32_256size, uint16_t, int32_t)
DECLARE_LAUNCH(elem2d_uint32_8x16_256size, uint32_t, int32_t)
DECLARE_LAUNCH(elem2d_float_wrap_4x16_64size, float, int32_t)
DECLARE_LAUNCH(elem2d_int32_clamp_4x8_32size, int32_t, int32_t)
DECLARE_LAUNCH(elem2d_half_zero_4x32_64size, aclFloat16, int32_t)
DECLARE_LAUNCH(elem2d_int32_unaligned_3x3_in_3x8_64size, int32_t, int32_t)
DECLARE_LAUNCH(elem2d_float_unaligned_5x5_in_5x8_64size, float, int32_t)
DECLARE_LAUNCH(elem2d_half_unaligned_3x9_in_3x16_64size, aclFloat16, int32_t)
DECLARE_LAUNCH(elem2d_int8_unaligned_3x17_in_3x32_64size, int8_t, int32_t)
DECLARE_LAUNCH(elem_scalar_float_1x1_in_1x8_8size, float, int32_t)
DECLARE_LAUNCH(elem_scalar_int32_1x1_in_1x8_8size, int32_t, int32_t)
DECLARE_LAUNCH(elem_scalar_half_1x1_in_1x16_16size, aclFloat16, int32_t)
DECLARE_LAUNCH(elem2d_dyn_float_4x8_64size, float, int32_t)
DECLARE_LAUNCH(elem2d_dyn_int32_3x3_in_3x8_64size, int32_t, int32_t)
DECLARE_LAUNCH(row_dyn_int32_3x16_8rows, int32_t, int32_t)
DECLARE_LAUNCH(row_dyn_half_4x32_16rows, aclFloat16, int32_t)
DECLARE_LAUNCH(row_nz_float_16x16_2blk, float, int32_t)
DECLARE_LAUNCH(row_nz_half_32x16_2blk, aclFloat16, int32_t)
DECLARE_LAUNCH(row_nz_int32_16x16_2blk, int32_t, int32_t)
DECLARE_LAUNCH(row_nz_int16_32x16_1blk, int16_t, int32_t)
DECLARE_LAUNCH(row_nz_int8_16x32_1blk, int8_t, int32_t)
DECLARE_LAUNCH(row_nz_float_clamp_16x8_1blk, float, int32_t)
DECLARE_LAUNCH(row_nz_half_zero_16x16_2blk, aclFloat16, int32_t)
DECLARE_LAUNCH(elem2d_nz_float_16x16_2blk, float, int32_t)
DECLARE_LAUNCH(elem2d_nz_half_16x16_1blk, aclFloat16, int32_t)
DECLARE_LAUNCH(elem2d_nz_int32_16x8_1blk, int32_t, int32_t)
DECLARE_LAUNCH(elem2d_nz_half_zero_16x16_1blk, aclFloat16, int32_t)
#define ROW_TEST(NAME, THOST, TIDX, R, C, TR) \
TEST_F(MGATHERTest, case_##NAME) \
{ \
run_mgather_test<THOST, TIDX>((size_t)TR * C, (size_t)R, (size_t)R * C, Launch_##NAME); \
}
#define ELEM_TEST(NAME, THOST, TIDX, N, TS) \
TEST_F(MGATHERTest, case_##NAME) \
{ \
run_mgather_test<THOST, TIDX>((size_t)TS, (size_t)N, (size_t)N, Launch_##NAME); \
}
#define ELEM2D_TEST(NAME, THOST, TIDX, R, C, TS) \
TEST_F(MGATHERTest, case_##NAME) \
{ \
run_mgather_test<THOST, TIDX>((size_t)TS, (size_t)R * C, (size_t)R * C, Launch_##NAME); \
}
#define SCALAR_TEST(NAME, THOST, TIDX, TS) \
TEST_F(MGATHERTest, case_##NAME) \
{ \
run_mgather_test<THOST, TIDX>((size_t)TS, (size_t)1, (size_t)1, Launch_##NAME); \
}
ROW_TEST(row_float_8x32_64rows, float, int32_t, 8, 32, 64)
ROW_TEST(row_half_16x64_64rows, aclFloat16, int32_t, 16, 64, 64)
ROW_TEST(row_bfloat16_16x16_64rows, uint16_t, int32_t, 16, 16, 64)
ROW_TEST(row_int32_8x16_32rows, int32_t, int32_t, 8, 16, 32)
ROW_TEST(row_uint32_8x16_32rows, uint32_t, int32_t, 8, 16, 32)
ROW_TEST(row_int16_8x16_32rows, int16_t, int32_t, 8, 16, 32)
ROW_TEST(row_uint16_8x16_32rows, uint16_t, int32_t, 8, 16, 32)
ROW_TEST(row_int8_8x32_32rows, int8_t, int32_t, 8, 32, 32)
ROW_TEST(row_uint8_8x32_32rows, uint8_t, int32_t, 8, 32, 32)
ROW_TEST(row_float_clamp_8x32_8rows, float, int32_t, 8, 32, 8)
ROW_TEST(row_int32_wrap_8x16_8rows, int32_t, int32_t, 8, 16, 8)
ROW_TEST(row_half_zero_8x32_8rows, aclFloat16, int32_t, 8, 32, 8)
ROW_TEST(row_int32_unaligned_3x8_8rows, int32_t, int32_t, 3, 8, 8)
ROW_TEST(row_float_partial_4x16_in_8x16, float, int32_t, 4, 16, 8)
ROW_TEST(row_half_partial_5x32_in_8x32, aclFloat16, int32_t, 5, 32, 8)
ROW_TEST(row_uint8_unaligned_3x32_32rows, uint8_t, int32_t, 3, 32, 8)
ROW_TEST(row_int16_partial_3x16_in_4x16, int16_t, int32_t, 3, 16, 8)
ELEM_TEST(elem_float_64_128size, float, int32_t, 64, 128)
ELEM_TEST(elem_half_64_128size, aclFloat16, int32_t, 64, 128)
ELEM_TEST(elem_bfloat16_64_128size, uint16_t, int32_t, 64, 128)
ELEM_TEST(elem_int32_32_64size, int32_t, int32_t, 32, 64)
ELEM_TEST(elem_uint32_32_64size, uint32_t, int32_t, 32, 64)
ELEM_TEST(elem_int16_32_64size, int16_t, int32_t, 32, 64)
ELEM_TEST(elem_uint16_32_64size, uint16_t, int32_t, 32, 64)
ELEM_TEST(elem_int8_64_128size, int8_t, int32_t, 64, 128)
ELEM_TEST(elem_uint8_64_128size, uint8_t, int32_t, 64, 128)
ELEM_TEST(elem_float_clamp_32_16size, float, int32_t, 32, 16)
ELEM_TEST(elem_int32_wrap_32_16size, int32_t, int32_t, 32, 16)
ELEM_TEST(elem_half_zero_32_16size, aclFloat16, int32_t, 32, 16)
ELEM2D_TEST(elem2d_float_8x32_256size, float, int32_t, 8, 32, 256)
ELEM2D_TEST(elem2d_int32_8x16_256size, int32_t, int32_t, 8, 16, 256)
ELEM2D_TEST(elem2d_half_4x32_256size, aclFloat16, int32_t, 4, 32, 256)
ELEM2D_TEST(elem2d_bfloat16_4x32_256size, uint16_t, int32_t, 4, 32, 256)
ELEM2D_TEST(elem2d_uint8_4x64_256size, uint8_t, int32_t, 4, 64, 256)
ELEM2D_TEST(elem2d_int8_4x64_256size, int8_t, int32_t, 4, 64, 256)
ELEM2D_TEST(elem2d_int16_4x32_256size, int16_t, int32_t, 4, 32, 256)
ELEM2D_TEST(elem2d_uint16_4x32_256size, uint16_t, int32_t, 4, 32, 256)
ELEM2D_TEST(elem2d_uint32_8x16_256size, uint32_t, int32_t, 8, 16, 256)
ELEM2D_TEST(elem2d_float_wrap_4x16_64size, float, int32_t, 4, 16, 64)
ELEM2D_TEST(elem2d_int32_clamp_4x8_32size, int32_t, int32_t, 4, 8, 32)
ELEM2D_TEST(elem2d_half_zero_4x32_64size, aclFloat16, int32_t, 4, 32, 64)
ELEM2D_TEST(elem2d_int32_unaligned_3x3_in_3x8_64size, int32_t, int32_t, 3, 3, 64)
ELEM2D_TEST(elem2d_float_unaligned_5x5_in_5x8_64size, float, int32_t, 5, 5, 64)
ELEM2D_TEST(elem2d_half_unaligned_3x9_in_3x16_64size, aclFloat16, int32_t, 3, 9, 64)
ELEM2D_TEST(elem2d_int8_unaligned_3x17_in_3x32_64size, int8_t, int32_t, 3, 17, 64)
SCALAR_TEST(elem_scalar_float_1x1_in_1x8_8size, float, int32_t, 8)
SCALAR_TEST(elem_scalar_int32_1x1_in_1x8_8size, int32_t, int32_t, 8)
SCALAR_TEST(elem_scalar_half_1x1_in_1x16_16size, aclFloat16, int32_t, 16)
#define ELEM2D_DYN_TEST(NAME, THOST, TIDX, RVR, RVC, RTS) \
TEST_F(MGATHERTest, case_##NAME) \
{ \
run_mgather_test<THOST, TIDX>((size_t)RTS, (size_t)RVR * RVC, (size_t)RVR * RVC, Launch_##NAME); \
}
ELEM2D_DYN_TEST(elem2d_dyn_float_4x8_64size, float, int32_t, 4, 8, 64)
ELEM2D_DYN_TEST(elem2d_dyn_int32_3x3_in_3x8_64size, int32_t, int32_t, 3, 3, 64)
ROW_TEST(row_dyn_int32_3x16_8rows, int32_t, int32_t, 3, 16, 8)
ROW_TEST(row_dyn_half_4x32_16rows, aclFloat16, int32_t, 4, 32, 16)
#define ROW_NZ_TEST(NAME, THOST, TIDX, R, C, BR, BC, C0) \
TEST_F(MGATHERTest, case_##NAME) \
{ \
run_mgather_test<THOST, TIDX>((size_t)BR * 16 * BC * C0, (size_t)R, (size_t)R * C, Launch_##NAME); \
}
#define ELEM2D_NZ_TEST(NAME, THOST, TIDX, R, C, BR, BC, C0) \
TEST_F(MGATHERTest, case_##NAME) \
{ \
run_mgather_test<THOST, TIDX>((size_t)BR * 16 * BC * C0, (size_t)R * C, (size_t)R * C, Launch_##NAME); \
}
ROW_NZ_TEST(row_nz_float_16x16_2blk, float, int32_t, 16, 16, 2, 2, 8)
ROW_NZ_TEST(row_nz_half_32x16_2blk, aclFloat16, int32_t, 32, 16, 2, 1, 16)
ROW_NZ_TEST(row_nz_int32_16x16_2blk, int32_t, int32_t, 16, 16, 2, 2, 8)
ROW_NZ_TEST(row_nz_int16_32x16_1blk, int16_t, int32_t, 32, 16, 2, 1, 16)
ROW_NZ_TEST(row_nz_int8_16x32_1blk, int8_t, int32_t, 16, 32, 2, 1, 32)
ROW_NZ_TEST(row_nz_float_clamp_16x8_1blk, float, int32_t, 16, 8, 2, 1, 8)
ROW_NZ_TEST(row_nz_half_zero_16x16_2blk, aclFloat16, int32_t, 16, 16, 2, 1, 16)
ELEM2D_NZ_TEST(elem2d_nz_float_16x16_2blk, float, int32_t, 16, 16, 2, 2, 8)
ELEM2D_NZ_TEST(elem2d_nz_half_16x16_1blk, aclFloat16, int32_t, 16, 16, 2, 1, 16)
ELEM2D_NZ_TEST(elem2d_nz_int32_16x8_1blk, int32_t, int32_t, 16, 8, 2, 1, 8)
ELEM2D_NZ_TEST(elem2d_nz_half_zero_16x16_1blk, aclFloat16, int32_t, 16, 16, 2, 1, 16)