* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file chunk_cat.h
* \brief
*/
#ifndef _CHUNK_CAT_DATA_H_
#define _CHUNK_CAT_DATA_H_
#include "chunk_cat_common.h"
using namespace AscendC;
template <typename T1, typename T2, bool NEAD_CAST=false>
class ChunkCat
{
public:
__aicore__ inline ChunkCat(TPipe *pipe) : pipe_(pipe) {}
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, const ChunkCatTilingData& tilingData)
{
blockIdx_ = GetBlockIdx();
int32_t usedCoreNum = GetBlockNum();
isAllAlign_ = tilingData.isAllAlign;
isHalfAlign_ = tilingData.isHalfAlign;
isOneConcat_ = tilingData.isOneConcat;
dim_ = tilingData.dim;
numChunk_ = tilingData.numChunk;
outputRow_ = tilingData.outputRow;
outputCol_ = tilingData.outputCol;
blockRowFactor_ = tilingData.blockRowFactor;
blockColFactor_ = tilingData.blockColFactor;
tailBlockRowFactor_ = tilingData.tailBlockRowFactor;
tailBlockColFactor_ = tilingData.tailBlockColFactor;
ubRowFactor_ = tilingData.ubRowFactor;
ubColFactor_ = tilingData.ubColFactor;
inputNum_ = tilingData.inputNum;
srcEleUbBlock_ = UB_BLOCK_SIZE / sizeof(T1);
dstEleUbBlock_ = UB_BLOCK_SIZE / sizeof(T2);
colRepeatNum_ = isHalfAlign_ ? HALF : srcEleUbBlock_;
blockRowGroup_ = blockIdx_ / tilingData.blockColNum;
blockColGroup_ = blockIdx_ % tilingData.blockColNum;
currentBlockRowFactor_ = blockRowGroup_ == tilingData.blockRowNum - 1 ? tailBlockRowFactor_ : blockRowFactor_;
currentBlockColFactor_ = blockColGroup_ == tilingData.blockColNum - 1 ? tailBlockColFactor_ : blockColFactor_;
int64_t dstGmOffset = blockRowGroup_ * blockRowFactor_ * outputCol_ + blockColGroup_ * blockColFactor_;
dstGlobal_.SetGlobalBuffer((__gm__ T2*)y + dstGmOffset);
inputList_ = ListTensorDesc(reinterpret_cast<__gm__ void*>(x));
pipe_->InitBuffer(srcBuf_, tilingData.inUbSize);
pipe_->InitBuffer(dstBuf_, tilingData.outUbSize);
srcLocal_ = srcBuf_.Get<T1>();
srcLocalT2_ = srcLocal_.template ReinterpretCast<T2>();
dstLocal_ = dstBuf_.Get<T2>();
dstLocalT1_ = dstLocal_.template ReinterpretCast<T1>();
dstLocalFP32_ = dstLocal_.template ReinterpretCast<float>();
}
__aicore__ inline void Process()
{
int64_t rowLoop = GetAlign(currentBlockRowFactor_, ubRowFactor_) / ubRowFactor_;
int64_t colLoop = GetAlign(currentBlockColFactor_, ubColFactor_) / ubColFactor_;
int64_t rowTail = currentBlockRowFactor_ % ubRowFactor_;
int64_t colTail = currentBlockColFactor_ % ubColFactor_;
uint64_t buf[10];
desc_.SetShapeAddr(buf);
int64_t inputCol[32];
for (int64_t i = 0; i < rowLoop * colLoop; i++) {
UbLoopInfo ubLoopInfo{};
ubLoopInfo.inputCol = inputCol;
ubLoopInfo.ubRowGroup = i / colLoop;
ubLoopInfo.ubColGroup = i % colLoop;
ubLoopInfo.currentUbRowFactor = (rowTail != 0 && ubLoopInfo.ubRowGroup == rowLoop - 1) ?
rowTail : ubRowFactor_;
ubLoopInfo.currentUbColFactor = (colTail != 0 && ubLoopInfo.ubColGroup == colLoop - 1) ?
colTail : ubColFactor_;
CopyIn(ubLoopInfo);
Compute(ubLoopInfo);
CopyCout(ubLoopInfo);
}
}
private:
__aicore__ inline void CopyIn(UbLoopInfo& ubLoopInfo)
{
dupToZero();
int64_t totalCol = 0;
int64_t localOffset = 0;
ubLoopInfo.colStart = blockColGroup_ * blockColFactor_ + ubLoopInfo.ubColGroup * ubColFactor_;
ubLoopInfo.rowStart = blockRowGroup_ * blockRowFactor_ + ubLoopInfo.ubRowGroup * ubRowFactor_;
for (uint32_t i = 0; i < inputNum_; i++) {
if (ubLoopInfo.totalUbCol >= ubLoopInfo.currentUbColFactor) {
break;
}
srcGlobal_.SetGlobalBuffer(inputList_.GetDataPtr<T1>(i));
TensorInfo tensorInfo{};
inputList_.GetDesc(desc_, i);
tensorInfo.chunkDimSize = desc_.GetShape(dim_);
tensorInfo.chunkCol = (tensorInfo.chunkDimSize + numChunk_ - 1) / numChunk_;
CopyInChunk(totalCol, localOffset, ubLoopInfo, tensorInfo);
if (isOneConcat_ && ubLoopInfo.count > 31) {
Compute(ubLoopInfo);
CopyCout(ubLoopInfo);
localOffset = 0;
ubLoopInfo.preCatCol += ubLoopInfo.totalUbCol;
ubLoopInfo.count = 0;
ubLoopInfo.totalUbCol = 0;
ubLoopInfo.totalUbColAlign = 0;
}
else if (ubLoopInfo.count > 31) {
if (!ubLoopInfo.isAllZero) {
SetFlag<HardEvent::MTE2_V>(event_);
WaitFlag<HardEvent::MTE2_V>(event_);
if (isAllAlign_) {
UBRearrange4Concat(ubLoopInfo, srcLocal_, dstLocalT1_);
DataCopy(srcLocal_, dstLocalT1_, ubLoopInfo.currentUbRowFactor * ubLoopInfo.totalUbCol);
} else {
UBRearrange4Trans(ubLoopInfo, srcLocal_, dstLocalT1_);
Trans1(ubLoopInfo, dstLocalT1_, srcLocal_);
UBRearrange4TransConcat<true>(ubLoopInfo, srcLocal_, dstLocalT1_);
Trans2<true>(ubLoopInfo, dstLocalT1_, srcLocal_);
}
SetFlag<HardEvent::V_MTE2>(event_);
WaitFlag<HardEvent::V_MTE2>(event_);
}
ubLoopInfo.inputCol[0] = ubLoopInfo.totalUbCol;
ubLoopInfo.count = 1;
ubLoopInfo.totalUbColAlign = ubLoopInfo.totalUbCol;
localOffset = isAllAlign_ ? ubLoopInfo.currentUbRowFactor * ubLoopInfo.totalUbCol :
(isHalfAlign_ ? TRANS_BLOCK * HALF * ubLoopInfo.totalUbCol :
TRANS_BLOCK * srcEleUbBlock_ * ubLoopInfo.totalUbCol);
}
}
}
__aicore__ inline void Compute(const UbLoopInfo& ubLoopInfo)
{
if (isOneConcat_) {
ComputeOneConcat(ubLoopInfo);
}
else if (ubLoopInfo.isAllZero) {
PipeBarrier<PIPE_V>();
SetFlag<HardEvent::MTE3_V>(event_);
WaitFlag<HardEvent::MTE3_V>(event_);
if constexpr (NEAD_CAST) {
uint32_t castCount = ubLoopInfo.currentUbRowFactor * GetAlign(ubLoopInfo.currentUbColFactor, dstEleUbBlock_);
DoCast(ubLoopInfo, castCount);
} else {
DataCopy(dstLocalT1_, srcLocal_, ubLoopInfo.currentUbRowFactor * GetAlign(ubLoopInfo.currentUbColFactor, dstEleUbBlock_));
}
}
else if (ubLoopInfo.count == 1 && ubLoopInfo.currentUbColFactor % srcEleUbBlock_ == 0) {
SetFlag<HardEvent::MTE2_V>(event_);
WaitFlag<HardEvent::MTE2_V>(event_);
if constexpr (NEAD_CAST) {
uint32_t castCount = ubLoopInfo.currentUbRowFactor * ubLoopInfo.currentUbColFactor;
DoCast(ubLoopInfo, castCount);
} else {
DataCopy(dstLocalT1_, srcLocal_, ubLoopInfo.currentUbRowFactor * ubLoopInfo.currentUbColFactor);
}
}
else if (isAllAlign_) {
ComputeAllAlign(ubLoopInfo);
} else {
ComputeNotAlign(ubLoopInfo);
}
SetFlag<HardEvent::V_MTE3>(event_);
WaitFlag<HardEvent::V_MTE3>(event_);
}
__aicore__ inline void CopyCout(const UbLoopInfo& ubLoopInfo)
{
if (isOneConcat_) {
int64_t localOffset = 0;
int64_t globalOffset = ubLoopInfo.ubRowGroup * ubRowFactor_ * outputCol_ + ubLoopInfo.ubColGroup * ubColFactor_ + ubLoopInfo.preCatCol;
for (int i = 0; i < ubLoopInfo.count; i++) {
uint16_t blockCount = ubLoopInfo.currentUbRowFactor;
uint32_t blockLen = ubLoopInfo.inputCol[i] * sizeof(T2);
uint32_t dstStride = (outputCol_ - ubLoopInfo.inputCol[i]) * sizeof(T2);
DataCopyExtParams copyParamsOut{blockCount, blockLen, 0, dstStride, 0};
DataCopyPad(dstGlobal_[globalOffset], dstLocal_[localOffset], copyParamsOut);
localOffset += GetAlign(ubLoopInfo.inputCol[i], srcEleUbBlock_);
globalOffset += ubLoopInfo.inputCol[i];
}
SetFlag<HardEvent::MTE3_MTE2>(event_);
WaitFlag<HardEvent::MTE3_MTE2>(event_);
return;
}
uint16_t blockCount = ubLoopInfo.currentUbRowFactor;
uint32_t blockLen = ubLoopInfo.currentUbColFactor * sizeof(T2);
uint32_t dstStride = (outputCol_ - ubLoopInfo.currentUbColFactor)* sizeof(T2);
DataCopyExtParams copyParamsOut{blockCount, blockLen, 0, dstStride, 0};
int64_t dstOffset = ubLoopInfo.ubRowGroup * ubRowFactor_ * outputCol_ + ubLoopInfo.ubColGroup * ubColFactor_;
DataCopyPad(dstGlobal_[dstOffset], dstLocal_, copyParamsOut);
SetFlag<HardEvent::MTE3_MTE2>(event_);
WaitFlag<HardEvent::MTE3_MTE2>(event_);
}
__aicore__ inline int64_t GetAlign(int64_t value, int64_t align)
{
return align == 0 ? value : (value + align - 1) / align * align;
}
__aicore__ inline void dupToZero()
{
T1 inputVal(0.0);
Duplicate<T1>(srcLocal_, inputVal, srcLocal_.GetSize());
SetFlag<HardEvent::V_MTE2>(event_);
WaitFlag<HardEvent::V_MTE2>(event_);
}
__aicore__ inline bool IsTensorInRange(int64_t totalCol, const UbLoopInfo& ubLoopInfo, const TensorInfo& tensorInfo)
{
return (totalCol < ubLoopInfo.colStart + ubLoopInfo.currentUbColFactor) &&
(totalCol + tensorInfo.tensorCol > ubLoopInfo.colStart);
}
__aicore__ inline void SplitTensorDim0(int64_t& totalCol, const UbLoopInfo& ubLoopInfo, TensorInfo& tensorInfo)
{
tensorInfo.splitCol = tensorInfo.tensorCol;
int64_t colEnd = ubLoopInfo.colStart + ubLoopInfo.currentUbColFactor;
if (totalCol < ubLoopInfo.colStart && (totalCol + tensorInfo.tensorCol) > colEnd) {
tensorInfo.isSplit = true;
tensorInfo.splitCol = ubLoopInfo.currentUbColFactor;
tensorInfo.startOffset = ubLoopInfo.colStart - totalCol;
} else if (totalCol < ubLoopInfo.colStart) {
tensorInfo.isSplit = true;
tensorInfo.splitCol = totalCol + tensorInfo.tensorCol - ubLoopInfo.colStart;
tensorInfo.startOffset = ubLoopInfo.colStart - totalCol;
} else if ((totalCol + tensorInfo.tensorCol) > colEnd) {
tensorInfo.isSplit = true;
tensorInfo.splitCol = colEnd - totalCol;
}
tensorInfo.splitColAlign = GetAlign(tensorInfo.splitCol, srcEleUbBlock_);
}
__aicore__ inline void ExecuteDataCopy(int64_t localOffset, int64_t gmOffset, uint16_t blockCount,
uint32_t blockLen, uint32_t srcStride)
{
AscendC::DataCopyExtParams copyParams{blockCount, blockLen, srcStride, 0, 0};
uint8_t rightPadValue = (GetAlign(blockLen, UB_BLOCK_SIZE) - blockLen) / sizeof(T1);
AscendC::DataCopyPadExtParams<T1> padParams{true, 0, rightPadValue, 0};
AscendC::DataCopyPad(srcLocal_[localOffset], srcGlobal_[gmOffset], copyParams, padParams);
}
__aicore__ inline void DoRowsCopy(int64_t localOffset, const UbLoopInfo& ubLoopInfo, const TensorInfo& tensorInfo)
{
uint16_t blockCount = tensorInfo.isSplit ? static_cast<uint16_t>(ubLoopInfo.currentUbRowFactor) : 1;
uint32_t blockLen = tensorInfo.isSplit ?
static_cast<uint32_t>(tensorInfo.splitCol * sizeof(T1)) :
static_cast<uint32_t>(ubLoopInfo.currentUbRowFactor * tensorInfo.splitCol * sizeof(T1));
uint32_t srcStride = (tensorInfo.tensorCol - tensorInfo.splitCol) * sizeof(T1);
int64_t gmOffset = ubLoopInfo.rowStart * tensorInfo.tensorCol + tensorInfo.startOffset;
ExecuteDataCopy(localOffset, gmOffset, blockCount, blockLen, srcStride);
}
__aicore__ inline void DoLastRowsCopy(int64_t localOffset, const UbLoopInfo& ubLoopInfo, const TensorInfo& tensorInfo)
{
int64_t srcGmOffset = ubLoopInfo.rowStart * tensorInfo.tensorCol;
uint32_t srcStride = (tensorInfo.tensorCol - tensorInfo.splitCol) * sizeof(T1);
if (!tensorInfo.isSplit) {
uint32_t blockLen = static_cast<uint32_t>(
(tensorInfo.chunkDimSize * tensorInfo.originCol - ubLoopInfo.rowStart * tensorInfo.tensorCol) * sizeof(T1));
ExecuteDataCopy(localOffset, srcGmOffset + tensorInfo.startOffset, 1, blockLen, srcStride);
return;
}
uint16_t blockCount = 0;
uint32_t blockLen = 0;
int64_t remainderCol = (tensorInfo.chunkDimSize % tensorInfo.chunkCol) * tensorInfo.originCol;
if (remainderCol == 0) {
blockCount = static_cast<uint16_t>(tensorInfo.chunkRow - ubLoopInfo.rowStart);
blockLen = static_cast<uint32_t>(tensorInfo.splitCol * sizeof(T1));
}
else if (tensorInfo.startOffset + tensorInfo.splitCol <= remainderCol) {
blockCount = static_cast<uint16_t>(tensorInfo.chunkRowAlign - ubLoopInfo.rowStart);
blockLen = static_cast<uint32_t>(tensorInfo.splitCol * sizeof(T1));
}
else if (tensorInfo.startOffset >= remainderCol) {
blockCount = static_cast<uint16_t>(tensorInfo.chunkRow - ubLoopInfo.rowStart);
blockLen = static_cast<uint32_t>(tensorInfo.splitCol * sizeof(T1));
}
else {
blockLen = (remainderCol - tensorInfo.startOffset) * sizeof(T1);
int64_t localOffsetPart = localOffset + (tensorInfo.chunkRow - ubLoopInfo.rowStart) * tensorInfo.splitColAlign;
int64_t gmOffsetPart = srcGmOffset + tensorInfo.startOffset +
(tensorInfo.chunkRow - ubLoopInfo.rowStart) * tensorInfo.tensorCol;
ExecuteDataCopy(localOffsetPart, gmOffsetPart, 1, blockLen, srcStride);
blockCount = static_cast<uint16_t>(tensorInfo.chunkRow - ubLoopInfo.rowStart);
blockLen = static_cast<uint32_t>(tensorInfo.splitCol * sizeof(T1));
}
ExecuteDataCopy(localOffset, srcGmOffset + tensorInfo.startOffset, blockCount, blockLen, srcStride);
}
__aicore__ inline void CopyInChunk(int64_t& totalCol, int64_t& localOffset, UbLoopInfo& ubLoopInfo, TensorInfo& tensorInfo)
{
for (uint32_t j = 1; j < desc_.GetDim(); j++) {
tensorInfo.originCol *= desc_.GetShape(j);
}
tensorInfo.tensorCol = tensorInfo.chunkCol * tensorInfo.originCol;
if (!IsTensorInRange(totalCol, ubLoopInfo, tensorInfo)) {
totalCol += tensorInfo.tensorCol;
return;
}
SplitTensorDim0(totalCol, ubLoopInfo, tensorInfo);
tensorInfo.chunkRow = tensorInfo.chunkDimSize / tensorInfo.chunkCol;
tensorInfo.chunkRowAlign = GetAlign(tensorInfo.chunkDimSize, tensorInfo.chunkCol) / tensorInfo.chunkCol;
int64_t localOffsetIncrement = (isOneConcat_ || isAllAlign_) ? ubLoopInfo.currentUbRowFactor :
(isHalfAlign_ ? TRANS_BLOCK * HALF : TRANS_BLOCK * srcEleUbBlock_);
if (ubLoopInfo.rowStart >= tensorInfo.chunkRowAlign) {
ubLoopInfo.inputCol[ubLoopInfo.count] = tensorInfo.splitCol;
ubLoopInfo.totalUbColAlign += tensorInfo.splitCol;
ubLoopInfo.totalUbCol += tensorInfo.splitCol;
ubLoopInfo.count++;
totalCol += tensorInfo.tensorCol;
localOffset += localOffsetIncrement * tensorInfo.splitCol;
return;
}
ubLoopInfo.isAllZero = false;
int64_t rowEnd = ubLoopInfo.rowStart + ubLoopInfo.currentUbRowFactor;
if (rowEnd < tensorInfo.chunkRowAlign) {
DoRowsCopy(localOffset, ubLoopInfo, tensorInfo);
} else {
DoLastRowsCopy(localOffset, ubLoopInfo, tensorInfo);
}
localOffsetIncrement *= (isOneConcat_ || tensorInfo.isSplit) ? tensorInfo.splitColAlign : tensorInfo.splitCol;
if (isOneConcat_) {
ubLoopInfo.inputCol[ubLoopInfo.count] = tensorInfo.splitCol;
ubLoopInfo.totalUbColAlign += tensorInfo.splitColAlign;
} else if (isAllAlign_) {
ubLoopInfo.inputCol[ubLoopInfo.count] = tensorInfo.splitCol;
ubLoopInfo.totalUbColAlign += tensorInfo.splitCol;
} else if (tensorInfo.isSplit) {
ubLoopInfo.inputCol[ubLoopInfo.count] = -tensorInfo.splitCol;
ubLoopInfo.totalUbColAlign += tensorInfo.splitColAlign;
} else {
ubLoopInfo.inputCol[ubLoopInfo.count] = tensorInfo.splitCol;
ubLoopInfo.totalUbColAlign += tensorInfo.splitCol;
}
ubLoopInfo.totalUbCol += tensorInfo.splitCol;
ubLoopInfo.count++;
totalCol += tensorInfo.tensorCol;
localOffset += localOffsetIncrement;
}
__aicore__ inline void UBRearrange4Trans(const UbLoopInfo& ubLoopInfo, LocalTensor<T1>& srcLocal, LocalTensor<T1>& dstLocal)
{
int64_t srcOffset = 0;
int64_t dstOffset = 0;
for (int64_t i = 0; i < ubLoopInfo.count; i++) {
uint16_t blockCount = TRANS_BLOCK;
uint16_t actualCol = ubLoopInfo.inputCol[i] > 0 ? ubLoopInfo.inputCol[i] :
GetAlign(-ubLoopInfo.inputCol[i], srcEleUbBlock_);
uint16_t blockLen = actualCol * colRepeatNum_ / srcEleUbBlock_;
uint16_t dstGap = ubLoopInfo.totalUbColAlign * colRepeatNum_ / srcEleUbBlock_ - blockLen;
DataCopyParams copyParams{blockCount, blockLen, 0, dstGap};
DataCopy(dstLocal[dstOffset], srcLocal[srcOffset], copyParams);
srcOffset += blockCount * blockLen * srcEleUbBlock_;
dstOffset += blockLen * srcEleUbBlock_;
}
PipeBarrier<PIPE_V>();
}
__aicore__ inline void Trans1(const UbLoopInfo& ubLoopInfo, LocalTensor<T1>& srcLocal, LocalTensor<T1>& dstLocal)
{
uint8_t repeatTimes = ubLoopInfo.totalUbColAlign * colRepeatNum_ / srcEleUbBlock_;
uint16_t srcRepStride = repeatTimes == 1 ? 0 : 1;
uint16_t dstRepStride = repeatTimes == 1 ? 0 : TRANS_BLOCK;
TransDataTo5HDParams transDataParams{false, false, repeatTimes, dstRepStride, srcRepStride};
uint64_t srcLocalList[TRANS_BLOCK];
uint64_t dstLocalList[TRANS_BLOCK];
if constexpr (sizeof(T1) == 2) {
LocalTensor<half> srcLocalFP16 = srcLocal.template ReinterpretCast<half>();
LocalTensor<half> dstLocalFP16 = dstLocal.template ReinterpretCast<half>();
for (int i = 0; i < TRANS_BLOCK; i++) {
uint64_t offset = i * ubLoopInfo.totalUbColAlign * colRepeatNum_;
srcLocalList[i] = reinterpret_cast<uint64_t>(srcLocalFP16[offset].GetPhyAddr());
}
for (int i = 0; i < TRANS_BLOCK; i++) {
uint64_t offset = i * TRANS_BLOCK;
dstLocalList[i] = reinterpret_cast<uint64_t>(dstLocalFP16[offset].GetPhyAddr());
}
TransDataTo5HD<half>(dstLocalList, srcLocalList, transDataParams);
} else {
for (int i = 0; i < TRANS_BLOCK; i++) {
uint64_t offset = i * ubLoopInfo.totalUbColAlign * colRepeatNum_;
srcLocalList[i] = reinterpret_cast<uint64_t>(srcLocal[offset].GetPhyAddr());
}
for (uint64_t i = 0; i < srcEleUbBlock_; i++) {
for (uint64_t j = 0; j < TRANS_BLOCK / srcEleUbBlock_; j++) {
uint64_t offset = i * TRANS_BLOCK + j * srcEleUbBlock_;
dstLocalList[i * TRANS_BLOCK / srcEleUbBlock_ + j] =
reinterpret_cast<uint64_t>(dstLocal[offset].GetPhyAddr());
}
}
TransDataTo5HD<T1>(dstLocalList, srcLocalList, transDataParams);
}
PipeBarrier<PIPE_V>();
}
template <bool NO_NEED_ALIGN=false>
__aicore__ inline void UBRearrange4TransConcat(const UbLoopInfo& ubLoopInfo, LocalTensor<T1>& srcLocal, LocalTensor<T1>& dstLocal)
{
int64_t srcOffset = 0;
int64_t dstOffset = 0;
for (int64_t i = 0; i < ubLoopInfo.count; i++) {
uint16_t blockCount = colRepeatNum_;
uint16_t actualCol = ubLoopInfo.inputCol[i] > 0 ? ubLoopInfo.inputCol[i] : -ubLoopInfo.inputCol[i];
uint16_t blockLen = actualCol * TRANS_BLOCK / srcEleUbBlock_;
uint16_t srcGap = ubLoopInfo.inputCol[i] > 0 ? 0 :
(GetAlign(-ubLoopInfo.inputCol[i], srcEleUbBlock_) + ubLoopInfo.inputCol[i]) * TRANS_BLOCK / srcEleUbBlock_;
uint16_t dstGap = GetAlign(ubLoopInfo.totalUbCol, dstEleUbBlock_) * TRANS_BLOCK / srcEleUbBlock_ - blockLen;
if constexpr (NO_NEED_ALIGN) {
dstGap = ubLoopInfo.totalUbCol * TRANS_BLOCK / srcEleUbBlock_ - blockLen;
}
DataCopyParams copyParams{blockCount, blockLen, srcGap, dstGap};
DataCopy(dstLocal[dstOffset], srcLocal[srcOffset], copyParams);
srcOffset += ubLoopInfo.inputCol[i] > 0 ? (colRepeatNum_ * actualCol * TRANS_BLOCK) :
(colRepeatNum_ * GetAlign(actualCol, srcEleUbBlock_) * TRANS_BLOCK);
dstOffset += (actualCol * TRANS_BLOCK);
}
PipeBarrier<PIPE_V>();
}
template <bool NO_NEED_ALIGN=false>
__aicore__ inline void Trans2(const UbLoopInfo& ubLoopInfo, LocalTensor<T1>& srcLocal, LocalTensor<T1>& dstLocal)
{
int64_t actualTotalUbCol = GetAlign(ubLoopInfo.totalUbCol, dstEleUbBlock_);
if constexpr (NO_NEED_ALIGN) {
actualTotalUbCol = ubLoopInfo.totalUbCol;
}
uint8_t repeatTimes = actualTotalUbCol * colRepeatNum_ / srcEleUbBlock_;
uint16_t srcRepStride = repeatTimes == 1 ? 0 : TRANS_BLOCK;
uint16_t dstRepStride = repeatTimes == 1 ? 0 : 1;
TransDataTo5HDParams transDataParams = {false, false, repeatTimes, dstRepStride, srcRepStride};
uint64_t srcLocalList[TRANS_BLOCK];
uint64_t dstLocalList[TRANS_BLOCK];
if (sizeof(T1) == 2) {
LocalTensor<half> srcLocalFP16 = srcLocal.template ReinterpretCast<half>();
LocalTensor<half> dstLocalFP16 = dstLocal.template ReinterpretCast<half>();
for (int i = 0; i < TRANS_BLOCK; i++) {
uint64_t offset = i * TRANS_BLOCK;
srcLocalList[i] = reinterpret_cast<uint64_t>(srcLocalFP16[offset].GetPhyAddr());
}
for (int i = 0; i < TRANS_BLOCK; i++) {
uint64_t offset = i * actualTotalUbCol * colRepeatNum_;
dstLocalList[i] = reinterpret_cast<uint64_t>(dstLocalFP16[offset].GetPhyAddr());
}
TransDataTo5HD<half>(dstLocalList, srcLocalList, transDataParams);
} else {
for (uint64_t i = 0; i < TRANS_BLOCK / srcEleUbBlock_; i++) {
for (uint64_t j = 0; j < srcEleUbBlock_; j++) {
uint64_t offset = i * srcEleUbBlock_ + j * TRANS_BLOCK;
srcLocalList[i * srcEleUbBlock_ + j] =
reinterpret_cast<uint64_t>(srcLocal[offset].GetPhyAddr());
}
}
for (uint64_t i = 0; i < TRANS_BLOCK; i += 2) {
uint64_t offset = (i / 2) * actualTotalUbCol * colRepeatNum_;
dstLocalList[i] = reinterpret_cast<uint64_t>(dstLocal[offset].GetPhyAddr());
}
for (uint64_t i = 1; i < TRANS_BLOCK; i += 2) {
uint64_t offset = (i / 2 + srcEleUbBlock_) * actualTotalUbCol * colRepeatNum_;
dstLocalList[i] = reinterpret_cast<uint64_t>(dstLocal[offset].GetPhyAddr());
}
TransDataTo5HD<T1>(dstLocalList, srcLocalList, transDataParams);
}
PipeBarrier<PIPE_V>();
}
__aicore__ inline void UBRearrange4Concat(const UbLoopInfo& ubLoopInfo, LocalTensor<T1>& srcLocal, LocalTensor<T1>& dstLocal)
{
int64_t srcOffset = 0;
int64_t dstOffset = 0;
for (int64_t i = 0; i < ubLoopInfo.count; i++) {
uint16_t blockCount = ubLoopInfo.currentUbRowFactor;
uint16_t blockLen = ubLoopInfo.inputCol[i] / srcEleUbBlock_;
uint16_t srcGap = 0;
uint16_t dstGap = (ubLoopInfo.totalUbColAlign - ubLoopInfo.inputCol[i]) / srcEleUbBlock_;
DataCopyParams copyParams{blockCount, blockLen, srcGap, dstGap};
DataCopy(dstLocal[dstOffset], srcLocal[srcOffset], copyParams);
srcOffset += blockCount * ubLoopInfo.inputCol[i];
dstOffset += ubLoopInfo.inputCol[i];
}
PipeBarrier<PIPE_V>();
}
__aicore__ inline void DoCast(const UbLoopInfo& ubLoopInfo, uint32_t castCount)
{
if constexpr (sizeof(T1) == sizeof(T2)) {
Cast(dstLocalFP32_, srcLocal_, RoundMode::CAST_NONE, castCount);
Cast(srcLocalT2_, dstLocalFP32_, RoundMode::CAST_RINT, castCount);
DataCopy(dstLocal_, srcLocalT2_, castCount);
} else {
Cast(dstLocal_, srcLocal_, RoundMode::CAST_NONE, castCount);
}
}
__aicore__ inline void ComputeOneConcat(const UbLoopInfo& ubLoopInfo)
{
SetFlag<HardEvent::MTE2_V>(event_);
WaitFlag<HardEvent::MTE2_V>(event_);
if (ubLoopInfo.totalUbColAlign == 0) {
return;
}
if constexpr (NEAD_CAST) {
uint32_t castCount = ubLoopInfo.currentUbRowFactor * ubLoopInfo.totalUbColAlign;
DoCast(ubLoopInfo, castCount);
} else {
DataCopy(dstLocalT1_, srcLocal_, ubLoopInfo.totalUbColAlign);
}
}
__aicore__ inline void ComputeAllAlign(const UbLoopInfo& ubLoopInfo)
{
SetFlag<HardEvent::MTE2_V>(event_);
WaitFlag<HardEvent::MTE2_V>(event_);
UBRearrange4Concat(ubLoopInfo, srcLocal_, dstLocalT1_);
if constexpr (NEAD_CAST) {
DataCopy(srcLocal_, dstLocalT1_, ubLoopInfo.currentUbRowFactor * ubLoopInfo.currentUbColFactor);
PipeBarrier<PIPE_V>();
uint32_t castCount = ubLoopInfo.currentUbRowFactor * ubLoopInfo.currentUbColFactor;
DoCast(ubLoopInfo, castCount);
}
}
__aicore__ inline void ComputeNotAlign(const UbLoopInfo& ubLoopInfo)
{
SetFlag<HardEvent::MTE2_V>(event_);
WaitFlag<HardEvent::MTE2_V>(event_);
UBRearrange4Trans(ubLoopInfo, srcLocal_, dstLocalT1_);
Trans1(ubLoopInfo, dstLocalT1_, srcLocal_);
UBRearrange4TransConcat(ubLoopInfo, srcLocal_, dstLocalT1_);
Trans2(ubLoopInfo, dstLocalT1_, srcLocal_);
if constexpr (NEAD_CAST) {
uint32_t castCount = ubLoopInfo.currentUbRowFactor * GetAlign(ubLoopInfo.currentUbColFactor, dstEleUbBlock_);
DoCast(ubLoopInfo, castCount);
} else {
DataCopy(dstLocalT1_, srcLocal_, ubLoopInfo.currentUbRowFactor * GetAlign(ubLoopInfo.currentUbColFactor, dstEleUbBlock_));
}
}
private:
bool isAllAlign_{false};
bool isHalfAlign_{false};
bool isOneConcat_{false};
int64_t blockIdx_{0};
int64_t inputNum_{0};
int64_t ubRowFactor_{0};
int64_t ubColFactor_{0};
int64_t srcEleUbBlock_{0};
int64_t dstEleUbBlock_{0};
int64_t dim_{0};
int64_t numChunk_{0};
int64_t outputRow_{0};
int64_t outputCol_{0};
int64_t blockRowFactor_{0};
int64_t blockColFactor_{0};
int64_t tailBlockRowFactor_{0};
int64_t tailBlockColFactor_{0};
int64_t blockRowGroup_{0};
int64_t blockColGroup_{0};
int64_t currentBlockRowFactor_{0};
int64_t currentBlockColFactor_{0};
int64_t colRepeatNum_{0};
TPipe *pipe_;
TEventID event_{0};
TensorDesc<T1> desc_;
ListTensorDesc inputList_;
GlobalTensor<T2> dstGlobal_;
GlobalTensor<T1> srcGlobal_;
TBuf<AscendC::TPosition::VECCALC> srcBuf_;
TBuf<AscendC::TPosition::VECCALC> dstBuf_;
LocalTensor<T1> srcLocal_;
LocalTensor<T2> srcLocalT2_;
LocalTensor<T2> dstLocal_;
LocalTensor<T1> dstLocalT1_;
LocalTensor<float> dstLocalFP32_;
};
#endif