/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file CopyInL1.h
 * \brief
 */
#ifndef COPYINL1_H
#define COPYINL1_H

enum class KVLAYOUT
{
    BNBD, // [blockNums, headNum, blockSize, headDim]
    BBH, // [blockNums, blockSize, headNum * headDim]
    NZ // [blockNums, headNum, d1, blockSize, d0], d1 = headDim / d0, d0 = 32 (block byte) / sizeof(KV_T)
};

struct CopyParam{
    uint32_t width;
    uint32_t height;
    uint32_t orgWidth;
};

struct PAShape{
    uint32_t blockNum;
    uint32_t blockSize;
    uint32_t headNum; // 一般为kv的head num
    uint32_t headDim; // mla下rope为64, 非rope为512
    uint32_t maxblockNumPerBatch; // block table 每一行的最大个数
    uint32_t actHeadDim; // 实际拷贝col大小,考虑到N切块 s*d, 对应d
    uint32_t copyRowNum;
    uint32_t copyRowNumAlign;
    uint32_t pageStride;
};

struct Position{
    uint32_t bIdx;
    uint32_t n2Idx;
    uint32_t s2Offset;
    uint32_t dIdx; // N轴被切,对应D轴被切
};

template<typename L1Type>
__aicore__ inline void GmCopyInToL1(LocalTensor<L1Type>& L1Tensor, GlobalTensor<L1Type>& GmTensor, const CopyParam& mmCopyParam)
{
    Nd2NzParams Gm2L1Nd2NzParams;
    Gm2L1Nd2NzParams.ndNum = 1; // ND矩阵的个数
    Gm2L1Nd2NzParams.nValue = mmCopyParam.height; // 单个ND矩阵的实际行数,单位为元素个数
    Gm2L1Nd2NzParams.dValue = mmCopyParam.width; // 单个ND矩阵的实际列数(vD),单位为元素个数
    Gm2L1Nd2NzParams.srcNdMatrixStride = 0; // 相邻ND矩阵起始地址之间的偏移, 单位为元素个数
    Gm2L1Nd2NzParams.srcDValue = mmCopyParam.orgWidth; // 同一个ND矩阵中相邻行起始地址之间的偏移, 单位为元素个数
    Gm2L1Nd2NzParams.dstNzC0Stride = (Gm2L1Nd2NzParams.nValue + 15) >> 4 << 4; // 转换为NZ矩阵后,相邻Block起始地址之间的偏移, 单位为Block个数
    Gm2L1Nd2NzParams.dstNzNStride = 1; // 转换为NZ矩阵后,ND之间相邻两行在NZ矩阵中起始地址之间的偏移, 单位为Block个数
    Gm2L1Nd2NzParams.dstNzMatrixStride = 0; // 两个NZ矩阵,起始地址之间的偏移, 单位为元素数量
    DataCopy(L1Tensor, GmTensor, Gm2L1Nd2NzParams);
}

// 场景:key、value GM to L1
// GM按ND格式存储
// L1按NZ格式存储
// GM的行、列、列的stride(D or ND)BNSD 和 BSH的区别
template<typename L1Type>
__aicore__ inline void DataCopyGmNDToL1(LocalTensor<L1Type>& l1Tensor, GlobalTensor<L1Type>& gmTensor,
                                        uint32_t rowAct, 
                                        uint32_t rowAlign,
                                        uint32_t col, // D
                                        uint32_t colStride) // D or N*D
{
    Nd2NzParams nd2nzPara;
    nd2nzPara.ndNum = 1; 
    nd2nzPara.nValue = rowAct; // 行数

    nd2nzPara.dValue = col; 
    nd2nzPara.srcDValue = colStride; 
    nd2nzPara.dstNzC0Stride = rowAlign; 
    nd2nzPara.dstNzNStride = 1; 
    nd2nzPara.srcNdMatrixStride = 0; 
    nd2nzPara.dstNzMatrixStride = 0; 
    DataCopy(l1Tensor, gmTensor, nd2nzPara);
}

