* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_dynamic_post.cpp
* \brief
*/
#include "interface/interpreter/raw_tensor_data.h"
#include "interface/tensor/float.h"
#include "test_dev_func_runner.h"
#include "test_data_prepare.h"
#include "test_suite_stest_ops.h"
#include "operator/models/nsa/attention_post.h"
#include "tilefwk/tensor.h"
using namespace npu::tile_fwk;
using namespace npu::tile_fwk::dynamic;
class AttentionPostSTest : public npu::tile_fwk::stest::TestSuite_STest_Ops_Aihac {};
namespace {
struct TestPostParams {
int b;
int n;
int s;
int h;
int kvLoraRank;
int vHeadDim;
};
template <typename T>
static std::vector<T> getGoldenVec(std::vector<int64_t> shape, std::string fileName)
{
int capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
std::vector<T> golden(capacity, 0);
readInput<T>(GetGoldenDir() + fileName, golden);
return golden;
}
template <
typename T = npu::tile_fwk::float16, bool nz = true, typename wUvDType = int8_t, bool isSmoothWUv = false,
typename wODType = int8_t, bool isSmoothWo = false>
void TestAttentionPost(const TestPostParams& params, const PostTileConfig& tileConfig, float precision)
{
SetInterpreterConfig();
int b = params.b;
int n = params.n;
int s = params.s;
int h = params.h;
int kvLoraRank = params.kvLoraRank;
int vHeadDim = params.vHeadDim;
DataType dType = (std::is_same<T, npu::tile_fwk::float16>::value) ? DT_FP16 : DT_BF16;
bool isQuantWUv = std::is_same<wUvDType, int8_t>::value;
bool isQuantWo = std::is_same<wODType, int8_t>::value;
std::vector<int64_t> xShape = {b, s, n, kvLoraRank};
std::vector<int64_t> wUvShape = {n, kvLoraRank, vHeadDim};
std::vector<int64_t> wUvScaleShape = {n, 1, vHeadDim};
std::vector<int64_t> smoothWUvShape = {1, kvLoraRank};
std::vector<int64_t> woShape = {n * vHeadDim, h};
std::vector<int64_t> woScaleShape = {1, h};
std::vector<int64_t> smoothWoShape = {1, n * vHeadDim};
std::vector<int64_t> outShape = {b, s, h};
TileOpFormat weightFormat = nz ? TileOpFormat::TILEOP_NZ : TileOpFormat::TILEOP_ND;
Tensor x(dType, xShape, "x");
Tensor wUv(isQuantWUv ? DT_INT8 : dType, wUvShape, "wUv");
Tensor wo(isQuantWo ? DT_INT8 : dType, woShape, "wo", weightFormat);
Tensor postOut(dType, outShape, "postOut");
std::vector<T> goldenDate = getGoldenVec<T>(outShape, "/golden_output.bin");
auto xData = CreateTensorData<T>(x, "/x.bin");
auto wUvData = CreateTensorData<wUvDType>(wUv, "/w_uv.bin");
auto woData = CreateTensorData<wODType>(wo, "/w_o.bin");
auto outputData = RawTensorData::CreateConstantTensor<T>(postOut, 0.0);
std::vector<RawTensorDataPtr> outputDataList = {outputData};
std::vector<RawTensorDataPtr> inputDataList = {xData, wUvData, woData};
QuantTensorWithData wUvQuant{isQuantWUv, isSmoothWUv, wUvScaleShape, smoothWUvShape,
"wUvScale", "smoothWUv", "/w_uv_scale.bin", "/smooth_w_uv.bin"};
CreateQuantTensorAndData(wUvQuant);
inputDataList.emplace_back(wUvQuant.scale.dataPtr);
inputDataList.emplace_back(wUvQuant.smooth.dataPtr);
QuantTensorWithData wOQuant{isQuantWo, isSmoothWo, woScaleShape, smoothWoShape,
"woScale", "smoothWo", "/w_o_scale.bin", "/smooth_w_o.bin"};
CreateQuantTensorAndData(wOQuant);
inputDataList.emplace_back(wOQuant.scale.dataPtr);
inputDataList.emplace_back(wOQuant.smooth.dataPtr);
ProgramData::GetInstance().AppendInputs({inputDataList});
ProgramData::GetInstance().AppendOutputs({outputDataList});
ProgramData::GetInstance().AppendGoldens({
RawTensorData::CreateTensor<T>(postOut, goldenDate),
});
PostTensors postTensors{
wUv, wo, wUvQuant.scale.tensor, wUvQuant.smooth.tensor, wOQuant.scale.tensor, wOQuant.smooth.tensor};
AttentionPostStandalone(x, postTensors, tileConfig, postOut);
#ifdef BUILD_WITH_CANN
DevFuncRunner::Run(Program::GetInstance().GetLastFunction(), inputDataList, outputDataList);
std::cout << "postOut ====== " << std::endl;
EXPECT_TRUE(resultCmp<T>(goldenDate, (T*)outputData->data(), precision));
resultCmp<T>(goldenDate, (T*)outputData->data(), precision, int(b * s * h * precision), 1000, false, false, 16);
#endif
}
void PerformanceConfig()
{
config::SetPassOption(CUBE_NBUFFER_SETTING, std::map<int64_t, int64_t>{{0, 4}});
}
TEST_F(AttentionPostSTest, b16_s1_nz_fp16_quant)
{
TestPostParams params = {16, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {16, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b16_s2_nz_fp16_quant)
{
TestPostParams params = {16, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {16, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b32_s1_nz_fp16_quant)
{
TestPostParams params = {32, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b32_s2_nz_fp16_quant)
{
TestPostParams params = {32, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b64_s1_nz_fp16_quant)
{
TestPostParams params = {64, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {64, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b64_s2_nz_fp16_quant)
{
TestPostParams params = {64, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b24_s1_nz_fp16_quant)
{
TestPostParams params = {24, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {24, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b24_s2_nz_fp16_quant)
{
TestPostParams params = {24, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {24, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b48_s1_nz_fp16_quant)
{
TestPostParams params = {48, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {48, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b48_s2_nz_fp16_quant)
{
TestPostParams params = {48, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {48, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b96_s1_nz_fp16_quant)
{
TestPostParams params = {96, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b96_s2_nz_fp16_quant)
{
TestPostParams params = {96, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b16_s1_nz_bf16_quant)
{
TestPostParams params = {16, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {16, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b16_s2_nz_bf16_quant)
{
TestPostParams params = {16, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {16, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b32_s1_nz_bf16_quant)
{
TestPostParams params = {32, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b32_s2_nz_bf16_quant)
{
TestPostParams params = {32, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b64_s1_nz_bf16_quant)
{
TestPostParams params = {64, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {64, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b64_s2_nz_bf16_quant)
{
TestPostParams params = {64, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b24_s1_nz_bf16_quant)
{
TestPostParams params = {24, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {24, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b24_s2_nz_bf16_quant)
{
TestPostParams params = {24, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {24, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b48_s1_nz_bf16_quant)
{
TestPostParams params = {48, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {48, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b48_s2_nz_bf16_quant)
{
TestPostParams params = {48, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {48, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b96_s1_nz_bf16_quant)
{
TestPostParams params = {96, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b96_s2_nz_bf16_quant)
{
TestPostParams params = {96, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, npu::tile_fwk::bfloat16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b16_s1_nd_fp16_quant)
{
TestPostParams params = {16, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {16, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, false, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b16_s2_nd_fp16_quant)
{
TestPostParams params = {16, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {16, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, false, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b32_s1_nd_fp16_quant)
{
TestPostParams params = {32, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, false, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b32_s2_nd_fp16_quant)
{
TestPostParams params = {32, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, false, npu::tile_fwk::float16, false, int8_t, true>(
params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b32_s1_nz_fp16)
{
TestPostParams params = {32, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, npu::tile_fwk::float16, false>(
params, tileConfig, 0.002f);
}
TEST_F(AttentionPostSTest, b32_s2_nz_fp16)
{
TestPostParams params = {32, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, npu::tile_fwk::float16, false, npu::tile_fwk::float16, false>(
params, tileConfig, 0.002f);
}
TEST_F(AttentionPostSTest, b16_s1_nd_fp16)
{
TestPostParams params = {16, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {16, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, false, npu::tile_fwk::float16, false, npu::tile_fwk::float16, false>(
params, tileConfig, 0.002f);
}
TEST_F(AttentionPostSTest, b16_s2_nd_fp16)
{
TestPostParams params = {16, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {16, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, false, npu::tile_fwk::float16, false, npu::tile_fwk::float16, false>(
params, tileConfig, 0.002f);
}
TEST_F(AttentionPostSTest, b32_s1_nd_fp16)
{
TestPostParams params = {32, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, false, npu::tile_fwk::float16, false, npu::tile_fwk::float16, false>(
params, tileConfig, 0.002f);
}
TEST_F(AttentionPostSTest, b32_s2_nd_fp16)
{
TestPostParams params = {32, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, false, npu::tile_fwk::float16, false, npu::tile_fwk::float16, false>(
params, tileConfig, 0.002f);
}
TEST_F(AttentionPostSTest, b16_s1_nz_fp16_quant_all)
{
TestPostParams params = {16, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {16, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, int8_t, true, int8_t, true>(params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b32_s2_nz_bf16_quant_all)
{
TestPostParams params = {32, 128, 2, 7168, 512, 128};
PostTileConfig tileConfig = {32, 2};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::bfloat16, true, int8_t, true, int8_t, true>(params, tileConfig, 0.006f);
}
TEST_F(AttentionPostSTest, b48_s1_nz_fp16_quant_all)
{
TestPostParams params = {48, 128, 1, 7168, 512, 128};
PostTileConfig tileConfig = {48, 1};
PerformanceConfig();
TestAttentionPost<npu::tile_fwk::float16, true, int8_t, true, int8_t, true>(params, tileConfig, 0.006f);
}
}