* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_dynamic_bmm.cpp
* \brief
*/
#include <gtest/gtest.h>
#include "tilefwk/data_type.h"
#include "interface/function/function.h"
#include "tilefwk/tilefwk_op.h"
#include "test_suite_stest_ops.h"
#include "interface/interpreter/raw_tensor_data.h"
#include "test_dev_func_runner.h"
using namespace npu::tile_fwk;
using namespace npu::tile_fwk::dynamic;
namespace {
const size_t BMM_SHAPE_SIZE = 4;
const size_t BMM_VIEW_SHAPE_SIZE = 3;
const size_t BMM_SHAPE_N_DIM = 2;
const size_t BMM_SHAPE_N_IDX = 3;
class DynamicBatchMatmulTest : public npu::tile_fwk::stest::TestSuite_STest_Ops_Aihac {};
template <typename dtype>
Tensor constructMatmulTensor(const std::vector<int64_t>& shape, const string& name, bool isNz)
{
auto dataType = GetAstDtype<dtype>();
return isNz ? Tensor(dataType, shape, name, TileOpFormat::TILEOP_NZ) : Tensor(dataType, shape, name);
}
inline SymbolicScalar CeilDivSymbolicScalar(SymbolicScalar a, int64_t b) { return b == 0 ? a : (a + b - 1) / b; }
template <typename outputDtype, bool transA, bool transB, bool isCNz>
static void NonSplitFunc(const Tensor& tensor_a, const Tensor& tensor_b, Tensor& tensor_c)
{
const auto& aShape = tensor_a.GetShape();
std::vector<SymbolicScalar> aValidShape = {aShape[0], aShape[1], aShape[2]};
const auto& bShape = tensor_b.GetShape();
std::vector<SymbolicScalar> bValidShape = {bShape[0], bShape[1], bShape[2]};
int64_t m = transA ? aShape[2] : aShape[1];
int64_t k = transA ? aShape[1] : aShape[2];
int64_t n = transB ? bShape[1] : bShape[2];
FUNCTION("testNoSplit", {tensor_a, tensor_b}, {tensor_c})
{
LOOP("mLoop", FunctionType::DYNAMIC_LOOP, mIdx, LoopRange(1))
{
Tensor dyn_a = View(tensor_a, aShape, aValidShape, {0, mIdx, 0});
Tensor dyn_b = View(tensor_b, bShape, bValidShape, {0, 0, 0});
TileShape::Current().SetMatrixSize({m, k, n});
tensor_c = Matrix::BatchMatmul(GetAstDtype<outputDtype>(), dyn_a, dyn_b, transA, transB, isCNz);
}
}
}
template <typename outputDtype, bool transA, bool transB, bool isCNz>
static void MSplitFunc(
const std::vector<int64_t>& viewShape, const Tensor& tensor_a, const Tensor& tensor_b, Tensor& tensor_c)
{
const auto& aShape = tensor_a.GetShape();
std::vector<SymbolicScalar> aValidShape = {aShape[0], aShape[1], aShape[2]};
const auto& bShape = tensor_b.GetShape();
std::vector<SymbolicScalar> bValidShape = {bShape[0], bShape[1], bShape[2]};
int64_t n = transB ? bShape[1] : bShape[2];
int64_t m = transA ? aShape[2] : aShape[1];
int64_t k = transA ? aShape[1] : aShape[2];
FUNCTION("testMSplit", {tensor_a, tensor_b}, {tensor_c})
{
LOOP(
"mLoop", FunctionType::DYNAMIC_LOOP, mIdx,
LoopRange(0, CeilDivSymbolicScalar(transA ? aShape[BMM_SHAPE_N_DIM] : aShape[1], viewShape[1]), 1))
{
Tensor dyn_a;
if (transA) {
dyn_a = View(
tensor_a, {aShape[0], aShape[1], viewShape[1]},
{aShape[0], aShape[1], std::min(aShape[2] - viewShape[1] * mIdx, viewShape[1])},
{0, 0, mIdx * viewShape[1]});
} else {
dyn_a = View(
tensor_a, {aShape[0], viewShape[1], aShape[2]},
{aShape[0], std::min(aShape[1] - viewShape[1] * mIdx, viewShape[1]), aShape[2]},
{0, mIdx * viewShape[1], 0});
}
Tensor dyn_b = View(tensor_b, bShape, bValidShape, {0, 0, 0});
TileShape::Current().SetMatrixSize({m, k, n});
Tensor res = Matrix::BatchMatmul(GetAstDtype<outputDtype>(), dyn_a, dyn_b, transA, transB, isCNz);
Assemble(res, {0, mIdx * viewShape[0], 0}, tensor_c);
}
}
}
template <typename outputDtype, bool transA, bool transB, bool isCNz>
static void NSplitFunc(
const std::vector<int64_t>& viewShape, const Tensor& tensor_a, const Tensor& tensor_b, Tensor& tensor_c)
{
const auto& aShape = tensor_a.GetShape();
std::vector<SymbolicScalar> aValidShape = {aShape[0], aShape[1], aShape[2]};
const auto& bShape = tensor_b.GetShape();
std::vector<SymbolicScalar> bValidShape = {bShape[0], bShape[1], bShape[2]};
int64_t m = transA ? aShape[2] : aShape[1];
int64_t k = transA ? aShape[1] : aShape[2];
int64_t n = transB ? bShape[1] : bShape[2];
FUNCTION("testNSplit", {tensor_a, tensor_b}, {tensor_c})
{
LOOP(
"nLoop", FunctionType::DYNAMIC_LOOP, nIdx,
LoopRange(
0, CeilDivSymbolicScalar(transB ? bShape[1] : bShape[BMM_SHAPE_N_DIM], viewShape[BMM_SHAPE_N_DIM]), 1))
{
Tensor dyn_a = View(tensor_a, aShape, aValidShape, {0, 0, 0});
Tensor dyn_b;
if (!transB) {
dyn_b = View(
tensor_b, {bShape[0], bShape[1], viewShape[2]},
{bShape[0], bShape[1], std::min(bShape[2] - viewShape[2] * nIdx, viewShape[2])},
{0, 0, nIdx * viewShape[2]});
} else {
dyn_b = View(
tensor_b, {bShape[0], viewShape[2], bShape[2]},
{bShape[0], std::min(bShape[1] - viewShape[2] * nIdx, viewShape[2]), bShape[2]},
{0, nIdx * viewShape[2], 0});
}
TileShape::Current().SetMatrixSize({m, k, n});
Tensor res = Matrix::BatchMatmul(GetAstDtype<outputDtype>(), dyn_a, dyn_b, transA, transB, isCNz);
Assemble(res, {0, 0, nIdx * viewShape[2]}, tensor_c);
}
}
}
template <typename outputDtype, bool transA, bool transB, bool isCNz>
static void MNSplitFunc(
const std::vector<int64_t>& viewShape, const Tensor& tensor_a, const Tensor& tensor_b, Tensor& tensor_c)
{
const auto& aShape = tensor_a.GetShape();
std::vector<SymbolicScalar> aValidShape = {aShape[0], aShape[1], aShape[2]};
const auto& bShape = tensor_b.GetShape();
std::vector<SymbolicScalar> bValidShape = {bShape[0], bShape[1], bShape[2]};
int64_t m = transA ? aShape[2] : aShape[1];
int64_t n = transB ? bShape[1] : bShape[2];
int64_t k = transA ? aShape[1] : aShape[2];
FUNCTION("testNSplit", {tensor_a, tensor_b}, {tensor_c})
{
LOOP(
"mLoop", FunctionType::DYNAMIC_LOOP, mIdx,
LoopRange(0, CeilDivSymbolicScalar(transA ? aShape[BMM_SHAPE_N_DIM] : aShape[1], viewShape[1]), 1))
{
LOOP(
"nLoop", FunctionType::DYNAMIC_LOOP, nIdx,
LoopRange(
0, CeilDivSymbolicScalar(transB ? bShape[1] : bShape[BMM_SHAPE_N_DIM], viewShape[BMM_SHAPE_N_DIM]),
1))
{
Tensor dyn_a;
if (transA) {
dyn_a = View(
tensor_a, {aShape[0], aShape[1], viewShape[1]},
{aShape[0], aShape[1], std::min(aShape[2] - viewShape[1] * mIdx, viewShape[1])},
{0, 0, mIdx * viewShape[1]});
} else {
dyn_a = View(
tensor_a, {aShape[0], viewShape[1], aShape[2]},
{aShape[0], std::min(aShape[1] - viewShape[1] * mIdx, viewShape[1]), aShape[2]},
{0, mIdx * viewShape[1], 0});
}
Tensor dyn_b;
if (!transB) {
dyn_b = View(
tensor_b, {bShape[0], bShape[1], viewShape[2]},
{bShape[0], bShape[1], std::min(bShape[2] - viewShape[2] * nIdx, viewShape[2])},
{0, 0, nIdx * viewShape[2]});
} else {
dyn_b = View(
tensor_b, {bShape[0], viewShape[2], bShape[2]},
{bShape[0], std::min(bShape[1] - viewShape[2] * nIdx, viewShape[2]), bShape[2]},
{0, nIdx * viewShape[2], 0});
}
TileShape::Current().SetMatrixSize({m, k, n});
Tensor res = Matrix::BatchMatmul(GetAstDtype<outputDtype>(), dyn_a, dyn_b, transA, transB, isCNz);
Assemble(res, {0, mIdx * viewShape[0], nIdx * viewShape[1]}, tensor_c);
}
}
}
}
template <typename inputDtype, typename outputDtype, bool transA, bool transB, bool isCNz>
void TestDynBatchMatmul(
const std::vector<int64_t>& mmShape, bool isANz, bool isBNz, const std::vector<int64_t>& viewShape, string dataPath)
{
SetInterpreterConfig();
if (mmShape.size() != BMM_SHAPE_SIZE || viewShape.size() != BMM_VIEW_SHAPE_SIZE) {
return;
}
int64_t b = mmShape[0];
int64_t m = mmShape[1];
int64_t k = mmShape[2];
int64_t n = mmShape[BMM_SHAPE_N_IDX];
Tensor tensor_a = transA ? constructMatmulTensor<inputDtype>({b, k, m}, "tensor_a", isANz) :
constructMatmulTensor<inputDtype>({b, m, k}, "tensor_a", isANz);
Tensor tensor_b = transB ? constructMatmulTensor<inputDtype>({b, n, k}, "tensor_b", isBNz) :
constructMatmulTensor<inputDtype>({b, k, n}, "tensor_b", isBNz);
Tensor tensor_c = constructMatmulTensor<outputDtype>({b, m, n}, "tensor_c", isCNz);
std::vector<inputDtype> aData(b * m * k, 0);
std::vector<inputDtype> bData(b * k * n, 0);
std::vector<outputDtype> golden(b * m * n, 0);
readInput<outputDtype>(dataPath + "/mat_c.bin", golden);
readInput<inputDtype>(dataPath + "/mat_a.bin", aData);
readInput<inputDtype>(dataPath + "/mat_b.bin", bData);
ProgramData::GetInstance().AppendOutputs({
RawTensorData::CreateConstantTensor<outputDtype>(tensor_c, 0.0f),
});
ProgramData::GetInstance().AppendInputs({
RawTensorData::CreateTensor<inputDtype>(tensor_a, aData),
RawTensorData::CreateTensor<inputDtype>(tensor_b, bData),
});
ProgramData::GetInstance().AppendGoldens({
RawTensorData::CreateTensor<outputDtype>(tensor_c, golden),
});
int64_t viewM = viewShape[1];
int64_t viewN = viewShape[2];
if (viewM > 0 && viewN > 0) {
MNSplitFunc<outputDtype, transA, transB, isCNz>(viewShape, tensor_a, tensor_b, tensor_c);
} else if (viewM > 0) {
MSplitFunc<outputDtype, transA, transB, isCNz>(viewShape, tensor_a, tensor_b, tensor_c);
} else if (viewN > 0) {
NSplitFunc<outputDtype, transA, transB, isCNz>(viewShape, tensor_a, tensor_b, tensor_c);
} else {
NonSplitFunc<outputDtype, transA, transB, isCNz>(tensor_a, tensor_b, tensor_c);
}
DevFuncRunner::Run(Program::GetInstance().GetLastFunction());
auto outs = npu::tile_fwk::ProgramData::GetInstance().GetOutputData(0);
EXPECT_TRUE(resultCmp(golden, (outputDtype*)outs->data(), 0.001f));
}
TEST_F(DynamicBatchMatmulTest, test_bmm_A_Bt_ND_fp16)
{
TileShape::Current().SetCubeTile({128, 128}, {128, 128}, {128, 128});
int64_t b = 3;
int64_t m = 2;
int64_t k = 576;
int64_t n = 4096;
bool isANz = false;
bool isBNz = false;
std::vector<int64_t> viewShape = {-1, -1, -1};
TestDynBatchMatmul<npu::tile_fwk::float16, float, false, true, false>(
{b, m, k, n}, isANz, isBNz, viewShape, GetGoldenDir());
}
TEST_F(DynamicBatchMatmulTest, test_bmm_A_Bt_NZ_fp16)
{
TileShape::Current().SetCubeTile({128, 128}, {128, 128}, {128, 128});
int64_t b = 4;
int64_t m = 96;
int64_t k = 128;
int64_t n = 256;
bool isANz = false;
bool isBNz = true;
std::vector<int64_t> viewShape = {-1, -1, -1};
TestDynBatchMatmul<npu::tile_fwk::float16, float, false, true, false>(
{b, m, k, n}, isANz, isBNz, viewShape, GetGoldenDir());
}
TEST_F(DynamicBatchMatmulTest, test_bmm_A_B_ND_bf16_tile1)
{
TileShape::Current().SetCubeTile({128, 128}, {256, 256}, {128, 128});
int64_t b = 6;
int64_t m = 1;
int64_t k = 576;
int64_t n = 4096;
bool isANz = false;
bool isBNz = false;
std::vector<int64_t> viewShape = {-1, -1, -1};
TestDynBatchMatmul<npu::tile_fwk::bfloat16, float, false, false, false>(
{b, m, k, n}, isANz, isBNz, viewShape, GetGoldenDir());
}
TEST_F(DynamicBatchMatmulTest, test_bmm_At_Bt_ND_fp16)
{
TileShape::Current().SetCubeTile({128, 128}, {128, 128}, {128, 128});
int64_t b = 3;
int64_t m = 2;
int64_t k = 576;
int64_t n = 4096;
bool isANz = false;
bool isBNz = false;
std::vector<int64_t> viewShape = {-1, -1, -1};
TestDynBatchMatmul<npu::tile_fwk::float16, float, true, true, false>(
{b, m, k, n}, isANz, isBNz, viewShape, GetGoldenDir());
}
TEST_F(DynamicBatchMatmulTest, test_bmm_At_Bt_ANZ_BND_fp16)
{
TileShape::Current().SetCubeTile({128, 128}, {128, 128}, {128, 128});
int64_t b = 3;
int64_t m = 128;
int64_t k = 256;
int64_t n = 512;
bool isANz = true;
bool isBNz = false;
std::vector<int64_t> viewShape = {-1, -1, -1};
TestDynBatchMatmul<npu::tile_fwk::float16, float, true, true, false>(
{b, m, k, n}, isANz, isBNz, viewShape, GetGoldenDir());
}
}