template<typename L1Type>
__aicore__ inline void DataCopyGmScaleNDToL1(LocalTensor<L1Type>& l1Tensor, GlobalTensor<L1Type>& gmTensor,
                                        uint32_t rowAct,
                                        uint32_t rowAlign,
                                        uint32_t col, // D
                                        uint32_t colStride) // D or N*D
{
    Nd2NzParams nd2nzPara;
    nd2nzPara.ndNum = 1;
    nd2nzPara.nValue = rowAct;
 
    nd2nzPara.dValue = col;
    nd2nzPara.srcDValue = colStride;
    nd2nzPara.dstNzC0Stride = rowAlign;
    nd2nzPara.dstNzNStride = 1;
    nd2nzPara.srcNdMatrixStride = 0;
    nd2nzPara.dstNzMatrixStride = nd2nzPara.nValue;
 
    LocalTensor<bfloat16_t> l1TensorCast = l1Tensor.template ReinterpretCast<bfloat16_t>();
    GlobalTensor<bfloat16_t> gmTensorCast;
    gmTensorCast.SetGlobalBuffer(((__gm__ bfloat16_t*)(gmTensor.GetPhyAddr())));
    DataCopy(l1TensorCast, gmTensorCast, nd2nzPara);
}

template<typename L1Type>
__aicore__ inline void DataCopyGmScaleDNToL1(LocalTensor<L1Type>& l1Tensor, GlobalTensor<L1Type>& gmTensor,
                                        uint32_t rowAct,
                                        uint32_t rowAlign,
                                        uint32_t col,
                                        uint32_t colStride)
{
    Dn2NzParams dn2nzPara;
    dn2nzPara.dnNum = 1;
    dn2nzPara.nValue = col / 2;
    dn2nzPara.dValue = rowAct;
    dn2nzPara.srcDValue = colStride / 2;
    dn2nzPara.dstNzC0Stride = dn2nzPara.nValue;
    dn2nzPara.dstNzNStride = 1;
    dn2nzPara.srcDnMatrixStride = 0;
    dn2nzPara.dstNzMatrixStride = dn2nzPara.nValue;
 
    LocalTensor<bfloat16_t> l1TensorCast = l1Tensor.template ReinterpretCast<bfloat16_t>();
    GlobalTensor<bfloat16_t> gmTensorCast;
    gmTensorCast.SetGlobalBuffer(((__gm__ bfloat16_t*)(gmTensor.GetPhyAddr())));
    DataCopy(l1TensorCast, gmTensorCast, dn2nzPara);
}

template<typename L1Type>
__aicore__ inline void DataCopyGmNZToL1(LocalTensor<L1Type>& l1Tensor, GlobalTensor<L1Type>& gmTensor,
                                        uint32_t rowAct,
                                        uint32_t dstRowStride,
                                        uint32_t srcRowStride,
                                        uint32_t col)
{
    uint32_t blockElementCnt = 32U / sizeof(L1Type);
    if constexpr (IsSameType<L1Type, int4b_t>::value) {
        blockElementCnt = 64U;
    }
    DataCopyParams intriParams;
    intriParams.blockCount = col / blockElementCnt;
    intriParams.blockLen = rowAct;
    intriParams.dstStride = dstRowStride;
    intriParams.srcStride = srcRowStride;
    DataCopy(l1Tensor, gmTensor, intriParams);
}

