/**
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file clock.asc
* \brief 基于Gather和Adds混合编程样例,展示在SIMD&SIMT混编场景中使用clock接口统计执行周期。
*/
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <iterator>
#include <vector>
#include "acl/acl.h"
#include "kernel_operator.h"
#include "simt_api/asc_simt.h"
namespace {
constexpr uint32_t THREAD_COUNT = 1024;
} // namespace
template <typename T>
__simt_vf__ __launch_bounds__(THREAD_COUNT) inline void simt_gather(
__gm__ T* input, __gm__ uint32_t* index, __ubuf__ T* gatherOutput, uint32_t inputTotalLength,
uint32_t indexTotalLength, uint32_t outputTotalLength)
{
// 记录SIMT Gather阶段开始前的时间戳。
uint64_t start = clock();
if (threadIdx.x >= outputTotalLength) {
return;
}
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= indexTotalLength) {
return;
}
uint32_t gatherIdx = index[idx];
if (gatherIdx >= inputTotalLength) {
return;
}
gatherOutput[threadIdx.x] = input[gatherIdx];
// 记录SIMT Gather阶段结束后的时间戳,仅由第一个线程输出周期差,避免多线程重复打印。
uint64_t end = clock();
if (blockIdx.x == 0 && threadIdx.x == 0) {
printf("simt_gather execute cycle : %lu\n", end - start);
}
}
template <typename T>
__simd_vf__ inline void simd_adds(
__ubuf__ T* output, __ubuf__ T* input, T addsAddend, uint32_t count, uint32_t oneRepeatSize, uint16_t repeatTimes)
{
AscendC::Reg::RegTensor<T> srcReg0;
AscendC::Reg::RegTensor<T> dstReg0;
AscendC::Reg::MaskReg maskReg;
for (uint16_t i = 0; i < repeatTimes; i++) {
maskReg = AscendC::Reg::UpdateMask<T>(count);
AscendC::Reg::LoadAlign(srcReg0, input + i * oneRepeatSize);
AscendC::Reg::Adds(dstReg0, srcReg0, addsAddend, maskReg);
AscendC::Reg::StoreAlign(output + i * oneRepeatSize, dstReg0, maskReg);
}
}
template <typename T>
__global__ __vector__ void gather_and_adds_kernel(
__gm__ T* input, __gm__ uint32_t* index, __gm__ T* output, T addsAddend, uint32_t inputTotalLength,
uint32_t indexTotalLength)
{
AscendC::InitSocState();
AscendC::LocalMemAllocator<AscendC::Hardware::UB> ubAllocator;
uint32_t indexTotalLengthPerBlock = indexTotalLength / AscendC::GetBlockNum();
AscendC::LocalTensor<T> gatherOutput = ubAllocator.Alloc<T>(indexTotalLengthPerBlock);
asc_vf_call<simt_gather<T>>(
dim3(THREAD_COUNT), input, index, (__ubuf__ T*)gatherOutput.GetPhyAddr(), inputTotalLength,
indexTotalLength, indexTotalLengthPerBlock);
AscendC::LocalTensor<T> addsOutput = ubAllocator.Alloc<T>(indexTotalLengthPerBlock);
constexpr uint32_t oneRepeatSize = AscendC::GetVecLen() / sizeof(T);
uint16_t repeatTimes = (indexTotalLengthPerBlock + oneRepeatSize - 1) / oneRepeatSize;
asc_vf_call<simd_adds<T>>(
(__ubuf__ T*)addsOutput.GetPhyAddr(), (__ubuf__ T*)gatherOutput.GetPhyAddr(), addsAddend,
indexTotalLengthPerBlock,
oneRepeatSize, repeatTimes);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::GlobalTensor<T> outputGlobalTensor;
outputGlobalTensor.SetGlobalBuffer(output + indexTotalLengthPerBlock * AscendC::GetBlockIdx());
AscendC::DataCopy(outputGlobalTensor, addsOutput, indexTotalLengthPerBlock);
AscendC::PipeBarrier<PIPE_ALL>();
}
template <typename T>
uint32_t VerifyResult(const std::vector<T>& output, const std::vector<T>& golden)
{
auto printTensor = [](const std::vector<T>& tensor, const char* name) {
constexpr size_t maxPrintSize = 20;
std::cout << name << ": ";
std::copy(tensor.begin(), tensor.begin() + std::min(tensor.size(), maxPrintSize),
std::ostream_iterator<T>(std::cout, " "));
if (tensor.size() > maxPrintSize) {
std::cout << "...";
}
std::cout << std::endl;
};
printTensor(output, "Output");
printTensor(golden, "Golden");
if (std::equal(golden.begin(), golden.end(), output.begin())) {
std::cout << "test pass!" << std::endl;
return 0;
}
std::cout << "test failed!" << std::endl;
return 1;
}
template <typename T, uint32_t inputTotalLength, uint32_t indexTotalLength, uint32_t numBlocks, uint32_t dynUBufSize>
uint32_t RunGatherAndAdds(T addsAddend)
{
size_t inputByteSize = inputTotalLength * sizeof(T);
size_t indexByteSize = indexTotalLength * sizeof(uint32_t);
size_t outputByteSize = indexTotalLength * sizeof(T);
aclInit(nullptr);
int32_t deviceId = 0;
aclrtSetDevice(deviceId);
aclrtStream stream = nullptr;
aclrtCreateStream(&stream);
T* inputDevice = nullptr;
aclrtMalloc((void**)&inputDevice, inputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
uint32_t* indexDevice = nullptr;
aclrtMalloc((void**)&indexDevice, indexByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
T* outputDevice = nullptr;
aclrtMalloc((void**)&outputDevice, outputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
std::vector<T> input(inputTotalLength);
std::vector<uint32_t> index(indexTotalLength);
std::vector<T> output(indexTotalLength);
std::vector<T> golden(indexTotalLength);
for (uint32_t i = 0; i < inputTotalLength; ++i) {
input[i] = static_cast<T>(static_cast<float>(i) * 0.01f);
}
for (uint32_t i = 0; i < indexTotalLength; ++i) {
index[i] = i % inputTotalLength;
golden[i] = input[index[i]] + addsAddend;
}
aclrtMemcpy(inputDevice, inputByteSize, input.data(), inputByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMemcpy(indexDevice, indexByteSize, index.data(), indexByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
gather_and_adds_kernel<T><<<numBlocks, dynUBufSize, stream>>>(
inputDevice, indexDevice, outputDevice, addsAddend, inputTotalLength, indexTotalLength);
aclrtSynchronizeStream(stream);
aclrtMemcpy(output.data(), outputByteSize, outputDevice, outputByteSize, ACL_MEMCPY_DEVICE_TO_HOST);
aclrtFree(inputDevice);
aclrtFree(indexDevice);
aclrtFree(outputDevice);
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
return VerifyResult(output, golden);
}
int32_t main(int32_t argc, char* argv[])
{
using DataType = float;
constexpr uint32_t inputTotalLength = 100000;
constexpr uint32_t indexTotalLength = 8 * 1024;
constexpr uint32_t numBlocks = 8;
constexpr uint32_t dynUBufSize = 2048;
constexpr DataType addsAddend = 1.0f;
return RunGatherAndAdds<DataType, inputTotalLength, indexTotalLength, numBlocks, dynUBufSize>(addsAddend);
}