* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file tileop_shmem.h
* \brief Shmem (shared memory) tileops: clear/set, GM/UB copy, Put/Get/Signal, Reduce.
*/
#ifndef __DISTRIBUTED_SHMEM__
#define __DISTRIBUTED_SHMEM__
#include "common.h"
#include <type_traits>
#include "utils/layout.h"
#include "utils/tile_tensor.h"
#include "vector/mte.h"
#ifdef SUPPORT_TILE_TENSOR
#include "pto/comm/pto_comm_inst.hpp"
#endif
namespace TileOp::Distributed {
template <typename TensorType>
TILEOP void GetStrideInfo(const TensorType& tensor, uint64_t strides[MAX_DIMS])
{
const auto layout = tensor.GetLayout();
strides[0] = layout.template GetStrideDim<DIM_1ST, MAX_DIMS>();
strides[1] = layout.template GetStrideDim<DIM_2ND, MAX_DIMS>();
strides[2] = layout.template GetStrideDim<DIM_3RD, MAX_DIMS>();
strides[4] = layout.template GetStrideDim<DIM_4TH, MAX_DIMS>();
strides[5] = layout.template GetStrideDim<DIM_5TH, MAX_DIMS>();
}
using ShapeDyn = pto::Shape<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC>;
using StrideDyn = pto::Stride<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC>;
TILEOP inline ShapeDyn MakeShape(uint32_t row, uint32_t col) { return ShapeDyn(1, 1, 1, row, col); }
TILEOP inline StrideDyn MakeStride(uint32_t row, uint32_t stride) { return StrideDyn(row, row, row, stride, 1); }
TILEOP inline uint32_t ToggleEvent(uint32_t eventId) { return eventId == EVENT_ID0 ? EVENT_ID1 : EVENT_ID0; }
template <typename T>
TILEOP constexpr T CeilDiv(T x, T y)
{
return (x + y - 1) / y;
}
template <AtomicType atomicType, typename TileType, typename GlobalType>
TILEOP void AtomicStore(GlobalType& global, TileType& tile)
{
if constexpr (atomicType == AtomicType::ADD) {
pto::TSTORE<TileType, GlobalType, pto::AtomicType::AtomicAdd>(global, tile);
} else {
pto::TSTORE<TileType, GlobalType, pto::AtomicType::AtomicNone>(global, tile);
}
}
template <typename T1>
TILEOP void GetShapesFromLayout(T1 tileTensor, int32_t shapes[MAX_DIMS])
{
const auto layout = tileTensor.GetLayout();
shapes[0] = static_cast<int32_t>(layout.template GetShapeDim<DIM_1ST, MAX_DIMS>());
shapes[1] = static_cast<int32_t>(layout.template GetShapeDim<DIM_2ND, MAX_DIMS>());
shapes[2] = static_cast<int32_t>(layout.template GetShapeDim<DIM_3RD, MAX_DIMS>());
shapes[3] = static_cast<int32_t>(layout.template GetShapeDim<DIM_4TH, MAX_DIMS>());
shapes[4] = static_cast<int32_t>(layout.template GetShapeDim<DIM_5TH, MAX_DIMS>());
}
template <typename C1>
TILEOP void GetCoordsFromCoordinate(C1 coordinate, int32_t coords[MAX_DIMS])
{
coords[0] = static_cast<int32_t>(TileOp::GetTupleElement<C1, DIM_1ST, MAX_DIMS, 0>(coordinate));
coords[1] = static_cast<int32_t>(TileOp::GetTupleElement<C1, DIM_2ND, MAX_DIMS, 0>(coordinate));
coords[2] = static_cast<int32_t>(TileOp::GetTupleElement<C1, DIM_3RD, MAX_DIMS, 0>(coordinate));
coords[3] = static_cast<int32_t>(TileOp::GetTupleElement<C1, DIM_4TH, MAX_DIMS, 0>(coordinate));
coords[4] = static_cast<int32_t>(TileOp::GetTupleElement<C1, DIM_5TH, MAX_DIMS, 0>(coordinate));
}
template <typename T>
using ShmemGlobalTensor = pto::GlobalTensor<T, ShapeDyn, StrideDyn, pto::Layout::ND>;
template <typename T, uint32_t MaxRowShape, uint32_t MaxColShape>
using ShmemUbTile = pto::Tile<
pto::TileType::Vec, T, MaxRowShape, AlignUp<uint32_t>(MaxColShape * sizeof(T), COPY_BLOCK_BYTE_SIZE) / sizeof(T),
pto::BLayout::RowMajor, pto::DYNAMIC, pto::DYNAMIC>;
template <typename T, uint32_t bufferEleNum>
TILEOP void ShmemClear(__ubuf__ T* buffer, __gm__ T* shmemTensorAddr, uint32_t shmemTensorEleNum)
{
ShmemUbTile<T, 1, bufferEleNum> ubTile(1, bufferEleNum);
pto::TASSIGN(ubTile, reinterpret_cast<uintptr_t>(buffer));
pto::TEXPANDS(ubTile, static_cast<T>(0));
PIPE_SYNC_EVENT(PIPE_V, PIPE_MTE3, EVENT_ID0);
uint32_t fullChunkCount = shmemTensorEleNum / bufferEleNum;
for (int32_t i = 0; i < fullChunkCount; i++) {
__gm__ T* dstAddr = shmemTensorAddr + bufferEleNum * i;
ShapeDyn shape = MakeShape(1, bufferEleNum);
StrideDyn strideDyn = MakeStride(1, bufferEleNum);
ShmemGlobalTensor<T> gmTensor(dstAddr, shape, strideDyn);
pto::TSTORE<decltype(ubTile), decltype(gmTensor), pto::AtomicType::AtomicNone>(gmTensor, ubTile);
}
uint32_t tailEleNum = shmemTensorEleNum % bufferEleNum;
if (tailEleNum != 0) {
__gm__ T* tailDstAddr = shmemTensorAddr + bufferEleNum * fullChunkCount;
ShapeDyn tailShape = MakeShape(1, tailEleNum);
StrideDyn tailStrideDyn = MakeStride(1, tailEleNum);
ShmemGlobalTensor<T> tailGmTensor(tailDstAddr, tailShape, tailStrideDyn);
ShmemUbTile<T, 1, bufferEleNum> tailUbTile(1, tailEleNum);
pto::TASSIGN(tailUbTile, reinterpret_cast<uintptr_t>(buffer));
pto::TSTORE<decltype(tailUbTile), decltype(tailGmTensor), pto::AtomicType::AtomicNone>(
tailGmTensor, tailUbTile);
}
}
template <typename T, uint32_t bufferEleNum>
TILEOP void ShmemClearFlag(
__ubuf__ T* buffer, __gm__ T* shmemTensorAddr, uint32_t worldSize, uint32_t maxTileNume, uint32_t stride)
{
ShmemUbTile<T, 1, bufferEleNum> ubTile(1, bufferEleNum);
pto::TASSIGN(ubTile, reinterpret_cast<uintptr_t>(buffer));
pto::TEXPANDS(ubTile, static_cast<T>(0));
PIPE_SYNC_EVENT(PIPE_V, PIPE_MTE3, EVENT_ID0);
uint32_t shmemTensorEleNum = worldSize * maxTileNume;
uint32_t fullChunkCount = shmemTensorEleNum / bufferEleNum;
for (int32_t i = 0; i < fullChunkCount; i++) {
__gm__ T* dstAddr = shmemTensorAddr + bufferEleNum * i;
ShapeDyn shape = MakeShape(1, bufferEleNum);
StrideDyn strideDyn = MakeStride(1, bufferEleNum);
ShmemGlobalTensor<T> gmTensor(dstAddr, shape, strideDyn);
pto::TSTORE<decltype(ubTile), decltype(gmTensor), pto::AtomicType::AtomicNone>(gmTensor, ubTile);
}
}
template <typename T, uint32_t bufferEleNum, typename T1, typename C1>
TILEOP void ShmemSet(
CoreFuncParam* param, __ubuf__ T* buffer, T1 shmemTensor, C1 coordinate, uint32_t ownerRank,
__gm__ int64_t* hcclContext)
{
const auto shmemTensorLayout = shmemTensor.GetLayout();
auto shmemTensorOffset = shmemTensorLayout.template GetGmOffset<C1, MAX_DIMS>(coordinate);
auto shape0 = shmemTensorLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
auto shape1 = shmemTensorLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
auto shape2 = shmemTensorLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
auto shape3 = shmemTensorLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>();
auto shape4 = shmemTensorLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>();
using shmemTensorDtype = typename T1::Type;
__gm__ shmemTensorDtype* shmemTensorAddr =
MapVirtualAddr<shmemTensorDtype>(hcclContext, shmemTensor.GetAddr(), ownerRank) + shmemTensorOffset;
uint32_t shmemTensorEleNum = shape0 * shape1 * shape2 * shape3 * shape4;
ShmemClear<T, bufferEleNum>(buffer, shmemTensorAddr, shmemTensorEleNum);
}
template <typename T, uint32_t worldSize, uint32_t stride, uint32_t bufferEleNum, typename T1>
TILEOP void ShmemSet(
CoreFuncParam* param, __ubuf__ T* buffer, T1 shmemTensor, uint32_t ownerRank, __gm__ int64_t* hcclContext)
{
uint32_t maxTileNume =
static_cast<uint32_t>(TileOp::Distributed::DecodeShmemAddrMaxTileNum((uint64_t)shmemTensor.GetAddr()));
__gm__ T* shmemTensorAddr = MapVirtualAddr<T>(hcclContext, shmemTensor.GetAddr(), ownerRank);
ShmemClearFlag<T, bufferEleNum>(buffer, shmemTensorAddr, worldSize, maxTileNume, stride);
}
template <
typename TargetType, typename UBType, typename SourceType, uint32_t bufferRowShape, uint32_t bufferColShape,
uint32_t srcStride, uint32_t dstStride, AtomicType atomicType>
TILEOP void CopyGmToGmBlock(
__gm__ TargetType* target, __ubuf__ UBType* buffer, __gm__ SourceType* source, uint32_t rowShape, uint32_t colShape,
uint32_t eventId = EVENT_ID0)
{
wait_flag(PIPE_MTE3, PIPE_S, eventId);
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE2, eventId);
ShapeDyn shape = MakeShape(rowShape, colShape);
StrideDyn srcStrideDyn = MakeStride(rowShape, srcStride);
StrideDyn dstStrideDyn = MakeStride(rowShape, dstStride);
ShmemGlobalTensor<SourceType> srcGlobal(source, shape, srcStrideDyn);
ShmemGlobalTensor<TargetType> dstGlobal(target, shape, dstStrideDyn);
constexpr uint64_t copyLen =
bufferRowShape * AlignUp<uint64_t>(bufferColShape * sizeof(UBType), 32) / sizeof(UBType);
__ubuf__ float* castUb = (__ubuf__ float*)(buffer + copyLen);
constexpr bool kAtomicAdd = (atomicType == AtomicType::ADD);
using SrcElemType = std::conditional_t<kAtomicAdd, UBType, float>;
using DstElemType = std::conditional_t<kAtomicAdd, float, UBType>;
ShmemUbTile<SrcElemType, bufferRowShape, bufferColShape> srcTile(rowShape, colShape);
ShmemUbTile<DstElemType, bufferRowShape, bufferColShape> dstTile(rowShape, colShape);
if constexpr (kAtomicAdd) {
pto::TASSIGN(srcTile, reinterpret_cast<uintptr_t>(buffer));
pto::TASSIGN(dstTile, reinterpret_cast<uintptr_t>(castUb));
} else {
pto::TASSIGN(srcTile, reinterpret_cast<uintptr_t>(castUb));
pto::TASSIGN(dstTile, reinterpret_cast<uintptr_t>(buffer));
}
pto::TLOAD(srcTile, srcGlobal);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_V, eventId);
pto::TCVT(dstTile, srcTile, pto::RoundMode::CAST_NONE);
PIPE_SYNC_EVENT(PIPE_V, PIPE_MTE3, eventId);
AtomicStore<atomicType>(dstGlobal, dstTile);
set_flag(PIPE_MTE3, PIPE_S, eventId);
}
template <
typename TargetType, typename UBType, typename SourceType, uint32_t bufferRowShape, uint32_t bufferColShape,
uint32_t srcStride, uint32_t dstStride, AtomicType atomicType>
TILEOP void CopyGmToGmRow(
__gm__ TargetType* dstPtr, __gm__ SourceType* srcPtr, __ubuf__ UBType* bufferA, __ubuf__ UBType* bufferB,
uint32_t rowShape, uint32_t colTailShape, uint32_t colFullBlockCount, uint32_t& eventId)
{
uint32_t colOffset = 0;
for (uint32_t colIndex = 0; colIndex < colFullBlockCount; ++colIndex, colOffset += bufferColShape) {
__ubuf__ UBType* useBuffer = eventId == EVENT_ID0 ? bufferA : bufferB;
CopyGmToGmBlock<
TargetType, UBType, SourceType, bufferRowShape, bufferColShape, srcStride, dstStride, atomicType>(
dstPtr + colOffset, useBuffer, srcPtr + colOffset, rowShape, bufferColShape, eventId);
eventId = eventId == EVENT_ID0 ? EVENT_ID1 : EVENT_ID0;
}
if (colTailShape > 0) {
__ubuf__ UBType* useBuffer = eventId == EVENT_ID0 ? bufferA : bufferB;
CopyGmToGmBlock<
TargetType, UBType, SourceType, bufferRowShape, bufferColShape, srcStride, dstStride, atomicType>(
dstPtr + colOffset, useBuffer, srcPtr + colOffset, rowShape, colTailShape, eventId);
eventId = eventId == EVENT_ID0 ? EVENT_ID1 : EVENT_ID0;
}
}
template <
bool useTPut, typename DataType, uint32_t tileRowShape, uint32_t tileColShape, uint32_t bufferRowShape,
uint32_t srcStride, uint32_t dstStride, AtomicType atomicType = AtomicType::SET>
TILEOP void CopyGmToGmByPutGet(
__gm__ DataType* target, __ubuf__ DataType* buffer, __gm__ DataType* source, uint32_t validRowShape,
uint32_t validColShape)
{
static_assert(bufferRowShape > 0, "bufferRowShape must be greater than 0.");
constexpr uint32_t kMaxTileRows = 4095;
constexpr uint32_t kEffectiveRows = bufferRowShape < kMaxTileRows ? bufferRowShape : kMaxTileRows;
constexpr uint32_t kChunkRows = tileRowShape < kEffectiveRows ? tileRowShape : kEffectiveRows;
constexpr uint32_t kAlignedCols =
AlignUp<uint32_t>(tileColShape * sizeof(DataType), COPY_BLOCK_BYTE_SIZE) / sizeof(DataType);
constexpr uint32_t kHalfBufferEleCount = bufferRowShape * kAlignedCols;
ShapeDyn shape = MakeShape(validRowShape, validColShape);
StrideDyn srcStrideDyn = MakeStride(validRowShape, srcStride);
StrideDyn dstStrideDyn = MakeStride(validRowShape, dstStride);
ShmemGlobalTensor<DataType> srcGlobal(source, shape, srcStrideDyn);
ShmemGlobalTensor<DataType> dstGlobal(target, shape, dstStrideDyn);
ShmemUbTile<DataType, kChunkRows, tileColShape> pingTile(kChunkRows, validColShape);
ShmemUbTile<DataType, kChunkRows, tileColShape> pongTile(kChunkRows, validColShape);
pto::TASSIGN(pingTile, reinterpret_cast<uintptr_t>(buffer));
pto::TASSIGN(pongTile, reinterpret_cast<uintptr_t>(buffer + kHalfBufferEleCount));
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
if constexpr (useTPut) {
if constexpr (atomicType == AtomicType::ADD) {
pto::comm::TPUT<pto::AtomicType::AtomicAdd>(dstGlobal, srcGlobal, pingTile, pongTile);
} else {
pto::comm::TPUT<pto::AtomicType::AtomicNone>(dstGlobal, srcGlobal, pingTile, pongTile);
}
} else {
pto::comm::TGET(dstGlobal, srcGlobal, pingTile, pongTile);
}
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
}
template <
typename TargetType, typename UBType, typename SourceType, uint32_t tileRowShape, uint32_t tileColShape,
uint32_t bufferRowShape, uint32_t bufferColShape, uint32_t srcStride, uint32_t dstStride, AtomicType atomicType>
TILEOP void CopyGmToGmByLaodStore(
__gm__ TargetType* target, __ubuf__ UBType* buffer, __gm__ SourceType* source, uint32_t validRowShape,
uint32_t validColShape)
{
uint32_t rowFullBlockCount = validRowShape / bufferRowShape;
uint32_t colFullBlockCount = validColShape / bufferColShape;
uint32_t rowTailShape = validRowShape % bufferRowShape;
uint32_t colTailShape = validColShape % bufferColShape;
constexpr uint32_t srcRowStride = bufferRowShape * srcStride;
constexpr uint32_t dstRowStride = bufferRowShape * dstStride;
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID1);
uint32_t eventId = EVENT_ID0;
constexpr uint32_t copyLen =
bufferRowShape * AlignUp<uint32_t>(bufferColShape * sizeof(UBType), 32) / sizeof(UBType);
__ubuf__ UBType* bufferA = buffer;
__ubuf__ UBType* bufferB = buffer + copyLen;
if constexpr (!std::is_same_v<TargetType, SourceType>) {
constexpr uint64_t castSize = AlignUp<uint64_t>(copyLen * sizeof(float), 256);
bufferB = buffer + copyLen + castSize / sizeof(UBType);
}
__gm__ SourceType* srcPtr = source;
__gm__ TargetType* dstPtr = target;
for (uint32_t rowIndex = 0; rowIndex < rowFullBlockCount;
++rowIndex, srcPtr += srcRowStride, dstPtr += dstRowStride) {
CopyGmToGmRow<TargetType, UBType, SourceType, bufferRowShape, bufferColShape, srcStride, dstStride, atomicType>(
dstPtr, srcPtr, bufferA, bufferB, bufferRowShape, colTailShape, colFullBlockCount, eventId);
}
if (rowTailShape > 0) {
CopyGmToGmRow<TargetType, UBType, SourceType, bufferRowShape, bufferColShape, srcStride, dstStride, atomicType>(
dstPtr, srcPtr, bufferA, bufferB, rowTailShape, colTailShape, colFullBlockCount, eventId);
}
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID1);
}
template <
typename NonShmemType, typename ShmemType, uint32_t tileRowShape, uint32_t tileColShape, uint32_t bufferRowShape,
uint32_t bufferColShape, uint32_t srcStride, uint32_t dstStride, AtomicType atomicType, typename T1, typename T2,
typename C1, typename C2>
TILEOP void ShmemPut(
CoreFuncParam* param, __ubuf__ NonShmemType* buffer, T1 src, T2 dst, C1 srcCoordinate, C2 dstCoordinate,
uint32_t srcValidShape0, uint32_t srcValidShape1, uint32_t srcValidShape2, uint32_t srcValidShape3,
uint32_t srcValidShape4, uint32_t ownerRank, __gm__ int64_t* hcclContext)
{
static_assert(T1::FORMAT == Hardware::GM && T2::FORMAT == Hardware::GM);
uint64_t srcStrides[MAX_DIMS];
GetStrideInfo(src, srcStrides);
auto srcGmOffset = src.GetLayout().template GetGmOffset<C1, MAX_DIMS>(srcCoordinate);
uint64_t dstStrides[MAX_DIMS];
GetStrideInfo(dst, dstStrides);
auto dstGmOffset = dst.GetLayout().template GetGmOffset<C2, MAX_DIMS>(dstCoordinate);
if constexpr (atomicType == AtomicType::ADD) {
SetAttomicType<ShmemType>();
set_atomic_add();
}
for (LoopVar index0 = 0; index0 < srcValidShape0; ++index0) {
for (LoopVar index1 = 0; index1 < srcValidShape1; ++index1) {
for (LoopVar index2 = 0; index2 < srcValidShape2; ++index2) {
auto srcOffset = index0 * srcStrides[0] + index1 * srcStrides[1] + index2 * srcStrides[2];
auto dstOffset = index0 * dstStrides[0] + index1 * dstStrides[1] + index2 * dstStrides[2];
__gm__ NonShmemType* srcAddr = src.GetAddr() + srcGmOffset + srcOffset;
__gm__ ShmemType* dstAddr =
MapVirtualAddr<ShmemType>(hcclContext, dst.GetAddr(), ownerRank) + dstGmOffset + dstOffset;
if constexpr (std::is_same_v<ShmemType, NonShmemType>) {
CopyGmToGmByPutGet<
true, ShmemType, tileRowShape, tileColShape, bufferRowShape, srcStride, dstStride, atomicType>(
dstAddr, buffer, srcAddr, srcValidShape3, srcValidShape4);
} else {
CopyGmToGmByLaodStore<
ShmemType, NonShmemType, NonShmemType, tileRowShape, tileColShape, bufferRowShape,
bufferColShape, srcStride, dstStride, atomicType>(
dstAddr, buffer, srcAddr, srcValidShape3, srcValidShape4);
}
}
}
}
if constexpr (atomicType == AtomicType::ADD) {
set_atomic_none();
}
}
template <
typename InShmemType, typename OutShmemType, uint32_t tileRowShape, uint32_t tileColShape, uint32_t bufferRowShape,
uint32_t bufferColShape, uint32_t srcStride, uint32_t dstStride, AtomicType atomicType>
TILEOP void ShmemPut(
CoreFuncParam* param, __ubuf__ InShmemType* buffer, __gm__ InShmemType* inShmemDataBaseAddr,
__gm__ OutShmemType* shmemDataBaseAddr, uint32_t inShmemDataOffset0, uint32_t inShmemDataOffset1,
uint32_t inShmemDataOffset2, uint32_t inShmemDataRawShape0, uint32_t inShmemDataRawShape1,
uint32_t inShmemDataRawShape2, uint32_t shmemDataOffset0, uint32_t shmemDataOffset1, uint32_t shmemDataOffset2,
uint32_t shmemDataRawShape0, uint32_t shmemDataRawShape1, uint32_t shmemDataRawShape2, uint32_t validShape0,
uint32_t validShape1, __gm__ int64_t* hcclContext)
{
(void)inShmemDataRawShape0;
(void)shmemDataRawShape0;
__gm__ InShmemType* inShmemDataAddr =
MapVirtualAddr<InShmemType>(hcclContext, inShmemDataBaseAddr, inShmemDataOffset0) +
CalcLinearOffset(inShmemDataRawShape2, inShmemDataOffset1, inShmemDataOffset2);
__gm__ OutShmemType* shmemDataAddr =
MapVirtualAddr<OutShmemType>(hcclContext, shmemDataBaseAddr, shmemDataOffset0) +
CalcLinearOffset(shmemDataRawShape2, shmemDataOffset1, shmemDataOffset2);
if constexpr (std::is_same_v<OutShmemType, InShmemType>) {
CopyGmToGmByPutGet<
true, OutShmemType, tileRowShape, tileColShape, bufferRowShape, srcStride, dstStride, atomicType>(
shmemDataAddr, buffer, inShmemDataAddr, validShape0, validShape1);
} else {
CopyGmToGmByLaodStore<
OutShmemType, InShmemType, InShmemType, tileRowShape, tileColShape, bufferRowShape, bufferColShape,
srcStride, dstStride, atomicType>(shmemDataAddr, buffer, inShmemDataAddr, validShape0, validShape1);
}
}
template <
typename Type, uint32_t tileRowShape, uint32_t tileColShape, uint32_t dstStride, AtomicType atomicType, typename T1,
typename T2, typename C1, typename C2>
TILEOP void ShmemStore(
CoreFuncParam* param, T1 src, T2 dst, C1 srcCoordinate, C2 dstCoordinate, uint32_t outValidShape0,
uint32_t outValidShape1, uint32_t outValidShape2, uint32_t outValidShape3, uint32_t outValidShape4,
uint32_t ownerRank, __gm__ int64_t* hcclContext)
{
static_assert(T1::FORMAT == Hardware::UB && T2::FORMAT == Hardware::GM);
uint64_t dstStrides[MAX_DIMS];
GetStrideInfo(dst, dstStrides);
auto dstGmOffset = dst.GetLayout().template GetGmOffset<C2, MAX_DIMS>(dstCoordinate);
for (LoopVar index0 = 0; index0 < outValidShape0; ++index0) {
for (LoopVar index1 = 0; index1 < outValidShape1; ++index1) {
for (LoopVar index2 = 0; index2 < outValidShape2; ++index2) {
auto dstOffset = index0 * dstStrides[0] + index1 * dstStrides[1] + index2 * dstStrides[2];
__gm__ Type* dstAddr =
MapVirtualAddr<Type>(hcclContext, dst.GetAddr(), ownerRank) + dstGmOffset + dstOffset;
auto srcTileOffets = TileOffset(index0, index1, index2);
uint64_t srcAddr = src.GetLinearAddr(srcTileOffets);
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE3, EVENT_ID0);
ShapeDyn shape = MakeShape(outValidShape3, outValidShape4);
StrideDyn dstStrideDyn = MakeStride(outValidShape3, dstStride);
ShmemGlobalTensor<Type> dstGlobal(dstAddr, shape, dstStrideDyn);
ShmemUbTile<Type, tileRowShape, tileColShape> ubTile(outValidShape3, outValidShape4);
pto::TASSIGN(ubTile, srcAddr);
AtomicStore<atomicType>(dstGlobal, ubTile);
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
}
}
}
}
template <
int64_t value, int32_t stride, AtomicType atomicType, bool notifyAll, uint32_t worldSize, uint32_t viewshape0,
uint32_t viewshape1, uint32_t viewshape2, uint32_t viewshape3, uint32_t viewTileNum, uint32_t totalTileNum,
int32_t tileShape0, int32_t tileShape1, int32_t tileShape2, int32_t tileShape3, int32_t tileShapeDim, typename T1,
typename C1>
TILEOP void ShmemSignal(
CoreFuncParam* param, __ubuf__ int32_t* buffer, T1 src, C1 srcCoordinate, uint32_t ownerRank,
__gm__ int64_t* hcclContext)
{
int32_t srcShapes[MAX_DIMS];
GetShapesFromLayout<T1>(src, srcShapes);
int32_t coords[MAX_DIMS];
GetCoordsFromCoordinate<C1>(srcCoordinate, coords);
constexpr int32_t startDim = MAX_DIMS - tileShapeDim;
constexpr int32_t tileShape[] = {tileShape0, tileShape1, tileShape2, tileShape3};
constexpr int32_t viewshapes[] = {viewshape0, viewshape1, viewshape2, viewshape3};
uint32_t tileIndex = 0;
uint32_t viewIndexAccum = 0;
uint32_t viewTileStride = 1;
uint32_t viewIndexStride = 1;
for (uint32_t dimIdx = 0; dimIdx < tileShapeDim; ++dimIdx) {
int32_t curDim = startDim + static_cast<int32_t>(dimIdx);
int32_t rawShape = srcShapes[curDim];
int32_t viewShape = viewshapes[dimIdx];
int32_t offset = coords[curDim];
int32_t tileShapeVal = tileShape[dimIdx];
int32_t viewIdx = offset / viewShape;
int32_t viewOffset = offset % viewShape;
int32_t viewTileIdx = viewOffset / tileShapeVal;
tileIndex += static_cast<uint32_t>(viewTileIdx) * viewTileStride;
viewIndexAccum += static_cast<uint32_t>(viewIdx) * viewIndexStride;
viewTileStride *= static_cast<uint32_t>(CeilDiv(viewShape, tileShapeVal));
viewIndexStride *= static_cast<uint32_t>(rawShape / viewShape);
}
tileIndex += viewIndexAccum * viewTileNum;
buffer[0] = static_cast<int32_t>(value);
constexpr uint32_t signalColShape = 8;
ShmemUbTile<int32_t, 1, signalColShape> signalTile(1, 1);
pto::TASSIGN(signalTile, reinterpret_cast<uintptr_t>(buffer));
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE3, EVENT_ID0);
ShapeDyn signalShape = MakeShape(1, 1);
StrideDyn signalStride = MakeStride(1, 1);
uint32_t sRank = notifyAll ? 0 : ownerRank;
uint32_t eRank = notifyAll ? worldSize : sRank + 1;
uint64_t baseOffset =
static_cast<uint64_t>(CalcLinearOffset(totalTileNum, coords[startDim - 1], tileIndex) * stride);
for (uint32_t rankId = sRank; rankId < eRank; rankId++) {
__gm__ int32_t* shmemSignalAddr = MapVirtualAddr<int32_t, 1>(hcclContext, src.GetAddr(), rankId) + baseOffset;
ShmemGlobalTensor<int32_t> signalGlobal(shmemSignalAddr, signalShape, signalStride);
AtomicStore<atomicType>(signalGlobal, signalTile);
}
PIPE_SYNC_EVENT(PIPE_MTE3, PIPE_S, EVENT_ID0);
}
template <
typename NonShmemType, typename ShmemType, uint32_t tileRowShape, uint32_t tileColShape, uint32_t bufferRowShape,
uint32_t bufferColShape, uint32_t srcStride, uint32_t dstStride, AtomicType atomicType, typename T1, typename T2,
typename C1, typename C2>
TILEOP void ShmemGet(
CoreFuncParam* param, __ubuf__ NonShmemType* buffer, T1 dst, T2 src, C1 dstCoordinate, C2 srcCoordinate,
uint32_t dstValidShape0, uint32_t dstValidShape1, uint32_t dstValidShape2, uint32_t dstValidShape3,
uint32_t dstValidShape4, uint32_t ownerRank, __gm__ int64_t* hcclContext)
{
static_assert(T1::FORMAT == Hardware::GM && T2::FORMAT == Hardware::GM);
uint64_t srcStrides[MAX_DIMS];
GetStrideInfo(src, srcStrides);
auto srcGmOffset = src.GetLayout().template GetGmOffset<C2, MAX_DIMS>(srcCoordinate);
uint64_t dstStrides[MAX_DIMS];
GetStrideInfo(dst, dstStrides);
auto dstGmOffset = dst.GetLayout().template GetGmOffset<C1, MAX_DIMS>(dstCoordinate);
for (LoopVar index0 = 0; index0 < dstValidShape0; ++index0) {
for (LoopVar index1 = 0; index1 < dstValidShape1; ++index1) {
for (LoopVar index2 = 0; index2 < dstValidShape2; ++index2) {
auto srcOffset = index0 * srcStrides[0] + index1 * srcStrides[1] + index2 * srcStrides[2];
auto dstOffset = index0 * dstStrides[0] + index1 * dstStrides[1] + index2 * dstStrides[2];
__gm__ ShmemType* srcAddr =
MapVirtualAddr<ShmemType>(hcclContext, src.GetAddr(), ownerRank) + srcGmOffset + srcOffset;
__gm__ NonShmemType* dstAddr = dst.GetAddr() + dstGmOffset + dstOffset;
if constexpr (std::is_same_v<NonShmemType, ShmemType>) {
CopyGmToGmByPutGet<
false, NonShmemType, tileRowShape, tileColShape, bufferRowShape, srcStride, dstStride>(
dstAddr, buffer, srcAddr, dstValidShape3, dstValidShape4);
} else {
CopyGmToGmByLaodStore<
NonShmemType, NonShmemType, ShmemType, tileRowShape, tileColShape, bufferRowShape,
bufferColShape, srcStride, dstStride, atomicType>(
dstAddr, buffer, srcAddr, dstValidShape3, dstValidShape4);
}
}
}
}
}
template <typename UBType, typename ShmemType, typename T1, typename T2, typename C>
TILEOP void ShmemLoad(
CoreFuncParam* param, T1 dst, T2 src, C srcCoordinate, uint32_t ownerRank, __gm__ int64_t* hcclContext)
{
__gm__ ShmemType* srcAddr = MapVirtualAddr<ShmemType>(hcclContext, src.GetAddr(), ownerRank);
src.SetAddr(srcAddr);
PIPE_SYNC_EVENT(PIPE_S, PIPE_MTE2, EVENT_ID0);
::TLoad(dst, src, srcCoordinate);
PIPE_SYNC_EVENT(PIPE_MTE2, PIPE_S, EVENT_ID0);
}
}
#endif