template<typename L1Type>
__aicore__ inline void GmCopyInToL1HasRopePANoContinue(LocalTensor<L1Type>& nopeTensor, LocalTensor<L1Type>& ropeTensor,
                                GlobalTensor<L1Type>& nopeGmTensor, GlobalTensor<L1Type>& ropeGmTensor,
                                GlobalTensor<int32_t>& blockTableGm, KVLAYOUT kvLayout, 
                                const PAShape &shape,
                                const PAShape &ropeShape,
                                const Position &startPos)
{
    uint32_t copyFinishRowCnt = 0;
    uint64_t blockTableBaseOffset = startPos.bIdx * shape.maxblockNumPerBatch; // 块表的基偏移量
    uint32_t curS2Idx = startPos.s2Offset;
    uint32_t blockElementCnt = 32U / sizeof(L1Type); // 每个块的元素数量
    // ropeshape的M方向与nopeshape保持一样, 此处只判断nopeshape的
    while(copyFinishRowCnt < shape.copyRowNum){
        uint64_t blockIdOffset = curS2Idx / shape.blockSize; // 获取block table上的索引
        uint64_t remainRowCnt = curS2Idx % shape.blockSize; // 获取在单个块上超出的行数
        uint64_t idInBlockTable = blockTableGm.GetValue(blockTableBaseOffset + blockIdOffset); // 从block table上获取的编号
        //计算可以拷贝行数
        uint32_t copyRowCnt = shape.blockSize - remainRowCnt; // 一次只能处理一个Block
        if (copyFinishRowCnt + copyRowCnt > shape.copyRowNum){
            copyRowCnt = shape.copyRowNum - copyFinishRowCnt; // 一个block未拷满
        }
        uint64_t offset = idInBlockTable * shape.blockSize * shape.headNum * shape.headDim; // PA的偏移
        if (shape.pageStride > 0) {
            offset = idInBlockTable * shape.pageStride;
        }
        uint64_t keyRopeOffset = idInBlockTable * ropeShape.blockSize * ropeShape.headNum * ropeShape.headDim;
        if (ropeShape.pageStride > 0) {
            keyRopeOffset = idInBlockTable * ropeShape.pageStride;
        }

        if (kvLayout == KVLAYOUT::NZ) {
            offset += static_cast<uint64_t>(startPos.n2Idx * shape.blockSize * shape.headDim) + remainRowCnt * blockElementCnt + startPos.dIdx * shape.blockSize;
            keyRopeOffset += static_cast<uint64_t>(startPos.n2Idx * ropeShape.blockSize * ropeShape.headDim) + remainRowCnt * blockElementCnt + startPos.dIdx * ropeShape.blockSize;
            LocalTensor<L1Type> tmpNopeDstTensor = nopeTensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = nopeGmTensor[offset];
            DataCopyGmNZToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, (shape.copyRowNumAlign - copyRowCnt), (shape.blockSize - copyRowCnt), shape.actHeadDim);

            LocalTensor<L1Type> tmpRopeDstTensor = ropeTensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpRopeSrcTensor = ropeGmTensor[keyRopeOffset];
            DataCopyGmNZToL1(tmpRopeDstTensor, tmpRopeSrcTensor, copyRowCnt, (ropeShape.copyRowNumAlign - copyRowCnt), (ropeShape.blockSize - copyRowCnt), ropeShape.actHeadDim);
        } else {
            uint64_t dStride = shape.headDim;
            uint64_t dRopeStride = ropeShape.headDim;
            if (kvLayout == KVLAYOUT::BBH) {
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim) + remainRowCnt * shape.headDim * shape.headNum + startPos.dIdx;
                keyRopeOffset += static_cast<uint64_t>(startPos.n2Idx * ropeShape.headDim) + remainRowCnt * ropeShape.headDim * ropeShape.headNum;
                dStride = shape.headDim * shape.headNum;
                dRopeStride = ropeShape.headDim * ropeShape.headNum;
            } else{
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim * shape.blockSize) + remainRowCnt * shape.headDim + startPos.dIdx;
                keyRopeOffset += static_cast<uint64_t>(startPos.n2Idx * ropeShape.headDim * ropeShape.blockSize) + remainRowCnt * ropeShape.headDim;
            }

            uint32_t dValue = shape.actHeadDim;
            uint32_t srcDValue = dStride;
            uint32_t dRopeValue = ropeShape.actHeadDim;
            uint32_t srcRopeDValue = dRopeStride;
            LocalTensor<L1Type> tmpNopeDstTensor = nopeTensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = nopeGmTensor[offset];
            DataCopyGmNDToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, shape.copyRowNumAlign, dValue, srcDValue);

            LocalTensor<L1Type> tmpRopeDstTensor = ropeTensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpRopeSrcTensor = ropeGmTensor[keyRopeOffset];
            DataCopyGmNDToL1(tmpRopeDstTensor, tmpRopeSrcTensor, copyRowCnt, shape.copyRowNumAlign, dRopeValue, srcRopeDValue);
        }
        copyFinishRowCnt += copyRowCnt;
        curS2Idx += copyRowCnt;
    }
}

