/**
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file assert.asc
* \brief 使用静态Tensor编程模式实现矩阵乘法,展示ascendc_assert接口的基本使用方法
*/
#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"
// half type, cube block: [16, 16]
constexpr uint32_t CUBE_BLOCK = 16;
template <uint32_t M, uint32_t K, uint32_t N, uint32_t singleCoreM, uint32_t baseM, uint32_t baseK, uint32_t baseN>
__global__ __cube__ void mmad_custom(__gm__ uint8_t* a, __gm__ uint8_t* b, __gm__ uint8_t* c)
{
AscendC::GlobalTensor<half> aGM;
AscendC::GlobalTensor<half> bGM;
AscendC::GlobalTensor<half> cGM;
uint32_t mIterIdx = AscendC::GetBlockIdx() % (M / singleCoreM);
aGM.SetGlobalBuffer((__gm__ half*)a + mIterIdx * singleCoreM * K);
bGM.SetGlobalBuffer((__gm__ half*)b);
cGM.SetGlobalBuffer((__gm__ half*)c + mIterIdx * singleCoreM * N);
// 断言检查:验证blockIdx在有效范围内
ascendc_assert(
AscendC::GetBlockIdx() < M / singleCoreM, "BlockIdx %u is out of valid range.\n", AscendC::GetBlockIdx());
// 断言检查:验证参数对齐要求
ascendc_assert(baseM % CUBE_BLOCK == 0, "baseM %u must be aligned to CUBE_BLOCK %u.\n", baseM, CUBE_BLOCK);
ascendc_assert(baseK % CUBE_BLOCK == 0, "baseK %u must be aligned to CUBE_BLOCK %u.\n", baseK, CUBE_BLOCK);
ascendc_assert(baseN % CUBE_BLOCK == 0, "baseN %u must be aligned to CUBE_BLOCK %u.\n", baseN, CUBE_BLOCK);
// 断言检查:验证维度匹配
ascendc_assert(
M == singleCoreM * (M / singleCoreM), "M %u should be divisible by singleCoreM %u.\n", M, singleCoreM);
AscendC::LocalMemAllocator<AscendC::Hardware::L1> l1Allocator;
AscendC::LocalMemAllocator<AscendC::Hardware::L0A> l0aAllocator;
AscendC::LocalMemAllocator<AscendC::Hardware::L0B> l0bAllocator;
AscendC::LocalMemAllocator<AscendC::Hardware::L0C> l0cAllocator;
// 使用LocalMemAllocator按申请顺序分配片上内存,避免手动维护LocalTensor地址偏移。
AscendC::LocalTensor<half> a1Local = l1Allocator.Alloc<AscendC::TPosition::A1, half>(baseM * baseK);
AscendC::LocalTensor<half> b1Local = l1Allocator.Alloc<AscendC::TPosition::B1, half>(baseK * baseN);
AscendC::LocalTensor<half> a2Local = l0aAllocator.Alloc<AscendC::TPosition::A2, half>(baseM * baseK);
AscendC::LocalTensor<half> b2Local = l0bAllocator.Alloc<AscendC::TPosition::B2, half>(baseK * baseN);
AscendC::LocalTensor<float> cLocal = l0cAllocator.Alloc<AscendC::TPosition::CO1, float>(baseM * baseN);
// 断言检查:验证LocalTensor大小
ascendc_assert(baseM * baseK * sizeof(half) <= 64 * 1024, "a2 LocalTensor size exceeds L0A buffer limit.\n");
ascendc_assert(baseK * baseN * sizeof(half) <= 64 * 1024, "b2 LocalTensor size exceeds L0B buffer limit.\n");
// A矩阵GM->L1调用DataCopy接口+Nd2NzParams结构体参数:{ndNum, nValue, dValue, srcNdMatrixStride, srcDValue,
// dstNzC0Stride, dstNzNStride, dstNzMatrixStride}
AscendC::Nd2NzParams nd2nzAParams = {1, baseM, baseK, 0, K, baseM, 1, 0};
AscendC::DataCopy(a1Local, aGM, nd2nzAParams);
// B矩阵GM->L1调用DataCopy接口+Nd2NzParams结构体参数:{ndNum, nValue, dValue, srcNdMatrixStride, srcDValue,
// dstNzC0Stride, dstNzNStride, dstNzMatrixStride}
AscendC::Nd2NzParams nd2nzBParams = {1, baseK, baseN, 0, N, baseK, 1, 0};
AscendC::DataCopy(b1Local, bGM, nd2nzBParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(EVENT_ID0);
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201)
// A矩阵L1->L0A(Nz->Zz)调用LoadData接口+LoadData2DParams结构体参数:{startIndex, repeatTimes, srcStride, sid,
// dstGap, ifTranspose, addrMode}
AscendC::LoadData2DParams loadDataAParams = {0, baseK / CUBE_BLOCK, baseM / CUBE_BLOCK, 0, 0, false, 0};
for (int i = 0; i < baseM / CUBE_BLOCK; ++i) {
AscendC::LoadData(a2Local[i * baseK * CUBE_BLOCK], a1Local[i * 512 / sizeof(half)], loadDataAParams);
}
#elif defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510)
// A矩阵L1->L0A(Nz->Nz)
AscendC::LoadData2DParams loadDataAParams = {
0, baseK / CUBE_BLOCK, baseM / CUBE_BLOCK, 0, (baseM / CUBE_BLOCK - 1), false, 0};
for (int i = 0; i < baseM / CUBE_BLOCK; ++i) {
AscendC::LoadData(a2Local[i * 512 / sizeof(half)], a1Local[i * 512 / sizeof(half)], loadDataAParams);
}
#endif
// B矩阵L1->L0B调用LoadData接口+LoadData2DParams结构体参数:{startIndex, repeatTimes, srcStride, sid, dstGap,
// ifTranspose, addrMode}
AscendC::LoadData2DParams loadDataBParams = {0, baseN / CUBE_BLOCK, baseK / CUBE_BLOCK, 0, 0, true, 0};
for (int i = 0; i < baseK / CUBE_BLOCK; ++i) {
AscendC::LoadData(b2Local[i * baseN * CUBE_BLOCK], b1Local[i * 512 / sizeof(half)], loadDataBParams);
}
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
// 矩阵乘调用Mmad接口+MmadParams结构体参数:{m, n, k, unitFlag, cmatrixSource, cmatrixInitVal}
AscendC::MmadParams mmadParams = {baseM, baseN, baseK, 0, false, true};
AscendC::Mmad(cLocal, a2Local, b2Local, mmadParams);
AscendC::SetFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
// L0C->GM搬运调用Fixpipe+FixpipeParamsV220结构体参数:{nSize, mSize, srcStride, dstStride, reluEn, quantPre,
// deqScalar, ndNum, srcNdStride, dstNdStride, unitFlag}
AscendC::FixpipeParamsV220 fixpipeParams = {baseN, baseM, baseM, N, false, QuantMode_t::F322F16, 0, 1, 0, 0, 0};
AscendC::Fixpipe(cGM, cLocal, fixpipeParams);
AscendC::PipeBarrier<PIPE_ALL>();
}
int32_t main(int32_t argc, char* argv[])
{
// matmul参数
constexpr uint32_t M = 256;
constexpr uint32_t K = 64;
constexpr uint32_t N = 256;
// 单核计算量
constexpr uint32_t singleCoreM = 128;
// 矩阵乘执行单元Tile参数
constexpr uint32_t baseM = 128;
constexpr uint32_t baseK = 64;
constexpr uint32_t baseN = 256;
size_t aFileSize = M * K * sizeof(half);
size_t bFileSize = K * N * sizeof(half);
size_t cFileSize = M * N * sizeof(half);
uint32_t numBlocks = 2;
aclInit(nullptr);
int32_t deviceId = 0;
aclrtSetDevice(deviceId);
aclrtStream stream = nullptr;
aclrtCreateStream(&stream);
uint8_t* aHost;
uint8_t* aDevice;
aclrtMallocHost((void**)(&aHost), aFileSize);
aclrtMalloc((void**)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t* bHost;
uint8_t* bDevice;
aclrtMallocHost((void**)(&bHost), bFileSize);
aclrtMalloc((void**)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize, ACL_MEMCPY_HOST_TO_DEVICE);
uint8_t* cHost;
uint8_t* cDevice;
aclrtMallocHost((void**)(&cHost), cFileSize);
aclrtMalloc((void**)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
mmad_custom<M, K, N, singleCoreM, baseM, baseK, baseN><<<numBlocks, nullptr, stream>>>(aDevice, bDevice, cDevice);
aclrtSynchronizeStream(stream);
aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize, ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile("./output/output.bin", cHost, cFileSize);
aclrtFree(aDevice);
aclrtFreeHost(aHost);
aclrtFree(bDevice);
aclrtFreeHost(bHost);
aclrtFree(cDevice);
aclrtFreeHost(cHost);
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
return 0;
}