* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CATLASS_GEMM_KERNEL_PADDING_MATMUL_HPP
#define CATLASS_GEMM_KERNEL_PADDING_MATMUL_HPP
#include "catlass/catlass.hpp"
#include "catlass/coord.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/arch/cross_core_sync.hpp"
#include "catlass/epilogue/tile/copy_gm_to_ub.hpp"
#include "catlass/epilogue/tile/copy_ub_to_gm.hpp"
namespace Catlass::Gemm::Kernel {
enum class PaddingTag {
NO_PADDING,
PADDING_ND,
PADDING_BLOCK_ND,
PADDING_NZ
};
template<
class ArchTag_,
class Element_,
class LayoutIn_,
class LayoutOut_,
uint32_t COMPUTE_LENGTH
>
struct PaddingMatrixNZ {
using ArchTag = ArchTag_;
using Element = Element_;
using LayoutIn = LayoutIn_;
using LayoutOut = LayoutOut_;
using ComputeLayoutSrc = Catlass::layout::RowMajor;
using ComputeLayoutDst = Catlass::layout::zN;
using CopyGm2Ub = Catlass::Epilogue::Tile::CopyGm2Ub<ArchTag, Gemm::GemmType<Element, ComputeLayoutSrc>>;
constexpr static uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element);
constexpr static uint32_t VNCHW_SIZE = 16;
static const PaddingTag paddingTag = PaddingTag::PADDING_NZ;
CATLASS_HOST_DEVICE static LayoutOut GetWorkspaceLayout(const LayoutIn& layout)
{
return LayoutOut::template MakeLayout<Element>(layout.shape(0), layout.shape(1));
}
static size_t GetWorkspaceSize(uint32_t rows, uint32_t cols)
{
if constexpr (std::is_same_v<LayoutIn, Catlass::layout::RowMajor>) {
return static_cast<size_t>(
RoundUp(rows, Catlass::C0_NUM_PER_FRACTAL)) * RoundUp(cols, ELE_NUM_PER_C0) * sizeof(Element);
} else {
return static_cast<size_t>(
RoundUp(rows, ELE_NUM_PER_C0)) * RoundUp(cols, Catlass::C0_NUM_PER_FRACTAL) * sizeof(Element);
}
}
CopyGm2Ub copyGm2Ub;
CATLASS_DEVICE
PaddingMatrixNZ(Arch::Resource<ArchTag> &resource)
{
int64_t bufferOffset = 0;
for (uint32_t i = 0; i < BUFFER_NUM; i++) {
inputBuffer[i] = resource.ubBuf.template GetBufferByByte<Element>(bufferOffset * sizeof(Element));
bufferOffset += COMPUTE_LENGTH;
outputBuffer[i] = resource.ubBuf.template GetBufferByByte<Element>(bufferOffset * sizeof(Element));
bufferOffset += COMPUTE_LENGTH;
}
}
CATLASS_DEVICE
void CopyUb2Ub(AscendC::LocalTensor<Element> const &dst, AscendC::LocalTensor<Element> const &src,
ComputeLayoutDst const &layoutDst, ComputeLayoutSrc const &layoutSrc)
{
uint32_t loops = CeilDiv(layoutSrc.shape(0), BLK_NUM_PER_VECTOR_FRACTAL);
for (uint32_t i = 0; i < loops; ++i) {
uint64_t offsetSrc = i * BLK_NUM_PER_VECTOR_FRACTAL * layoutSrc.stride(0);
uint64_t offsetDst = i * BLK_NUM_PER_VECTOR_FRACTAL * layoutDst.stride(0);
uint32_t repeatTimes = CeilDiv(layoutSrc.shape(1), ELE_NUM_PER_C0);
uint64_t mask = 256 / sizeof(Element);
AscendC::CopyRepeatParams copyRepeatParams {
1, static_cast<uint16_t>(layoutSrc.stride(0) / ELE_NUM_PER_C0),
static_cast<uint16_t>(layoutDst.stride(3) / ELE_NUM_PER_C0), 1
};
AscendC::Copy(dst[offsetDst], src[offsetSrc], mask, repeatTimes, copyRepeatParams);
}
}
CATLASS_DEVICE
void CopyUb2Gm(AscendC::GlobalTensor<Element> const &dst, AscendC::LocalTensor<Element> const &src,
ComputeLayoutDst const &layoutDst, ComputeLayoutDst const &layoutSrc)
{
if (layoutDst.stride(3) / ELE_NUM_PER_C0 < STRIDE_LIMIT) {
AscendC::DataCopyParams dataCopyParams(
layoutDst.shape(3),
layoutDst.shape(0) * layoutDst.shape(1),
(layoutSrc.stride(3) - layoutSrc.shape(0) * layoutSrc.shape(1) * layoutSrc.shape(2)) / ELE_NUM_PER_C0,
(layoutDst.stride(3) - layoutDst.shape(0) * layoutDst.shape(1) * layoutDst.shape(2)) / ELE_NUM_PER_C0
);
AscendC::DataCopy(dst, src, dataCopyParams);
} else {
uint32_t blockCount = CeilDiv<ELE_NUM_PER_C0>(layoutDst.orgShape(1));
uint32_t blockLen = RoundUp<C0_NUM_PER_FRACTAL>(layoutDst.orgShape(0));
AscendC::DataCopyParams dataCopyParams(1, blockLen, 0, 0);
for (uint32_t i = 0; i < blockCount; i++) {
uint64_t dstOffset = i * layoutDst.stride(3);
uint64_t srcOffset = i * layoutSrc.stride(3);
AscendC::DataCopy(dst[dstOffset], src[srcOffset], dataCopyParams);
}
}
}
CATLASS_DEVICE
ComputeLayoutSrc GetPaddingComputeLayout(layout::RowMajor const &layout)
{
return layout;
}
CATLASS_DEVICE
ComputeLayoutSrc GetPaddingComputeLayout(layout::ColumnMajor const &layout)
{
return ComputeLayoutSrc(layout.shape(1), layout.shape(0), layout.stride(1));
}
CATLASS_DEVICE
ComputeLayoutDst GetPaddingComputeLayout(layout::zN const &layout)
{
return layout;
}
CATLASS_DEVICE
ComputeLayoutDst GetPaddingComputeLayout(layout::nZ const &layout)
{
return ComputeLayoutDst::MakeLayout<Element>(layout.orgShape(1), layout.orgShape(0));
}
CATLASS_DEVICE
void CommonND2NZ(AscendC::GlobalTensor<Element> const &dst, AscendC::GlobalTensor<Element> const &src,
ComputeLayoutDst const &computeLayoutDst, ComputeLayoutSrc const &computeLayoutSrc)
{
uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum();
uint32_t aivId = AscendC::GetBlockIdx();
uint32_t refTileRows = 16;
uint32_t refTileCols = COMPUTE_LENGTH / refTileRows;
uint32_t tileRows = refTileRows;
uint32_t tileCols = RoundUp(
computeLayoutSrc.shape(1) / CeilDiv(computeLayoutSrc.shape(1), refTileCols), ELE_NUM_PER_C0);
uint32_t rowTiles = CeilDiv(computeLayoutSrc.shape(0), tileRows);
uint32_t colTiles = CeilDiv(computeLayoutSrc.shape(1), tileCols);
uint32_t totalTiles = rowTiles * colTiles;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventIds[0]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventIds[1]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventIds[0]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventIds[1]);
for (uint32_t loopIdx = aivId; loopIdx < totalTiles; loopIdx += aivNum) {
uint32_t rowTileIdx = loopIdx / colTiles;
uint32_t colTileIdx = loopIdx % colTiles;
MatrixCoord tileCoord{rowTileIdx * tileRows, colTileIdx * tileCols};
uint32_t rowTileActual = (rowTileIdx == rowTiles - 1) ?
(computeLayoutSrc.shape(0) - rowTileIdx * tileRows) : tileRows;
uint32_t colTileActual = (colTileIdx == colTiles - 1) ?
(computeLayoutSrc.shape(1) - colTileIdx * tileCols) : tileCols;
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventIds[bufferIndex]);
uint64_t tileOffsetSrc = computeLayoutSrc.GetOffset(tileCoord);
ComputeLayoutSrc tileLayoutSrc = computeLayoutSrc.GetTileLayout(MatrixCoord(rowTileActual, colTileActual));
ComputeLayoutSrc ubLayoutIn = ComputeLayoutSrc{rowTileActual, RoundUp(colTileActual, ELE_NUM_PER_C0)};
if (tileCols < computeLayoutSrc.shape(1) || (computeLayoutSrc.shape(1) % ELE_NUM_PER_C0)) {
copyGm2Ub(inputBuffer[bufferIndex], src[tileOffsetSrc], ubLayoutIn, tileLayoutSrc);
} else {
ComputeLayoutSrc tileLayoutSrcConti =
computeLayoutSrc.GetTileLayout(MatrixCoord(1, rowTileActual * colTileActual));
ComputeLayoutSrc ubLayoutInConti = ComputeLayoutSrc{1, rowTileActual * colTileActual};
copyGm2Ub(inputBuffer[bufferIndex], src[tileOffsetSrc], ubLayoutInConti, tileLayoutSrcConti);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventIds[bufferIndex]);
ComputeLayoutDst ubLayoutOut = ComputeLayoutDst::template MakeLayout<Element>(rowTileActual, colTileActual);
CopyUb2Ub(outputBuffer[bufferIndex], inputBuffer[bufferIndex], ubLayoutOut, ubLayoutIn);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventIds[bufferIndex]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventIds[bufferIndex]);
uint64_t tileOffsetDst = computeLayoutDst.GetOffset(tileCoord);
ComputeLayoutDst tileLayoutDst = computeLayoutDst.GetTileLayout(MatrixCoord(rowTileActual, colTileActual));
CopyUb2Gm(dst[tileOffsetDst], outputBuffer[bufferIndex], tileLayoutDst, ubLayoutOut);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventIds[bufferIndex]);
bufferIndex = 1 - bufferIndex;
}
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventIds[0]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventIds[1]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventIds[0]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventIds[1]);
}
CATLASS_DEVICE
void VnchwND2NZ(AscendC::GlobalTensor<Element> const &dst, AscendC::GlobalTensor<Element> const &src,
ComputeLayoutDst const &computeLayoutDst, ComputeLayoutSrc const &computeLayoutSrc)
{
uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum();
uint32_t aivId = AscendC::GetBlockIdx();
uint64_t dstLocalList[VNCHW_SIZE];
uint64_t srcLocalList[VNCHW_SIZE];
uint32_t vnchwMaxInnerAxisLen = COMPUTE_LENGTH / VNCHW_SIZE / ELE_NUM_PER_C0;
uint32_t innerAxisLen = computeLayoutSrc.shape(1);
uint32_t transMatrixCols = innerAxisLen * ELE_NUM_PER_C0;
uint32_t transMatrixRows = VNCHW_SIZE;
bool cond1 = (computeLayoutSrc.shape(0) > 8192) && (innerAxisLen < vnchwMaxInnerAxisLen);
bool cond2 =
(computeLayoutSrc.shape(0) > 4096) && (innerAxisLen < vnchwMaxInnerAxisLen * 2 && innerAxisLen % 2 == 0);
bool cond3 =
(computeLayoutSrc.shape(0) > 2048) && (innerAxisLen < vnchwMaxInnerAxisLen * 4 && innerAxisLen % 4 == 0);
if (cond2 && !cond1) {
transMatrixCols /= 2;
} else if (cond3 && !cond1) {
transMatrixCols /= 4;
} else if (!cond1) {
if (innerAxisLen % 4 == 0) {
transMatrixCols /= 4;
} else if (innerAxisLen % 2 == 0) {
transMatrixCols /= 2;
}
}
uint32_t innerAxisPerCols = transMatrixCols / innerAxisLen;
uint32_t tileRows = innerAxisPerCols * transMatrixRows;
AscendC::TransDataTo5HDParams paramsRepeatOne{false, false, 1, 0, 0};
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventIds[0]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventIds[1]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventIds[0]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventIds[1]);
uint32_t rowTiles = CeilDiv(computeLayoutSrc.shape(0), tileRows);
for (uint32_t loopIdx = aivId; loopIdx < rowTiles; loopIdx += aivNum) {
uint32_t rowTileIdx = loopIdx;
uint32_t rowTileActual = (rowTileIdx == rowTiles - 1) ?
(computeLayoutSrc.shape(0) - rowTileIdx * tileRows) : tileRows;
MatrixCoord tileCoord {rowTileIdx * tileRows, 0};
uint64_t tileOffsetSrc = computeLayoutSrc.GetOffset(tileCoord);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventIds[bufferIndex]);
ComputeLayoutSrc tileLayoutSrc =
computeLayoutSrc.GetTileLayout(MatrixCoord(1, rowTileActual * innerAxisLen));
ComputeLayoutSrc ubLayoutIn = ComputeLayoutSrc{1, RoundUp(rowTileActual * innerAxisLen, ELE_NUM_PER_C0)};
copyGm2Ub(inputBuffer[bufferIndex], src[tileOffsetSrc], ubLayoutIn, tileLayoutSrc);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventIds[bufferIndex]);
for (uint32_t i = 0; i < VNCHW_SIZE; ++i) {
srcLocalList[i] = i * transMatrixCols * sizeof(Element) + inputBuffer[bufferIndex].GetPhyAddr();
dstLocalList[i] = i * ELE_NUM_PER_C0 * sizeof(Element) + outputBuffer[bufferIndex].GetPhyAddr();
}
uint8_t repeatTimes = transMatrixCols / ELE_NUM_PER_C0;
if (repeatTimes > 1) {
AscendC::TransDataTo5HDParams params{false, false, repeatTimes, VNCHW_SIZE, 1};
AscendC::TransDataTo5HD<Element>(dstLocalList, srcLocalList, params);
} else {
AscendC::TransDataTo5HD<Element>(dstLocalList, srcLocalList, paramsRepeatOne);
}
uint32_t vnchwPerInnerAxis = CeilDiv(innerAxisLen, ELE_NUM_PER_C0);
for (uint32_t i = 0; i < vnchwPerInnerAxis; ++i) {
if constexpr (std::is_same_v<Element, half>) {
for (uint32_t j = 0; j < VNCHW_SIZE; ++j) {
srcLocalList[j] = (j * ELE_NUM_PER_C0 + i * VNCHW_SIZE * ELE_NUM_PER_C0) * sizeof(Element)
+ outputBuffer[bufferIndex].GetPhyAddr();
dstLocalList[j] = (i * tileRows * ELE_NUM_PER_C0 + j * innerAxisPerCols * ELE_NUM_PER_C0)
* sizeof(Element) + inputBuffer[bufferIndex].GetPhyAddr();
}
}
if constexpr (std::is_same_v<Element, float>) {
for (uint32_t j = 0; j < VNCHW_SIZE; ++j) {
srcLocalList[j] = ((2 * (j % 8) + j / 8) * ELE_NUM_PER_C0 + i * VNCHW_SIZE * ELE_NUM_PER_C0)
* sizeof(Element) + outputBuffer[bufferIndex].GetPhyAddr();
dstLocalList[j] = (i * tileRows * ELE_NUM_PER_C0 + (j / 2 + 8 * (j % 2)) * innerAxisPerCols
* ELE_NUM_PER_C0) * sizeof(Element) + inputBuffer[bufferIndex].GetPhyAddr();
}
}
uint8_t repeatTimes = innerAxisPerCols;
if (repeatTimes > 1) {
AscendC::TransDataTo5HDParams params{false, false, repeatTimes, 1,
static_cast<uint16_t>(innerAxisLen * VNCHW_SIZE / ELE_NUM_PER_C0)};
AscendC::TransDataTo5HD<Element>(dstLocalList, srcLocalList, params);
} else {
AscendC::TransDataTo5HD<Element>(dstLocalList, srcLocalList, paramsRepeatOne);
}
}
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventIds[bufferIndex]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventIds[bufferIndex]);
ComputeLayoutDst ubLayoutOut = ComputeLayoutDst::template MakeLayout<Element>(tileRows, innerAxisLen);
ComputeLayoutDst tileLayoutUb = ubLayoutOut.GetTileLayout(MatrixCoord(rowTileActual, innerAxisLen));
uint64_t tileOffsetDst = computeLayoutDst.GetOffset(tileCoord);
ComputeLayoutDst tileLayoutDst = computeLayoutDst.GetTileLayout(MatrixCoord(rowTileActual, innerAxisLen));
CopyUb2Gm(dst[tileOffsetDst], inputBuffer[bufferIndex], tileLayoutDst, tileLayoutUb);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventIds[bufferIndex]);
auto tmpBuffer = inputBuffer[bufferIndex];
inputBuffer[bufferIndex] = outputBuffer[bufferIndex];
outputBuffer[bufferIndex] = tmpBuffer;
bufferIndex = 1 - bufferIndex;
}
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventIds[0]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventIds[1]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventIds[0]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventIds[1]);
}
CATLASS_DEVICE
void operator()(AscendC::GlobalTensor<Element> const &dst, AscendC::GlobalTensor<Element> const &src,
LayoutOut const &layoutDst, LayoutIn const &layoutSrc)
{
auto computeLayoutSrc = GetPaddingComputeLayout(layoutSrc);
auto computeLayoutDst = GetPaddingComputeLayout(layoutDst);
uint32_t vnchwMaxInnerAxisLen = COMPUTE_LENGTH / VNCHW_SIZE / ELE_NUM_PER_C0;
bool cond0 = computeLayoutSrc.shape(1) % ELE_NUM_PER_C0 != 0;
bool cond1 = computeLayoutSrc.shape(1) < vnchwMaxInnerAxisLen && computeLayoutSrc.shape(1);
bool cond2 = computeLayoutSrc.shape(1) < vnchwMaxInnerAxisLen * 2 && computeLayoutSrc.shape(1) % 2 == 0;
bool cond3 = computeLayoutSrc.shape(1) < vnchwMaxInnerAxisLen * 4 && computeLayoutSrc.shape(1) % 4 == 0;
bool cond4 = computeLayoutSrc.shape(0) > 1024;
if (cond0 && cond4 && (cond1 || cond2 || cond3)) {
VnchwND2NZ(dst, src, computeLayoutDst, computeLayoutSrc);
} else {
CommonND2NZ(dst, src, computeLayoutDst, computeLayoutSrc);
}
}
private:
static const uint32_t BUFFER_NUM = 2;
AscendC::LocalTensor<Element> inputBuffer[BUFFER_NUM];
AscendC::LocalTensor<Element> outputBuffer[BUFFER_NUM];
AscendC::TEventID eventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1};
uint32_t bufferIndex{ 0 };
static_assert(BUFFER_NUM * COMPUTE_LENGTH * sizeof(Element) * 2 <= ArchTag::UB_SIZE, "Excedding the UB space!");
static_assert(std::is_same_v<LayoutIn, layout::RowMajor> ||
std::is_same_v<LayoutIn, layout::ColumnMajor>, "Unsported layout for PaddingMatrixNZ!");
};
template<
class ArchTag_,
class Element_,
class LayoutIn_,
class LayoutOut_,
uint32_t COMPUTE_LENGTH
>
struct PaddingMatrixBlockND {
public:
using ArchTag = ArchTag_;
using Element = Element_;
using LayoutIn = LayoutIn_;
using LayoutOut = LayoutOut_;
using ComputeLayout = Catlass::layout::RowMajor;
using ComputeLayoutDst = Catlass::layout::PaddingRowMajor;
using CopyGm2Ub = Catlass::Epilogue::Tile::CopyGm2Ub<
ArchTag, Gemm::GemmType<Element, ComputeLayout>>;
using CopyUb2Gm = Catlass::Epilogue::Tile::CopyUb2Gm<
ArchTag, Gemm::GemmType<Element, ComputeLayout>>;
static const PaddingTag paddingTag = PaddingTag::PADDING_BLOCK_ND;
CATLASS_HOST_DEVICE static
LayoutOut GetWorkspaceLayout(const LayoutIn& layout, uint32_t rowAlign, uint32_t colAlign)
{
return LayoutOut(layout.shape(0), layout.shape(1), rowAlign, colAlign);
}
static size_t GetWorkspaceSize(uint32_t rows, uint32_t cols, uint32_t rowAlign, uint32_t colAlign)
{
return static_cast<size_t>(RoundUp(rows, rowAlign)) * RoundUp(cols, colAlign) * sizeof(Element);
}
CopyGm2Ub copyGm2Ub;
CopyUb2Gm copyUb2Gm;
CATLASS_DEVICE
PaddingMatrixBlockND(Arch::Resource<ArchTag> &resource)
{
int64_t bufferOffset = 0;
for (uint32_t i = 0; i < BUFFER_NUM; i++) {
inputBuffer[i] = resource.ubBuf.template GetBufferByByte<Element>(bufferOffset * sizeof(Element));
bufferOffset += COMPUTE_LENGTH;
}
}
CATLASS_DEVICE
ComputeLayout GetPaddingComputeLayout(layout::RowMajor const &layout)
{
return ComputeLayout(layout.shape(0), layout.shape(1), layout.stride(0));
}
CATLASS_DEVICE
ComputeLayout GetPaddingComputeLayout(layout::ColumnMajor const &layout)
{
return ComputeLayout(layout.shape(1), layout.shape(0), layout.stride(1));
}
CATLASS_DEVICE
ComputeLayoutDst GetPaddingComputeLayout(layout::PaddingRowMajor const &layout)
{
return ComputeLayoutDst(
layout.shape(0) * layout.shape(1),
layout.shape(2) * layout.shape(3),
layout.shape(0),
layout.shape(2)
);
}
CATLASS_DEVICE
ComputeLayoutDst GetPaddingComputeLayout(layout::PaddingColumnMajor const &layout)
{
return ComputeLayoutDst(
layout.shape(2) * layout.shape(3),
layout.shape(0) * layout.shape(1),
layout.shape(2),
layout.shape(0)
);
}
CATLASS_DEVICE
void operator()(AscendC::GlobalTensor<Element> const &dst,
AscendC::GlobalTensor<Element> const &src,
LayoutOut const &layoutDst, LayoutIn const &layoutSrc)
{
auto computeLayoutSrc = GetPaddingComputeLayout(layoutSrc);
auto computeLayoutDst = GetPaddingComputeLayout(layoutDst);
uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum();
uint32_t aivId = AscendC::GetBlockIdx();
uint32_t tilesNum = computeLayoutSrc.shape(0);
uint32_t tileLen = computeLayoutSrc.shape(1);
uint32_t roundTileLen = RoundUp<BYTE_PER_BLK / sizeof(Element)>(computeLayoutSrc.shape(1));
uint32_t tilesPerAiv = tilesNum / aivNum;
uint32_t tileRemain = tilesNum % aivNum;
if (aivId < tileRemain) {
tilesPerAiv++;
}
uint32_t mIdx = aivId * tilesPerAiv;
if (aivId >= tileRemain) {
mIdx += tileRemain;
}
MatrixCoord blockOffset(mIdx, 0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[0]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[1]);
uint32_t coreLoops{ 0 };
if (roundTileLen > COMPUTE_LENGTH) {
uint32_t loopsPerTile = (tileLen + COMPUTE_LENGTH - 1) / COMPUTE_LENGTH;
coreLoops = tilesPerAiv * loopsPerTile;
for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) {
uint32_t tileIdx = loopIdx / loopsPerTile;
uint32_t inTileLoopIdx = loopIdx % loopsPerTile;
MatrixCoord loopOffset(tileIdx, inTileLoopIdx * COMPUTE_LENGTH);
uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + loopOffset);
uint32_t actualDataNum = COMPUTE_LENGTH;
if (tileLen - inTileLoopIdx * COMPUTE_LENGTH < COMPUTE_LENGTH) {
actualDataNum = tileLen - inTileLoopIdx * COMPUTE_LENGTH;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(1, actualDataNum));
ComputeLayout ubLayout = ComputeLayout{1, actualDataNum};
ComputeLayout dstLayout = ComputeLayout(CeilDiv(actualDataNum, computeLayoutDst.shape(2)),
computeLayoutDst.shape(2), computeLayoutDst.stride(3));
copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + loopOffset);
copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
bufferIndex = (bufferIndex + 1) % BUFFER_NUM;
}
} else {
uint32_t tilesPerLoop = COMPUTE_LENGTH / roundTileLen;
coreLoops = (tilesPerAiv + tilesPerLoop - 1) / tilesPerLoop;
for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) {
uint32_t tileIdx = loopIdx * tilesPerLoop;
MatrixCoord tileOffset(tileIdx, 0);
uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + tileOffset);
uint32_t actualTilesNum = tilesPerLoop;
if (tilesPerAiv - tileIdx < tilesPerLoop) {
actualTilesNum = tilesPerAiv - tileIdx;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(actualTilesNum, tileLen));
ComputeLayout ubLayout = ComputeLayout{actualTilesNum, tileLen, roundTileLen};
ComputeLayout dstLayout = ComputeLayout{CeilDiv(tileLen, computeLayoutDst.shape(2)),
computeLayoutDst.shape(2), computeLayoutDst.stride(3)};
copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
for (uint32_t i = 0; i < actualTilesNum; ++i) {
uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + tileOffset + MatrixCoord(i, 0));
uint64_t ubOffset = ubLayout.GetOffset(MatrixCoord(i, 0));
copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex][ubOffset], dstLayout, ubLayout);
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
bufferIndex = (bufferIndex + 1) % BUFFER_NUM;
}
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[0]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[1]);
}
CATLASS_DEVICE
~PaddingMatrixBlockND() {}
private:
static const uint32_t BUFFER_NUM = 2;
AscendC::LocalTensor<Element> inputBuffer[BUFFER_NUM];
AscendC::TEventID eventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1};
uint32_t bufferIndex{ 0 };
static_assert(BUFFER_NUM * COMPUTE_LENGTH * sizeof(Element) <= ArchTag::UB_SIZE, "Excedding the UB space!");
static_assert(std::is_same_v<LayoutIn, layout::RowMajor> ||
std::is_same_v<LayoutIn, layout::ColumnMajor>, "Unsported layout for PaddingMatrixBlockNd!");
};
template<
class ArchTag_,
class Element_,
class Layout_,
uint32_t COMPUTE_LENGTH
>
struct PaddingMatrixND {
public:
using ArchTag = ArchTag_;
using Element = Element_;
using Layout = Layout_;
using CopyGm2Ub = Catlass::Epilogue::Tile::CopyGm2Ub<
ArchTag, Gemm::GemmType<Element, Catlass::layout::RowMajor>>;
using CopyUb2Gm = Catlass::Epilogue::Tile::CopyUb2Gm<
ArchTag, Gemm::GemmType<Element, Catlass::layout::RowMajor>>;
using ComputeLayout = Catlass::layout::RowMajor;
constexpr static uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element);
static const PaddingTag paddingTag = PaddingTag::PADDING_ND;
using LayoutIn = Layout_;
using LayoutOut = Layout_;
CATLASS_HOST_DEVICE static
LayoutOut GetWorkspaceLayout(const LayoutIn& layout, uint32_t align)
{
if constexpr (std::is_same_v<LayoutIn, layout::RowMajor>) {
return LayoutOut{layout.shape(0), layout.shape(1), RoundUp(layout.shape(1), align)};
} else {
return LayoutOut{layout.shape(0), layout.shape(1), RoundUp(layout.shape(0), align)};
}
}
static size_t GetWorkspaceSize(uint32_t rows, uint32_t cols, uint32_t align)
{
if constexpr (std::is_same_v<LayoutIn, layout::RowMajor>) {
return static_cast<size_t>(rows) * RoundUp(cols, align) * sizeof(Element);
} else {
return static_cast<size_t>(cols) * RoundUp(rows, align) * sizeof(Element);
}
}
CopyGm2Ub copyGm2Ub;
CopyUb2Gm copyUb2Gm;
CATLASS_DEVICE
PaddingMatrixND(Arch::Resource<ArchTag> &resource)
{
int64_t bufferOffset = 0;
for (uint32_t i = 0; i < BUFFER_NUM; i++) {
inputBuffer[i] = resource.ubBuf.template GetBufferByByte<Element>(bufferOffset * sizeof(Element));
bufferOffset += COMPUTE_LENGTH;
}
}
CATLASS_DEVICE
ComputeLayout GetPaddingComputeLayout(layout::RowMajor const &layout)
{
return ComputeLayout(layout.shape(0), layout.shape(1), layout.stride(0));
}
CATLASS_DEVICE
ComputeLayout GetPaddingComputeLayout(layout::ColumnMajor const &layout)
{
return ComputeLayout(layout.shape(1), layout.shape(0), layout.stride(1));
}
CATLASS_DEVICE
void operator()(AscendC::GlobalTensor<Element> const &dst,
AscendC::GlobalTensor<Element> const &src,
Layout const &layoutDst, Layout const &layoutSrc)
{
ComputeLayout computeLayoutSrc = GetPaddingComputeLayout(layoutSrc);
ComputeLayout computeLayoutDst = GetPaddingComputeLayout(layoutDst);
uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum();
uint32_t aivId = AscendC::GetBlockIdx();
uint32_t tilesNum = computeLayoutSrc.shape(0);
uint32_t tileLen = computeLayoutSrc.shape(1);
uint32_t roundTileLen = RoundUp(computeLayoutSrc.shape(1), ELE_NUM_PER_C0);
uint32_t tilesPerAiv = tilesNum / aivNum;
uint32_t tileRemain = tilesNum % aivNum;
if (aivId < tileRemain) {
tilesPerAiv++;
}
uint32_t mIdx = aivId * tilesPerAiv;
if (aivId >= tileRemain) {
mIdx += tileRemain;
}
MatrixCoord blockOffset(mIdx, 0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[0]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[1]);
uint32_t coreLoops{ 0 };
if (roundTileLen > COMPUTE_LENGTH) {
uint32_t loopsPerTile = (tileLen + COMPUTE_LENGTH - 1) / COMPUTE_LENGTH;
coreLoops = tilesPerAiv * loopsPerTile;
for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) {
uint32_t tileIdx = loopIdx / loopsPerTile;
uint32_t inTileLoopIdx = loopIdx % loopsPerTile;
MatrixCoord loopOffset(tileIdx, inTileLoopIdx * COMPUTE_LENGTH);
uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + loopOffset);
uint32_t actualDataNum = COMPUTE_LENGTH;
if (tileLen - inTileLoopIdx * COMPUTE_LENGTH < COMPUTE_LENGTH) {
actualDataNum = tileLen - inTileLoopIdx * COMPUTE_LENGTH;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(1, actualDataNum));
ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(1, actualDataNum));
ComputeLayout &ubLayout = dstLayout;
copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + loopOffset);
copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
bufferIndex = (bufferIndex + 1) % BUFFER_NUM;
}
} else {
uint32_t tilesPerLoop = COMPUTE_LENGTH / roundTileLen;
coreLoops = (tilesPerAiv + tilesPerLoop - 1) / tilesPerLoop;
for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) {
uint32_t tileIdx = loopIdx * tilesPerLoop;
MatrixCoord tileOffset(tileIdx, 0);
uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + tileOffset);
uint32_t actualTilesNum = tilesPerLoop;
if (tilesPerAiv - tileIdx < tilesPerLoop) {
actualTilesNum = tilesPerAiv - tileIdx;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(actualTilesNum, tileLen));
ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(actualTilesNum, tileLen));
ComputeLayout ubLayout = ComputeLayout{actualTilesNum, tileLen, roundTileLen};
copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + tileOffset);
copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
bufferIndex = (bufferIndex + 1) % BUFFER_NUM;
}
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[0]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[1]);
}
CATLASS_DEVICE
~PaddingMatrixND() {}
private:
static const uint32_t BUFFER_NUM = 2;
AscendC::LocalTensor<Element> inputBuffer[BUFFER_NUM];
AscendC::TEventID eventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1};
uint32_t bufferIndex{ 0 };
static_assert(BUFFER_NUM * COMPUTE_LENGTH * sizeof(Element) <= ArchTag::UB_SIZE, "Excedding the UB space!");
static_assert(std::is_same_v<LayoutIn, layout::RowMajor> ||
std::is_same_v<LayoutIn, layout::ColumnMajor>, "Unsported layout for PaddingMatrixND!");
};
template <PaddingTag, class ArchTag, class Element, class LayoutIn, uint32_t COMPUTE_LENGTH = 0>
struct PaddingBuilder {
static_assert(DEPENDENT_FALSE<ArchTag>, "Padding is not implemented for this layout");
};
template <class ArchTag, class Element, class LayoutIn, uint32_t COMPUTE_LENGTH>
struct PaddingBuilder<PaddingTag::NO_PADDING, ArchTag, Element, LayoutIn, COMPUTE_LENGTH> {
using LayoutAfterPadding = LayoutIn;
using Padding = void;
};
template <class ArchTag, class Element, class LayoutIn, uint32_t COMPUTE_LENGTH>
struct PaddingBuilder<PaddingTag::PADDING_ND, ArchTag, Element, LayoutIn, COMPUTE_LENGTH> {
using LayoutAfterPadding = LayoutIn;
using Padding = std::conditional_t<
(COMPUTE_LENGTH > 0),
Catlass::Gemm::Kernel::PaddingMatrixND<ArchTag, Element, LayoutIn, COMPUTE_LENGTH>,
Catlass::Gemm::Kernel::PaddingMatrixND<ArchTag, Element, LayoutIn, 96 * 1024 / sizeof(Element)>
>;
};
template <class ArchTag, class Element, class LayoutIn, uint32_t COMPUTE_LENGTH>
struct PaddingBuilder<PaddingTag::PADDING_BLOCK_ND, ArchTag, Element, LayoutIn, COMPUTE_LENGTH> {
using LayoutAfterPadding = std::conditional_t<std::is_same_v<LayoutIn, layout::RowMajor>,
layout::PaddingRowMajor, layout::PaddingColumnMajor>;
using Padding = std::conditional_t<
(COMPUTE_LENGTH > 0),
Catlass::Gemm::Kernel::PaddingMatrixBlockND<ArchTag, Element, LayoutIn, LayoutAfterPadding, COMPUTE_LENGTH>,
Catlass::Gemm::Kernel::PaddingMatrixBlockND<
ArchTag, Element, LayoutIn, LayoutAfterPadding, 96 * 1024 / sizeof(Element)>
>;
};
template <class ArchTag, class Element, class LayoutIn, uint32_t COMPUTE_LENGTH>
struct PaddingBuilder<PaddingTag::PADDING_NZ, ArchTag, Element, LayoutIn, COMPUTE_LENGTH> {
using LayoutAfterPadding = std::conditional_t<std::is_same_v<LayoutIn, layout::RowMajor>,
layout::zN, layout::nZ>;
using Padding = std::conditional_t<
(COMPUTE_LENGTH > 0),
Catlass::Gemm::Kernel::PaddingMatrixNZ<ArchTag, Element, LayoutIn, LayoutAfterPadding, COMPUTE_LENGTH>,
Catlass::Gemm::Kernel::PaddingMatrixNZ<
ArchTag, Element, LayoutIn, LayoutAfterPadding, 48 * 1024 / sizeof(Element)>
>;
};
template<
PaddingTag paddingTag_,
class ArchTag_,
class ElementIn_,
class ElementOut_,
class Layout_
>
struct RemovePaddingNDAndCast {
public:
using ArchTag = ArchTag_;
using ElementIn = ElementIn_;
using ElementOut = ElementOut_;
using Layout = Layout_;
using CopyGm2Ub = Catlass::Epilogue::Tile::CopyGm2Ub<
ArchTag, Gemm::GemmType<ElementIn, Catlass::layout::RowMajor>>;
using CopyUb2Gm = Catlass::Epilogue::Tile::CopyUb2Gm<
ArchTag, Gemm::GemmType<ElementOut, Catlass::layout::RowMajor>>;
using ComputeLayout = Catlass::layout::RowMajor;
constexpr static uint32_t ELE_NUM_PER_C0 = std::max(BYTE_PER_C0 / sizeof(ElementIn), BYTE_PER_C0 / sizeof(ElementOut));
static const PaddingTag paddingTag = paddingTag_;
using LayoutIn = Layout_;
using LayoutOut = Layout_;
CATLASS_HOST_DEVICE static
LayoutOut GetWorkspaceLayout(const LayoutIn& layout, uint32_t align)
{
if constexpr (std::is_same_v<LayoutIn, layout::RowMajor>) {
return LayoutOut{layout.shape(0), layout.shape(1), RoundUp(layout.shape(1), align)};
} else {
return LayoutOut{layout.shape(0), layout.shape(1), RoundUp(layout.shape(0), align)};
}
}
static size_t GetWorkspaceSize(uint32_t rows, uint32_t cols, uint32_t align = 1)
{
if constexpr (std::is_same_v<LayoutIn, layout::RowMajor>) {
return static_cast<size_t>(rows) * RoundUp(cols, align) * sizeof(ElementIn);
} else {
return static_cast<size_t>(cols) * RoundUp(rows, align) * sizeof(ElementIn);
}
}
CopyGm2Ub copyGm2Ub;
CopyUb2Gm copyUb2Gm;
CATLASS_DEVICE
RemovePaddingNDAndCast(Arch::Resource<ArchTag> &resource)
{
int64_t bufferOffset = 0;
constexpr uint32_t inputBufSize = COMPUTE_LENGTH * sizeof(ElementIn);
for (uint32_t i = 0; i < BUFFER_NUM; i++) {
inputBuffer[i] = resource.ubBuf.template GetBufferByByte<ElementIn>(bufferOffset + inputBufSize * i);
}
if constexpr (!std::is_same_v<ElementIn, ElementOut>) {
constexpr uint32_t outputBufSize = COMPUTE_LENGTH * sizeof(ElementOut);
bufferOffset = inputBufSize * BUFFER_NUM;
for (uint32_t i = 0; i < BUFFER_NUM; i++) {
outputBuffer[i] = resource.ubBuf.template GetBufferByByte<ElementOut>(bufferOffset + outputBufSize * i);
}
}
}
CATLASS_DEVICE
ComputeLayout GetPaddingComputeLayout(layout::RowMajor const &layout)
{
return ComputeLayout(layout.shape(0), layout.shape(1), layout.stride(0));
}
CATLASS_DEVICE
ComputeLayout GetPaddingComputeLayout(layout::ColumnMajor const &layout)
{
return ComputeLayout(layout.shape(1), layout.shape(0), layout.stride(1));
}
CATLASS_DEVICE
void operator()(AscendC::GlobalTensor<ElementOut> const &dst,
AscendC::GlobalTensor<ElementIn> const &src,
Layout const &layoutDst, Layout const &layoutSrc, bool useSingleCore = false)
{
ComputeLayout computeLayoutSrc = GetPaddingComputeLayout(layoutSrc);
ComputeLayout computeLayoutDst = GetPaddingComputeLayout(layoutDst);
uint32_t aivNum, aivId;
if (useSingleCore) {
aivNum = AscendC::GetSubBlockNum();
aivId = AscendC::GetSubBlockIdx();
} else {
aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum();
aivId = AscendC::GetBlockIdx();
}
uint32_t tilesNum = computeLayoutSrc.shape(0);
uint32_t tileLen = computeLayoutSrc.shape(1);
uint32_t roundTileLen = RoundUp(computeLayoutSrc.shape(1), ELE_NUM_PER_C0);
uint32_t tilesPerAiv = tilesNum / aivNum;
uint32_t tileRemain = tilesNum % aivNum;
if (aivId < tileRemain) {
tilesPerAiv++;
}
uint32_t mIdx = aivId * tilesPerAiv;
if (aivId >= tileRemain) {
mIdx += tileRemain;
}
MatrixCoord blockOffset(mIdx, 0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[0]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[1]);
uint32_t coreLoops{ 0 };
if (roundTileLen > COMPUTE_LENGTH) {
uint32_t loopsPerTile = (tileLen + COMPUTE_LENGTH - 1) / COMPUTE_LENGTH;
coreLoops = tilesPerAiv * loopsPerTile;
for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) {
uint32_t tileIdx = loopIdx / loopsPerTile;
uint32_t inTileLoopIdx = loopIdx % loopsPerTile;
MatrixCoord loopOffset(tileIdx, inTileLoopIdx * COMPUTE_LENGTH);
uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + loopOffset);
uint32_t actualDataNum = COMPUTE_LENGTH;
if (tileLen - inTileLoopIdx * COMPUTE_LENGTH < COMPUTE_LENGTH) {
actualDataNum = tileLen - inTileLoopIdx * COMPUTE_LENGTH;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(1, actualDataNum));
ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(1, actualDataNum));
ComputeLayout &ubLayout = dstLayout;
copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout);
if constexpr (!std::is_same_v<ElementIn, ElementOut>) {
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventIds[bufferIndex]);
if constexpr (std::is_same_v<ElementOut, half>) {
AscendC::Cast(outputBuffer[bufferIndex], inputBuffer[bufferIndex], AscendC::RoundMode::CAST_NONE, actualDataNum);
} else {
AscendC::Cast(outputBuffer[bufferIndex], inputBuffer[bufferIndex], AscendC::RoundMode::CAST_RINT, actualDataNum);
}
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventIds[bufferIndex]);
uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + loopOffset);
copyUb2Gm(dst[gmDstOffset], outputBuffer[bufferIndex], dstLayout, ubLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
} else {
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + loopOffset);
copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
}
bufferIndex = (bufferIndex + 1) % BUFFER_NUM;
}
} else {
uint32_t tilesPerLoop = COMPUTE_LENGTH / roundTileLen;
coreLoops = (tilesPerAiv + tilesPerLoop - 1) / tilesPerLoop;
for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) {
uint32_t tileIdx = loopIdx * tilesPerLoop;
MatrixCoord tileOffset(tileIdx, 0);
uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + tileOffset);
uint32_t actualTilesNum = tilesPerLoop;
if (tilesPerAiv - tileIdx < tilesPerLoop) {
actualTilesNum = tilesPerAiv - tileIdx;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(actualTilesNum, tileLen));
ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(actualTilesNum, tileLen));
ComputeLayout ubLayout = ComputeLayout{actualTilesNum, tileLen, roundTileLen};
copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout);
if constexpr (!std::is_same_v<ElementIn, ElementOut>) {
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventIds[bufferIndex]);
uint32_t castDataNum = actualTilesNum * roundTileLen;
if constexpr (std::is_same_v<ElementOut, half>) {
AscendC::Cast(outputBuffer[bufferIndex], inputBuffer[bufferIndex], AscendC::RoundMode::CAST_NONE, castDataNum);
} else {
AscendC::Cast(outputBuffer[bufferIndex], inputBuffer[bufferIndex], AscendC::RoundMode::CAST_RINT, castDataNum);
}
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventIds[bufferIndex]);
uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + tileOffset);
copyUb2Gm(dst[gmDstOffset], outputBuffer[bufferIndex], dstLayout, ubLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
} else {
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(eventIds[bufferIndex]);
uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + tileOffset);
copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[bufferIndex]);
}
bufferIndex = (bufferIndex + 1) % BUFFER_NUM;
}
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[0]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventIds[1]);
}
CATLASS_DEVICE
~RemovePaddingNDAndCast() {}
private:
static const uint32_t BUFFER_NUM = 2;
AscendC::LocalTensor<ElementIn> inputBuffer[BUFFER_NUM];
AscendC::LocalTensor<ElementOut> outputBuffer[BUFFER_NUM];
AscendC::TEventID eventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1};
uint32_t bufferIndex{ 0 };
static constexpr uint32_t COMPUTE_LENGTH =
std::is_same_v<ElementIn, ElementOut> ?
ArchTag::UB_SIZE / sizeof(ElementIn) / BUFFER_NUM :
ArchTag::UB_SIZE / (sizeof(ElementIn) + sizeof(ElementOut)) / BUFFER_NUM;
static_assert(
std::is_same_v<ElementIn, ElementOut> ?
BUFFER_NUM * COMPUTE_LENGTH * sizeof(ElementIn) <= ArchTag::UB_SIZE :
BUFFER_NUM * COMPUTE_LENGTH * (sizeof(ElementIn) + sizeof(ElementOut)) <= ArchTag::UB_SIZE,
"Excedding the UB space!");
static_assert(std::is_same_v<LayoutIn, layout::RowMajor> ||
std::is_same_v<LayoutIn, layout::ColumnMajor>, "Unsported layout for RemovePaddingNDAndCast!");
};
template <
class BlockMmad_,
class BlockEpilogue_,
class BlockScheduler_
>
class PaddingMatmul {
public:
using BlockMmad = BlockMmad_;
using ArchTag = typename BlockMmad::ArchTag;
using ElementA = typename BlockMmad::ElementA;
using ElementB = typename BlockMmad::ElementB;
using LayoutA = typename BlockMmad::LayoutA;
using LayoutB = typename BlockMmad::LayoutB;
static const uint32_t COMPUTE_LENGTH_A = 96 * 1024 / sizeof(ElementA);
using PaddingA = PaddingMatrixND<ArchTag, ElementA, LayoutA, COMPUTE_LENGTH_A>;
static const uint32_t COMPUTE_LENGTH_B = 96 * 1024 / sizeof(ElementB);
using PaddingB = PaddingMatrixND<ArchTag, ElementB, LayoutB, COMPUTE_LENGTH_B>;
using L1TileShape = typename BlockMmad::L1TileShape;
using ElementC = typename BlockMmad::ElementC;
using LayoutC = typename BlockMmad::LayoutC;
using BlockScheduler = BlockScheduler_;
struct Params {
GemmCoord problemShape;
GM_ADDR ptrA;
LayoutA layoutA;
GM_ADDR ptrB;
LayoutB layoutB;
GM_ADDR ptrC;
LayoutC layoutC;
GM_ADDR ptrWA;
LayoutA layoutWA;
GM_ADDR ptrWB;
LayoutB layoutWB;
CATLASS_HOST_DEVICE
Params() {}
CATLASS_HOST_DEVICE
Params(GemmCoord const &problemShape_,
GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrC_, LayoutC layoutC_,
GM_ADDR ptrWA_, LayoutA layoutWA_, GM_ADDR ptrWB_, LayoutB layoutWB_)
: problemShape(problemShape_), ptrA(ptrA_), layoutA(layoutA_), ptrB(ptrB_), layoutB(layoutB_),
ptrC(ptrC_), layoutC(layoutC_), ptrWA(ptrWA_), layoutWA(layoutWA_), ptrWB(ptrWB_), layoutWB(layoutWB_) {}
};
struct Arguments {
GemmCoord problemShape;
uint32_t align;
size_t elementSize;
GM_ADDR ptrA;
GM_ADDR ptrB;
GM_ADDR ptrC;
};
static bool CanImplement(const Arguments &args)
{
return true;
}
static layout::RowMajor GetWorkspaceLayout(layout::RowMajor layout, uint32_t align)
{
if (align == 0) {
return 0;
}
return layout::RowMajor(layout.shape(0), layout.shape(1),
(layout.shape(1) + align - 1) / align * align);
}
static layout::ColumnMajor GetWorkspaceLayout(layout::ColumnMajor layout, uint32_t align)
{
return layout::ColumnMajor(layout.shape(0), layout.shape(1),
(layout.shape(0) + align - 1) / align * align);
}
static size_t GetWorkspaceLen(layout::RowMajor layout)
{
return layout.shape(0) * layout.stride(0);
}
static size_t GetWorkspaceLen(layout::ColumnMajor layout)
{
return layout.shape(1) * layout.stride(1);
}
static size_t GetWorkspaceSize(const Arguments &args)
{
GemmCoord problemShape = args.problemShape;
LayoutA layoutA{problemShape.m(), problemShape.k()};
LayoutB layoutB{problemShape.k(), problemShape.n()};
size_t sizeWA = GetWorkspaceLen(GetWorkspaceLayout(layoutA, args.align)) * args.elementSize;
size_t sizeWB = GetWorkspaceLen(GetWorkspaceLayout(layoutB, args.align)) * args.elementSize;
return sizeWA + sizeWB;
}
static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace)
{
GemmCoord problemShape = args.problemShape;
uint32_t m = problemShape.m();
uint32_t n = problemShape.n();
uint32_t k = problemShape.k();
LayoutA layoutA{m, k};
LayoutB layoutB{k, n};
LayoutC layoutC{m, n};
size_t sizeWA = GetWorkspaceLen(GetWorkspaceLayout(layoutA, args.align)) * args.elementSize;
uint8_t *workspaceWB = workspace + sizeWA;
Params params{problemShape,
args.ptrA,
layoutA,
args.ptrB,
layoutB,
args.ptrC,
layoutC,
workspace,
GetWorkspaceLayout(layoutA, args.align),
workspaceWB,
GetWorkspaceLayout(layoutB, args.align)};
return params;
}
CATLASS_DEVICE
PaddingMatmul() {}
template <int32_t CORE_TYPE = g_coreType>
CATLASS_DEVICE
void operator()(Params const ¶ms);
template <>
CATLASS_DEVICE
void operator()<AscendC::AIV>(Params const ¶ms)
{
AscendC::GlobalTensor<ElementA> gmA;
AscendC::GlobalTensor<ElementA> gmWA;
gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrA));
gmWA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrWA));
PaddingA paddingA(resource);
paddingA(gmWA, gmA, params.layoutWA, params.layoutA);
AscendC::GlobalTensor<ElementB> gmB;
AscendC::GlobalTensor<ElementB> gmWB;
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrB));
gmWB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrWB));
PaddingB paddingB(resource);
paddingB(gmWB, gmB, params.layoutWB, params.layoutB);
Catlass::Arch::CrossCoreBarrier<0x0, PIPE_MTE3>();
Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishPadding);
AscendC::PipeBarrier<PIPE_ALL>();
}
template <>
CATLASS_DEVICE
void operator()<AscendC::AIC>(Params const ¶ms)
{
Catlass::Arch::CrossCoreWaitFlag(flagAivFinishPadding);
BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops();
AscendC::GlobalTensor<ElementA> gmA;
gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrWA);
AscendC::GlobalTensor<ElementB> gmB;
gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrWB);
AscendC::GlobalTensor<ElementC> gmC;
gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC);
BlockMmad blockMmad(resource);
for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) {
GemmCoord blockIdxCoord = matmulBlockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockIdxCoord);
MatrixCoord offsetA{blockIdxCoord.m() * L1TileShape::M, blockIdxCoord.k() * L1TileShape::K};
MatrixCoord offsetB{blockIdxCoord.k() * L1TileShape::K, blockIdxCoord.n() * L1TileShape::N};
MatrixCoord offsetC{blockIdxCoord.m() * L1TileShape::M, blockIdxCoord.n() * L1TileShape::N};
int64_t gmOffsetA = params.layoutWA.GetOffset(offsetA);
int64_t gmOffsetB = params.layoutWB.GetOffset(offsetB);
int64_t gmOffsetC = params.layoutC.GetOffset(offsetC);
blockMmad(
gmA[gmOffsetA], params.layoutWA,
gmB[gmOffsetB], params.layoutWB,
gmC[gmOffsetC], params.layoutC,
actualBlockShape);
}
AscendC::PipeBarrier<PIPE_ALL>();
}
private:
static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 0;
Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE};
Arch::Resource<ArchTag> resource;
};
}
#endif