template<typename L1Type>
__aicore__ inline void GmCopyInToL1HasRopePA(LocalTensor<L1Type>& nopeTensor, LocalTensor<L1Type>& ropeTensor,
                                GlobalTensor<L1Type>& nopeGmTensor, GlobalTensor<L1Type>& ropeGmTensor,
                                GlobalTensor<int32_t>& blockTableGm, KVLAYOUT kvLayout,
                                const PAShape &shape,
                                const PAShape &ropeShape,
                                const Position &startPos)
{
    uint32_t copyFinishRowCnt = 0;
    uint64_t blockTableBaseOffset = startPos.bIdx * shape.maxblockNumPerBatch; // 块表的基偏移量
    uint32_t curS2Idx = startPos.s2Offset;
    uint32_t blockElementCnt = 32U / sizeof(L1Type); // 每个块的元素数量
    // ropeshape的M方向与nopeshape保持一样, 此处只判断nopeshape的
    while(copyFinishRowCnt < shape.copyRowNum){
        uint64_t blockIdOffset = curS2Idx / shape.blockSize; // 获取block table上的索引
        uint64_t remainRowCnt = curS2Idx % shape.blockSize; // 获取在单个块上超出的行数
        uint64_t idInBlockTable = blockTableGm.GetValue(blockTableBaseOffset + blockIdOffset); // 从block table上获取的编号
        //计算可以拷贝行数
        uint32_t copyRowCnt = shape.blockSize - remainRowCnt; // 一次只能处理一个Block
        if (copyFinishRowCnt + copyRowCnt > shape.copyRowNum){
            copyRowCnt = shape.copyRowNum - copyFinishRowCnt; // 一个block未拷满
        }
        uint64_t offset = idInBlockTable * shape.blockSize * shape.headNum * shape.headDim; // PA的偏移
        uint64_t keyRopeOffset = idInBlockTable * ropeShape.blockSize * ropeShape.headNum * ropeShape.headDim;
        if (kvLayout == KVLAYOUT::NZ) {
            offset += static_cast<uint64_t>(startPos.n2Idx * shape.blockSize * shape.headDim) + remainRowCnt * blockElementCnt + startPos.dIdx * shape.blockSize;
            keyRopeOffset += static_cast<uint64_t>(startPos.n2Idx * ropeShape.blockSize * ropeShape.headDim) + remainRowCnt * blockElementCnt + startPos.dIdx * ropeShape.blockSize;
            LocalTensor<L1Type> tmpNopeDstTensor = nopeTensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = nopeGmTensor[offset];
            DataCopyGmNZToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, (shape.copyRowNumAlign - copyRowCnt), (shape.blockSize - copyRowCnt), shape.actHeadDim);

            LocalTensor<L1Type> tmpRopeDstTensor = ropeTensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpRopeSrcTensor = ropeGmTensor[keyRopeOffset];
            DataCopyGmNZToL1(tmpRopeDstTensor, tmpRopeSrcTensor, copyRowCnt, (ropeShape.copyRowNumAlign - copyRowCnt), (ropeShape.blockSize - copyRowCnt), ropeShape.actHeadDim);
        } else {
            uint64_t dStride = shape.headDim;
            uint64_t dRopeStride = ropeShape.headDim;
            if (kvLayout == KVLAYOUT::BBH) {
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim) + remainRowCnt * shape.headDim * shape.headNum + startPos.dIdx;
                keyRopeOffset += static_cast<uint64_t>(startPos.n2Idx * ropeShape.headDim) + remainRowCnt * ropeShape.headDim * ropeShape.headNum;
                dStride = shape.headDim * shape.headNum;
                dRopeStride = ropeShape.headDim * ropeShape.headNum;
            } else{
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim * shape.blockSize) + remainRowCnt * shape.headDim + startPos.dIdx;
                keyRopeOffset += static_cast<uint64_t>(startPos.n2Idx * ropeShape.headDim * ropeShape.blockSize) + remainRowCnt * ropeShape.headDim;
            }

            uint32_t dValue = shape.actHeadDim;
            uint32_t srcDValue = dStride;
            uint32_t dRopeValue = ropeShape.actHeadDim;
            uint32_t srcRopeDValue = dRopeStride;
            LocalTensor<L1Type> tmpNopeDstTensor = nopeTensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = nopeGmTensor[offset];
            DataCopyGmNDToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, shape.copyRowNumAlign, dValue, srcDValue);

            LocalTensor<L1Type> tmpRopeDstTensor = ropeTensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpRopeSrcTensor = ropeGmTensor[keyRopeOffset];
            DataCopyGmNDToL1(tmpRopeDstTensor, tmpRopeSrcTensor, copyRowCnt, shape.copyRowNumAlign, dRopeValue, srcRopeDValue);
        }
        copyFinishRowCnt += copyRowCnt;
        curS2Idx += copyRowCnt;
    }
}

