/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file main.cpp
 * \brief Main implementation file for Ascend matrix multiplication kernel
 *        This file contains the kernel implementation and host-side setup code
 */

#include <iostream>
#include <vector>
#include <cmath>
#include <memory>
#include <random>
#include <iomanip>
#include <algorithm>
#include <cstring>
#include <cstdint>
#include <filesystem>
#include <sys/stat.h>
#include <fstream>
#include <fcntl.h>
#include <unistd.h>

#include "acl/acl.h"
#include "kernel_basic_intf.h"
#include "tiling/platform/platform_ascendc.h"
#include "include/tensor_api/tensor.h"

#define ERROR_LOG(fmt, args...) fprintf(stdout, "[ERROR]  " fmt "\n", ##args)
namespace AscendC::Te {
// Layout definitions for matrices A and B (NZ format by default, can be transposed to ZN)
static constexpr bool transA = false;
static constexpr bool transB = false;
template <typename T>
using MakeLayoutAL1 =
    AscendC::Std::conditional_t<transA, AscendC::Te::FrameLayoutFormat<AscendC::Te::ZNLayoutPtn, AscendC::Te::LayoutTraitDefault<T>>, AscendC::Te::FrameLayoutFormat<AscendC::Te::NZLayoutPtn, AscendC::Te::LayoutTraitDefault<T>>>;
template <typename T>
using MakeLayoutBL1 =
    AscendC::Std::conditional_t<transB, AscendC::Te::FrameLayoutFormat<AscendC::Te::ZNLayoutPtn, AscendC::Te::LayoutTraitDefault<T>>, AscendC::Te::FrameLayoutFormat<AscendC::Te::NZLayoutPtn, AscendC::Te::LayoutTraitDefault<T>>>;
} // namespace AscendC::Te

namespace tool {
// Memory and buffer configuration constants
constexpr static int32_t L0C_C0 = 16;
constexpr static uint64_t NUM_TWO = 2;
constexpr static uint64_t DOUBLE_BUFFER_COUNT = 2;                       // Double buffering for ping-pong operation
constexpr static int64_t L0A_SIZE = 64 * 1024;                           // L0A buffer size (64KB)
constexpr static int64_t TOTAL_L0C_SIZE = 256 * 1024;                    // Total L0C buffer size (256KB)
constexpr static uint64_t HALF_L0_SIZE = L0A_SIZE / DOUBLE_BUFFER_COUNT; // Half L0A for ping-pong

constexpr static uint64_t BASIC_BLOCK_SIZE_256 = 256UL;
constexpr static uint64_t DATA_SIZE_FP32 = 4UL;
constexpr static uint64_t RPC_WORKSIZE = 20UL;
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
constexpr static uint16_t BLOCK_BASE_M = 256;
constexpr static uint16_t BLOCK_BASE_N = 256;
constexpr static uint16_t AIC_SYNC_AIV_FLAG = 8;
constexpr static uint16_t FLAG_ID_MAX = 16;
constexpr static uint64_t BLOCK_BYTE_SIZE = 32UL;
constexpr static uint64_t BASIC_BLOCK_SIZE_16 = 16UL;
constexpr static uint16_t AIC_SYNC_AIV_MODE_4 = 4;

// Synchronization flag values
constexpr static uint16_t ZERO_FLAG = 0;  // First flag value
constexpr static uint16_t FIRST_FLAG = 1; // Second flag value

constexpr uint32_t NO_FINAL_ACCUMULATION = 2; // Enable unit flag (inner loops)
constexpr uint32_t FINAL_ACCUMULATION = 3;    // Enable unit flag for outer last iteration

template <typename T>
void FillRandomData(std::vector<T>& data, T min, T max);
float Bf16ToFloat(uint16_t h);
uint16_t FloatToBf16(float f);
template <typename T>
void ComputeGolden(
    int m, int k, int n, std::vector<T>& hostInput, std::vector<T>& hostWeight, std::vector<T>& goldenOutput);
template <typename T>
std::vector<uint64_t> Compare(std::vector<T>& hostOutput, std::vector<T>& goldenOutput);
__aicore__ inline uint64_t CeilDiv(uint64_t a, uint64_t b);
__aicore__ inline uint64_t CeilAlign(uint64_t a, uint64_t b);
bool CheckIsSk(int m, int n, int k, int numBlocks);
__aicore__ inline bool CheckIsSkScene(uint32_t tileIdx, uint32_t blockNum, uint32_t tileNum);
bool isSKScene(int m, int n, int blockNum);
inline bool ReadFile(const std::string& filePath, size_t& fileSize, void* buffer, size_t bufferSize);
inline bool WriteFile(const std::string& filePath, const void* buffer, size_t size);// Ceiling division
} // namespace tool

