* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* @file test_common.h
* @brief 公共测试框架 - 提供测试统计、断言宏和 ACL 初始化功能
*/
#ifndef CANN_OPS_TENSOR_TEST_COMMON_H
#define CANN_OPS_TENSOR_TEST_COMMON_H
#include <iostream>
#include <string>
#include <vector>
#include <cstdlib>
#include "cann_ops_tensor_types.h"
#include "cann_ops_tensor.h"
#include "utils/type_utils.hpp"
#include "acl/acl.h"
namespace OpsTensorTest {
struct TestStats {
int total = 0;
int passed = 0;
int failed = 0;
void print(const std::string& name = "") const {
if (!name.empty()) {
std::cout << name << ": ";
}
std::cout << "总测试数=" << total << ", 通过=" << passed << ", 失败=" << failed << std::endl;
}
};
extern TestStats g_global_stats;
#define TEST_CASE_BEGIN(test_name) \
do { \
std::cout << "[RUN] " << test_name << "..." << std::endl; \
} while(0)
#define TEST_CASE_PASS(local_stats, test_name) \
do { \
(local_stats).total++; \
(local_stats).passed++; \
OpsTensorTest::g_global_stats.total++; \
OpsTensorTest::g_global_stats.passed++; \
std::cout << "[PASS] " << test_name << std::endl; \
} while(0)
#define TEST_ASSERT(local_stats, condition, error_msg) \
do { \
if (!(condition)) { \
(local_stats).total++; \
(local_stats).failed++; \
OpsTensorTest::g_global_stats.total++; \
OpsTensorTest::g_global_stats.failed++; \
std::cerr << " [ERROR] " << error_msg << std::endl; \
std::exit(1); \
} \
} while(0)
#define TEST_ASSERT_ARRAY_EQ(local_stats, actual, expected, length, error_msg) \
do { \
bool all_match = std::equal((expected), (expected) + (length), (actual)); \
if (!all_match) { \
auto iter = std::mismatch((expected), (expected) + (length), (actual)); \
size_t mismatch_idx = std::distance((expected), iter.first); \
std::cerr << " [ERROR] " << error_msg << " at index " << mismatch_idx << std::endl; \
} \
TEST_ASSERT((local_stats), all_match, error_msg); \
} while(0)
#define TEST_ASSERT_ARRAY_NEAR(local_stats, actual, expected, length, tol, error_msg) \
do { \
bool all_match = true; \
size_t first_mismatch = 0; \
if ((actual).size() != (expected).size()) { \
all_match = false; \
std::cerr << " [ERROR] " << error_msg << ": actual.size() (" << (actual).size() \
<< ") != expected.size() (" << (expected).size() << ")" << std::endl; \
} else { \
for (size_t _i = 0; _i < static_cast<size_t>(length); ++_i) { \
if (std::abs((actual)[_i] - (expected)[_i]) > (tol)) { \
all_match = false; \
first_mismatch = _i; \
break; \
} \
} \
if (!all_match) { \
std::cerr << " [ERROR] " << error_msg << " at index " << first_mismatch \
<< ": actual=" << (actual)[first_mismatch] \
<< ", expected=" << (expected)[first_mismatch] << std::endl; \
} \
} \
TEST_ASSERT((local_stats), all_match, error_msg); \
} while(0)
#define TEST_PRINT_HEADER(op_name) \
do { \
std::cout << "========================================" << std::endl; \
std::cout << " " << op_name << "算子单元测试" << std::endl; \
std::cout << "========================================" << std::endl; \
std::cout << std::endl; \
} while(0)
#define TEST_PRINT_RESULT(op_name, local_stats) \
do { \
std::cout << std::endl; \
std::cout << "========================================" << std::endl; \
std::cout << " " << op_name << "算子测试结果" << std::endl; \
std::cout << "========================================" << std::endl; \
(local_stats).print(#op_name); \
std::cout << "========================================" << std::endl; \
} while(0)
#define TEST_PRINT_RESULT_NAME(name, local_stats) \
do { \
std::cout << std::endl; \
(local_stats).print(name); \
} while(0)
#define TEST_PRINT_GLOBAL_RESULT() \
do { \
std::cout << std::endl; \
std::cout << "========================================" << std::endl; \
std::cout << " 全局测试结果" << std::endl; \
std::cout << "========================================" << std::endl; \
OpsTensorTest::g_global_stats.print("全部算子"); \
std::cout << "========================================" << std::endl; \
} while(0)
class ACLManager {
public:
static int init(aclrtStream& stream);
static void finalize(aclrtStream& stream);
};
* Elementwise 测试接口声明
*============================================================================*/
* @brief Elementwise 操作配置基类
* 包含所有 Elementwise 操作的公共配置
*/
struct ElementwiseTestConfigBase {
std::vector<int64_t> dimensions;
acltensorDataType_t dataType = ACLTENSOR_R_32F;
ElementwiseTestConfigBase(const std::vector<int64_t>& dims, acltensorDataType_t dt)
: dimensions(dims), dataType(dt) {}
int64_t numElements() const {
if (dimensions.empty()) {
return 0;
}
int64_t total = 1;
for (auto dim : dimensions) {
total *= dim;
}
return total;
}
uint32_t numModes() const {
return static_cast<uint32_t>(dimensions.size());
}
};
* @brief Elementwise Binary 操作测试配置(2个输入:A、C)
*/
struct ElementwiseBinaryTestConfig : public ElementwiseTestConfigBase {
std::vector<int32_t> modeA;
std::vector<int32_t> modeC;
std::vector<int32_t> modeD;
ElementwiseBinaryTestConfig(int64_t size, acltensorDataType_t dt = ACLTENSOR_R_32F)
: ElementwiseTestConfigBase(std::vector<int64_t>{size}, dt)
, modeA(1, 0)
, modeC(1, 0)
, modeD(1, 0) {}
ElementwiseBinaryTestConfig(
const std::vector<int64_t>& dims,
const std::vector<int32_t>& ma,
const std::vector<int32_t>& mc,
const std::vector<int32_t>& md,
acltensorDataType_t dt = ACLTENSOR_R_32F)
: ElementwiseTestConfigBase(dims, dt)
, modeA(ma)
, modeC(mc)
, modeD(md) {}
};
* @brief Elementwise Trinary 操作测试配置(3个输入:A、B、C)
* 未来实现:D = op2(op1(A, B), C)
*/
struct ElementwiseTrinaryTestConfig : public ElementwiseTestConfigBase {
std::vector<int32_t> modeA;
std::vector<int32_t> modeB;
std::vector<int32_t> modeC;
std::vector<int32_t> modeD;
};
* @brief 执行 Elementwise Binary 操作(通用测试函数)
*
* 支持的操作类型:ADD, SUB, MUL, DIV
* 支持任意维度、任意数据类型
*
* @param h_A 输入张量 A(主机内存)
* @param h_C 输入张量 C(主机内存)
* @param h_D 输出张量 D(主机内存)
* @param config 操作配置(包含维度、mode、数据类型等)
* @param opType 操作类型 (ACLTENSOR_OP_ADD/SUB/MUL/DIV)
* @param stream ACL 流
* @return ACLTENSOR_STATUS_SUCCESS 成功,否则返回错误码
*
* 使用示例:
* @code
* // 简单一维场景
* float A[] = {1.0f, 2.0f, 3.0f};
* float C[] = {4.0f, 5.0f, 6.0f};
* float D[3];
*
* ElementwiseBinaryTestConfig config(3); // 一维,大小为3
* ExecuteElementwiseBinaryTest(A, C, D, config, ACLTENSOR_OP_ADD, stream);
* @endcode
*/
acltensorStatus_t ExecuteElementwiseBinaryTest(
const void* h_A,
const void* h_C,
void* h_D,
const ElementwiseBinaryTestConfig& config,
acltensorOperator_t opType,
aclrtStream stream);
using TestFunc = void(*)(aclrtStream, TestStats&);
struct TestRegistry {
struct TestEntry {
const char* name;
TestFunc func;
};
static std::vector<TestEntry>& get_tests() {
static std::vector<TestEntry> tests;
return tests;
}
static void register_test(const char* name, TestFunc func) {
get_tests().push_back({name, func});
}
};
#define REGISTER_OP_TEST(op_name) \
namespace op_name##Test { \
void run_all_tests(aclrtStream, OpsTensorTest::TestStats&); \
namespace { \
struct Registrar { \
Registrar() { \
OpsTensorTest::TestRegistry::register_test(#op_name, run_all_tests); \
} \
}; \
static Registrar registrar; \
} \
}
}
#endif