/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file assert.asc
 * \brief 使用静态Tensor编程模式实现矩阵乘法,展示ascendc_assert接口的基本使用方法
 */

#include "acl/acl.h"
#include "kernel_operator.h"
#include "data_utils.h"

// half type, cube block: [16, 16]
constexpr uint32_t CUBE_BLOCK = 16;

template <uint32_t M, uint32_t K, uint32_t N, uint32_t singleCoreM, uint32_t baseM, uint32_t baseK, uint32_t baseN>
__global__ __cube__ void mmad_custom(__gm__ uint8_t* a, __gm__ uint8_t* b, __gm__ uint8_t* c)
{
    AscendC::GlobalTensor<half> aGM;
    AscendC::GlobalTensor<half> bGM;
    AscendC::GlobalTensor<half> cGM;
    uint32_t mIterIdx = AscendC::GetBlockIdx() % (M / singleCoreM);
    aGM.SetGlobalBuffer((__gm__ half*)a + mIterIdx * singleCoreM * K);
    bGM.SetGlobalBuffer((__gm__ half*)b);
    cGM.SetGlobalBuffer((__gm__ half*)c + mIterIdx * singleCoreM * N);

    // 断言检查:验证blockIdx在有效范围内
    ascendc_assert(
        AscendC::GetBlockIdx() < M / singleCoreM, "BlockIdx %u is out of valid range.\n", AscendC::GetBlockIdx());
    // 断言检查:验证参数对齐要求
    ascendc_assert(baseM % CUBE_BLOCK == 0, "baseM %u must be aligned to CUBE_BLOCK %u.\n", baseM, CUBE_BLOCK);
    ascendc_assert(baseK % CUBE_BLOCK == 0, "baseK %u must be aligned to CUBE_BLOCK %u.\n", baseK, CUBE_BLOCK);
    ascendc_assert(baseN % CUBE_BLOCK == 0, "baseN %u must be aligned to CUBE_BLOCK %u.\n", baseN, CUBE_BLOCK);
    // 断言检查:验证维度匹配
    ascendc_assert(
        M == singleCoreM * (M / singleCoreM), "M %u should be divisible by singleCoreM %u.\n", M, singleCoreM);

    AscendC::LocalMemAllocator<AscendC::Hardware::L1> l1Allocator;
    AscendC::LocalMemAllocator<AscendC::Hardware::L0A> l0aAllocator;
    AscendC::LocalMemAllocator<AscendC::Hardware::L0B> l0bAllocator;
    AscendC::LocalMemAllocator<AscendC::Hardware::L0C> l0cAllocator;
    // 使用LocalMemAllocator按申请顺序分配片上内存,避免手动维护LocalTensor地址偏移。
    AscendC::LocalTensor<half> a1Local = l1Allocator.Alloc<AscendC::TPosition::A1, half>(baseM * baseK);
    AscendC::LocalTensor<half> b1Local = l1Allocator.Alloc<AscendC::TPosition::B1, half>(baseK * baseN);
    AscendC::LocalTensor<half> a2Local = l0aAllocator.Alloc<AscendC::TPosition::A2, half>(baseM * baseK);
    AscendC::LocalTensor<half> b2Local = l0bAllocator.Alloc<AscendC::TPosition::B2, half>(baseK * baseN);
    AscendC::LocalTensor<float> cLocal = l0cAllocator.Alloc<AscendC::TPosition::CO1, float>(baseM * baseN);

    // 断言检查:验证LocalTensor大小
    ascendc_assert(baseM * baseK * sizeof(half) <= 64 * 1024, "a2 LocalTensor size exceeds L0A buffer limit.\n");
    ascendc_assert(baseK * baseN * sizeof(half) <= 64 * 1024, "b2 LocalTensor size exceeds L0B buffer limit.\n");

    // A矩阵GM->L1调用DataCopy接口+Nd2NzParams结构体参数:{ndNum, nValue, dValue, srcNdMatrixStride, srcDValue,
    // dstNzC0Stride, dstNzNStride, dstNzMatrixStride}
    AscendC::Nd2NzParams nd2nzAParams = {1, baseM, baseK, 0, K, baseM, 1, 0};
    AscendC::DataCopy(a1Local, aGM, nd2nzAParams);
    // B矩阵GM->L1调用DataCopy接口+Nd2NzParams结构体参数:{ndNum, nValue, dValue, srcNdMatrixStride, srcDValue,
    // dstNzC0Stride, dstNzNStride, dstNzMatrixStride}
    AscendC::Nd2NzParams nd2nzBParams = {1, baseK, baseN, 0, N, baseK, 1, 0};
    AscendC::DataCopy(b1Local, bGM, nd2nzBParams);

    AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(EVENT_ID0);
    AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(EVENT_ID0);
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201)
    // A矩阵L1->L0A(Nz->Zz)调用LoadData接口+LoadData2DParams结构体参数:{startIndex, repeatTimes, srcStride, sid,
    // dstGap, ifTranspose, addrMode}
    AscendC::LoadData2DParams loadDataAParams = {0, baseK / CUBE_BLOCK, baseM / CUBE_BLOCK, 0, 0, false, 0};
    for (int i = 0; i < baseM / CUBE_BLOCK; ++i) {
        AscendC::LoadData(a2Local[i * baseK * CUBE_BLOCK], a1Local[i * 512 / sizeof(half)], loadDataAParams);
    }