namespace matmul {
/**
 * @brief Matrix multiplication kernel for Ascend AI processor
 *
 * This kernel implements C = A * B using optimized memory hierarchy:
 * - Double buffering between GM -> L1 and L1 -> L0
 * - Tiled computation to fit in on-chip memory
 * - Multi-core parallelization
 *
 * @tparam T Data type (float in this implementation)
 * @param aGm Global memory pointer to matrix A (size m*k)
 * @param bGm Global memory pointer to matrix B (size k*n)
 * @param cGm Global memory pointer to output matrix C (size m*n)
 * @param m Rows of A and C
 * @param k Columns of A, rows of B
 * @param n Columns of B and C
 */
template <typename T>
__global__ __aicore__ __mix__(1, 2) void MatmulKernel(
    GM_ADDR aGm, GM_ADDR bGm, GM_ADDR cGm, GM_ADDR workspaceGm, uint32_t m, uint32_t k, uint32_t n)
{
    // Initialize tiling parameters for memory hierarchy
    uint64_t baseM = 256;
    uint64_t baseN = 256;
    uint64_t baseK = 128 / sizeof(T);
    uint64_t kL1 = 512 / sizeof(T);
    uint64_t mTileNum = tool::CeilDiv(m, baseM);
    uint64_t nTileNum = tool::CeilDiv(n, baseN);
    uint64_t tileNum = mTileNum * nTileNum;
    uint64_t kL1TileNum = tool::CeilDiv(k, kL1);
    uint64_t tailKL1 = k - (kL1TileNum - 1) * kL1;
    uint64_t tailBaseM = m - (mTileNum - 1) * baseM;
    uint64_t tailBaseN = n - (nTileNum - 1) * baseN;
    uint64_t l0cOffset = 0; // L0C buffer offset
    uint64_t tailKSingleCore = k;
    uint64_t kTileIdx = 0;
    uint64_t mnIdxInCurLoop = 1;

    uint64_t curBlockIdx = AscendC::GetBlockIdx();
    uint64_t blockNum = AscendC::GetBlockNum();

    uint64_t skKTileNum = 1;
    uint64_t skKSingleCore = 1;

    if(tileNum <= blockNum / tool::NUM_TWO) {
        skKTileNum = blockNum / tileNum;
        skKSingleCore = tool::CeilDiv(k, skKTileNum);
    } else {
        skKTileNum = blockNum / (tileNum % blockNum);
        skKSingleCore = tool::CeilDiv(k, skKTileNum);
        skKTileNum = tool::CeilDiv(k, skKSingleCore);
    }

    int64_t tailMNTileNum = tileNum < blockNum ? tileNum : tileNum % blockNum;
    uint64_t totalMNTileNumInDP = tileNum - tailMNTileNum;
    tileNum = totalMNTileNumInDP + tailMNTileNum * skKTileNum;
    int64_t tailSKTotalTileNum = tailMNTileNum * skKTileNum;
    uint64_t usedCoreNum = tileNum < blockNum ? tileNum : blockNum;

    // Double buffering indices
    uint64_t l0PingPong = 0;             // L0 buffer ping-pong index
    uint64_t l1PingPong = 0;             // L1 buffer ping-pong index
    uint64_t l1BufferAOffset[2] = {0UL}; // L1 buffer offsets for matrix A (ping/pong)
    uint64_t l1BufferBOffset[2] = {0UL}; // L1 buffer offsets for matrix B (ping/pong)

    struct CopyGm2UbParams {
        uint64_t offsetWorkspaceGM = 0;
        uint64_t kCnt = 0;
        uint64_t mBurstOri = 0;
        uint64_t mBurst = 0;
        uint64_t burstLen = 0;
        uint64_t srcGap = 0;
    };
    CopyGm2UbParams copyGm2UbParams_;

    struct CopyUb2GmParams {
        uint64_t offsetCGm = 0;
        uint64_t mLength = 0;
        uint64_t burstLen = 0;
        uint64_t dstGap = 0;
        uint64_t srcGap = 0;
    };
    CopyUb2GmParams copyUb2GmParams_;

    __gm__ float* workspaceGmAddr = reinterpret_cast<__gm__ float*>(workspaceGm);

    AscendC::GlobalTensor<float> workspaceGlobal_;
    workspaceGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(workspaceGm));

