* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file moe_combine.h
* \brief
*/
#ifndef __DISTRIBUTED_COMBINE__
#define __DISTRIBUTED_COMBINE__
#include "common.h"
#include <type_traits>
namespace TileOp::Distributed {
template <typename T, uint32_t topK, uint16_t rowShape, uint16_t colShape, uint16_t paddedColShape>
TILEOP void MoeDistributedCombineSend(
CoreFuncParam* param, __ubuf__ T* dataBuffer, __ubuf__ int32_t* assistInfoForCombineBuffer,
__ubuf__ int32_t* signalBuffer, __gm__ T* expandX, __gm__ int32_t* assistInfoForCombine,
__ubuf__ int32_t* recvCounts, __gm__ T* shmemDataBaseAddr, __gm__ int32_t* shmemSignalBaseAddr,
uint64_t expandXOffset0, uint64_t expandXOffset1, __gm__ int64_t* hcclContext)
{
signalBuffer[0] = 1;
uint64_t maxRow = ((expandXOffset0 + rowShape) < recvCounts[0]) ? (expandXOffset0 + rowShape) : recvCounts[0];
for (uint64_t row = expandXOffset0; row < maxRow; row++) {
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
TileOp::UBCopyIn<int32_t, 1, 3, 8, 3>(
assistInfoForCombineBuffer, assistInfoForCombine + MOE_COMBINE_INFO_NUM * row);
set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0);
int32_t rankId = assistInfoForCombineBuffer[0];
int32_t tokenId = assistInfoForCombineBuffer[1];
int32_t kOffset = assistInfoForCombineBuffer[2];
TileOp::UBCopyIn<T, 1, colShape, paddedColShape, colShape>(dataBuffer, expandX + colShape * row);
__gm__ T* winDataAddr =
MapVirtualAddr<T>(hcclContext, shmemDataBaseAddr, rankId) + colShape * (topK * tokenId + kOffset);
set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
TileOp::UBCopyOut<T, 1, colShape, colShape, paddedColShape>(winDataAddr, dataBuffer);
__gm__ int32_t* winSignalAddr =
MapVirtualAddr<int32_t>(hcclContext, shmemSignalBaseAddr, rankId) + MOE_COMBINE_SIGNAL_OFFSET * tokenId;
set_atomic_add();
set_atomic_s32();
pipe_barrier(PIPE_MTE3);
copy_ubuf_to_gm(winSignalAddr, signalBuffer, 0, 1, 1, 0, 0);
set_atomic_none();
}
}
TILEOP void MoeDistributedCombineWaitSignal(
__gm__ int32_t* winSignalAddr, __ubuf__ int32_t* signalBuffer, int32_t expectedValue)
{
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
do {
copy_gm_to_ubuf(signalBuffer, winSignalAddr, 0, 1, 1, 0, 0);
set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0);
} while (signalBuffer[0] != expectedValue);
}
template <typename T, uint32_t topK, uint16_t colShape, uint16_t paddedColShape>
TILEOP void MoeDistributedCombineCompute(
__ubuf__ T* out, __ubuf__ float* mulFp32Buffer, __ubuf__ float* sumFp32Buffer, __ubuf__ float* expertScales,
__gm__ T* winDataAddr)
{
uint8_t repeat = static_cast<uint8_t>(
AlignUp<uint16_t>(sizeof(float) * paddedColShape, VECTOR_INSTRUCTION_BYTE_SIZE) / VECTOR_INSTRUCTION_BYTE_SIZE);
vector_dup(sumFp32Buffer, 0.0f, repeat, 1, 1, 8, 8);
for (int kOffset = 0; kOffset < topK; kOffset++) {
set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0);
TileOp::UBCopyIn<T, 1, colShape, paddedColShape, colShape>(out, winDataAddr);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
vconv_bf162f32(mulFp32Buffer, out, repeat, 1, 1, 8, 4);
pipe_barrier(PIPE_V);
vmuls(mulFp32Buffer, mulFp32Buffer, static_cast<float>(expertScales[kOffset]), repeat, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
vadd(sumFp32Buffer, mulFp32Buffer, sumFp32Buffer, repeat, 1, 1, 1, 8, 8, 8);
winDataAddr += colShape;
}
pipe_barrier(PIPE_V);
vconv_f322bf16a(out, sumFp32Buffer, repeat, 1, 1, 4, 8);
}
template <typename T, uint32_t topK, uint16_t rowShape, uint16_t colShape, uint16_t paddedColShape>
TILEOP void MoeDistributedCombineReceive(
CoreFuncParam* param, __gm__ T* out, __ubuf__ float* mulFp32Buffer, __ubuf__ float* sumFp32Buffer,
__ubuf__ T* outBuffer, __ubuf__ float* expertScales, __gm__ T* shmemDataBaseAddr,
__gm__ int32_t* shmemSignalBaseAddr, uint64_t shmemDataOffset0, uint64_t shmemDataOffset1,
uint64_t shmemDataOffset2, uint64_t shmemDataOffset3, int64_t rowOffset, __gm__ int64_t* hcclContext)
{
uint64_t thisRankId = shmemDataOffset0;
for (uint64_t tokenId = rowOffset; tokenId < rowOffset + rowShape; tokenId++) {
__gm__ int32_t* winSignalAddr =
MapVirtualAddr<int32_t>(hcclContext, shmemSignalBaseAddr, thisRankId) + MOE_COMBINE_SIGNAL_OFFSET * tokenId;
__ubuf__ int32_t* signalBuffer = reinterpret_cast<__ubuf__ int32_t*>(outBuffer);
MoeDistributedCombineWaitSignal(winSignalAddr, signalBuffer, topK);
constexpr uint32_t expertScalesColShape =
AlignUp<uint32_t>(sizeof(float) * topK, COPY_BLOCK_BYTE_SIZE) / sizeof(float);
__gm__ T* winDataAddr =
MapVirtualAddr<T>(hcclContext, shmemDataBaseAddr, thisRankId) + colShape * topK * tokenId;
MoeDistributedCombineCompute<T, topK, colShape, paddedColShape>(
outBuffer, mulFp32Buffer, sumFp32Buffer, expertScales + expertScalesColShape * tokenId, winDataAddr);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
TileOp::UBCopyOut<T, 1, colShape, colShape, paddedColShape>(out + colShape * tokenId, outBuffer);
}
}
}
#endif