Copyright (c) 2025 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef PTO_COMM_TREDUCE_HPP
#define PTO_COMM_TREDUCE_HPP
#include <type_traits>
#include "pto/common/debug.h"
#include "pto/common/type.hpp"
#include "pto/common/constants.hpp"
#include "pto/common/pto_instr.hpp"
#include "pto/comm/comm_types.hpp"
namespace pto {
namespace comm {
namespace detail {
template <typename TileData>
PTO_INTERNAL void ReduceTiles(TileData &acc, TileData &recv, ReduceOp op)
{
switch (op) {
case ReduceOp::Sum:
TADD(acc, acc, recv);
break;
case ReduceOp::Max:
TMAX(acc, acc, recv);
break;
case ReduceOp::Min:
TMIN(acc, acc, recv);
break;
default:
PTO_ASSERT(false, "TREDUCE: unknown ReduceOp");
break;
}
}
PTO_INTERNAL int GetRemoteRank(int rootIdx, int remoteOrdinal)
{
return (remoteOrdinal < rootIdx) ? remoteOrdinal : (remoteOrdinal + 1);
}
}
template <typename ParallelGroupType, typename GlobalDstData, typename TileData>
PTO_INTERNAL void TreduceSimple(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData, TileData &accTileData,
TileData &recvTileData, ReduceOp op, int rootIdx, int nranks)
{
if (nranks == 1) {
TLOAD(accTileData, parallelGroup[rootIdx]);
set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
TSTORE(dstGlobalData, accTileData);
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
return;
}
TLOAD(accTileData, parallelGroup[rootIdx]);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
for (int r = 0; r < nranks; ++r) {
if (r == rootIdx)
continue;
TLOAD(recvTileData, parallelGroup[r]);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
detail::ReduceTiles(accTileData, recvTileData, op);
set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
}
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
TSTORE(dstGlobalData, accTileData);
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
}
template <typename ParallelGroupType, typename GlobalDstData, typename TileData, typename DynStrideT>
PTO_INTERNAL void TreduceProcessChunkSingle(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData,
TileData &accTileData, TileData &recvTileData, ReduceOp op,
int64_t srcOffset, int64_t dstOffset, int currentRows, int currentCols,
const DynStrideT &srcChunkStride, const DynStrideT &dstChunkStride,
int rootIdx, int nranks)
{
using GlobalSrcData = typename ParallelGroupTraits<ParallelGroupType>::GlobalDataType;
using T = typename GlobalSrcData::RawDType;
using DynShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>;
using SrcViewT = GlobalTensor<T, DynShape, DynStrideT, GlobalSrcData::layout>;
using DstViewT = GlobalTensor<T, DynShape, DynStrideT, GlobalDstData::layout>;
constexpr bool isDynamicRow = (TileData::ValidRow == DYNAMIC);
constexpr bool isDynamicCol = (TileData::ValidCol == DYNAMIC);
if constexpr (isDynamicRow) {
accTileData.RowMaskInternal = currentRows;
recvTileData.RowMaskInternal = currentRows;
}
if constexpr (isDynamicCol) {
accTileData.ColMaskInternal = currentCols;
recvTileData.ColMaskInternal = currentCols;
}
DynShape chunkShape(1, 1, 1, currentRows, currentCols);
SrcViewT rootView(parallelGroup[rootIdx].data() + srcOffset, chunkShape, srcChunkStride);
TLOAD(accTileData, rootView);
if (nranks == 1) {
set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
} else {
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
for (int r = 0; r < nranks; ++r) {
if (r == rootIdx)
continue;
SrcViewT remoteView(parallelGroup[r].data() + srcOffset, chunkShape, srcChunkStride);
TLOAD(recvTileData, remoteView);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
detail::ReduceTiles(accTileData, recvTileData, op);
set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
}
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
}
DstViewT dstView(dstGlobalData.data() + dstOffset, chunkShape, dstChunkStride);
TSTORE(dstView, accTileData);
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
}
template <typename ParallelGroupType, typename GlobalDstData, typename TileData>
PTO_INTERNAL void TreduceChunkedSingle(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData,
TileData &accTileData, TileData &recvTileData, ReduceOp op, int gShape0,
int gShape1, int gShape2, int gShape3, int gShape4, int tileValidRow,
int tileValidCol, int rootIdx, int nranks)
{
using GlobalSrcData = typename ParallelGroupTraits<ParallelGroupType>::GlobalDataType;
GlobalSrcData &refTensor = parallelGroup[rootIdx];
const int srcStride0 = refTensor.GetStride(GlobalTensorDim::DIM_0);
const int srcStride1 = refTensor.GetStride(GlobalTensorDim::DIM_1);
const int srcStride2 = refTensor.GetStride(GlobalTensorDim::DIM_2);
const int srcStride3 = refTensor.GetStride(GlobalTensorDim::DIM_3);
const int srcStride4 = refTensor.GetStride(GlobalTensorDim::DIM_4);
const int dstStride0 = dstGlobalData.GetStride(GlobalTensorDim::DIM_0);
const int dstStride1 = dstGlobalData.GetStride(GlobalTensorDim::DIM_1);
const int dstStride2 = dstGlobalData.GetStride(GlobalTensorDim::DIM_2);
const int dstStride3 = dstGlobalData.GetStride(GlobalTensorDim::DIM_3);
const int dstStride4 = dstGlobalData.GetStride(GlobalTensorDim::DIM_4);
using DynStride = Stride<DYNAMIC, DYNAMIC, DYNAMIC, DYNAMIC, DYNAMIC>;
DynStride srcChunkStride(srcStride0, srcStride1, srcStride2, srcStride3, srcStride4);
DynStride dstChunkStride(dstStride0, dstStride1, dstStride2, dstStride3, dstStride4);
for (int i0 = 0; i0 < gShape0; ++i0) {
for (int i1 = 0; i1 < gShape1; ++i1) {
for (int i2 = 0; i2 < gShape2; ++i2) {
int64_t srcBase = static_cast<int64_t>(i0) * srcStride0 + static_cast<int64_t>(i1) * srcStride1 +
static_cast<int64_t>(i2) * srcStride2;
int64_t dstBase = static_cast<int64_t>(i0) * dstStride0 + static_cast<int64_t>(i1) * dstStride1 +
static_cast<int64_t>(i2) * dstStride2;
for (int rowOff = 0; rowOff < gShape3; rowOff += tileValidRow) {
int curRows = (rowOff + tileValidRow <= gShape3) ? tileValidRow : (gShape3 - rowOff);
for (int colOff = 0; colOff < gShape4; colOff += tileValidCol) {
int curCols = (colOff + tileValidCol <= gShape4) ? tileValidCol : (gShape4 - colOff);
int64_t srcOff = srcBase + static_cast<int64_t>(rowOff) * srcStride3 +
static_cast<int64_t>(colOff) * srcStride4;
int64_t dstOff = dstBase + static_cast<int64_t>(rowOff) * dstStride3 +
static_cast<int64_t>(colOff) * dstStride4;
TreduceProcessChunkSingle<ParallelGroupType, GlobalDstData, TileData>(
parallelGroup, dstGlobalData, accTileData, recvTileData, op, srcOff, dstOff, curRows,
curCols, srcChunkStride, dstChunkStride, rootIdx, nranks);
}
}
}
}
}
}
template <typename ParallelGroupType, typename GlobalDstData, typename TileData>
PTO_INTERNAL void TREDUCE_IMPL(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData, TileData &accTileData,
TileData &recvTileData, ReduceOp op)
{
using GlobalSrcData = typename ParallelGroupTraits<ParallelGroupType>::GlobalDataType;
using T = typename GlobalSrcData::RawDType;
static_assert(std::is_same_v<T, typename GlobalDstData::RawDType>, "TREDUCE: GlobalData type mismatch!");
static_assert(std::is_same_v<T, typename TileData::DType>,
"TREDUCE: TileData element type must match GlobalData element type");
const int rootRank = parallelGroup.GetRootIdx();
const int groupSize = parallelGroup.GetSize();
PTO_ASSERT(groupSize > 0, "ParallelGroup size must be greater than 0!");
PTO_ASSERT(rootRank >= 0 && rootRank < groupSize, "rootIdx must be in range [0, nranks)!");
GlobalSrcData &refTensor = parallelGroup[rootRank];
const int gShape0 = refTensor.GetShape(GlobalTensorDim::DIM_0);
const int gShape1 = refTensor.GetShape(GlobalTensorDim::DIM_1);
const int gShape2 = refTensor.GetShape(GlobalTensorDim::DIM_2);
const int gShape3 = refTensor.GetShape(GlobalTensorDim::DIM_3);
const int gShape4 = refTensor.GetShape(GlobalTensorDim::DIM_4);
const int64_t totalRows = static_cast<int64_t>(gShape0) * gShape1 * gShape2 * gShape3;
const int chunkRows = accTileData.GetValidRow();
const int chunkCols = accTileData.GetValidCol();
PTO_ASSERT(chunkRows > 0, "TREDUCE: tileValidRow must be greater than 0");
PTO_ASSERT(chunkCols > 0, "TREDUCE: tileValidCol must be greater than 0");
if (totalRows == 0 || gShape4 == 0) {
return;
}
if (totalRows <= chunkRows && gShape4 <= chunkCols) {
TreduceSimple<ParallelGroupType, GlobalDstData, TileData>(parallelGroup, dstGlobalData, accTileData,
recvTileData, op, rootRank, groupSize);
return;
}
constexpr bool isDynamicRow = (TileData::ValidRow == DYNAMIC);
constexpr bool isDynamicCol = (TileData::ValidCol == DYNAMIC);
if constexpr (!isDynamicRow) {
PTO_ASSERT(gShape3 % chunkRows == 0,
"TREDUCE chunked: shape3 must be divisible by tile ValidRow when ValidRow is static. "
"Use a Tile with DYNAMIC ValidRow for partial row chunk support.");
}
if constexpr (!isDynamicCol) {
PTO_ASSERT(gShape4 % chunkCols == 0,
"TREDUCE chunked: shape4 must be divisible by tile ValidCol when ValidCol is static. "
"Use a Tile with DYNAMIC ValidCol for partial column chunk support.");
}
TreduceChunkedSingle<ParallelGroupType, GlobalDstData, TileData>(
parallelGroup, dstGlobalData, accTileData, recvTileData, op, gShape0, gShape1, gShape2, gShape3, gShape4,
chunkRows, chunkCols, rootRank, groupSize);
}
template <typename ParallelGroupType, typename GlobalDstData, typename TileData>
PTO_INTERNAL void TreduceSimplePingPong(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData,
TileData &accTileData, TileData &pingTile, TileData &pongTile, ReduceOp op,
int rootIdx, int nranks, int numRemote)
{
auto &rootTensor = parallelGroup[rootIdx];
if (nranks == 1) {
TLOAD(accTileData, rootTensor);
PtoSetWaitFlag<PIPE_MTE2, PIPE_MTE3>();
TSTORE(dstGlobalData, accTileData);
PtoSetWaitFlag<PIPE_MTE3, PIPE_MTE2>();
return;
}
TLOAD(accTileData, rootTensor);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
TLOAD(pingTile, parallelGroup[detail::GetRemoteRank(rootIdx, 0)]);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
for (int i = 0; i < numRemote; ++i) {
const bool hasNext = (i + 1 < numRemote);
const bool usePing = ((i & 1) == 0);
TileData ¤tTile = usePing ? pingTile : pongTile;
TileData &nextTile = usePing ? pongTile : pingTile;
const auto currentEvent = usePing ? EVENT_ID1 : EVENT_ID2;
const auto nextEvent = usePing ? EVENT_ID2 : EVENT_ID1;
if (hasNext) {
TLOAD(nextTile, parallelGroup[detail::GetRemoteRank(rootIdx, i + 1)]);
set_flag(PIPE_MTE2, PIPE_V, nextEvent);
}
wait_flag(PIPE_MTE2, PIPE_V, currentEvent);
detail::ReduceTiles(accTileData, currentTile, op);
if (hasNext) {
set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
} else {
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
}
}
TSTORE(dstGlobalData, accTileData);
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
}
template <typename ParallelGroupType, typename TileData, typename DynStrideT>
PTO_INTERNAL void TreducePingPongLoop(ParallelGroupType ¶llelGroup, TileData &accTileData, TileData &pingTile,
TileData &pongTile, ReduceOp op, int64_t srcOffset, int currentRows,
int currentCols, const DynStrideT &srcChunkStride, int rootIdx, int numRemote)
{
using GlobalSrcData = typename ParallelGroupTraits<ParallelGroupType>::GlobalDataType;
using T = typename GlobalSrcData::RawDType;
using DynShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>;
using SrcViewT = GlobalTensor<T, DynShape, DynStrideT, GlobalSrcData::layout>;
DynShape chunkShape(1, 1, 1, currentRows, currentCols);
const int firstRemoteRank = detail::GetRemoteRank(rootIdx, 0);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
SrcViewT firstView(parallelGroup[firstRemoteRank].data() + srcOffset, chunkShape, srcChunkStride);
TLOAD(pingTile, firstView);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
for (int i = 0; i < numRemote; ++i) {
const bool scheduleNext = (i + 1 < numRemote);
const bool currentIsPing = ((i & 1) == 0);
TileData ¤tTile = currentIsPing ? pingTile : pongTile;
TileData &nextTile = currentIsPing ? pongTile : pingTile;
const event_t currentReady = currentIsPing ? EVENT_ID1 : EVENT_ID2;
const event_t nextReady = currentIsPing ? EVENT_ID2 : EVENT_ID1;
if (scheduleNext) {
const int nextRemoteRank = detail::GetRemoteRank(rootIdx, i + 1);
SrcViewT nextView(parallelGroup[nextRemoteRank].data() + srcOffset, chunkShape, srcChunkStride);
TLOAD(nextTile, nextView);
set_flag(PIPE_MTE2, PIPE_V, nextReady);
}
wait_flag(PIPE_MTE2, PIPE_V, currentReady);
detail::ReduceTiles(accTileData, currentTile, op);
if (!scheduleNext) {
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
continue;
}
set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
}
}
template <typename ParallelGroupType, typename GlobalDstData, typename TileData, typename DynStrideT>
PTO_INTERNAL void TreduceProcessChunkPingPong(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData,
TileData &accTileData, TileData &pingTile, TileData &pongTile,
ReduceOp op, int64_t srcOffset, int64_t dstOffset, int currentRows,
int currentCols, const DynStrideT &srcChunkStride,
const DynStrideT &dstChunkStride, int rootIdx, int nranks, int numRemote)
{
using GlobalSrcData = typename ParallelGroupTraits<ParallelGroupType>::GlobalDataType;
using ElemT = typename GlobalSrcData::RawDType;
using DynShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>;
using SrcViewT = GlobalTensor<ElemT, DynShape, DynStrideT, GlobalSrcData::layout>;
using DstViewT = GlobalTensor<ElemT, DynShape, DynStrideT, GlobalDstData::layout>;
constexpr bool isDynamicRow = (TileData::ValidRow == DYNAMIC);
constexpr bool isDynamicCol = (TileData::ValidCol == DYNAMIC);
if constexpr (isDynamicRow) {
accTileData.RowMaskInternal = currentRows;
pingTile.RowMaskInternal = currentRows;
pongTile.RowMaskInternal = currentRows;
}
if constexpr (isDynamicCol) {
accTileData.ColMaskInternal = currentCols;
pingTile.ColMaskInternal = currentCols;
pongTile.ColMaskInternal = currentCols;
}
DynShape chunkShape(1, 1, 1, currentRows, currentCols);
SrcViewT rootView(parallelGroup[rootIdx].data() + srcOffset, chunkShape, srcChunkStride);
TLOAD(accTileData, rootView);
if (nranks == 1) {
set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
} else {
TreducePingPongLoop<ParallelGroupType, TileData>(parallelGroup, accTileData, pingTile, pongTile, op, srcOffset,
currentRows, currentCols, srcChunkStride, rootIdx, numRemote);
}
DstViewT dstView(dstGlobalData.data() + dstOffset, chunkShape, dstChunkStride);
TSTORE(dstView, accTileData);
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
}
template <typename ParallelGroupType, typename GlobalDstData, typename TileData>
PTO_INTERNAL void TreduceChunkedPingPong(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData,
TileData &accTileData, TileData &pingTile, TileData &pongTile, ReduceOp op,
int gShape0, int gShape1, int gShape2, int gShape3, int gShape4,
int tileValidRow, int tileValidCol, int rootIdx, int nranks, int numRemote)
{
using GlobalSrcData = typename ParallelGroupTraits<ParallelGroupType>::GlobalDataType;
GlobalSrcData &refTensor = parallelGroup[rootIdx];
const int srcStride[5] = {static_cast<int>(refTensor.GetStride(GlobalTensorDim::DIM_0)),
static_cast<int>(refTensor.GetStride(GlobalTensorDim::DIM_1)),
static_cast<int>(refTensor.GetStride(GlobalTensorDim::DIM_2)),
static_cast<int>(refTensor.GetStride(GlobalTensorDim::DIM_3)),
static_cast<int>(refTensor.GetStride(GlobalTensorDim::DIM_4))};
const int dstStride[5] = {static_cast<int>(dstGlobalData.GetStride(GlobalTensorDim::DIM_0)),
static_cast<int>(dstGlobalData.GetStride(GlobalTensorDim::DIM_1)),
static_cast<int>(dstGlobalData.GetStride(GlobalTensorDim::DIM_2)),
static_cast<int>(dstGlobalData.GetStride(GlobalTensorDim::DIM_3)),
static_cast<int>(dstGlobalData.GetStride(GlobalTensorDim::DIM_4))};
using DynStride = Stride<DYNAMIC, DYNAMIC, DYNAMIC, DYNAMIC, DYNAMIC>;
DynStride srcChunkStride(srcStride[0], srcStride[1], srcStride[2], srcStride[3], srcStride[4]);
DynStride dstChunkStride(dstStride[0], dstStride[1], dstStride[2], dstStride[3], dstStride[4]);
for (int dim0 = 0; dim0 < gShape0; ++dim0) {
for (int dim1 = 0; dim1 < gShape1; ++dim1) {
for (int dim2 = 0; dim2 < gShape2; ++dim2) {
const int64_t srcBase = static_cast<int64_t>(dim0) * srcStride[0] +
static_cast<int64_t>(dim1) * srcStride[1] +
static_cast<int64_t>(dim2) * srcStride[2];
const int64_t dstBase = static_cast<int64_t>(dim0) * dstStride[0] +
static_cast<int64_t>(dim1) * dstStride[1] +
static_cast<int64_t>(dim2) * dstStride[2];
int rowCursor = 0;
while (rowCursor < gShape3) {
const int rowRemain = gShape3 - rowCursor;
const int curRows = (rowRemain < tileValidRow) ? rowRemain : tileValidRow;
int colCursor = 0;
while (colCursor < gShape4) {
const int colRemain = gShape4 - colCursor;
const int curCols = (colRemain < tileValidCol) ? colRemain : tileValidCol;
const int64_t srcOff = srcBase + static_cast<int64_t>(rowCursor) * srcStride[3] +
static_cast<int64_t>(colCursor) * srcStride[4];
const int64_t dstOff = dstBase + static_cast<int64_t>(rowCursor) * dstStride[3] +
static_cast<int64_t>(colCursor) * dstStride[4];
TreduceProcessChunkPingPong<ParallelGroupType, GlobalDstData, TileData>(
parallelGroup, dstGlobalData, accTileData, pingTile, pongTile, op, srcOff, dstOff, curRows,
curCols, srcChunkStride, dstChunkStride, rootIdx, nranks, numRemote);
colCursor += tileValidCol;
}
rowCursor += tileValidRow;
}
}
}
}
}
template <typename ParallelGroupType, typename GlobalDstData, typename TileData>
PTO_INTERNAL void TREDUCE_IMPL(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData, TileData &accTileData,
TileData &pingTile, TileData &pongTile, ReduceOp op)
{
using GlobalSrcData = typename ParallelGroupTraits<ParallelGroupType>::GlobalDataType;
using T = typename GlobalSrcData::RawDType;
static_assert(std::is_same_v<T, typename GlobalDstData::RawDType>, "TREDUCE: GlobalData type mismatch!");
static_assert(std::is_same_v<T, typename TileData::DType>,
"TREDUCE: TileData element type must match GlobalData element type");
const int rootIdx = parallelGroup.GetRootIdx();
const int nranks = parallelGroup.GetSize();
PTO_ASSERT(nranks > 0, "ParallelGroup size must be greater than 0!");
PTO_ASSERT(rootIdx >= 0 && rootIdx < nranks, "rootIdx must be in range [0, nranks)!");
GlobalSrcData &refTensor = parallelGroup[rootIdx];
const int gShape0 = refTensor.GetShape(GlobalTensorDim::DIM_0);
const int gShape1 = refTensor.GetShape(GlobalTensorDim::DIM_1);
const int gShape2 = refTensor.GetShape(GlobalTensorDim::DIM_2);
const int gShape3 = refTensor.GetShape(GlobalTensorDim::DIM_3);
const int gShape4 = refTensor.GetShape(GlobalTensorDim::DIM_4);
const int64_t totalRows = static_cast<int64_t>(gShape0) * gShape1 * gShape2 * gShape3;
const int tileValidRow = accTileData.GetValidRow();
const int tileValidCol = accTileData.GetValidCol();
PTO_ASSERT(tileValidRow > 0, "TREDUCE: tileValidRow must be greater than 0");
PTO_ASSERT(tileValidCol > 0, "TREDUCE: tileValidCol must be greater than 0");
if (totalRows == 0 || gShape4 == 0) {
return;
}
const int remoteCount = nranks - 1;
if (totalRows <= tileValidRow && gShape4 <= tileValidCol) {
TreduceSimplePingPong<ParallelGroupType, GlobalDstData, TileData>(
parallelGroup, dstGlobalData, accTileData, pingTile, pongTile, op, rootIdx, nranks, remoteCount);
return;
}
constexpr bool isDynamicRow = (TileData::ValidRow == DYNAMIC);
constexpr bool isDynamicCol = (TileData::ValidCol == DYNAMIC);
if constexpr (!isDynamicRow) {
PTO_ASSERT(gShape3 % tileValidRow == 0,
"TREDUCE chunked: shape3 must be divisible by tile ValidRow when ValidRow is static. "
"Use a Tile with DYNAMIC ValidRow for partial row chunk support.");
}
if constexpr (!isDynamicCol) {
PTO_ASSERT(gShape4 % tileValidCol == 0,
"TREDUCE chunked: shape4 must be divisible by tile ValidCol when ValidCol is static. "
"Use a Tile with DYNAMIC ValidCol for partial column chunk support.");
}
TreduceChunkedPingPong<ParallelGroupType, GlobalDstData, TileData>(
parallelGroup, dstGlobalData, accTileData, pingTile, pongTile, op, gShape0, gShape1, gShape2, gShape3, gShape4,
tileValidRow, tileValidCol, rootIdx, nranks, remoteCount);
}
}
}
#endif