template<typename L1Type>
__aicore__ inline void GmCopyInToL1PA(LocalTensor<L1Type>& l1Tensor, GlobalTensor<L1Type>& gmTensor,
                                GlobalTensor<int32_t>& blockTableGm, KVLAYOUT kvLayout,
                                const PAShape &shape, const Position &startPos)
{
    uint32_t copyFinishRowCnt = 0;
    uint64_t blockTableBaseOffset = startPos.bIdx * shape.maxblockNumPerBatch; // 块表的基偏移量
    uint32_t curS2Idx = startPos.s2Offset;
    uint32_t blockElementCnt = 32U / sizeof(L1Type); // 每个块的元素数量
    while(copyFinishRowCnt < shape.copyRowNum){
        uint64_t blockIdOffset = curS2Idx / shape.blockSize; // 获取block table上的索引
        uint64_t remainRowCnt = curS2Idx % shape.blockSize; // 获取在单个块上超出的行数
        uint64_t idInBlockTable = blockTableGm.GetValue(blockTableBaseOffset + blockIdOffset); // 从block table上获取的编号
        //计算可以拷贝行数
        uint32_t copyRowCnt = shape.blockSize - remainRowCnt; // 一次只能处理一个Block
        if (copyFinishRowCnt + copyRowCnt > shape.copyRowNum){
            copyRowCnt = shape.copyRowNum - copyFinishRowCnt; // 一个block未拷满
        }
        uint64_t offset = idInBlockTable * shape.blockSize * shape.headNum * shape.headDim; // PA的偏移
        if (kvLayout == KVLAYOUT::NZ) {
            offset += static_cast<uint64_t>(startPos.n2Idx * shape.blockSize * shape.headDim) + remainRowCnt * blockElementCnt + startPos.dIdx * shape.blockSize;
            
            LocalTensor<L1Type> tmpNopeDstTensor = l1Tensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = gmTensor[offset];
            DataCopyGmNZToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, (shape.copyRowNumAlign - copyRowCnt), (shape.blockSize - copyRowCnt), shape.actHeadDim);
        } else {
            uint64_t dStride = shape.headDim;
            if (kvLayout == KVLAYOUT::BBH) {
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim) + remainRowCnt * shape.headDim * shape.headNum + startPos.dIdx;
                dStride = shape.headDim * shape.headNum;
            } else {
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim * shape.blockSize) + remainRowCnt * shape.headDim + startPos.dIdx;
            }

            uint32_t dValue = shape.actHeadDim;
            uint32_t srcDValue = dStride;
            LocalTensor<L1Type> tmpNopeDstTensor = l1Tensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = gmTensor[offset];
            DataCopyGmNDToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, shape.copyRowNumAlign, dValue, srcDValue);
        }
        copyFinishRowCnt += copyRowCnt;
        curS2Idx += copyRowCnt;
    }
}

template<typename L1Type>
__aicore__ inline void GmScaleCopyInToL1PAForND(LocalTensor<L1Type>& l1Tensor, GlobalTensor<L1Type>& gmTensor,
                                GlobalTensor<int32_t>& blockTableGm, KVLAYOUT kvLayout,
                                const PAShape &shape, const Position &startPos)
{
    uint32_t copyFinishRowCnt = 0;
    uint64_t blockTableBaseOffset = startPos.bIdx * shape.maxblockNumPerBatch;
    uint32_t curS2Idx = startPos.s2Offset;
    constexpr uint32_t blockElementCnt = 32U / sizeof(L1Type);
    while(copyFinishRowCnt < shape.copyRowNum) {
        uint64_t blockIdOffset = curS2Idx / shape.blockSize;
        uint64_t remainRowCnt = curS2Idx % shape.blockSize;
        uint64_t idInBlockTable = blockTableGm.GetValue(blockTableBaseOffset + blockIdOffset);
        uint32_t copyRowCnt = shape.blockSize - remainRowCnt;
        if (copyFinishRowCnt + copyRowCnt > shape.copyRowNum) {
            copyRowCnt = shape.copyRowNum - copyFinishRowCnt;
        }
        uint64_t offset = idInBlockTable * shape.blockSize * shape.headNum * shape.headDim;
        if (kvLayout == KVLAYOUT::NZ) {
            offset += static_cast<uint64_t>(startPos.n2Idx * shape.blockSize * shape.headDim) + remainRowCnt * blockElementCnt + startPos.dIdx * shape.blockSize;
 
            LocalTensor<L1Type> tmpNopeDstTensor = l1Tensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = gmTensor[offset];
            DataCopyGmNZToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, (shape.copyRowNumAlign - copyRowCnt), (shape.blockSize - copyRowCnt), shape.actHeadDim);
        } else {
            uint64_t dStride = shape.headDim;
            if (kvLayout == KVLAYOUT::BBH) {
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim) + remainRowCnt * shape.headDim * shape.headNum + startPos.dIdx;
                dStride = shape.headDim * shape.headNum;
            } else {
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim * shape.blockSize) + remainRowCnt * shape.headDim + startPos.dIdx;
            }
 
            uint32_t dValue = shape.actHeadDim;
            uint32_t srcDValue = dStride;
 
            LocalTensor<L1Type> tmpNopeDstTensor = l1Tensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = gmTensor[offset * 2];
            DataCopyGmScaleNDToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, copyRowCnt, dValue, srcDValue);
        }
        copyFinishRowCnt += copyRowCnt;
        curS2Idx += copyRowCnt;
    }
}
 