    AscendC::GlobalTensor<bfloat16_t> cGlobal_;
    cGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ bfloat16_t*>(cGm));

    if ASCEND_IS_AIC {
        auto layoutA = AscendC::Te::MakeFrameLayout<AscendC::Te::NDExtLayoutPtn>(m, k);
        auto layoutB = AscendC::Te::MakeFrameLayout<AscendC::Te::NDExtLayoutPtn>(k, n);
        auto layoutC = AscendC::Te::MakeFrameLayout<AscendC::Te::NDExtLayoutPtn>(m, n);

        auto tensorAgm = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::GM>(reinterpret_cast<__gm__ T*>(aGm)), layoutA);
        auto tensorBgm = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::GM>(reinterpret_cast<__gm__ T*>(bGm)), layoutB);
        auto tensorCgm = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::GM>(reinterpret_cast<__gm__ T*>(cGm)), layoutC);

        // Initialize hardware event flags for synchronization
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(tool::ZERO_FLAG);  // MTE1->MTE2 sync
        AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(tool::FIRST_FLAG); // Second sync flag
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(tool::ZERO_FLAG);     // M->MTE1 sync
        AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(tool::FIRST_FLAG);    // Second M sync flag

        // Cap core usage to total tiles (avoid idle cores when workload is insufficient)
        if (curBlockIdx >= usedCoreNum) {
            AscendC::CrossCoreSetFlag<tool::AIC_SYNC_AIV_MODE_4, PIPE_FIX>(tool::AIC_SYNC_AIV_FLAG);
            AscendC::CrossCoreSetFlag<tool::AIC_SYNC_AIV_MODE_4, PIPE_FIX>(tool::AIC_SYNC_AIV_FLAG + tool::FLAG_ID_MAX);
            return;
        }
        AscendC::SetMMLayoutTransform(true);

        for (uint64_t tileIdx = curBlockIdx; tileIdx < tileNum; tileIdx += blockNum) {
            // Decompose the core-level tile index into K-tile and (M,N)-tile components
            // skKTileNum: number of K-dimension tiles
            int64_t tmpTileIdx = tileIdx;
            if (!tool::CheckIsSkScene(0, blockNum, tileNum)) { // SK Preload in DP+SK
                if (tileIdx % usedCoreNum < tailSKTotalTileNum &&
                    (tool::CeilDiv(tileIdx + 1, usedCoreNum) == (tool::CeilDiv(tileNum, usedCoreNum) - 1))) {
                    tmpTileIdx = tileIdx + usedCoreNum;
                } else if (
                    tileIdx % usedCoreNum < tailSKTotalTileNum &&
                    (tool::CeilDiv(tileIdx + 1, usedCoreNum) == tool::CeilDiv(tileNum, usedCoreNum))) {
                    tmpTileIdx = tileIdx - usedCoreNum;
                }
            }

            uint64_t curKTileNum = tool::CheckIsSkScene(tmpTileIdx, blockNum, tileNum) ? skKTileNum : 1;
            if (tool::CheckIsSkScene(tmpTileIdx, blockNum, tileNum)) { // SK scene
                kTileIdx = (tmpTileIdx % usedCoreNum) % curKTileNum;
                mnIdxInCurLoop = (tmpTileIdx % usedCoreNum) / curKTileNum + totalMNTileNumInDP;
            } else { // DP scene
                kTileIdx = 0;
                mnIdxInCurLoop = tmpTileIdx / curKTileNum;
            }

            // Further decompose (M,N) tile index into M and N dimensions
            uint64_t mTileIdx = mnIdxInCurLoop / nTileNum;
            uint64_t nTileIdx = mnIdxInCurLoop % nTileNum;

            // Get actual sizes (handle tail tiles which may be smaller than base size)
            int64_t curM = mTileIdx == (mTileNum - 1) ? tailBaseM : baseM;
            int64_t curN = nTileIdx == (nTileNum - 1) ? tailBaseN : baseN;
            int64_t curSK = kTileIdx == (curKTileNum - 1) ? k - (curKTileNum - 1) * skKSingleCore : skKSingleCore;

            // Calculate offset into workspace GM memory for this core's output tile
            // Layout: each (M,N) tile contains skKTileNum contiguous K-tile results
            int64_t offsetWorkspace = (((tmpTileIdx % blockNum) / curKTileNum) * curKTileNum + kTileIdx) * tool::BLOCK_BASE_M * tool::BLOCK_BASE_N;

            // Define layout for the workspace tensor (2D matrix of size curM x curN)
            auto layoutWorkspace = AscendC::Te::MakeFrameLayout<AscendC::Te::NDExtLayoutPtn>(curM, curN);

            // Create tensor representing this core's output region in GM workspace
            auto gmWorkSpace =
                AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::GM>(workspaceGmAddr + offsetWorkspace), layoutWorkspace);

            auto tensorAGmBlock = tensorAgm.Slice(
                AscendC::Te::MakeCoord(mTileIdx * baseM, kTileIdx * skKSingleCore),
                AscendC::Te::MakeShape(curM, curSK));
            auto tensorBGmBlock = tensorBgm.Slice(
                AscendC::Te::MakeCoord(kTileIdx * skKSingleCore, nTileIdx * baseN),
                AscendC::Te::MakeShape(curSK, curN));
            auto tensorCGmBlock = tensorCgm.Slice(
                AscendC::Te::MakeCoord(mTileIdx * baseM, nTileIdx * baseN), AscendC::Te::MakeShape(curM, curN));

            auto layoutL0C = AscendC::Te::MakeFrameLayout<AscendC::Te::NZLayoutPtn, AscendC::Std::Int<tool::L0C_C0>>(curM, curN);
            auto tensorL0C = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::L0C, float>(l0cOffset), layoutL0C);

            kL1TileNum = tool::CeilDiv(curSK, kL1);

            for (uint64_t iter0 = 0; iter0 < kL1TileNum; ++iter0) {
                uint64_t l1BufId = l1PingPong & 1;                         // Current L1 buffer ID
                AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BufId); // Wait for previous transfer

                auto curGmBKL1 = (iter0 + 1 == kL1TileNum) ? (curSK - iter0 * kL1) : kL1;
                auto curGmAKL1 = curGmBKL1;

                // Copy GM to L1 buffers with double buffering
                uint64_t AOffsetL1 = baseM * kL1 * sizeof(T);
                uint64_t BOffsetL1 = baseN * kL1 * sizeof(T);
                l1BufferAOffset[l1BufId] = l1BufId * AOffsetL1;
                l1BufferBOffset[l1BufId] = tool::DOUBLE_BUFFER_COUNT * AOffsetL1 + l1BufId * BOffsetL1;

                auto copyGM2L1 = AscendC::Te::MakeCopy(AscendC::Te::CopyGM2L1{});
                auto layoutAL1 = AscendC::Te::MakeLayoutAL1<T>{}(curM, curGmAKL1);
                auto tensorAL1 =
                    AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::L1, T>(l1BufferAOffset[l1BufId]), layoutAL1);
                auto tensorAGmTile =
                    tensorAGmBlock.Slice(AscendC::Te::MakeCoord(0, iter0 * kL1), AscendC::Te::MakeShape(curM, curGmAKL1));
                AscendC::Te::Copy(copyGM2L1, tensorAL1, tensorAGmTile);

                auto layoutBL1 = AscendC::Te::MakeLayoutBL1<T>{}(curGmBKL1, curN);
                auto tensorBL1 =
                    AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::L1, T>(l1BufferBOffset[l1BufId]), layoutBL1);
                auto tensorBGmTile =
                    tensorBGmBlock.Slice(AscendC::Te::MakeCoord(iter0 * kL1, 0), AscendC::Te::MakeShape(curGmBKL1, curN));
                AscendC::Te::Copy(copyGM2L1, tensorBL1, tensorBGmTile);

                AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BufId);
                AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1BufId);

                uint64_t kL0IterNum = tool::CeilDiv(curGmBKL1, baseK);
                uint64_t tailKL0 = curGmBKL1 - (kL0IterNum - 1) * baseK;

                for (uint16_t iter1 = 0; iter1 < kL0IterNum; ++iter1) {
                    uint64_t l0BufId = l0PingPong & 1;                // Current L0 buffer ID
                    uint64_t l0Offset = tool::HALF_L0_SIZE * l0BufId; // Offset in L0
                    uint64_t curKL0 = (iter1 + 1 == kL0IterNum) ? tailKL0 : baseK;

                    AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BufId);

                    auto copyL12L0A = AscendC::Te::MakeCopy(AscendC::Te::CopyL12L0A{});
                    auto copyL12L0B = AscendC::Te::MakeCopy(AscendC::Te::CopyL12L0B{});
                    auto layoutAL0 = AscendC::Te::MakeFrameLayout<AscendC::Te::NZLayoutPtn, AscendC::Te::LayoutTraitDefault<T>>(curM, curKL0);
                    auto tensorAL0 = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::L0A, T>(l0Offset), layoutAL0);
                    auto tensorAL1Tile =
                        tensorAL1.Slice(AscendC::Te::MakeCoord(0, iter1 * baseK), AscendC::Te::MakeShape(curM, curKL0));
                    AscendC::Te::Copy(copyL12L0A, tensorAL0, tensorAL1Tile);

                    auto layoutBL0 = AscendC::Te::MakeFrameLayout<AscendC::Te::ZNLayoutPtn, AscendC::Te::LayoutTraitDefault<T>>(curKL0, curN);
                    auto tensorBL0 = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::L0B, T>(l0Offset), layoutBL0);
                    auto tensorBL1Tile =
                        tensorBL1.Slice(AscendC::Te::MakeCoord(iter1 * baseK, 0), AscendC::Te::MakeShape(curKL0, curN));
                    AscendC::Te::Copy(copyL12L0B, tensorBL0, tensorBL1Tile);

                    AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(l0BufId);
                    AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(l0BufId);

                    AscendC::MmadParams para;
                    para.cmatrixInitVal = (iter1 == 0 && iter0 == 0);
                    para.m = curM;
                    para.n = curN;
                    para.k = curKL0;

                    if (iter1 == (kL0IterNum - 1) && iter0 == (kL1TileNum - 1)) {
                        para.unitFlag = tool::FINAL_ACCUMULATION;
                    } else {
                        para.unitFlag = tool::NO_FINAL_ACCUMULATION;
                    }

                    auto MadOp = AscendC::Te::MakeMmad(AscendC::Te::MmadOperation{}, AscendC::Te::MmadTraitDefault{});
                    AscendC::Te::Mmad(MadOp, tensorL0C, tensorAL0, tensorBL0, para);

                    AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BufId);
                    l0PingPong++; // Toggle L0 buffer
                }
                if (iter0 + 1 == kL1TileNum) {
                    auto CopyL0C2GM = AscendC::Te::MakeCopy(AscendC::Te::CopyL0C2GM{});
                    if (tool::CheckIsSkScene(tmpTileIdx, blockNum, tileNum) && tailMNTileNum <= (blockNum / tool::NUM_TWO)) {
                        AscendC::Te::Copy(
                            CopyL0C2GM, gmWorkSpace, tensorL0C, AscendC::Te::FixpipeParams(tool::FINAL_ACCUMULATION));
                    } else {
                        AscendC::Te::Copy(CopyL0C2GM, tensorCGmBlock, tensorL0C, AscendC::Te::FixpipeParams(tool::FINAL_ACCUMULATION));
                    }
                }
                AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BufId);
                l1PingPong++; // Toggle L1 buffer
            }
            if (tmpTileIdx + usedCoreNum >= tileNum) {
                AscendC::CrossCoreSetFlag<tool::AIC_SYNC_AIV_MODE_4, PIPE_FIX>(tool::AIC_SYNC_AIV_FLAG);
                AscendC::CrossCoreSetFlag<tool::AIC_SYNC_AIV_MODE_4, PIPE_FIX>(tool::AIC_SYNC_AIV_FLAG + tool::FLAG_ID_MAX);
            }
        }
        // Final synchronization waits
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(tool::ZERO_FLAG);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(tool::ZERO_FLAG);
        AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(tool::FIRST_FLAG);
        AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(tool::FIRST_FLAG);
        AscendC::SetMMLayoutTransform(false);
    }

    if ASCEND_IS_AIV {
        // Number of (M,N) tiles assigned to the last round of cores (remainder distribution)
        uint64_t lastLoopTotalCnt = (mTileNum * nTileNum % usedCoreNum) * skKTileNum;

        // Idle cores exit early after synchronization
        uint64_t curBlockIdxInAiv = AscendC::GetBlockIdx();
        if (curBlockIdxInAiv >= lastLoopTotalCnt * AscendC::GetTaskRation()) {
            AscendC::CrossCoreWaitFlag<tool::AIC_SYNC_AIV_MODE_4, PIPE_MTE3>(tool::AIC_SYNC_AIV_FLAG);
            AscendC::SyncAll();
            return;
        }

        // Active cores synchronize before proceeding
        AscendC::CrossCoreWaitFlag<tool::AIC_SYNC_AIV_MODE_4, PIPE_MTE3>(tool::AIC_SYNC_AIV_FLAG);
        AscendC::SyncAll();

        // Decompose balanced workload index back to (M,N) tile and K-tile
        uint64_t newBlockIdx = AscendC::GetBlockIdx() / (AscendC::GetTaskRation() * skKTileNum);
        uint64_t cGmIndex = newBlockIdx + (mTileNum * nTileNum - (mTileNum * nTileNum) % blockNum);
        uint64_t kTileIdx = AscendC::GetBlockIdx() % (AscendC::GetTaskRation() * skKTileNum);
        uint64_t mTileIdx = cGmIndex / nTileNum;
        uint64_t nTileIdx = cGmIndex % nTileNum;

        uint64_t curM = mTileIdx != (mTileNum - 1) ? baseM : (m - (mTileNum - 1) * baseM);
        uint64_t curN = nTileIdx != (nTileNum - 1) ? baseN : (n - (nTileNum - 1) * baseN);

        if (!tool::CheckIsSkScene(0, blockNum, tileNum)) {
            AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(tool::ZERO_FLAG);
            AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(tool::ZERO_FLAG);
        }

        uint64_t aivMte2Num = tool::CheckIsSkScene(0, blockNum, tileNum) ? AscendC::GetTaskRation() : AscendC::BLOCK_CUBE;

        for (uint64_t index = 0; index < aivMte2Num; ++index) {
            uint64_t mBurstBase =
                tool::CeilAlign(tool::CeilDiv(curM, skKTileNum * AscendC::GetTaskRation()), tool::CeilDiv(tool::BLOCK_BYTE_SIZE, curN));
            uint64_t mBurstCnt = tool::CeilDiv(curM, mBurstBase);
            uint64_t mBurstTail = curM - (mBurstCnt - 1) * mBurstBase;
            if (kTileIdx >= mBurstCnt) {
                copyGm2UbParams_.mBurstOri = 0;
            } else {
                copyGm2UbParams_.mBurstOri = (kTileIdx == mBurstCnt - 1) ? mBurstTail : mBurstBase;
            }

            copyGm2UbParams_.kCnt = skKTileNum;
            copyGm2UbParams_.mBurst = tool::CeilDiv(copyGm2UbParams_.mBurstOri, aivMte2Num);
            // Calculate init address of workspace for moving into UB.
            copyGm2UbParams_.offsetWorkspaceGM = newBlockIdx * skKTileNum * tool::BLOCK_BASE_M * tool::BLOCK_BASE_N +
                                                 (kTileIdx * mBurstBase + copyGm2UbParams_.mBurst * index) * curN;
            // Calculate init address of GM for moving out to GM.
            copyUb2GmParams_.offsetCGm = nTileIdx * baseN + mTileIdx * baseM * n +
                                         (kTileIdx * mBurstBase + copyGm2UbParams_.mBurst * index) * n;
            uint64_t singleCnt = 1;
            if (index == singleCnt - 1) {
                copyGm2UbParams_.mBurst = copyGm2UbParams_.mBurstOri - (singleCnt - 1) * copyGm2UbParams_.mBurst;
            } else if (index >= singleCnt) {
                copyGm2UbParams_.mBurst = 0;
            }
            // datasize for moving in ub, align to 32B
            copyGm2UbParams_.burstLen = tool::CeilAlign(copyGm2UbParams_.mBurst * curN, tool::BASIC_BLOCK_SIZE_16);
            // gap of src between cur burst and next burst
            copyGm2UbParams_.srcGap = tool::BLOCK_BASE_M * tool::BLOCK_BASE_N - copyGm2UbParams_.burstLen;

            // args for ub2gm
            copyUb2GmParams_.mLength = copyGm2UbParams_.mBurst;
            copyUb2GmParams_.burstLen = curN;
            copyUb2GmParams_.dstGap = n - curN;
            copyUb2GmParams_.srcGap = 0;

            AscendC::LocalTensor<float> ubAddTensor{AscendC::TPosition::VECIN, 0, AscendC::TOTAL_UB_SIZE};
            AscendC::DataCopyExtParams dataCopyExtParams{
                static_cast<uint16_t>(copyGm2UbParams_.kCnt), static_cast<uint32_t>(copyGm2UbParams_.burstLen * sizeof(float)),
                static_cast<uint32_t>(copyGm2UbParams_.srcGap * sizeof(float)), 0, 0};
            if (copyGm2UbParams_.mBurst == 0) {
                return;
            }
            AscendC::DataCopyPad<float>(
                ubAddTensor, workspaceGlobal_[copyGm2UbParams_.offsetWorkspaceGM], dataCopyExtParams, {false, 0, 0, 0});
            AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(tool::ZERO_FLAG);
            AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(tool::ZERO_FLAG);

            for (uint64_t i = 1; i < skKTileNum; ++i) {
                Add(ubAddTensor, ubAddTensor, ubAddTensor[i * copyGm2UbParams_.burstLen], copyGm2UbParams_.burstLen);
            }

            AscendC::DataCopyExtParams ub2gmExtParams{
                static_cast<uint16_t>(copyUb2GmParams_.mLength),
                static_cast<uint32_t>(copyUb2GmParams_.burstLen * sizeof(bfloat16_t)),
                static_cast<uint32_t>(copyUb2GmParams_.srcGap * sizeof(bfloat16_t) / tool::BLOCK_BYTE_SIZE),
                static_cast<uint32_t>(copyUb2GmParams_.dstGap * sizeof(bfloat16_t)), 0};

            AscendC::LocalTensor<bfloat16_t> ubCastDst{AscendC::TPosition::VECIN, 0, AscendC::TOTAL_UB_SIZE};
            Cast(ubCastDst, ubAddTensor, AscendC::RoundMode::CAST_RINT, copyGm2UbParams_.burstLen);
            AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(tool::ZERO_FLAG);
            AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(tool::ZERO_FLAG);
            AscendC::DataCopyPad<bfloat16_t, AscendC::PaddingMode::Compact>(cGlobal_[copyUb2GmParams_.offsetCGm], ubCastDst, ub2gmExtParams);
        }
    }
}

} // namespace matmul

