/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/**
 * \file fill.asc
 * \brief 本样例展示如何使用Fill接口对L0A Buffer和L0B Buffer进行初始化
 */

#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"

#ifdef ASCENDC_CPU_DEBUG
#include "cpu_debug_launch.h"
#endif

template <typename T, typename U, typename S, int32_t M, int32_t N, int32_t K>
class KernelFill {
public:
    __aicore__ inline KernelFill() {}
    __aicore__ inline void Init(__gm__ uint8_t* aGm, __gm__ uint8_t* bGm, __gm__ uint8_t* cGm)
    {
        aGlobal.SetGlobalBuffer((__gm__ T*)aGm);
        bGlobal.SetGlobalBuffer((__gm__ U*)bGm);
        cGlobal.SetGlobalBuffer((__gm__ S*)cGm);
    }

    // L0A 初始化
    __aicore__ inline void InitConstA2(AscendC::LocalTensor<T>& a2Local)
    {
        AscendC::InitConstValueParams<T> fillParams(
            1, static_cast<uint16_t>(M * K * sizeof(T) / 512), 0, static_cast<T>(1));
        AscendC::Fill(a2Local, fillParams);
    }

    // L0B 初始化
    __aicore__ inline void InitConstB2(AscendC::LocalTensor<U>& b2Local)
    {
        AscendC::InitConstValueParams<U> fillParams(
            1, static_cast<uint16_t>(K * N * sizeof(U) / 512), 0, static_cast<U>(1));
        AscendC::Fill(b2Local, fillParams);
    }

    // A1初始化
    __aicore__ inline void InitConstA1(AscendC::LocalTensor<T>& a1Local)
    {
        AscendC::Fill(a1Local, {1, static_cast<uint16_t>(M * K * sizeof(T) / 32), 0, 1});
    }

    // B1初始化
    __aicore__ inline void InitConstB1(AscendC::LocalTensor<U>& b1Local)
    {
        AscendC::Fill(b1Local, {1, static_cast<uint16_t>(K * N * sizeof(U) / 32), 0, 1});
    }

    __aicore__ inline void Load2DA1ToA2(AscendC::LocalTensor<T>& a1Local, AscendC::LocalTensor<T>& a2Local)
    {
        AscendC::LoadData2DParamsV2 loadDataParams;
        loadDataParams.mStartPosition = 0;
        loadDataParams.kStartPosition = 0;
        loadDataParams.mStep = AscendC::DivCeil(M, 16);
        loadDataParams.kStep = AscendC::DivCeil(K * sizeof(T), 32);
        loadDataParams.srcStride = AscendC::DivCeil(M, 16);
        loadDataParams.dstStride = AscendC::DivCeil(M, 16);
        loadDataParams.sid = 0;
        loadDataParams.ifTranspose = false;
        AscendC::LoadData(a2Local, a1Local, loadDataParams);
    }

    __aicore__ inline void Load2DB1ToB2(AscendC::LocalTensor<U>& b1Local, AscendC::LocalTensor<U>& b2Local)
    {
        uint16_t nAlign = AscendC::DivCeil(N * sizeof(U), 32);
        uint16_t kAlign = AscendC::DivCeil(K, 16);
        AscendC::LoadData2DParamsV2 loadDataParams;
        loadDataParams.mStartPosition = 0;
        loadDataParams.kStartPosition = 0;
        loadDataParams.mStep = kAlign;
        loadDataParams.kStep = nAlign;
        loadDataParams.srcStride = kAlign;
        loadDataParams.dstStride = nAlign;
        loadDataParams.sid = 0;
        loadDataParams.ifTranspose = true;
        AscendC::LoadData(b2Local, b1Local, loadDataParams);
    }

    __aicore__ inline void Compute(
        AscendC::LocalTensor<T>& a2Local, AscendC::LocalTensor<U>& b2Local, AscendC::LocalTensor<S>& co1Local)
    {
        AscendC::MmadParams mmadParams;
        mmadParams.m = M;
        mmadParams.n = N;
        mmadParams.k = K;
        mmadParams.isBias = false;
        AscendC::Mmad(co1Local, a2Local, b2Local, mmadParams);
    }