template<typename L1Type>
__aicore__ inline void GmScaleCopyInToL1PAForDN(LocalTensor<L1Type>& l1Tensor, GlobalTensor<L1Type>& gmTensor,
                                GlobalTensor<int32_t>& blockTableGm, KVLAYOUT kvLayout,
                                const PAShape &shape, const Position &startPos)
{
    uint32_t copyFinishRowCnt = 0;
    uint64_t blockTableBaseOffset = startPos.bIdx * shape.maxblockNumPerBatch;
    uint32_t curS2Idx = startPos.s2Offset;
    constexpr uint32_t blockElementCnt = 32U / sizeof(L1Type);
    while(copyFinishRowCnt < shape.copyRowNum) {
        uint64_t blockIdOffset = curS2Idx / shape.blockSize;
        uint64_t remainRowCnt = curS2Idx % shape.blockSize;
        uint64_t idInBlockTable = blockTableGm.GetValue(blockTableBaseOffset + blockIdOffset);
        uint32_t copyRowCnt = shape.blockSize - remainRowCnt;
        if (copyFinishRowCnt + copyRowCnt > shape.copyRowNum) {
            copyRowCnt = shape.copyRowNum - copyFinishRowCnt;
        }
        uint64_t offset = idInBlockTable * shape.blockSize * shape.headNum * shape.headDim;
        if (kvLayout == KVLAYOUT::NZ) {
            offset += static_cast<uint64_t>(startPos.n2Idx * shape.blockSize * shape.headDim) + remainRowCnt * blockElementCnt + startPos.dIdx * shape.blockSize;
            
            LocalTensor<L1Type> tmpNopeDstTensor = l1Tensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = gmTensor[offset];
            DataCopyGmNZToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, (shape.copyRowNumAlign - copyRowCnt), (shape.blockSize - copyRowCnt), shape.actHeadDim);
        } else {
            uint64_t dStride = shape.headDim;
            if (kvLayout == KVLAYOUT::BBH) {
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim) + remainRowCnt * shape.headDim * shape.headNum + startPos.dIdx;
                dStride = shape.headDim * shape.headNum;
            } else {
                offset += static_cast<uint64_t>(startPos.n2Idx * shape.headDim * shape.blockSize) + remainRowCnt * shape.headDim + startPos.dIdx;
            }
 
            uint32_t dValue = shape.actHeadDim;
            uint32_t srcDValue = dStride;
            LocalTensor<L1Type> tmpNopeDstTensor = l1Tensor[copyFinishRowCnt * blockElementCnt];
            GlobalTensor<L1Type> tmpNopeSrcTensor = gmTensor[offset];
            
            DataCopyGmScaleDNToL1(tmpNopeDstTensor, tmpNopeSrcTensor, copyRowCnt, copyRowCnt, dValue, srcDValue);
        }
        copyFinishRowCnt += copyRowCnt;
        curS2Idx += copyRowCnt;
    }
}

