* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file moe_distributed_combine.cpp
* \brief
*/
#include "distributed_common.h"
#include "interface/function/function.h"
#include "interface/inner/tilefwk.h"
#include "interface/operation/operation.h"
#include "interface/program/program.h"
#include "interface/tensor/logical_tensor.h"
#include "interface/utils/common.h"
#include "tilefwk/error_code.h"
#include "tilefwk/data_type.h"
#include "tilefwk/symbolic_distributed.h"
#include "tilefwk/tensor.h"
#include "tilefwk/tilefwk.h"
namespace npu::tile_fwk::Distributed {
void MoeDistributedCombineValidateExpandX(
const Tensor& expandX, const Tensor& expertScales, int32_t epWorldSize, int32_t moeExpertNum)
{
uint32_t supportedDim = 2;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DIM, expandX.GetShape().size() == supportedDim)
<< "The dim of \"expandX\" only supports " << supportedDim << ", but got " << expandX.GetShape().size();
int32_t expandXRow = expandX.GetShape(0);
int32_t topK = expertScales.GetShape(1);
int32_t batchSize = expertScales.GetShape(0);
int32_t expectedRow = std::min(topK * batchSize * epWorldSize, batchSize * moeExpertNum);
ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, expandXRow == expectedRow)
<< "The first axis of \"expandX\" must be the smaller value between topK (the "
<< "second axis of \"expertScales\") * batchSize (the first axis of \"expertScales\") * epWorldSize and "
<< "batchSize * moeExpertNum, topK=" << topK << ", batchSize=" << batchSize << ", epWorldSize=" << epWorldSize
<< ", moeExpertNum=" << moeExpertNum << ", the expected first axis of \"expandX\" should be " << expectedRow
<< " but got " << expandXRow;
int32_t expandXCol = expandX.GetShape(1);
int32_t supportedHiddenSize = 5120;
ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, expandXCol == supportedHiddenSize)
<< "The second axis of \"expandX\" only supports " << supportedHiddenSize << ", but got " << expandXCol;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DTYPE, expandX.GetDataType() == DT_BF16)
<< "The data type of \"expandX\" only supports DT_BF16, but got " << DataType2String(expandX.GetDataType());
ASSERT(DistributedErrorCode::INVALID_TENSOR_FORMAT, expandX.Format() == npu::tile_fwk::TileOpFormat::TILEOP_ND)
<< "The format of \"expandX\" only supports ND, but got NZ";
}
void MoeDistributedCombineValidateAssistInfoForCombine(const Tensor& assistInfoForCombine, const Tensor& expandX)
{
uint32_t supportedDim = 2;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DIM, assistInfoForCombine.GetShape().size() == supportedDim)
<< "The dim of \"assistInfoForCombine\" only "
<< "supports " << supportedDim << ", but got " << assistInfoForCombine.GetShape().size();
int32_t assistInfoForCombineRow = assistInfoForCombine.GetShape(0);
int32_t expandXRow = expandX.GetShape(0);
ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, assistInfoForCombineRow == expandXRow)
<< "The first axis of \"assistInfoForCombine\" must be consistent "
<< "with that of \"expandX\", but expandXRow=" << expandXRow
<< ", assistInfoForCombineRow=" << assistInfoForCombineRow;
int32_t assistInfoForCombineCol = assistInfoForCombine.GetShape(1);
int32_t supportedAssistInfoForCombineCol = 3;
ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, assistInfoForCombineCol == supportedAssistInfoForCombineCol)
<< "The second axis of "
<< "\"assistInfoForCombine\" must be " << supportedAssistInfoForCombineCol << ", but got "
<< assistInfoForCombineCol;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DTYPE, assistInfoForCombine.GetDataType() == DT_INT32)
<< "The data type of \"assistInfoForCombine\" only supports "
<< "DT_INT32, but got " << DataType2String(assistInfoForCombine.GetDataType());
ASSERT(
DistributedErrorCode::INVALID_TENSOR_FORMAT,
assistInfoForCombine.Format() == npu::tile_fwk::TileOpFormat::TILEOP_ND)
<< "The format of "
<< "\"assistInfoForCombine\" only supports ND, but got NZ";
}
void MoeDistributedCombineValidateRecvCounts(const Tensor& recvCounts)
{
ASSERT(DistributedErrorCode::INVALID_TENSOR_DIM, recvCounts.GetShape().size() == 1)
<< "The dim of \"recvCounts\" only supports 1, but got " << recvCounts.GetShape().size();
int32_t recvCountsSize = recvCounts.GetShape(0);
ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, recvCountsSize == 1)
<< "The size of \"recvCounts\" must be 1, but recvCountsSize=" << recvCountsSize;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DTYPE, recvCounts.GetDataType() == DT_INT32)
<< "The data type of \"recvCounts\" only supports DT_INT32, but got "
<< DataType2String(recvCounts.GetDataType());
ASSERT(DistributedErrorCode::INVALID_TENSOR_FORMAT, recvCounts.Format() == npu::tile_fwk::TileOpFormat::TILEOP_ND)
<< "The format of \"recvCounts\" only "
<< "supports ND, but got NZ";
}
void MoeDistributedCombineValidateExpertScales(const Tensor& expertScales)
{
uint32_t supportedDim = 2;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DIM, expertScales.GetShape().size() == supportedDim)
<< "The dim of \"expertScales\" only supports " << supportedDim << ", but got "
<< expertScales.GetShape().size();
int32_t expertScalesRow = expertScales.GetShape(0);
int32_t supportedExpertScalesRow1 = 8;
int32_t supportedExpertScalesRow2 = 256;
ASSERT(
DistributedErrorCode::INVALID_TENSOR_SHAPE,
(expertScalesRow == supportedExpertScalesRow1) || (expertScalesRow == supportedExpertScalesRow2))
<< "The "
<< "first axis of \"expertScales\" only supports " << supportedExpertScalesRow1 << " or "
<< supportedExpertScalesRow2 << ", but got " << expertScalesRow;
int32_t expertScalesCol = expertScales.GetShape(1);
int32_t supportedExpertScalesCol = 8;
ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, expertScalesCol == supportedExpertScalesCol)
<< "The second axis of \"expertScales\" only supports " << supportedExpertScalesCol << ", but got "
<< expertScalesCol;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DTYPE, expertScales.GetDataType() == DT_FP32)
<< "The data type of \"expertScales\" only supports DT_FP32, but got "
<< DataType2String(expertScales.GetDataType());
ASSERT(DistributedErrorCode::INVALID_TENSOR_FORMAT, expertScales.Format() == npu::tile_fwk::TileOpFormat::TILEOP_ND)
<< "The format of \"expertScales\" only "
<< "supports ND, but got NZ";
}
void MoeDistributedCombineValidateOut(const Tensor& out, const Tensor& expertScales, const Tensor& expandX)
{
uint32_t supportedDim = 2;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DIM, out.GetShape().size() == supportedDim)
<< "The dim of \"out\" only supports " << supportedDim << ", but got " << out.GetShape().size();
int32_t outRow = out.GetShape(0);
int32_t expertScalesRow = expertScales.GetShape(0);
ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, outRow == expertScalesRow)
<< "The first axis of \"out\" must be consistent with that of \"expertScales\", "
<< "but expertScalesRow=" << expertScalesRow << ", outRow=" << outRow;
int32_t outCol = out.GetShape(1);
int32_t expandXCol = expandX.GetShape(1);
ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, outCol == expandXCol)
<< "The second axis of \"out\" must be consistent with that of \"expandX\", but "
<< "expandXCol=" << expandXCol << ", outCol=" << outCol;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DTYPE, out.GetDataType() == expandX.GetDataType())
<< "The data type of \"out\" must be consistent with that of "
<< "\"expandX\", but data type of \"expandX\" is " << DataType2String(expandX.GetDataType()) << " and the "
<< "data type of \"out\" is " << DataType2String(out.GetDataType());
ASSERT(DistributedErrorCode::INVALID_TENSOR_FORMAT, out.Format() == npu::tile_fwk::TileOpFormat::TILEOP_ND)
<< "The format of \"out\" only supports ND, but got "
<< "NZ";
}
void MoeDistributedCombineValidateGroup(const char* group)
{
ASSERT(DistributedErrorCode::NULLPTR, group != nullptr) << "\"group\" cannot be nullptr";
int32_t groupLen = std::strlen(group);
int32_t maxGroupLen = 128;
ASSERT(DistributedErrorCode::INVALID_GROUP_NAME, (groupLen >= 1) && (groupLen < maxGroupLen))
<< "The length of \"group\" only supports [1, " << maxGroupLen << "), but got " << groupLen;
}
void MoeDistributedCombineValidateMoeEpWorldSize(int32_t epWorldSize)
{
int32_t supportedEpWorldSize1 = 4;
int32_t supportedEpWorldSize2 = 8;
ASSERT(
DistributedErrorCode::INVALID_WORLD_SIZE,
(epWorldSize == supportedEpWorldSize1) || (epWorldSize == supportedEpWorldSize2))
<< "epWorldSize only "
<< "supports " << supportedEpWorldSize1 << " or " << supportedEpWorldSize2 << ", but got " << epWorldSize;
}
void MoeDistributedCombineValidateMoeExpertNum(int32_t moeExpertNum)
{
int32_t supportedMoeExpertNum = 160;
ASSERT(DistributedErrorCode::INVALID_MOE_EXPERT_NUM, moeExpertNum == supportedMoeExpertNum)
<< "moeExpertNum only supports " << supportedMoeExpertNum << ", but "
<< "got " << moeExpertNum;
}
void TiledMoeDistributedCombineSend(
Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
ASSERT(DistributedErrorCode::INVALID_OPERAND_NUM, iOperand.size() == 5UL)
<< "TiledMoeDistributedCombineSend iOperand size is not equal to 5";
ASSERT(DistributedErrorCode::INVALID_OPERAND_NUM, oOperand.size() == 1UL)
<< "TiledMoeDistributedCombineSend oOperand size is not equal to 1";
auto expandX = iOperand[0];
auto assistInfoForCombine = iOperand[1];
auto recvCounts = iOperand[2];
auto shmemData = iOperand[3];
auto shmemSignal = iOperand[4];
auto out = oOperand[0];
int64_t hiddenSize = expandX->shape[1];
int64_t dataByteSize = BytesOf(expandX->Datatype());
ASSERT(DistributedErrorCode::DIVISION_BY_ZERO, dataByteSize != 0) << "iOperand expandX dType size cannot be zero";
int64_t paddedColShape = AlignUp(dataByteSize * hiddenSize, COPY_BLOCK_BYTE_SIZE) / dataByteSize;
Shape assistInfoForCombineShape =
Shape{static_cast<int64_t>(COPY_BLOCK_BYTE_SIZE) / static_cast<int64_t>(BytesOf(DT_INT32))};
Shape signalShape = Shape{static_cast<int64_t>(REPEAT_BYTE) / static_cast<int64_t>(BytesOf(DT_INT32))};
MoeCombineAttr distOpAttr;
op.GetAttr(OpAttributeKey::distOpAttr, distOpAttr);
CreateTileOp(
tileShape, [&](int32_t tileIndex, int32_t rowOffset, int32_t colOffset, int32_t rowShape, int32_t colShape) {
(void)tileIndex;
auto expandXTile = expandX->View(function, {rowShape, colShape}, {rowOffset, colOffset});
auto dataBuffer = std::make_shared<LogicalTensor>(function, expandX->Datatype(), Shape{hiddenSize});
auto assistInfoForCombineBuffer =
std::make_shared<LogicalTensor>(function, DT_INT32, assistInfoForCombineShape);
auto signalBuffer = std::make_shared<LogicalTensor>(function, DT_INT32, signalShape);
auto& tileOp = function.AddOperation(
Opcode::OP_MOE_DISTRIBUTED_COMBINE_SEND,
{expandXTile, assistInfoForCombine, recvCounts, shmemData, shmemSignal},
{out, dataBuffer, assistInfoForCombineBuffer, signalBuffer});
distOpAttr.paddedColShape = paddedColShape;
tileOp.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
tileOp.SetAttr(OpAttributeKey::dontTouch, true);
});
}
void TiledMoeDistributedCombineReceive(
Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
(void)op;
ASSERT(DistributedErrorCode::INVALID_OPERAND_NUM, iOperand.size() == 4UL)
<< "TiledMoeDistributedCombineReceive iOperand size is not equal to 4";
ASSERT(DistributedErrorCode::INVALID_OPERAND_NUM, oOperand.size() == 1UL)
<< "TiledMoeDistributedCombineReceive oOperand size is not equal to 1";
auto predToken = iOperand[0];
auto expertScales = iOperand[1];
auto shmemDataThisRank = iOperand[2];
auto shmemSignalThisRank = iOperand[3];
auto out = oOperand[0];
int64_t topK = expertScales->shape[1];
int64_t hiddenSize = out->shape[1];
int64_t dataByteSize = BytesOf(out->Datatype());
ASSERT(DistributedErrorCode::DIVISION_BY_ZERO, dataByteSize != 0) << "oOperand out dType size cannot be zero";
int64_t paddedColShape = AlignUp(dataByteSize * hiddenSize, COPY_BLOCK_BYTE_SIZE) / dataByteSize;
int64_t floatByteSize = BytesOf(DataType::DT_FP32);
ASSERT(DistributedErrorCode::DIVISION_BY_ZERO, floatByteSize != 0) << "floatByteSize cannot be zero";
int64_t floatEleNum = AlignUp(floatByteSize * paddedColShape, REPEAT_BYTE) / floatByteSize;
MoeCombineAttr distOpAttr;
distOpAttr.topK = topK;
CreateTileOp(
tileShape, [&](int32_t tileIndex, int32_t rowOffset, int32_t colOffset, int32_t rowShape, int32_t colShape) {
(void)tileIndex;
auto shmemDataTile = shmemDataThisRank->View(
function, {1, 1, topK * rowShape, colShape}, {0, 0, topK * rowOffset, colOffset});
auto outTile = out->View(function, {rowShape, colShape}, {rowOffset, colOffset});
auto mulFp32Buffer = std::make_shared<LogicalTensor>(function, DT_FP32, Shape{floatEleNum});
auto sumFp32Buffer = std::make_shared<LogicalTensor>(function, DT_FP32, Shape{floatEleNum});
auto outBuffer = std::make_shared<LogicalTensor>(function, out->Datatype(), Shape{hiddenSize});
auto& tileOp = function.AddOperation(
Opcode::OP_MOE_DISTRIBUTED_COMBINE_RECEIVE,
{predToken, expertScales, shmemDataTile, shmemSignalThisRank},
{outTile, mulFp32Buffer, sumFp32Buffer, outBuffer});
distOpAttr.paddedColShape = paddedColShape;
distOpAttr.rowOffset = rowOffset;
distOpAttr.rowShape = rowShape;
tileOp.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
});
}
void CreateShmemTensor(
Tensor& shmemTensor, int32_t rankSize, int32_t hcclGroupIndex, DataType dataType, const Shape& shape,
uint64_t memType = 0)
{
auto& function = *Program::GetInstance().GetCurrentFunction();
Shape shmemShape{rankSize};
shmemShape.insert(shmemShape.end(), shape.begin(), shape.end());
auto shmemTensorInner = std::make_shared<LogicalTensor>(function, dataType, shmemShape);
shmemTensor = shmemTensorInner;
Program::GetInstance().GetTensorSlotManager()->TensorWrite(shmemTensor, SlotProperty::SHMEM_TENSOR);
auto& op = function.AddOperation(Opcode::OP_BIND_TENSOR, {}, {shmemTensorInner});
op.SetAttribute(
OpAttributeKey::bindTensor,
BindTensor(
hcclGroupIndex, memType,
BytesOf(dataType) * std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>())));
}
Tensor MoeDistributedCombineSend(
const Tensor& in, const Tensor& assistInfoForCombine, const Tensor& recvCounts, const Tensor& shmemData,
const Tensor& shmemSignal, int32_t topK)
{
auto& function = *Program::GetInstance().GetCurrentFunction();
auto out = std::make_shared<LogicalTensor>(function, DT_INT32, Shape{1});
auto& op = function.AddOperation(
Opcode::OP_MOE_DISTRIBUTED_COMBINE_SEND,
{in.GetStorage(), assistInfoForCombine.GetStorage(), recvCounts.GetStorage(), shmemData.GetStorage(),
shmemSignal.GetStorage()},
{out});
MoeCombineAttr distOpAttr;
distOpAttr.topK = topK;
op.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
return out;
}
Tensor MoeDistributedCombineReceive(
const Tensor& predToken, const Tensor& expertScales, const Tensor& shmemData, const Tensor& shmemSignal)
{
auto& function = *Program::GetInstance().GetCurrentFunction();
int32_t batchSize = expertScales.GetShape(0);
int32_t hiddenSize = shmemData.GetShape(3);
auto out = std::make_shared<LogicalTensor>(function, shmemData.GetDataType(), Shape{batchSize, hiddenSize});
function.AddOperation(
Opcode::OP_MOE_DISTRIBUTED_COMBINE_RECEIVE,
{predToken.GetStorage(), expertScales.GetStorage(), shmemData.GetStorage(), shmemSignal.GetStorage()}, {out});
return out;
}
void MoeDistributedCombineValidate(
const Tensor& expandX, const Tensor& assistInfoForCombine, const Tensor& recvCounts, const Tensor& expertScales,
const char* group, uint32_t epWorldSize, uint32_t moeExpertNum, uint32_t sharedExpertNum,
uint32_t sharedExpertRankNum, Tensor& out)
{
(void)sharedExpertNum;
(void)sharedExpertRankNum;
MoeDistributedCombineValidateExpandX(expandX, expertScales, epWorldSize, moeExpertNum);
MoeDistributedCombineValidateAssistInfoForCombine(assistInfoForCombine, expandX);
MoeDistributedCombineValidateRecvCounts(recvCounts);
MoeDistributedCombineValidateExpertScales(expertScales);
MoeDistributedCombineValidateOut(out, expertScales, expandX);
MoeDistributedCombineValidateGroup(group);
MoeDistributedCombineValidateMoeEpWorldSize(epWorldSize);
MoeDistributedCombineValidateMoeExpertNum(moeExpertNum);
}
void MoeDistributedCombine(
const Tensor& expandX, const Tensor& assistInfoForCombine, const Tensor& recvCounts, const Tensor& expertScales,
const char* group, uint32_t epWorldSize, uint32_t moeExpertNum, uint32_t sharedExpertNum,
uint32_t sharedExpertRankNum, Tensor& out)
{
MoeDistributedCombineValidate(
expandX, assistInfoForCombine, recvCounts, expertScales, group, epWorldSize, moeExpertNum, sharedExpertNum,
sharedExpertRankNum, out);
int32_t batchSize = expertScales.GetShape(0);
int32_t topK = expertScales.GetShape(1);
int32_t hiddenSize = expandX.GetShape(1);
int32_t shmemDataRow = topK * batchSize;
Shape shmemDataShape = {1, shmemDataRow, hiddenSize};
int32_t shmemSignalCol = SAME_ADDR_BYTE_SIZE / BytesOf(DataType::DT_FP32);
Shape shmemSignalShape = {batchSize, shmemSignalCol};
Tensor shmemData;
Tensor shmemSignal;
int32_t hcclGroupIndex = static_cast<int>(CommGroupRecorder::GetInstance().Input(std::string(group)));
LOOP("CreateShmemTensor", FunctionType::DYNAMIC_LOOP, index, LoopRange(1))
{
(void)index;
CreateShmemTensor(shmemData, epWorldSize, hcclGroupIndex, expandX.GetDataType(), shmemDataShape);
CreateShmemTensor(shmemSignal, epWorldSize, hcclGroupIndex, DT_INT32, shmemSignalShape);
}
LOOP("MoeDistributedCombine", FunctionType::DYNAMIC_LOOP, index, LoopRange(1))
{
(void)index;
int32_t expandXRow = expandX.GetShape(0);
int32_t aivNum = 40;
TileShape::Current().SetDistTile(
{expandXRow / aivNum, aivNum, expandXRow % aivNum}, {hiddenSize, 1, 0}, {0, 0, 0});
auto sendOut =
MoeDistributedCombineSend(expandX, assistInfoForCombine, recvCounts, shmemData, shmemSignal, topK);
SymbolicScalar thisRank = GetHcclRankId(group);
auto shmemDataThisRank =
View(shmemData, {1, 1, shmemDataRow, hiddenSize}, std::vector<SymbolicScalar>{thisRank, 0, 0, 0});
auto shmemSignalThisRank =
View(shmemSignal, {1, batchSize, shmemSignalCol}, std::vector<SymbolicScalar>{thisRank, 0, 0});
TileShape::Current().SetDistTile({1, batchSize, 0}, {hiddenSize, 1, 0}, {0, 0, 0});
out = MoeDistributedCombineReceive(sendOut, expertScales, shmemDataThisRank, shmemSignalThisRank);
}
}
void MoeDistributedCombineV2(
const Tensor& expandX, const Tensor& assistInfoForCombine, const Tensor& recvCounts, const Tensor& expertScales,
const char* group, uint32_t epWorldSize, uint32_t moeExpertNum, uint32_t sharedExpertNum,
uint32_t sharedExpertRankNum, Tensor& out)
{
MoeDistributedCombineValidate(
expandX, assistInfoForCombine, recvCounts, expertScales, group, epWorldSize, moeExpertNum, sharedExpertNum,
sharedExpertRankNum, out);
int32_t batchSize = expertScales.GetShape(0);
int32_t topK = expertScales.GetShape(1);
int32_t hiddenSize = expandX.GetShape(1);
auto shmemTensor = CreateShmemTensor(group, epWorldSize, expandX.GetDataType(), {batchSize * topK, hiddenSize});
SymbolicScalar recvCountsScalar = GetTensorData(recvCounts, {0});
std::set<int> unrollList = {64, 32, 16, 8, 4, 2, 1};
LOOP("MoeDistributedCombineSend", FunctionType::DYNAMIC_LOOP, rowIndex, LoopRange(recvCountsScalar), unrollList)
{
SymbolicScalar rankId = GetTensorData(assistInfoForCombine, {rowIndex, 0});
SymbolicScalar tokenId = GetTensorData(assistInfoForCombine, {rowIndex, 1});
SymbolicScalar kOffset = GetTensorData(assistInfoForCombine, {rowIndex, 2});
Tensor expandXTile = View(expandX, {1, hiddenSize}, {rowIndex, 0});
auto shmemDataTile = ShmemView(shmemTensor, {1, hiddenSize}, {topK * tokenId + kOffset, 0});
TileShape::Current().SetVecTile({1, hiddenSize});
Tensor predToken(DT_INT32, {1, 1}, "sendPredToken");
Tensor shmemPutOut = ShmemPut(expandXTile, shmemDataTile, rankId, AtomicType::SET, predToken);
auto shmemSignalTile = ShmemView(shmemTensor, {1, hiddenSize}, {tokenId, 0});
ShmemSignal(shmemSignalTile, rankId, rankId, 1, AtomicType::ADD, shmemPutOut);
}
SymbolicScalar thisRank = GetHcclRankId(group);
LOOP("MoeDistributedCombineReceive", FunctionType::DYNAMIC_LOOP, tokenId, LoopRange(batchSize))
{
auto shmemSignalTile = ShmemView(shmemTensor, {1, hiddenSize}, {tokenId, 0});
TileShape::Current().SetVecTile({1, hiddenSize});
Tensor predToken(DT_INT32, {1, 1}, "receivePredToken");
Tensor waitUntilOut = ShmemWaitUntil(shmemSignalTile, thisRank, OpType::EQ, topK, true, predToken);
TileShape::Current().SetVecTile({topK, hiddenSize});
auto shmemDataTile = ShmemView(shmemTensor, {topK, hiddenSize}, {topK * tokenId, 0});
Tensor shmemGetOutFp16 = ShmemGet(shmemDataTile, thisRank, waitUntilOut);
TileShape::Current().SetVecTile({topK / 2, hiddenSize});
Tensor shmemGetOutFp32 = npu::tile_fwk::Cast(shmemGetOutFp16, DT_FP32);
Tensor expertScalesTile = View(expertScales, {1, topK}, {tokenId, 0});
int64_t kTileShape = AlignUp(topK, 16);
int64_t l0bSize = 65536;
ASSERT(DistributedErrorCode::DIVISION_BY_ZERO, (BytesOf(DT_FP32) != 0) && (kTileShape != 0))
<< "Divisor kTileShape cannot be zero";
int64_t nTileShape = l0bSize / BytesOf(DT_FP32) / kTileShape;
TileShape::Current().SetCubeTile({1, 1}, {kTileShape, kTileShape}, {nTileShape, nTileShape});
Tensor matmulOutFp32 = Matrix::Matmul(DT_FP32, expertScalesTile, shmemGetOutFp32);
Tensor matmulOutFp16 = npu::tile_fwk::Cast(matmulOutFp32, DT_BF16);
Assemble(matmulOutFp16, {tokenId, 0}, out);
}
}
}