#elif defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510)
    // A矩阵L1->L0A(Nz->Nz)
    AscendC::LoadData2DParams loadDataAParams = {
        0, baseK / CUBE_BLOCK, baseM / CUBE_BLOCK, 0, (baseM / CUBE_BLOCK - 1), false, 0};
    for (int i = 0; i < baseM / CUBE_BLOCK; ++i) {
        AscendC::LoadData(a2Local[i * 512 / sizeof(half)], a1Local[i * 512 / sizeof(half)], loadDataAParams);
    }
#endif
    // B矩阵L1->L0B调用LoadData接口+LoadData2DParams结构体参数:{startIndex, repeatTimes, srcStride, sid, dstGap,
    // ifTranspose, addrMode}
    AscendC::LoadData2DParams loadDataBParams = {0, baseN / CUBE_BLOCK, baseK / CUBE_BLOCK, 0, 0, true, 0};
    for (int i = 0; i < baseK / CUBE_BLOCK; ++i) {
        AscendC::LoadData(b2Local[i * baseN * CUBE_BLOCK], b1Local[i * 512 / sizeof(half)], loadDataBParams);
    }

    AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
    AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
    // 矩阵乘调用Mmad接口+MmadParams结构体参数:{m, n, k, unitFlag, cmatrixSource, cmatrixInitVal}
    AscendC::MmadParams mmadParams = {baseM, baseN, baseK, 0, false, true};
    AscendC::Mmad(cLocal, a2Local, b2Local, mmadParams);

    AscendC::SetFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
    AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
    // L0C->GM搬运调用Fixpipe+FixpipeParamsV220结构体参数:{nSize, mSize, srcStride, dstStride, reluEn, quantPre,
    // deqScalar, ndNum, srcNdStride, dstNdStride, unitFlag}
    AscendC::FixpipeParamsV220 fixpipeParams = {baseN, baseM, baseM, N, false, QuantMode_t::F322F16, 0, 1, 0, 0, 0};
    AscendC::Fixpipe(cGM, cLocal, fixpipeParams);
    AscendC::PipeBarrier<PIPE_ALL>();
}

int32_t main(int32_t argc, char* argv[])
{
    // matmul参数
    constexpr uint32_t M = 256;
    constexpr uint32_t K = 64;
    constexpr uint32_t N = 256;
    // 单核计算量
    constexpr uint32_t singleCoreM = 128;
    // 矩阵乘执行单元Tile参数
    constexpr uint32_t baseM = 128;
    constexpr uint32_t baseK = 64;
    constexpr uint32_t baseN = 256;

    size_t aFileSize = M * K * sizeof(half);
    size_t bFileSize = K * N * sizeof(half);
    size_t cFileSize = M * N * sizeof(half);
    uint32_t numBlocks = 2;

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t* aHost;
    uint8_t* aDevice;
    aclrtMallocHost((void**)(&aHost), aFileSize);
    aclrtMalloc((void**)&aDevice, aFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x1_gm.bin", aFileSize, aHost, aFileSize);
    aclrtMemcpy(aDevice, aFileSize, aHost, aFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t* bHost;
    uint8_t* bDevice;
    aclrtMallocHost((void**)(&bHost), bFileSize);
    aclrtMalloc((void**)&bDevice, bFileSize, ACL_MEM_MALLOC_HUGE_FIRST);
    ReadFile("./input/x2_gm.bin", bFileSize, bHost, bFileSize);
    aclrtMemcpy(bDevice, bFileSize, bHost, bFileSize, ACL_MEMCPY_HOST_TO_DEVICE);

    uint8_t* cHost;
    uint8_t* cDevice;
    aclrtMallocHost((void**)(&cHost), cFileSize);
    aclrtMalloc((void**)&cDevice, cFileSize, ACL_MEM_MALLOC_HUGE_FIRST);

    mmad_custom<M, K, N, singleCoreM, baseM, baseK, baseN><<<numBlocks, nullptr, stream>>>(aDevice, bDevice, cDevice);
    aclrtSynchronizeStream(stream);

    aclrtMemcpy(cHost, cFileSize, cDevice, cFileSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", cHost, cFileSize);

    aclrtFree(aDevice);
    aclrtFreeHost(aHost);
    aclrtFree(bDevice);
    aclrtFreeHost(bHost);
    aclrtFree(cDevice);
    aclrtFreeHost(cHost);
    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();
    return 0;
}