* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file matmul.hpp
* \brief
*/
#ifndef CATLASS_GEMM_KERNEL_MATMUL_ALLTOALL_HPP
#define CATLASS_GEMM_KERNEL_MATMUL_ALLTOALL_HPP
#include "../../3rd/template_linear_algebra/op_kernel/template_linear_algebra/catlass.hpp"
#include "../../3rd/template_linear_algebra/op_kernel/template_linear_algebra/coord.hpp"
#include "../../3rd/template_linear_algebra/op_kernel/template_linear_algebra/gemm_coord.hpp"
#include "../../3rd/template_linear_algebra/op_kernel/template_linear_algebra/matrix_coord.hpp"
#include "../../3rd/template_linear_algebra/op_kernel/template_linear_algebra/arch/resource.hpp"
#include "../../3rd/template_linear_algebra/op_kernel/template_linear_algebra/arch/cross_core_sync.hpp"
#include "../../3rd/template_linear_algebra/op_kernel/template_linear_algebra/epilogue/tile/copy_gm_to_ub.hpp"
#include "../../3rd/template_linear_algebra/op_kernel/template_linear_algebra/epilogue/tile/copy_ub_to_gm.hpp"
#include "../../3rd/template_linear_algebra/op_kernel/template_linear_algebra/gemm/kernel/padding_matmul.hpp"
namespace Catlass::Gemm::Kernel {
template <class PrologueA, class PrologueB, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, bool HasBias>
class MatmulAlltoAllKernel : public CommBase {
public:
using BlockMmad = BlockMmad_;
using ArchTag = typename BlockMmad::ArchTag;
using ElementA = typename BlockMmad::ElementA;
using ElementB = typename BlockMmad::ElementB;
using LayoutWA = typename BlockMmad::LayoutA;
using LayoutWB = typename BlockMmad::LayoutB;
template <bool condition, class mmad>
struct BiasTypeHelper {
using type = typename mmad::ElementBias;
};
template <class mmad>
struct BiasTypeHelper<false, mmad> {
using type = float;
};
template <class T>
struct LayoutHelper {
using type = typename T::LayoutIn;
};
template <>
struct LayoutHelper<void> {
using type = void;
};
using ElementBias = typename BiasTypeHelper<HasBias, BlockMmad>::type;
using LayoutA = std::conditional_t<std::is_void_v<PrologueA>, LayoutWA, typename LayoutHelper<PrologueA>::type>;
using LayoutB = std::conditional_t<std::is_void_v<PrologueB>, LayoutWB, typename LayoutHelper<PrologueB>::type>;
using L1TileShape = typename BlockMmad::L1TileShape;
using ElementC = typename BlockMmad::ElementC;
using LayoutC = typename BlockMmad::LayoutC;
using BlockScheduler = BlockScheduler_;
struct Params {
GemmCoord problemShape;
GM_ADDR ptrA;
LayoutA layoutA;
GM_ADDR ptrB;
LayoutB layoutB;
GM_ADDR ptrBias;
GM_ADDR ptrC;
LayoutC layoutC;
int32_t pValue;
int32_t rankSize;
int32_t pipeDepth;
bool transB;
CATLASS_HOST_DEVICE
Params()
{
}
CATLASS_HOST_DEVICE
Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_,
GM_ADDR ptrBias_, GM_ADDR ptrC_, LayoutC layoutC_, int32_t pValue_, int32_t rankSize_,
int32_t pipeDepth_, bool transB_)
: problemShape(problemShape_), ptrA(ptrA_), layoutA(layoutA_), ptrB(ptrB_), layoutB(layoutB_),
ptrBias(ptrBias_), ptrC(ptrC_), layoutC(layoutC_), pValue(pValue_), rankSize(rankSize_),
pipeDepth(pipeDepth_), transB(transB_)
{
}
};
CATLASS_DEVICE
MatmulAlltoAllKernel()
{
}
template <int32_t CORE_TYPE = g_coreType>
CATLASS_DEVICE void operator()(Params const ¶ms);
template <>
CATLASS_DEVICE void operator()<AscendC::AIC>(Params const ¶ms)
{
if constexpr (!std::is_void_v<PrologueA> || !std::is_void_v<PrologueB>) {
Catlass::Arch::CrossCoreWaitFlag(flagAivFinishPadding);
}
AscendC::GlobalTensor<ElementA> gmA;
gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA);
AscendC::GlobalTensor<ElementB> gmB;
gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB);
AscendC::GlobalTensor<ElementC> gmC;
gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC);
AscendC::GlobalTensor<ElementBias> gmBias;
if constexpr (HasBias) {
gmBias.SetGlobalBuffer((__gm__ ElementBias *)params.ptrBias);
}
BlockMmad blockMmad(resource);
int32_t coreIdx = AscendC::GetBlockIdx();
int32_t coreNum = AscendC::GetBlockNum();
int32_t mLoops = (params.problemShape.m() + L1TileShape::M - 1) / L1TileShape::M;
int32_t commCount = (mLoops + params.pValue - 1) / params.pValue;
for (int32_t commIdx = 0; commIdx < commCount; commIdx++) {
uint64_t flagIdx = commIdx % params.pipeDepth;
int32_t tailN = (params.problemShape.n() / params.rankSize) % L1TileShape::N;
uint32_t nLoops = params.problemShape.n() / params.rankSize / L1TileShape::N;
if (tailN) {
nLoops += 1;
}
nLoops *= params.rankSize;
int32_t coreLoops = mLoops * nLoops;
uint32_t actualPValue = params.pValue;
if (commIdx == commCount - 1) {
actualPValue = mLoops - commIdx * params.pValue;
}
if (commIdx >= params.pipeDepth) {
WaitEvent(flagIdx);
}
int64_t gmAPingpongSize = L1TileShape::M * params.pValue * params.problemShape.n();
int32_t actualLoopNum = actualPValue * nLoops;
int32_t nLoopPerRank = nLoops / params.rankSize;
GemmCoord calSwizzleShape{actualPValue, nLoops, 1};
BlockScheduler matmulBlockScheduler(calSwizzleShape,
MakeCoord(static_cast<uint32_t>(1), static_cast<uint32_t>(1)));
for (int32_t loopOffset = 0; loopOffset < actualLoopNum; loopOffset++) {
int32_t loopIdx = commIdx * params.pValue * nLoops + loopOffset;
if (loopIdx % coreNum != coreIdx) {
continue;
}
GemmCoord blockIdxCoord = matmulBlockScheduler.GetBlockCoord(loopOffset);
uint32_t mActual = (commIdx == commCount - 1) && (blockIdxCoord.m() == (mLoops - 1) % params.pValue) &&
(params.problemShape.m() % L1TileShape::M != 0) ?
(params.problemShape.m() - commIdx * params.pValue * L1TileShape::M -
blockIdxCoord.m() * L1TileShape::M) :
L1TileShape::M;
uint32_t nActual = (blockIdxCoord.n() % nLoopPerRank == (nLoopPerRank - 1)) ?
(params.problemShape.n() / params.rankSize -
blockIdxCoord.n() % nLoopPerRank * L1TileShape::N) :
L1TileShape::N;
GemmCoord actualBlockShape{mActual, nActual, params.problemShape.k()};
MatrixCoord offsetA{blockIdxCoord.m() * L1TileShape::M, blockIdxCoord.k() * L1TileShape::K};
MatrixCoord offsetC{blockIdxCoord.m() * L1TileShape::M, blockIdxCoord.n() * L1TileShape::N};
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA) +
commIdx * L1TileShape::M * params.pValue * params.problemShape.k();
int64_t gmOffsetB = ((blockIdxCoord.n() / nLoopPerRank) * (params.problemShape.n() / params.rankSize) +
(blockIdxCoord.n() % nLoopPerRank) * L1TileShape::N) *
(params.transB ? params.problemShape.k() : 1);
int64_t rankOffset = gmAPingpongSize / params.rankSize;
int64_t rankIdx = blockIdxCoord.n() / nLoopPerRank;
int64_t rankOffsetInRank = blockIdxCoord.n() % nLoopPerRank;
int64_t gmOffsetC = flagIdx * L1TileShape::M * params.pValue * params.problemShape.n() +
rankIdx * rankOffset +
blockIdxCoord.m() * L1TileShape::M * (params.problemShape.n() / params.rankSize) +
blockIdxCoord.n() % nLoopPerRank * L1TileShape::N;
if constexpr (HasBias) {
int64_t gmOffsetBias =
(blockIdxCoord.n() / nLoopPerRank) * (params.problemShape.n() / params.rankSize) +
(blockIdxCoord.n() % nLoopPerRank) * L1TileShape::N;
blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC],
params.layoutC, gmBias[gmOffsetBias], actualBlockShape);
} else {
bool isFirstBlock = (loopOffset < coreNum);
bool hasNextBlock = false;
GemmCoord nextBlockIdCoord;
GemmCoord nextActualBlockShape;
if (loopOffset + coreNum < actualLoopNum) {
hasNextBlock = true;
nextBlockIdCoord = matmulBlockScheduler.GetBlockCoord(loopOffset + coreNum);
uint32_t mActualNext = (commIdx == commCount - 1) &&
(nextBlockIdCoord.m() == (mLoops - 1) % params.pValue) &&
(params.problemShape.m() % L1TileShape::M != 0) ?
(params.problemShape.m() - commIdx * params.pValue * L1TileShape::M -
nextBlockIdCoord.m() * L1TileShape::M) :
L1TileShape::M;
uint32_t nActualNext = (nextBlockIdCoord.n() % nLoopPerRank == (nLoopPerRank - 1)) ?
(params.problemShape.n() / params.rankSize -
nextBlockIdCoord.n() % nLoopPerRank * L1TileShape::N) :
L1TileShape::N;
nextActualBlockShape = GemmCoord{mActualNext, nActualNext, params.problemShape.k()};
}
MatrixCoord offsetNextA{nextBlockIdCoord.m() * L1TileShape::M,
nextBlockIdCoord.k() * L1TileShape::K};
MatrixCoord offsetNextB{nextBlockIdCoord.k() * L1TileShape::K,
nextBlockIdCoord.n() * L1TileShape::N};
int64_t gmOffsetNextA = params.layoutA.GetOffset(offsetNextA) +
commIdx * L1TileShape::M * params.pValue * params.problemShape.k();
int64_t gmOffsetNextB =
((nextBlockIdCoord.n() / nLoopPerRank) * (params.problemShape.n() / params.rankSize) +
(nextBlockIdCoord.n() % nLoopPerRank) * L1TileShape::N) *
(params.transB ? params.problemShape.k() : 1);
blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC],
params.layoutC, gmA[gmOffsetNextA], gmB[gmOffsetNextB], actualBlockShape,
nextActualBlockShape, isFirstBlock, hasNextBlock);
}
}
FFTSCrossCoreSync<PIPE_FIX, 2>(flagIdx);
}
}
private:
static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 0;
Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE};
Arch::Resource<ArchTag> resource;
};
}
#endif