* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstring>
#include <vector>
#include "verify.h"
#include "blas_test.h"
#include "csv_loader.h"
#include "gemv_batched_param.h"
#include "gemv_batched_golden.h"
#include "gemv_batched_npu_wrapper.h"
static uint16_t FloatToHalf(float val)
{
uint32_t f; memcpy(&f, &val, sizeof(f));
uint32_t sign = (f >> 16) & 0x8000;
int32_t exp = ((f >> 23) & 0xFF) - 127 + 15;
uint32_t mant = (f >> 13) & 0x3FF;
if (exp <= 0)
return static_cast<uint16_t>(sign);
if (exp >= 31)
return static_cast<uint16_t>(sign | 0x7C00);
return static_cast<uint16_t>(sign | (exp << 10) | mant);
}
static float HalfToFloat(uint16_t h)
{
uint32_t sign = (h >> 15) & 1;
uint32_t exp = (h >> 10) & 0x1F;
uint32_t mant = h & 0x3FF;
uint32_t f;
if (exp == 0) {
f = (sign << 31) | (mant << 13);
} else if (exp == 31) {
f = (sign << 31) | 0x7F800000u | (mant << 13);
} else {
f = (sign << 31) | ((exp + 127 - 15) << 23) | (mant << 13);
}
float result; memcpy(&result, &f, sizeof(result));
return result;
}
static uint16_t FloatToBfloat(float val)
{
uint32_t f; memcpy(&f, &val, sizeof(f));
return static_cast<uint16_t>(f >> 16);
}
static float BfloatToFloat(uint16_t b)
{
uint32_t f = static_cast<uint32_t>(b) << 16;
float r; memcpy(&r, &f, sizeof(r));
return r;
}
static void QuantizeRoundTripWithHalf(const std::vector<float>& src, std::vector<float>& dst)
{
dst.resize(src.size());
for (size_t i = 0; i < src.size(); i++)
dst[i] = HalfToFloat(FloatToHalf(src[i]));
}
static void QuantizeRoundTripWithBf16(const std::vector<float>& src, std::vector<float>& dst)
{
dst.resize(src.size());
for (size_t i = 0; i < src.size(); i++)
dst[i] = BfloatToFloat(FloatToBfloat(src[i]));
}
static void QuantizeToHalf(const std::vector<float>& src, std::vector<uint16_t>& dst)
{
dst.resize(src.size());
for (size_t i = 0; i < src.size(); i++)
dst[i] = FloatToHalf(src[i]);
}
static void QuantizeToBf16(const std::vector<float>& src, std::vector<uint16_t>& dst)
{
dst.resize(src.size());
for (size_t i = 0; i < src.size(); i++)
dst[i] = FloatToBfloat(src[i]);
}
static void ConvertHalfToFloat(const std::vector<uint16_t>& src, std::vector<float>& dst)
{
dst.resize(src.size());
for (size_t i = 0; i < src.size(); i++)
dst[i] = HalfToFloat(src[i]);
}
static void ConvertBf16ToFloat(const std::vector<uint16_t>& src, std::vector<float>& dst)
{
dst.resize(src.size());
for (size_t i = 0; i < src.size(); i++)
dst[i] = BfloatToFloat(src[i]);
}
class GemvBatchedArch35Test : public BlasTest<GemvBatchedParam> {};
TEST_F(GemvBatchedArch35Test, NullHandle)
{
float alpha = 1.0f, beta = 0.0f;
std::vector<float> af(64, 1.0f), xf(8, 1.0f), yf(8, 0.0f);
aclblasStatus_t ret =
aclblasGemvBatchedS_npu(nullptr, ACLBLAS_OP_N, 8, 8, &alpha, af.data(), 8, xf.data(), 1, &beta, yf.data(), 1, 1);
EXPECT_EQ(static_cast<int>(ret), static_cast<int>(ACLBLAS_STATUS_HANDLE_IS_NULLPTR));
}
INSTANTIATE_TEST_SUITE_P(
GemvBatched, GemvBatchedArch35Test,
::testing::ValuesIn(GetCasesFromCsv<GemvBatchedParam>(ReplaceFileExtension2Csv(__FILE__))),
PrintCaseInfoString<GemvBatchedParam>);
TEST_P(GemvBatchedArch35Test, CsvDriven)
{
const auto& p = GetParam();
int safeM = std::max(1, std::abs(p.m));
int safeN = std::max(1, std::abs(p.n));
int safeBatchCount = std::max(1, p.batchCount);
int safeLda = std::max(safeM, p.lda);
const bool isTransN = (p.trans == ACLBLAS_OP_N);
const int xCount = isTransN ? safeN : safeM;
const int yCount = isTransN ? safeM : safeN;
const size_t xStride = static_cast<size_t>((xCount - 1) * std::abs(p.incx) + 1);
const size_t yStride = static_cast<size_t>((yCount - 1) * std::abs(p.incy) + 1);
auto aFloat = makeBlasArray(
static_cast<int64_t>(safeBatchCount) * safeLda * safeN, p.a,
p.randomSeed);
std::vector<float> xFloat(safeBatchCount * xStride);
std::vector<float> yFloat(safeBatchCount * yStride);
for (int b = 0; b < safeBatchCount; b++) {
if (p.x.method != BlasFillMode::M_NULLPTR) {
auto xBatch = makeBlasStrided(xCount, p.incx, p.x, p.randomSeed + 1 + b);
for (size_t i = 0; i < xStride; i++)
xFloat[b * xStride + i] = xBatch[i];
}
if (p.y.method != BlasFillMode::M_NULLPTR) {
auto yBatch = makeBlasStrided(yCount, p.incy, p.y, p.randomSeed + 2 + b);
for (size_t i = 0; i < yStride; i++)
yFloat[b * yStride + i] = yBatch[i];
}
}
const float* alphaPtr = (p.alphaFill.method == BlasFillMode::M_NULLPTR) ? nullptr : &p.alpha;
const float* betaPtr = (p.betaFill.method == BlasFillMode::M_NULLPTR) ? nullptr : &p.beta;
const float* aPtr = (p.a.method == BlasFillMode::M_NULLPTR) ? nullptr : aFloat.data();
const float* xPtr = (p.x.method == BlasFillMode::M_NULLPTR) ? nullptr : xFloat.data();
float* yErrPtr = (p.y.method == BlasFillMode::M_NULLPTR) ? nullptr : yFloat.data();
if (p.expectResult != ACLBLAS_STATUS_SUCCESS) {
std::vector<float> yNpu(yFloat);
aclblasStatus_t ret = aclblasGemvBatchedS_npu(
GemvBatchedArch35Test::handle_, p.trans, p.m, p.n, alphaPtr,
aPtr, p.lda, xPtr, p.incx, betaPtr, yErrPtr, p.incy, p.batchCount);
EXPECT_EQ(static_cast<int>(ret), static_cast<int>(p.expectResult));
return;
}
aclblasStatus_t ret = ACLBLAS_STATUS_SUCCESS;
std::vector<float> npuFloat(p.batchCount * yStride);
std::vector<float> goldenFloat(p.batchCount * yStride);
std::vector<float> aGolden;
std::vector<float> xGolden;
switch (p.dtype) {
case 1: {
aGolden = aFloat;
xGolden = xFloat;
npuFloat = yFloat;
ret = aclblasGemvBatchedS_npu(
GemvBatchedArch35Test::handle_, p.trans, p.m, p.n, alphaPtr,
aFloat.data(), p.lda, xFloat.data(), p.incx, betaPtr,
npuFloat.data(), p.incy, p.batchCount);
goldenFloat = yFloat;
break;
}
case 0: {
std::vector<uint16_t> aHalf, xHalf, yHalf;
QuantizeToHalf(aFloat, aHalf);
QuantizeToHalf(xFloat, xHalf);
QuantizeToHalf(yFloat, yHalf);
QuantizeRoundTripWithHalf(aFloat, aGolden);
QuantizeRoundTripWithHalf(xFloat, xGolden);
std::vector<float> yGoldenTmp;
QuantizeRoundTripWithHalf(yFloat, yGoldenTmp);
std::vector<uint16_t> yNpu = yHalf;
ret = aclblasGemvBatchedHSH_npu(
GemvBatchedArch35Test::handle_, p.trans, p.m, p.n, alphaPtr,
aHalf.data(), p.lda, xHalf.data(), p.incx, betaPtr,
yNpu.data(), p.incy, p.batchCount);
ConvertHalfToFloat(yNpu, npuFloat);
goldenFloat = yGoldenTmp;
break;
}
case 2: {
std::vector<uint16_t> aHalf, xHalf;
QuantizeToHalf(aFloat, aHalf);
QuantizeToHalf(xFloat, xHalf);
QuantizeRoundTripWithHalf(aFloat, aGolden);
QuantizeRoundTripWithHalf(xFloat, xGolden);
goldenFloat = yFloat;
npuFloat = yFloat;
ret = aclblasGemvBatchedHSS_npu(
GemvBatchedArch35Test::handle_, p.trans, p.m, p.n, alphaPtr,
aHalf.data(), p.lda, xHalf.data(), p.incx, betaPtr,
npuFloat.data(), p.incy, p.batchCount);
break;
}
case 3: {
std::vector<uint16_t> aBf16, xBf16, yBf16;
QuantizeToBf16(aFloat, aBf16);
QuantizeToBf16(xFloat, xBf16);
QuantizeToBf16(yFloat, yBf16);
QuantizeRoundTripWithBf16(aFloat, aGolden);
QuantizeRoundTripWithBf16(xFloat, xGolden);
std::vector<float> yGoldenTmp;
QuantizeRoundTripWithBf16(yFloat, yGoldenTmp);
std::vector<uint16_t> yNpu = yBf16;
ret = aclblasGemvBatchedTST_npu(
GemvBatchedArch35Test::handle_, p.trans, p.m, p.n, alphaPtr,
aBf16.data(), p.lda, xBf16.data(), p.incx, betaPtr,
yNpu.data(), p.incy, p.batchCount);
ConvertBf16ToFloat(yNpu, npuFloat);
goldenFloat = yGoldenTmp;
break;
}
case 4: {
std::vector<uint16_t> aBf16, xBf16;
QuantizeToBf16(aFloat, aBf16);
QuantizeToBf16(xFloat, xBf16);
QuantizeRoundTripWithBf16(aFloat, aGolden);
QuantizeRoundTripWithBf16(xFloat, xGolden);
goldenFloat = yFloat;
npuFloat = yFloat;
ret = aclblasGemvBatchedTSS_npu(
GemvBatchedArch35Test::handle_, p.trans, p.m, p.n, alphaPtr,
aBf16.data(), p.lda, xBf16.data(), p.incx, betaPtr,
npuFloat.data(), p.incy, p.batchCount);
break;
}
default:
ASSERT_TRUE(false) << "Unknown dtype: " << p.dtype;
return;
}
ASSERT_EQ(ret, ACLBLAS_STATUS_SUCCESS);
aclblasGemvBatched_cpu(
GemvBatchedArch35Test::handle_, p.trans, p.m, p.n, alphaPtr, aGolden.data(), p.lda,
xGolden.data(), p.incx, betaPtr, goldenFloat.data(), p.incy, p.batchCount);
if (yCount == 0 || p.batchCount == 0 || p.mereThreshold <= 0.0) {
return;
}
const int absIncy = std::abs(p.incy);
std::vector<float> npuLogical(p.batchCount * yCount);
std::vector<float> cpuLogical(p.batchCount * yCount);
for (int b = 0; b < p.batchCount; b++) {
for (int i = 0; i < yCount; i++) {
int yIdx = (p.incy > 0) ? (i * p.incy) : ((yCount - 1 - i) * absIncy);
npuLogical[b * yCount + i] = npuFloat[b * yStride + yIdx];
cpuLogical[b * yCount + i] = goldenFloat[b * yStride + yIdx];
}
}
VerifyConfig cfg;
cfg.mode = PrecisionMode::MERE_MARE;
cfg.mereThreshold = p.mereThreshold;
cfg.mareMultiplier = p.mareMultiplier;
EXPECT_TRUE(Verifier::verifyVector(
npuLogical.data(), cpuLogical.data(), static_cast<size_t>(p.batchCount) * yCount, 1, cfg, p.caseName));
}