/**
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/**
* \file fill.asc
* \brief 本样例展示如何使用Fill接口对L0A Buffer和L0B Buffer进行初始化
*/
#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"
#ifdef ASCENDC_CPU_DEBUG
#include "cpu_debug_launch.h"
#endif
template <typename T, typename U, typename S, int32_t M, int32_t N, int32_t K>
class KernelFill {
public:
__aicore__ inline KernelFill() {}
__aicore__ inline void Init(__gm__ uint8_t* aGm, __gm__ uint8_t* bGm, __gm__ uint8_t* cGm)
{
aGlobal.SetGlobalBuffer((__gm__ T*)aGm);
bGlobal.SetGlobalBuffer((__gm__ U*)bGm);
cGlobal.SetGlobalBuffer((__gm__ S*)cGm);
}
// L0A 初始化
__aicore__ inline void InitConstA2(AscendC::LocalTensor<T>& a2Local)
{
AscendC::InitConstValueParams<T> fillParams(
1, static_cast<uint16_t>(M * K * sizeof(T) / 512), 0, static_cast<T>(1));
AscendC::Fill(a2Local, fillParams);
}
// L0B 初始化
__aicore__ inline void InitConstB2(AscendC::LocalTensor<U>& b2Local)
{
AscendC::InitConstValueParams<U> fillParams(
1, static_cast<uint16_t>(K * N * sizeof(U) / 512), 0, static_cast<U>(1));
AscendC::Fill(b2Local, fillParams);
}
// A1初始化
__aicore__ inline void InitConstA1(AscendC::LocalTensor<T>& a1Local)
{
AscendC::Fill(a1Local, {1, static_cast<uint16_t>(M * K * sizeof(T) / 32), 0, 1});
}
// B1初始化
__aicore__ inline void InitConstB1(AscendC::LocalTensor<U>& b1Local)
{
AscendC::Fill(b1Local, {1, static_cast<uint16_t>(K * N * sizeof(U) / 32), 0, 1});
}
__aicore__ inline void Load2DA1ToA2(AscendC::LocalTensor<T>& a1Local, AscendC::LocalTensor<T>& a2Local)
{
AscendC::LoadData2DParamsV2 loadDataParams;
loadDataParams.mStartPosition = 0;
loadDataParams.kStartPosition = 0;
loadDataParams.mStep = AscendC::DivCeil(M, 16);
loadDataParams.kStep = AscendC::DivCeil(K * sizeof(T), 32);
loadDataParams.srcStride = AscendC::DivCeil(M, 16);
loadDataParams.dstStride = AscendC::DivCeil(M, 16);
loadDataParams.sid = 0;
loadDataParams.ifTranspose = false;
AscendC::LoadData(a2Local, a1Local, loadDataParams);
}
__aicore__ inline void Load2DB1ToB2(AscendC::LocalTensor<U>& b1Local, AscendC::LocalTensor<U>& b2Local)
{
uint16_t nAlign = AscendC::DivCeil(N * sizeof(U), 32);
uint16_t kAlign = AscendC::DivCeil(K, 16);
AscendC::LoadData2DParamsV2 loadDataParams;
loadDataParams.mStartPosition = 0;
loadDataParams.kStartPosition = 0;
loadDataParams.mStep = kAlign;
loadDataParams.kStep = nAlign;
loadDataParams.srcStride = kAlign;
loadDataParams.dstStride = nAlign;
loadDataParams.sid = 0;
loadDataParams.ifTranspose = true;
AscendC::LoadData(b2Local, b1Local, loadDataParams);
}
__aicore__ inline void Compute(
AscendC::LocalTensor<T>& a2Local, AscendC::LocalTensor<U>& b2Local, AscendC::LocalTensor<S>& co1Local)
{
AscendC::MmadParams mmadParams;
mmadParams.m = M;
mmadParams.n = N;
mmadParams.k = K;
mmadParams.isBias = false;
AscendC::Mmad(co1Local, a2Local, b2Local, mmadParams);
}
__aicore__ inline void CopyL0CToGm(AscendC::LocalTensor<S>& co1Local)
{
AscendC::FixpipeParamsV220 fixpipeParams;
fixpipeParams.nSize = N;
fixpipeParams.mSize = M;
fixpipeParams.srcStride = M;
fixpipeParams.dstStride = N;
fixpipeParams.ndNum = 1;
fixpipeParams.srcNdStride = 2;
fixpipeParams.dstNdStride = M * N;
fixpipeParams.quantPre = QuantMode_t::NoQuant;
AscendC::Fixpipe<S, S, AscendC::CFG_ROW_MAJOR>(cGlobal, co1Local, fixpipeParams);
}
__aicore__ inline void Process()
{
AscendC::LocalTensor<T> a2Local(AscendC::TPosition::A2, a2Addr, M * K);
AscendC::LocalTensor<U> b2Local(AscendC::TPosition::B2, b2Addr, N * K);
AscendC::LocalTensor<S> co1Local(AscendC::TPosition::CO1, co1Addr, M * N);
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201)
// Atlas A2/A3 训练/推理系列产品直接通过Fill初始化L0A和L0B
InitConstA2(a2Local);
InitConstB2(b2Local);
#elif defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510)
AscendC::LocalTensor<T> a1Local(AscendC::TPosition::A1, a1Addr, M * K);
AscendC::LocalTensor<U> b1Local(AscendC::TPosition::B1, b1Addr, K * N);
// Ascend 950PR/Ascend 950DT不支持直接初始化L0A和L0B,所以通过Fill初始化L1,再通过LoadData搬运到L0A和L0B
InitConstA1(a1Local);
InitConstB1(b1Local);
AscendC::PipeBarrier<PIPE_MTE1>();
Load2DA1ToA2(a1Local, a2Local);
Load2DB1ToB2(b1Local, b2Local);
#endif
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
// 调用`Mmad`接口进行矩阵乘计算
Compute(a2Local, b2Local, co1Local);
AscendC::SetFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
CopyL0CToGm(co1Local);
}
private:
uint64_t a1Addr = 0;
uint64_t b1Addr = 0;
uint64_t a2Addr = 0;
uint64_t b2Addr = a2Addr + static_cast<uint64_t>(M) * K * sizeof(T);
uint64_t co1Addr = b2Addr + static_cast<uint64_t>(N) * K * sizeof(U);
AscendC::GlobalTensor<T> aGlobal;
AscendC::GlobalTensor<U> bGlobal;
AscendC::GlobalTensor<S> cGlobal;
};
template <int32_t M, int32_t N, int32_t K>
__global__ __cube__ void fill(__gm__ uint8_t* x, __gm__ uint8_t* y, __gm__ uint8_t* z)
{
AscendC::InitSocState();
KernelFill<half, half, float, M, N, K> op;
op.Init(x, y, z);
op.Process();
AscendC::PipeBarrier<PIPE_ALL>();
}
int32_t main(int32_t argc, char* argv[])
{
constexpr uint32_t NUM_BLOCKS = 1;
constexpr int32_t M = 128;
constexpr int32_t N = 64;
constexpr int32_t K = 128;
aclInit(nullptr);
int32_t deviceId = 0;
aclrtSetDevice(deviceId);
aclrtStream stream = nullptr;
aclrtCreateStream(&stream);
size_t aInputByteSize = M * K * sizeof(half);
size_t bInputByteSize = K * N * sizeof(half);
size_t outputByteSize = M * N * sizeof(float);
uint8_t *xHost, *yHost, *zHost;
uint8_t *xDevice, *yDevice, *zDevice;
aclrtMallocHost((void**)(&xHost), aInputByteSize);
aclrtMallocHost((void**)(&yHost), bInputByteSize);
aclrtMallocHost((void**)(&zHost), outputByteSize);
aclrtMalloc((void**)&xDevice, aInputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void**)&yDevice, bInputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void**)&zDevice, outputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
ReadFile("./input/input_x.bin", aInputByteSize, xHost, aInputByteSize);
ReadFile("./input/input_y.bin", bInputByteSize, yHost, bInputByteSize);
aclrtMemcpy(xDevice, aInputByteSize, xHost, aInputByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemcpy(yDevice, bInputByteSize, yHost, bInputByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
fill<M, N, K><<<NUM_BLOCKS, nullptr, stream>>>(xDevice, yDevice, zDevice);
aclrtSynchronizeStream(stream);
aclrtMemcpy(zHost, outputByteSize, zDevice, outputByteSize, ACL_MEMCPY_DEVICE_TO_HOST);
WriteFile("./output/output.bin", zHost, outputByteSize);
aclrtFree(xDevice);
aclrtFree(yDevice);
aclrtFree(zDevice);
aclrtFreeHost(xHost);
aclrtFreeHost(yHost);
aclrtFreeHost(zHost);
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
return 0;
}