template<typename INPUT_T>
__aicore__ inline void CopyToL1Nd2Nz(const LocalTensor<INPUT_T> &l1Tensor, const GlobalTensor<INPUT_T> &gmTensor,
    uint32_t nValue, uint32_t dValue, uint32_t srcDValue)
{
    Nd2NzParams gm2L1Nd2NzParams;
    gm2L1Nd2NzParams.ndNum = 1; // ND矩阵的个数
    gm2L1Nd2NzParams.nValue = nValue; // 单个ND矩阵的实际行数,单位为元素个数
    gm2L1Nd2NzParams.dValue = dValue; // 单个ND矩阵的实际列数,单位为元素个数
    gm2L1Nd2NzParams.srcNdMatrixStride = 0; // 相邻ND矩阵起始地址之间的偏移, 单位为元素个数
    gm2L1Nd2NzParams.srcDValue = srcDValue; // 同一个ND矩阵中相邻行起始地址之间的偏移, 单位为元素个数
#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__NPU_ARCH__ == 5102)
    if constexpr (IsSameType<INPUT_T, fp8_e5m2_t>::value || IsSameType<INPUT_T, fp8_e4m3fn_t>::value ||
        IsSameType<INPUT_T, hifloat8_t>::value || IsSameType<INPUT_T, int8_t>::value) {
        gm2L1Nd2NzParams.dstNzC0Stride = (nValue + 31) >> 5 << 5;
    } else {
        gm2L1Nd2NzParams.dstNzC0Stride = (nValue + 15) >> 4 << 4;
    }
#else
    gm2L1Nd2NzParams.dstNzC0Stride = (nValue + 15) >> 4 << 4; // NZ矩阵相邻Block起始地址之间的偏移, 单位为Block个数
#endif
    gm2L1Nd2NzParams.dstNzNStride = 1; // 转换为NZ矩阵后,ND之间相邻两行在NZ矩阵中起始地址之间的偏移, 单位为Block个数
    gm2L1Nd2NzParams.dstNzMatrixStride = 0; // 两个NZ矩阵,起始地址之间的偏移, 单位为元素数量
    DataCopy(l1Tensor, gmTensor, gm2L1Nd2NzParams);
}

template<typename INPUT_T>
__aicore__ inline void CopyScaleToL1Nd2Nz(const LocalTensor<INPUT_T> &l1Tensor, const GlobalTensor<INPUT_T> &gmTensor,
    uint32_t nValue, uint32_t dValue, uint32_t srcDValue)
{
    Nd2NzParams gm2L1Nd2NzParams;
    gm2L1Nd2NzParams.ndNum = 1; // ND矩阵的个数
    gm2L1Nd2NzParams.nValue = nValue / 2; // 单个ND矩阵的实际行数,单位为元素个数
    gm2L1Nd2NzParams.dValue = dValue; // 单个ND矩阵的实际列数,单位为元素个数
    gm2L1Nd2NzParams.srcNdMatrixStride = 0; // 相邻ND矩阵起始地址之间的偏移, 单位为元素个数
    gm2L1Nd2NzParams.srcDValue = srcDValue; // 同一个ND矩阵中相邻行起始地址之间的偏移, 单位为元素个数
    gm2L1Nd2NzParams.dstNzC0Stride = nValue / 2; // NZ矩阵相邻Block起始地址之间的偏移, 单位为Block个数
    gm2L1Nd2NzParams.dstNzNStride = 1; // 转换为NZ矩阵后,ND之间相邻两行在NZ矩阵中起始地址之间的偏移, 单位为Block个数
    gm2L1Nd2NzParams.dstNzMatrixStride = gm2L1Nd2NzParams.nValue; // 两个NZ矩阵,起始地址之间的偏移, 单位为元素数量
 
    LocalTensor<bfloat16_t> l1TensorCast = l1Tensor.template ReinterpretCast<bfloat16_t>();
    GlobalTensor<bfloat16_t> gmTensorCast;
    gmTensorCast.SetGlobalBuffer(((__gm__ bfloat16_t*)(gmTensor.GetPhyAddr())));
    DataCopy(l1TensorCast, gmTensorCast, gm2L1Nd2NzParams);
}
 
