/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file clock.asc
 * \brief 基于Gather和Adds混合编程样例,展示在SIMD&SIMT混编场景中使用clock接口统计执行周期。
 */

#include <algorithm>
#include <cstdint>
#include <iostream>
#include <iterator>
#include <vector>

#include "acl/acl.h"
#include "kernel_operator.h"
#include "simt_api/asc_simt.h"

namespace {
constexpr uint32_t THREAD_COUNT = 1024;
} // namespace

template <typename T>
__simt_vf__ __launch_bounds__(THREAD_COUNT) inline void simt_gather(
    __gm__ T* input, __gm__ uint32_t* index, __ubuf__ T* gatherOutput, uint32_t inputTotalLength,
    uint32_t indexTotalLength, uint32_t outputTotalLength)
{
    // 记录SIMT Gather阶段开始前的时间戳。
    uint64_t start = clock();

    if (threadIdx.x >= outputTotalLength) {
        return;
    }

    uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= indexTotalLength) {
        return;
    }

    uint32_t gatherIdx = index[idx];
    if (gatherIdx >= inputTotalLength) {
        return;
    }

    gatherOutput[threadIdx.x] = input[gatherIdx];

    // 记录SIMT Gather阶段结束后的时间戳,仅由第一个线程输出周期差,避免多线程重复打印。
    uint64_t end = clock();
    if (blockIdx.x == 0 && threadIdx.x == 0) {
        printf("simt_gather execute cycle : %lu\n", end - start);
    }
}

template <typename T>
__simd_vf__ inline void simd_adds(
    __ubuf__ T* output, __ubuf__ T* input, T addsAddend, uint32_t count, uint32_t oneRepeatSize, uint16_t repeatTimes)
{
    AscendC::Reg::RegTensor<T> srcReg0;
    AscendC::Reg::RegTensor<T> dstReg0;
    AscendC::Reg::MaskReg maskReg;

    for (uint16_t i = 0; i < repeatTimes; i++) {
        maskReg = AscendC::Reg::UpdateMask<T>(count);
        AscendC::Reg::LoadAlign(srcReg0, input + i * oneRepeatSize);
        AscendC::Reg::Adds(dstReg0, srcReg0, addsAddend, maskReg);
        AscendC::Reg::StoreAlign(output + i * oneRepeatSize, dstReg0, maskReg);
    }
}

template <typename T>
__global__ __vector__ void gather_and_adds_kernel(
    __gm__ T* input, __gm__ uint32_t* index, __gm__ T* output, T addsAddend, uint32_t inputTotalLength,
    uint32_t indexTotalLength)
{
    AscendC::InitSocState();
    AscendC::LocalMemAllocator<AscendC::Hardware::UB> ubAllocator;

    uint32_t indexTotalLengthPerBlock = indexTotalLength / AscendC::GetBlockNum();
    AscendC::LocalTensor<T> gatherOutput = ubAllocator.Alloc<T>(indexTotalLengthPerBlock);
    asc_vf_call<simt_gather<T>>(
        dim3(THREAD_COUNT), input, index, (__ubuf__ T*)gatherOutput.GetPhyAddr(), inputTotalLength,
        indexTotalLength, indexTotalLengthPerBlock);

    AscendC::LocalTensor<T> addsOutput = ubAllocator.Alloc<T>(indexTotalLengthPerBlock);
    constexpr uint32_t oneRepeatSize = AscendC::GetVecLen() / sizeof(T);
    uint16_t repeatTimes = (indexTotalLengthPerBlock + oneRepeatSize - 1) / oneRepeatSize;

    asc_vf_call<simd_adds<T>>(
        (__ubuf__ T*)addsOutput.GetPhyAddr(), (__ubuf__ T*)gatherOutput.GetPhyAddr(), addsAddend,
        indexTotalLengthPerBlock,
        oneRepeatSize, repeatTimes);

    AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
    AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);

    AscendC::GlobalTensor<T> outputGlobalTensor;
    outputGlobalTensor.SetGlobalBuffer(output + indexTotalLengthPerBlock * AscendC::GetBlockIdx());
    AscendC::DataCopy(outputGlobalTensor, addsOutput, indexTotalLengthPerBlock);
    AscendC::PipeBarrier<PIPE_ALL>();
}

template <typename T>
uint32_t VerifyResult(const std::vector<T>& output, const std::vector<T>& golden)
{
    auto printTensor = [](const std::vector<T>& tensor, const char* name) {
        constexpr size_t maxPrintSize = 20;
        std::cout << name << ": ";
        std::copy(tensor.begin(), tensor.begin() + std::min(tensor.size(), maxPrintSize),
            std::ostream_iterator<T>(std::cout, " "));
        if (tensor.size() > maxPrintSize) {
            std::cout << "...";
        }
        std::cout << std::endl;
    };
    printTensor(output, "Output");
    printTensor(golden, "Golden");
    if (std::equal(golden.begin(), golden.end(), output.begin())) {
        std::cout << "test pass!" << std::endl;
        return 0;
    }
    std::cout << "test failed!" << std::endl;
    return 1;
}

template <typename T, uint32_t inputTotalLength, uint32_t indexTotalLength, uint32_t numBlocks, uint32_t dynUBufSize>
uint32_t RunGatherAndAdds(T addsAddend)
{
    size_t inputByteSize = inputTotalLength * sizeof(T);
    size_t indexByteSize = indexTotalLength * sizeof(uint32_t);
    size_t outputByteSize = indexTotalLength * sizeof(T);

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    T* inputDevice = nullptr;
    aclrtMalloc((void**)&inputDevice, inputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);

    uint32_t* indexDevice = nullptr;
    aclrtMalloc((void**)&indexDevice, indexByteSize, ACL_MEM_MALLOC_HUGE_FIRST);

    T* outputDevice = nullptr;
    aclrtMalloc((void**)&outputDevice, outputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);

    std::vector<T> input(inputTotalLength);
    std::vector<uint32_t> index(indexTotalLength);
    std::vector<T> output(indexTotalLength);
    std::vector<T> golden(indexTotalLength);
    for (uint32_t i = 0; i < inputTotalLength; ++i) {
        input[i] = static_cast<T>(static_cast<float>(i) * 0.01f);
    }
    for (uint32_t i = 0; i < indexTotalLength; ++i) {
        index[i] = i % inputTotalLength;
        golden[i] = input[index[i]] + addsAddend;
    }

    aclrtMemcpy(inputDevice, inputByteSize, input.data(), inputByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(indexDevice, indexByteSize, index.data(), indexByteSize, ACL_MEMCPY_HOST_TO_DEVICE);

    gather_and_adds_kernel<T><<<numBlocks, dynUBufSize, stream>>>(
        inputDevice, indexDevice, outputDevice, addsAddend, inputTotalLength, indexTotalLength);

    aclrtSynchronizeStream(stream);
    aclrtMemcpy(output.data(), outputByteSize, outputDevice, outputByteSize, ACL_MEMCPY_DEVICE_TO_HOST);

    aclrtFree(inputDevice);
    aclrtFree(indexDevice);
    aclrtFree(outputDevice);
    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();

    return VerifyResult(output, golden);
}

int32_t main(int32_t argc, char* argv[])
{
    using DataType = float;
    constexpr uint32_t inputTotalLength = 100000;
    constexpr uint32_t indexTotalLength = 8 * 1024;
    constexpr uint32_t numBlocks = 8;
    constexpr uint32_t dynUBufSize = 2048;
    constexpr DataType addsAddend = 1.0f;
    return RunGatherAndAdds<DataType, inputTotalLength, indexTotalLength, numBlocks, dynUBufSize>(addsAddend);
}