// Utility macro for condition checking with error message
#define CHECK_COND(cond, message, return_expr)              \
    do {                                                    \
        if (!(cond)) {                                      \
            std::cerr << "ERROR: " << message << std::endl; \
            return_expr;                                    \
        }                                                   \
    } while (0)

// Print command-line usage help
void printUsage(const std::string& programName)
{
    std::cerr << "Usage: " << programName << " m k n" << std::endl;
    std::cerr << "Args: " << std::endl;
    std::cerr << "  m: row of matrix A" << std::endl;
    std::cerr << "  k: col of matrix A" << std::endl;
    std::cerr << "  n: col of matrix B" << std::endl;
    std::cerr << "Example: " << programName << " 100 50 200" << std::endl;
}

// Brief parses and validates command-line arguments
void parseArguments(int argc, char* argv[], int& m, int& k, int& n)
{
    if (argc >= 2 && (std::string(argv[1]) == "--help" || std::string(argv[1]) == "-h")) {
        printUsage(argv[0]);
        exit(1);
    }
    if (argc < 4) {
        throw std::invalid_argument("ERROR: Lacks Arguments");
    }
    try {
        m = std::stoi(argv[1]);
        k = std::stoi(argv[2]);
        n = std::stoi(argv[3]);
    } catch (const std::invalid_argument& e) {
        throw std::invalid_argument("ERROR: m k n must be Integer");
    }

    if (m <= 0 || k <= 0 || n <= 0) {
        throw std::invalid_argument("ERROR: m k n must be positive");
    }
}