template<typename INPUT_T>
__aicore__ inline void CopyScaleToL1Dn2Nz(const LocalTensor<INPUT_T> &l1Tensor, const GlobalTensor<INPUT_T> &gmTensor,
    uint32_t nValue, uint32_t dValue, uint32_t srcDValue)
{
    Dn2NzParams gm2L1Dn2NzParams;
    gm2L1Dn2NzParams.dnNum = 1; // ND矩阵的个数
    gm2L1Dn2NzParams.nValue = nValue / 2; // 单个DN矩阵的实际列数,单位为元素个数
    gm2L1Dn2NzParams.dValue = dValue; // 单个DN矩阵的实际行数,单位为元素个数
    gm2L1Dn2NzParams.srcDnMatrixStride = 0; // 相邻Dn矩阵起始地址之间的偏移, 单位为元素个数
    gm2L1Dn2NzParams.srcDValue = srcDValue / 2; // 同一个Dn矩阵中相邻行起始地址之间的偏移, 单位为元素个数
    gm2L1Dn2NzParams.dstNzC0Stride = nValue / 2;
    gm2L1Dn2NzParams.dstNzNStride = 1; // 转换为NZ矩阵后,ND之间相邻两行在NZ矩阵中起始地址之间的偏移, 单位为Block个数
    gm2L1Dn2NzParams.dstNzMatrixStride = gm2L1Dn2NzParams.nValue; // 两个NZ矩阵,起始地址之间的偏移, 单位为元素数量
 
    LocalTensor<bfloat16_t> l1TensorCast = l1Tensor.template ReinterpretCast<bfloat16_t>();
    GlobalTensor<bfloat16_t> gmTensorCast;
    gmTensorCast.SetGlobalBuffer(((__gm__ bfloat16_t*)(gmTensor.GetPhyAddr())));
    DataCopy(l1TensorCast, gmTensorCast, gm2L1Dn2NzParams);
}

template<typename INPUT_T>
__aicore__ inline void CopyToL1Nd2NzGS1Merge(const LocalTensor<INPUT_T> &l1Tensor, const GlobalTensor<INPUT_T> &gmTensor,
    uint32_t ndNum, uint32_t nValue, uint32_t dValue, uint32_t srcNdMatrixStride, uint32_t srcDValue, uint32_t dstNzC0Stride) // BSNGD 合轴拷贝
{
    Nd2NzParams gm2L1Nd2NzParams;
    gm2L1Nd2NzParams.ndNum = ndNum; // ND矩阵的个数
    gm2L1Nd2NzParams.nValue = nValue; // 单个ND矩阵的实际行数,单位为元素个数
    gm2L1Nd2NzParams.dValue = dValue; // 单个ND矩阵的实际列数,单位为元素个数
    gm2L1Nd2NzParams.srcNdMatrixStride = srcNdMatrixStride; // 相邻ND矩阵起始地址之间的偏移, 单位为元素个数
    gm2L1Nd2NzParams.srcDValue = srcDValue; // 同一个ND矩阵中相邻行起始地址之间的偏移, 单位为元素个数
#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__NPU_ARCH__ == 5102)
    if constexpr (IsSameType<INPUT_T, fp8_e5m2_t>::value || IsSameType<INPUT_T, fp8_e4m3fn_t>::value ||
        IsSameType<INPUT_T, hifloat8_t>::value || IsSameType<INPUT_T, int8_t>::value) {
        gm2L1Nd2NzParams.dstNzC0Stride = (dstNzC0Stride + 31) >> 5 << 5; // NZ矩阵相邻Block起始地址之间的偏移,单位为Block个数,32对齐
    } else {
        gm2L1Nd2NzParams.dstNzC0Stride = (dstNzC0Stride + 15) >> 4 << 4; // NZ矩阵相邻Block起始地址之间的偏移,单位为Block个数,16对齐
    }
#else
    gm2L1Nd2NzParams.dstNzC0Stride = (dstNzC0Stride + 15) >> 4 << 4; // NZ矩阵相邻Block起始地址之间的偏移,单位为Block个数,16对齐
#endif
    gm2L1Nd2NzParams.dstNzNStride = 1; // 转换为NZ矩阵后,ND之间相邻两行在NZ矩阵中起始地址之间的偏移, 单位为Block个数
    gm2L1Nd2NzParams.dstNzMatrixStride = nValue * 32 / sizeof(INPUT_T); // 两个NZ矩阵,起始地址之间的偏移, 单位为元素数量
    DataCopy(l1Tensor, gmTensor, gm2L1Nd2NzParams);
}
#endif