    __aicore__ inline void CopyL0CToGm(AscendC::LocalTensor<S>& co1Local)
    {
        AscendC::FixpipeParamsV220 fixpipeParams;
        fixpipeParams.nSize = N;
        fixpipeParams.mSize = M;
        fixpipeParams.srcStride = M;
        fixpipeParams.dstStride = N;
        fixpipeParams.ndNum = 1;
        fixpipeParams.srcNdStride = 2;
        fixpipeParams.dstNdStride = M * N;
        fixpipeParams.quantPre = QuantMode_t::NoQuant;
        AscendC::Fixpipe<S, S, AscendC::CFG_ROW_MAJOR>(cGlobal, co1Local, fixpipeParams);
    }

    __aicore__ inline void Process()
    {
        AscendC::LocalTensor<T> a2Local(AscendC::TPosition::A2, a2Addr, M * K);
        AscendC::LocalTensor<U> b2Local(AscendC::TPosition::B2, b2Addr, N * K);
        AscendC::LocalTensor<S> co1Local(AscendC::TPosition::CO1, co1Addr, M * N);
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201)
        // Atlas A2/A3 训练/推理系列产品直接通过Fill初始化L0A和L0B
        InitConstA2(a2Local);
        InitConstB2(b2Local);
#elif defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510)
        AscendC::LocalTensor<T> a1Local(AscendC::TPosition::A1, a1Addr, M * K);
        AscendC::LocalTensor<U> b1Local(AscendC::TPosition::B1, b1Addr, K * N);
        // Ascend 950PR/Ascend 950DT不支持直接初始化L0A和L0B,所以通过Fill初始化L1,再通过LoadData搬运到L0A和L0B
        InitConstA1(a1Local);
        InitConstB1(b1Local);
        AscendC::PipeBarrier<PIPE_MTE1>();
        Load2DA1ToA2(a1Local, a2Local);
        Load2DB1ToB2(b1Local, b2Local);
#endif
        AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
        // 调用`Mmad`接口进行矩阵乘计算
        Compute(a2Local, b2Local, co1Local);
        AscendC::SetFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
        AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
        CopyL0CToGm(co1Local);
    }

private:
    uint64_t a1Addr = 0;
    uint64_t b1Addr = 0;
    uint64_t a2Addr = 0;
    uint64_t b2Addr = a2Addr + static_cast<uint64_t>(M) * K * sizeof(T);
    uint64_t co1Addr = b2Addr + static_cast<uint64_t>(N) * K * sizeof(U);
    AscendC::GlobalTensor<T> aGlobal;
    AscendC::GlobalTensor<U> bGlobal;
    AscendC::GlobalTensor<S> cGlobal;
};

template <int32_t M, int32_t N, int32_t K>
__global__ __cube__ void fill(__gm__ uint8_t* x, __gm__ uint8_t* y, __gm__ uint8_t* z)
{
    AscendC::InitSocState();
    KernelFill<half, half, float, M, N, K> op;
    op.Init(x, y, z);
    op.Process();
    AscendC::PipeBarrier<PIPE_ALL>();
}

int32_t main(int32_t argc, char* argv[])
{
    constexpr uint32_t NUM_BLOCKS = 1;
    constexpr int32_t M = 128;
    constexpr int32_t N = 64;
    constexpr int32_t K = 128;

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);
    size_t aInputByteSize = M * K * sizeof(half);
    size_t bInputByteSize = K * N * sizeof(half);
    size_t outputByteSize = M * N * sizeof(float);
    uint8_t *xHost, *yHost, *zHost;
    uint8_t *xDevice, *yDevice, *zDevice;
    aclrtMallocHost((void**)(&xHost), aInputByteSize);
    aclrtMallocHost((void**)(&yHost), bInputByteSize);
    aclrtMallocHost((void**)(&zHost), outputByteSize);
    aclrtMalloc((void**)&xDevice, aInputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&yDevice, bInputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&zDevice, outputByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/input_x.bin", aInputByteSize, xHost, aInputByteSize);
    ReadFile("./input/input_y.bin", bInputByteSize, yHost, bInputByteSize);
    aclrtMemcpy(xDevice, aInputByteSize, xHost, aInputByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(yDevice, bInputByteSize, yHost, bInputByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
    fill<M, N, K><<<NUM_BLOCKS, nullptr, stream>>>(xDevice, yDevice, zDevice);
    aclrtSynchronizeStream(stream);
    aclrtMemcpy(zHost, outputByteSize, zDevice, outputByteSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", zHost, outputByteSize);
    aclrtFree(xDevice);
    aclrtFree(yDevice);
    aclrtFree(zDevice);
    aclrtFreeHost(xHost);
    aclrtFreeHost(yHost);
    aclrtFreeHost(zHost);
    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();
    return 0;
}