/**
 * @brief Main function - host-side setup and execution
 *
 * This function:
 * 1. Parses command line arguments
 * 2. Initializes Ascend Computing Language (ACL) resources
 * 3. Allocates and initializes host/device memory
 * 4. Launches the kernel
 * 5. Verifies results against CPU reference
 * 6. Cleans up resources
 */
int main(int argc, char* argv[])
{
    using namespace tool;
    int m, k, n;
    try {
        parseArguments(argc, argv, m, k, n);
    } catch (const std::exception& e) {
        std::cerr << e.what() << std::endl;
        printUsage(argv[0]);
        return 1;
    }

    // Initialize ACL (Ascend Computing Language) resources
    int32_t deviceId = 0;
    aclrtStream stream;
    auto ret = aclInit(nullptr);
    CHECK_COND(ret == ACL_SUCCESS, "aclInit failed.", return 1);
    ret = aclrtSetDevice(deviceId);
    CHECK_COND(ret == ACL_SUCCESS, "aclrtSetDevice failed.", return 1);
    ret = aclrtCreateStream(&stream);
    CHECK_COND(ret == ACL_SUCCESS, "aclrtCreateStream failed.", return 1);

    std::vector<uint16_t> hostInput(m * k, 0);
    std::vector<uint16_t> hostWeight(k * n, 0);
    std::vector<uint16_t> hostOutput(m * n, 0);
    std::vector<uint16_t> goldenOutput(m * n, 0);

    auto sizeInput = hostInput.size() * sizeof(uint16_t);
    auto sizeWeight = hostWeight.size() * sizeof(uint16_t);
    auto sizeOutput = hostOutput.size() * sizeof(uint16_t);
    
    std::string cmd = "python3 gen_data.py " + std::to_string(m) + " " + std::to_string(k) + " " + std::to_string(n);
    system(cmd.c_str());

    std::string baseDir = std::filesystem::current_path();
    std::string inputDir = baseDir + "/input";
    std::string outputDir = baseDir + "/output";
    ReadFile(inputDir + "/input_a.bin", sizeInput, hostInput.data(), sizeInput);
    ReadFile(inputDir + "/input_b.bin", sizeWeight, hostWeight.data(), sizeWeight);

    auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance();
    CHECK_COND(ascendcPlatform != nullptr, "get ascendcPlatform failed.", return 1);
    uint32_t numBlocks = ascendcPlatform->GetCoreNumAic(); // Number of AI cores

    uint8_t *deviceInput = nullptr;
    uint8_t *deviceWeight = nullptr;
    uint8_t *deviceOutput = nullptr;
    uint8_t *dWorkSpace = nullptr;
    size_t sizeWorkSpace = numBlocks * tool::BASIC_BLOCK_SIZE_256 * tool::BASIC_BLOCK_SIZE_256 * tool::DATA_SIZE_FP32 +
                           tool::RPC_WORKSIZE * tool::MB_SIZE;
    std::unique_ptr<void, aclError (*)(void*)> deviceInputPtr(deviceInput, aclrtFree);
    ret = aclrtMalloc((void**)&deviceInput, sizeInput, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_COND(ret == ACL_SUCCESS, "aclrtMalloc deviceInput failed.", return 1);

    std::unique_ptr<void, aclError (*)(void*)> deviceWeightPtr(deviceWeight, aclrtFree);
    ret = aclrtMalloc((void**)&deviceWeight, sizeWeight, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_COND(ret == ACL_SUCCESS, "aclrtMalloc deviceWeight failed.", return 1);

    std::unique_ptr<void, aclError (*)(void*)> deviceOutputPtr(deviceOutput, aclrtFree);
    ret = aclrtMalloc((void**)&deviceOutput, sizeOutput, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_COND(ret == ACL_SUCCESS, "aclrtMalloc deviceOutput failed.", return 1);

    ret = aclrtMemcpy(deviceInput, sizeInput, hostInput.data(), sizeInput, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_COND(ret == ACL_SUCCESS, "aclrtMemcpy deviceInput failed.", return 1);
    ret = aclrtMemcpy(deviceWeight, sizeWeight, hostWeight.data(), sizeWeight, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_COND(ret == ACL_SUCCESS, "aclrtMemcpy deviceWeight failed.", return 1);

    std::unique_ptr<void, aclError (*)(void*)> deviceWorkSpace(dWorkSpace, aclrtFree);
    CHECK_COND(
        aclrtMalloc((void**)&dWorkSpace, sizeWorkSpace, ACL_MEM_MALLOC_HUGE_ONLY) == ACL_SUCCESS,
        "Failed to allocate the device buffer for WorkSpace.", return 1);

    matmul::MatmulKernel<bfloat16_t><<<numBlocks, nullptr, stream>>>(deviceInput, deviceWeight, deviceOutput, dWorkSpace, m, k, n);

    ret = aclrtSynchronizeStream(stream);
    CHECK_COND(ret == ACL_SUCCESS, "aclrtSynchronizeStream failed.", return 1);

    ret = aclrtMemcpy(hostOutput.data(), sizeOutput, deviceOutput, sizeOutput, ACL_MEMCPY_DEVICE_TO_HOST);
    CHECK_COND(ret == ACL_SUCCESS, "aclrtMemcpy deviceOutput failed.", return 1);

    auto toFloat = [](const std::vector<uint16_t>& src, std::vector<float>& dst) {
        std::transform(src.begin(), src.end(), dst.begin(), Bf16ToFloat);
    };

    WriteFile(outputDir + "/npu_out.bin", hostOutput.data(), sizeOutput);

    cmd = "python3 verify_result.py " + std::to_string(m) + " " + std::to_string(n);
    if (std::system(cmd.c_str()) != 0) {
        return 1;
    }

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();
    return 0;
}

namespace tool {
/**
 * @brief Fill a vector with random data
 *
 * @tparam T Data type (integral or floating point)
 * @param data Vector to fill
 * @param min Minimum value
 * @param max Maximum value
 */
template <typename T>
void FillRandomData(std::vector<T>& data, T min, T max)
{
    std::random_device rd;
    std::mt19937 gen(rd());
    if constexpr (std::is_integral<T>::value) {
        std::uniform_int_distribution<T> dist(min, max);
        for (auto& elem : data)
            elem = dist(gen);
    } else if constexpr (std::is_floating_point<T>::value) {
        std::uniform_real_distribution<T> dist(min, max);
        for (auto& elem : data)
            elem = dist(gen);
    }
}

/**
 * @brief Compute matrix multiplication on CPU as golden reference
 *
 * @tparam T Data type
 * @param m Rows of A
 * @param k Columns of A / Rows of B
 * @param n Columns of B
 * @param hostInput Matrix A
 * @param hostWeight Matrix B
 * @param goldenOutput Output matrix C (reference)
 */
template <typename T>
void ComputeGolden(
    int m, int k, int n, std::vector<T>& hostInput, std::vector<T>& hostWeight, std::vector<T>& goldenOutput)
{
    for (uint32_t row = 0; row < m; ++row) {
        for (uint32_t col = 0; col < n; ++col) {
            size_t offsetGolden = row * n + col;
            T sum = 0;
            for (uint32_t iter = 0; iter < k; ++iter) {
                size_t offsetInput = row * k + iter;
                size_t offsetWeight = iter * n + col;
                sum += hostInput[offsetInput] * hostWeight[offsetWeight];
            }
            goldenOutput[offsetGolden] = sum;
        }
    }
}

/**
 * @brief Compare kernel output with golden reference
 *
 * @tparam T Data type
 * @param hostOutput Kernel output
 * @param goldenOutput CPU reference
 * @return std::vector<uint64_t> Indices where values differ beyond tolerance
 */
template <typename T>
std::vector<uint64_t> Compare(std::vector<T>& hostOutput, std::vector<T>& goldenOutput)
{
    std::vector<uint64_t> errorIndices;
    const float rtol = 1.0f / 256; // Relative tolerance for float comparison
    for (uint64_t i = 0; i < hostOutput.size(); ++i) {
        T actualValue = hostOutput[i];
        T expectValue = goldenOutput[i];
        T diff = std::fabs(actualValue - expectValue);
        if (diff > rtol * std::max(1.0f, std::fabs(expectValue))) {
            errorIndices.push_back(i);
        }
    }
    return errorIndices;
}

/**
 * @brief Ceiling division for integer arithmetic
 */
__aicore__ inline uint64_t CeilDiv(uint64_t a, uint64_t b)
{
    if (b == 0) {
        return a;
    }
    return (a + b - 1) / b;
}

__aicore__ inline uint64_t CeilAlign(uint64_t a, uint64_t b)
{
    return CeilDiv(a, b) * b;
}

/**
 * @brief Convert a 16-bit brain floating-point (bfloat16) value to a 32-bit float
 * @param h 16-bit bfloat16 value stored in uint16_t format
 * @return The converted 32-bit floating-point value
 */
float Bf16ToFloat(uint16_t h)
{
    uint32_t sign = (h & 0x8000U) ? 0x80000000U : 0x00000000U;
    uint32_t exponent = (h >> 7) & 0x00FFU;
    uint32_t mantissa = h & 0x007FU;
    uint32_t f_bits = sign | (exponent << 23) | (mantissa << (23 - 7));
    return *reinterpret_cast<float*>(&f_bits);
}

/**
 * @brief Convert a 32-bit float to a 16-bit brain floating-point (bfloat16) value
 * @param f 32-bit floating-point value to convert
 * @return The converted 16-bit bfloat16 value stored in uint16_t format (truncated rounding)
 */
uint16_t FloatToBf16(float f)
{
    uint32_t f_bits;
    std::memcpy(&f_bits, &f, sizeof(f_bits));

    // Extract the high 16 bits (simple truncation)
    return static_cast<uint16_t>(f_bits >> 16);
}

__aicore__ inline bool CheckIsSkScene(uint32_t tileIdx, uint32_t blockNum, uint32_t tileNum) {
    return tool::CeilDiv((tileIdx + 1), blockNum) == tool::CeilDiv(tileNum, blockNum);
}

/**
 * @brief Check if Split-K scheduling should be used for tail tiles in matrix multiplication
 * @param m Number of rows in output matrix (M dimension)
 * @param n Number of columns in output matrix (N dimension)
 * @param k Inner dimension (K dimension) - reserved for API consistency
 * @param blockNum Number of available compute blocks
 * @return true if Split-K scheduling should be used for tail tiles, false otherwise
 */
bool isSKScene(int m, int n, int blockNum) {
    const int TILE_SIZE = 256;
    
    // Total number of 256x256 tiles
    int64_t tileNum = ((m + TILE_SIZE - 1) / TILE_SIZE) * ((n + TILE_SIZE - 1) / TILE_SIZE);
    
    // Remaining tiles after distributing full blocks
    int64_t tailTiles = (tileNum < blockNum) ? tileNum : tileNum % blockNum;
    
    // Use SK scene if tail tiles are less than half the blocks
    return (tailTiles > 0) && (tailTiles <= (blockNum / 2));
}

inline bool ReadFile(const std::string& filePath, size_t& fileSize, void* buffer, size_t bufferSize)
{
    struct stat sBuf;
    int fileStatus = stat(filePath.data(), &sBuf);
    if (fileStatus == -1) {
        ERROR_LOG("failed to get file");
        return false;
    }
    if (S_ISREG(sBuf.st_mode) == 0) {
        ERROR_LOG("%s is not a file, please enter a file", filePath.c_str());
        return false;
    }

    std::ifstream file;
    file.open(filePath, std::ios::binary);
    if (!file.is_open()) {
        ERROR_LOG("Open file failed. path = %s", filePath.c_str());
        return false;
    }

    std::filebuf* buf = file.rdbuf();
    size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in);
    if (size == 0) {
        ERROR_LOG("file size is 0");
        file.close();
        return false;
    }
    if (size > bufferSize) {
        ERROR_LOG("file size is larger than buffer size");
        file.close();
        return false;
    }
    buf->pubseekpos(0, std::ios::in);
    buf->sgetn(static_cast<char*>(buffer), size);
    fileSize = size;
    file.close();
    return true;
}

/**
 * @brief Write data to file
 * @param [in] filePath: file path
 * @param [in] buffer: data to write to file
 * @param [in] size: size to write
 * @return write result
 */
inline bool WriteFile(const std::string& filePath, const void* buffer, size_t size)
{
    if (buffer == nullptr) {
        ERROR_LOG("Write file failed. buffer is nullptr");
        return false;
    }

    int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE);
    if (fd < 0) {
        ERROR_LOG("Open file failed. path = %s", filePath.c_str());
        return false;
    }

    size_t writeSize = write(fd, buffer, size);
    (void)close(fd);
    if (writeSize != size) {
        ERROR_LOG("Write file Failed.");
        return false;
    }

    return true;
}

} // namespace tool