* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef MEMORY_COPY_ARCH35_H
#define MEMORY_COPY_ARCH35_H
#include "vector_common.h"
#include "memory_copy.h"
template <LayOutTypeEnum LAYOUT>
__aicore__ inline constexpr ActualSeqLensMode GetQActSeqMode()
{
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_TND || LAYOUT == LayOutTypeEnum::LAYOUT_NTD) {
return ActualSeqLensMode::ACCUM;
} else {
return ActualSeqLensMode::BY_BATCH;
}
}
template <LayOutTypeEnum LAYOUT, const bool PAGE_ATTENTION>
__aicore__ inline constexpr ActualSeqLensMode GetKvActSeqMode()
{
if constexpr (PAGE_ATTENTION) {
return ActualSeqLensMode::BY_BATCH;
}
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_TND || LAYOUT == LayOutTypeEnum::LAYOUT_NTD) {
return ActualSeqLensMode::ACCUM;
} else {
return ActualSeqLensMode::BY_BATCH;
}
}
template <LayOutTypeEnum LAYOUT>
__aicore__ inline constexpr GmFormat GetQueryGmFormat()
{
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BSH) {
return GmFormat::BSNGD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_SBH) {
return GmFormat::SBNGD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BNSD) {
return GmFormat::BNGSD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_TND) {
return GmFormat::TNGD;
} else {
return GmFormat::NGTD;
}
}
template <LayOutTypeEnum LAYOUT>
__aicore__ inline constexpr GmFormat GetKVGmFormat()
{
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BSH) {
return GmFormat::BSND;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_SBH) {
return GmFormat::SBND;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BNSD) {
return GmFormat::BNSD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_TND) {
return GmFormat::TND;
} else {
return GmFormat::NTD;
}
}
template <LayOutTypeEnum LAYOUT, bool useDn = false>
__aicore__ inline constexpr GmFormat GetQueryScaleGmFormat()
{
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BSH || LAYOUT == LayOutTypeEnum::LAYOUT_BNSD) {
return GmFormat::BNGSD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_TND) {
if constexpr (!useDn) {
return GmFormat::NTGD;
} else {
return GmFormat::TNGD;
}
} else {
return GmFormat::TNGD;
}
}
template <LayOutTypeEnum LAYOUT, uint8_t kvLayoutType = 0, bool isPa = false>
__aicore__ inline constexpr GmFormat GetKeyScaleGmFormat()
{
if constexpr (kvLayoutType == 0) {
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BSH) {
return GmFormat::BSND;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_SBH) {
return GmFormat::SBND;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BNSD) {
return GmFormat::BNSD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_TND) {
return GmFormat::TND;
} else {
return GmFormat::NTD;
}
} else if constexpr (kvLayoutType == 1) {
return GmFormat::PA_BnBsND;
} else if constexpr (kvLayoutType == 2) {
return GmFormat::PA_BnNBsD;
} else {
return GmFormat::PA_NZ_K_SCALE;
}
}
template <LayOutTypeEnum LAYOUT, uint8_t kvLayoutType = 0, bool isPa = false>
__aicore__ inline constexpr GmFormat GetValueScaleGmFormat()
{
if constexpr (kvLayoutType == 0) {
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BSH) {
return GmFormat::BSND;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_SBH) {
return GmFormat::SBND;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BNSD) {
return GmFormat::BNSD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_TND) {
return GmFormat::TND2;
} else {
return GmFormat::NTD;
}
} else if constexpr (kvLayoutType == 1) {
return GmFormat::PA_BnBsND;
} else if constexpr (kvLayoutType == 2) {
return GmFormat::PA_BnNBsD;
} else {
return GmFormat::PA_NZ;
}
}
template <LayOutTypeEnum LAYOUT>
__aicore__ inline constexpr GmFormat GetOutGmFormat()
{
static_assert((LAYOUT == LayOutTypeEnum::LAYOUT_BSH) || (LAYOUT == LayOutTypeEnum::LAYOUT_BNSD) ||
(LAYOUT == LayOutTypeEnum::LAYOUT_TND) || (LAYOUT == LayOutTypeEnum::LAYOUT_NTD) ||
(LAYOUT == LayOutTypeEnum::LAYOUT_NBSD),
"Get OutAttention GmFormat fail, OUT_LAYOUT_T is incorrect");
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BSH) {
return GmFormat::BSNGD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BNSD) {
return GmFormat::BNGSD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_TND) {
return GmFormat::TNGD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_NTD) {
return GmFormat::NGTD;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_NBSD) {
return GmFormat::NGBSD;
}
}
template <LayOutTypeEnum LAYOUT>
__aicore__ inline constexpr UbFormat GetOutUbFormat()
{
static_assert((LAYOUT == LayOutTypeEnum::LAYOUT_BSH) || (LAYOUT == LayOutTypeEnum::LAYOUT_BNSD) ||
(LAYOUT == LayOutTypeEnum::LAYOUT_TND) || (LAYOUT == LayOutTypeEnum::LAYOUT_NTD),
"Get OutAttention UB GmFormat fail, LAYOUT is incorrect");
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BSH || LAYOUT == LayOutTypeEnum::LAYOUT_TND) {
return UbFormat::S1G;
} else if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BNSD || LAYOUT == LayOutTypeEnum::LAYOUT_NTD) {
return UbFormat::GS1;
}
}
template <LayOutTypeEnum LAYOUT>
__aicore__ inline uint64_t SeqLenFromTensorList(__gm__ uint8_t *keyPtr, uint32_t bIndex)
{
uint64_t dimInfo[4];
AscendC::TensorDesc<__gm__ uint8_t> keyTensorDesc;
ListTensorDesc keyListTensorDesc((__gm__ void *)keyPtr);
keyTensorDesc.SetShapeAddr(&dimInfo[0]);
keyListTensorDesc.GetDesc(keyTensorDesc, bIndex);
if constexpr (LAYOUT == LayOutTypeEnum::LAYOUT_BSH) {
return keyTensorDesc.GetShape(1);
} else {
return keyTensorDesc.GetShape(2);
}
}
template <LayOutTypeEnum LAYOUT_T>
__aicore__ inline constexpr UbFormat GetPseUbFormat()
{
static_assert((LAYOUT_T == LayOutTypeEnum::LAYOUT_BSH) || (LAYOUT_T == LayOutTypeEnum::LAYOUT_BNSD) ||
(LAYOUT_T == LayOutTypeEnum::LAYOUT_TND) || (LAYOUT_T == LayOutTypeEnum::LAYOUT_NTD),
"Get PSE UbFormat fail, LAYOUT_T is incorrect");
if constexpr (LAYOUT_T == LayOutTypeEnum::LAYOUT_BNSD || LAYOUT_T == LayOutTypeEnum::LAYOUT_NTD) {
return UbFormat::GS1;
} else {
return UbFormat::S1G;
}
}
template <LayOutTypeEnum LAYOUT_T>
__aicore__ inline constexpr bool IsSupportPse()
{
if constexpr (LAYOUT_T == LayOutTypeEnum::LAYOUT_BNSD || LAYOUT_T == LayOutTypeEnum::LAYOUT_BSH) {
return true;
} else {
return false;
}
}
struct PostQuantInfo_V2 {
uint32_t gSize;
uint32_t dSize;
uint32_t s1Size;
uint32_t n2Idx;
uint32_t gS1Idx;
uint32_t gS1DealSize;
uint32_t colCount;
};
template <typename PARAM_T, GmFormat GM_FORMAT, UbFormat UB_FORMAT>
__aicore__ void CopyParamsGmToUb(LocalTensor<PARAM_T> &dstUb, FaGmTensor<PARAM_T, GM_FORMAT> &srcTensor,
PostQuantInfo_V2 &postQuantInfo)
{
OffsetCalculator<GM_FORMAT> &offsetCalculator = srcTensor.offsetCalculator;
if constexpr (UB_FORMAT == UbFormat::S1G) {
uint32_t s1IdxStart = postQuantInfo.gS1Idx / offsetCalculator.GetDimG();
uint32_t gIdxStart = postQuantInfo.gS1Idx % offsetCalculator.GetDimG();
uint32_t s1IdxEnd = (postQuantInfo.gS1Idx + postQuantInfo.gS1DealSize) / offsetCalculator.GetDimG();
uint32_t gIdxEnd = (postQuantInfo.gS1Idx + postQuantInfo.gS1DealSize) % offsetCalculator.GetDimG();
if (s1IdxEnd - s1IdxStart > 1) {
uint64_t offset = offsetCalculator.GetOffset(postQuantInfo.n2Idx, 0, 0);
uint32_t blockCount = offsetCalculator.GetDimG();
CopySingleMatrixNDToND<PARAM_T>(dstUb, srcTensor.gmTensor[offset], offsetCalculator.GetDimG(),
offsetCalculator.GetDimD(), offsetCalculator.GetStrideG(),
postQuantInfo.colCount);
} else {
uint32_t headSize = 0;
if (s1IdxStart == s1IdxEnd) {
headSize = gIdxEnd - gIdxStart;
} else {
headSize = offsetCalculator.GetDimG() - gIdxStart;
}
uint64_t offset = offsetCalculator.GetOffset(postQuantInfo.n2Idx, gIdxStart, 0);
CopySingleMatrixNDToND<PARAM_T>(dstUb, srcTensor.gmTensor[offset], headSize, offsetCalculator.GetDimD(),
offsetCalculator.GetStrideG(), postQuantInfo.colCount);
if ((s1IdxEnd - s1IdxStart == 1) && (gIdxEnd > 0)) {
offset = offsetCalculator.GetOffset(postQuantInfo.n2Idx, 0, 0);
uint32_t ubOffset = headSize * postQuantInfo.colCount;
CopySingleMatrixNDToND<PARAM_T>(dstUb[ubOffset], srcTensor.gmTensor[offset], gIdxEnd,
offsetCalculator.GetDimD(), offsetCalculator.GetStrideG(),
postQuantInfo.colCount);
}
}
} else {
uint32_t gIdxStart = postQuantInfo.gS1Idx / postQuantInfo.s1Size;
uint32_t s1IdxStart = postQuantInfo.gS1Idx % postQuantInfo.s1Size;
uint64_t offset = offsetCalculator.GetOffset(postQuantInfo.n2Idx, gIdxStart, 0);
CopySingleMatrixNDToND<PARAM_T>(
dstUb, srcTensor.gmTensor[offset],
((postQuantInfo.gS1DealSize + s1IdxStart) + (postQuantInfo.s1Size - 1)) / postQuantInfo.s1Size,
offsetCalculator.GetDimD(), offsetCalculator.GetStrideG(), postQuantInfo.colCount);
}
}
#endif