/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file moe_dispatch.cpp
 * \brief
 */
#include <functional>
#include <memory>
#include <vector>
#include "interface/operation/operation.h"
#include "tilefwk/tensor.h"
#include "interface/tensor/logical_tensor.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "interface/program/program.h"
#include "interface/utils/common.h"
#include "distributed_common.h"
#include "tilefwk/symbolic_distributed.h"
#include "interface/function/function.h"

namespace npu::tile_fwk {
namespace Distributed {
void MoeDistributedDispatchValidateGroup(const char* group)
{
    ASSERT(DistributedErrorCode::NULLPTR, group != nullptr) << "\"group\" cannot be nullptr";
    int32_t groupLen = std::strlen(group);
    int32_t maxGroupLen = MAX_GROUP_NAME_LENGTH;
    ASSERT(DistributedErrorCode::INVALID_GROUP_NAME, (groupLen >= 1) && (groupLen < maxGroupLen))
        << "The length of \"group\" only supports [1, " << maxGroupLen << "), but got " << groupLen;
}

void MoeDistributedDispatchValidTensor(
    const Tensor& input, uint64_t dim, DataType dType, int32_t row, int32_t col)
{
    ASSERT(DistributedErrorCode::INVALID_TENSOR_FORMAT, input.Format() == TileOpFormat::TILEOP_ND)
        << "Distributed constraint violated: " + input.GetName() + " format must be TILEOP_ND.";
    ASSERT(DistributedErrorCode::INVALID_TENSOR_DIM, input.Dim() == dim)
        << "Distributed constraint violated: " + input.GetName() + " dim must be " + std::to_string(dim)
        << "but got " << std::to_string(input.Dim()) ;
    ASSERT(DistributedErrorCode::INVALID_TENSOR_DTYPE, input.GetDataType() == dType)
        << "Distributed constraint violated: " + input.GetName() + " dataType must be " << DataType2String(dType)
        << "but got " << DataType2String(input.GetDataType());
    ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, input.GetShape(0) == row)
        << "Distributed constraint violated: " + input.GetName() + " row must be " + std::to_string(row)
        << "but got " << std::to_string(input.GetShape(0));
    if (input.Dim() == MOE_INPUT_DIM) {
        ASSERT(DistributedErrorCode::INVALID_TENSOR_SHAPE, input.GetShape(1) == col)
            << "Distributed constraint violated: " + input.GetName() + " col must be " + std::to_string(col)
            << "but got " << std::to_string(input.GetShape(1));
    }
}

void MoeDistributedDispatchValidConfig(const MoeConfig& moeConfig)
{
    int32_t supportedEpWorldSize1 = 4;
    int32_t supportedEpWorldSize2 = 8;
    ASSERT(
        DistributedErrorCode::INVALID_WORLD_SIZE,
        (moeConfig.rankNum == supportedEpWorldSize1) || (moeConfig.rankNum == supportedEpWorldSize2))
        << "epWorldSize only "
        << "supports " << std::to_string(supportedEpWorldSize1) << " or " << std::to_string(supportedEpWorldSize2)
        << ", but got " << std::to_string(moeConfig.rankNum);
    ASSERT(DistributedErrorCode::INVALID_MOE_EXPERT_NUM, moeConfig.routedExpertNum == ROUTED_EXPET_NUM)
    "Distributed constraint violated: moeConfig routedExpertNum must be " + std::to_string(ROUTED_EXPET_NUM)
        << "but got " << std::to_string(moeConfig.routedExpertNum);
}

void TiledDispatchFFNSched(
    Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
    const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
    auto syncTensor = iOperand[DIST_INDEX_ZERO];
    auto shmemFlag = iOperand[DIST_INDEX_ONE];
    auto recvTokenCntOut = oOperand[DIST_INDEX_ZERO];
    int flagColSize = shmemFlag->GetShape()[3];
    std::string hcclGroupIndex;
    std::vector<int64_t> bufferShape;
    int32_t sharedExpertNum = 0;
    int64_t expertNumPerRank;
    op.GetAttr("hcclGroupIndex", hcclGroupIndex);
    op.GetAttr("dispatchBufferSize", bufferShape);
    op.GetAttr("expertNumPerRank", expertNumPerRank);

    const auto& tileRank = tileShape.GetDistTileRank();
    int32_t totalTileNum = GetTotalTileNum(tileRank) * static_cast<int32_t>(expertNumPerRank);
    const int32_t tileRankShape = tileRank[DIST_HEAD_SHAPE];
    const int32_t tileRankCnt = tileRank[DIST_HEAD_COUNT] + (tileRank[DIST_TAIL_SHAPE] == 0 ? 0 : 1);
    const int32_t tailRankShape = tileRank[DIST_TAIL_SHAPE];
    int32_t tileIndex = 0;
    for (int expertIndex = 0; expertIndex < expertNumPerRank; ++expertIndex) {
        for (int rankIndex = 0; rankIndex < tileRankCnt; ++rankIndex) {
            int32_t rankShape = ((tileRank[2] != 0) && (rankIndex == tileRankCnt - 1) ? tailRankShape : tileRankShape);
            int32_t rankOffset = rankIndex * tileRankShape;
            auto bufferTensor = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, bufferShape);
            auto shmemFlagTile =
                shmemFlag->View(function, {1, 1, rankShape, flagColSize}, {0, expertIndex, rankOffset, 0});
            auto& opr = function.AddOperation(
                Opcode::OP_FFN_SCHED, {syncTensor, shmemFlagTile}, {recvTokenCntOut, bufferTensor});
            std::string extraParam = std::to_string(tileIndex) + ", " + hcclGroupIndex + ", " +
                                     std::to_string(sharedExpertNum) + ", " + std::to_string(totalTileNum) + ", " +
                                     std::to_string(rankShape) + ", " + std::to_string(expertNumPerRank);
            MoeDispatchAttr distOpAttr;
            distOpAttr.extraTemplateParam = extraParam;
            opr.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
            tileIndex++;
        }
    }
}

void TiledDispatchFFNCombineInfo(
    Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
    const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
    auto recvTokenCntOut = iOperand[DIST_INDEX_ZERO];
    auto shmemData = iOperand[DIST_INDEX_ONE];
    auto shmemFlag = iOperand[DIST_INDEX_TWO];
    auto combineInfo = oOperand[DIST_INDEX_ZERO];

    int32_t shmemDataLength = shmemData->GetShape()[3];
    Shape combineInfoBufferShape = {combineInfo->GetShape()[0] + 32};
    std::string hcclGroupIndex;
    std::vector<int64_t> bufferShape;
    std::string axisH;
    std::string batchSize;
    int32_t sharedExpertNum = 0;
    int64_t expertNumPerRank;
    op.GetAttr("expertNumPerRank", expertNumPerRank);
    op.GetAttr("hcclGroupIndex", hcclGroupIndex);
    op.GetAttr("dispatchBufferSize", bufferShape);
    op.GetAttr("hiddenSize", axisH);
    op.GetAttr("tokenBatchSize", batchSize);

    const auto& tileRank = tileShape.GetDistTileRank();
    int32_t totalTileNum = GetTotalTileNum(tileRank) * static_cast<int32_t>(expertNumPerRank);
    const int32_t tileRankShape = tileRank[DIST_HEAD_SHAPE];
    const int32_t tileRankCnt = tileRank[DIST_HEAD_COUNT] + (tileRank[DIST_TAIL_SHAPE] == 0 ? 0 : 1);
    const int32_t tailRankShape = tileRank[DIST_TAIL_SHAPE];

    int32_t tileIndex = 0;
    for (int expertIndex = 0; expertIndex < expertNumPerRank; ++expertIndex) {
        for (int rankIndex = 0; rankIndex < tileRankCnt; ++rankIndex) {
            int32_t rankShape = ((tileRank[2] != 0) && (rankIndex == tileRankCnt - 1) ? tailRankShape : tileRankShape);
            int32_t rankOffset = rankIndex * tileRankShape;
            auto bufferCombineInfo = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, bufferShape);
            auto shmemDataTile =
                shmemData->View(function, {1, rankShape, 1, shmemDataLength}, {0, rankOffset, expertIndex, 0});
            auto& opr = function.AddOperation(
                Opcode::OP_FFN_COMBINEINFO, {shmemDataTile, shmemFlag, recvTokenCntOut},
                {combineInfo, bufferCombineInfo});
            std::string extraParam = std::to_string(tileIndex) + ", " + hcclGroupIndex + ", " +
                                     std::to_string(sharedExpertNum) + ", " + std::to_string(totalTileNum) + ", " +
                                     std::to_string(rankShape) + ", " + axisH + ", " + batchSize + ", " +
                                     std::to_string(combineInfo->GetShape()[0]);
            MoeDispatchAttr distOpAttr;
            distOpAttr.extraTemplateParam = extraParam;
            opr.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
            tileIndex++;
        }
    }
}

void TiledDispatchFFNBatching(
    Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
    const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
    auto recvTokenCntOut = iOperand[DIST_INDEX_ZERO];
    auto shmemData = iOperand[DIST_INDEX_ONE];
    auto shmemFlag = iOperand[DIST_INDEX_TWO];
    auto expandX = oOperand[DIST_INDEX_ZERO];
    auto validCnt = oOperand[DIST_INDEX_ONE];

    int32_t shmemDataLength = shmemData->GetShape()[3];
    std::string groupIndex;
    std::vector<int64_t> bufferShape;
    std::string axisH;
    std::string batchSize;
    int32_t sharedExpertNum = 0;
    int64_t expertNumPerRank;
    op.GetAttr("expertNumPerRank", expertNumPerRank);
    op.GetAttr("hcclGroupIndex", groupIndex);
    op.GetAttr("dispatchBufferSize", bufferShape);
    op.GetAttr("hiddenSize", axisH);
    op.GetAttr("tokenBatchSize", batchSize);

    const auto& tileRank = tileShape.GetDistTileRank();
    int32_t totalTileNum = GetTotalTileNum(tileRank) * static_cast<int32_t>(expertNumPerRank);
    const int32_t tileRankShape = tileRank[DIST_HEAD_SHAPE];
    const int32_t tileRankCnt = tileRank[DIST_HEAD_COUNT] + (tileRank[DIST_TAIL_SHAPE] == 0 ? 0 : 1);
    const int32_t tailRankShape = tileRank[DIST_TAIL_SHAPE];

    int32_t tileIndex = 0;
    for (int expertIndex = 0; expertIndex < expertNumPerRank; ++expertIndex) {
        for (int rankIndex = 0; rankIndex < tileRankCnt; ++rankIndex) {
            int32_t rankShape = ((tileRank[2] != 0) && (rankIndex == tileRankCnt - 1) ? tailRankShape : tileRankShape);
            int32_t rankOffset = rankIndex * tileRankShape;
            auto bufferTensor = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, bufferShape);
            auto shmemDataTile =
                shmemData->View(function, {1, rankShape, 1, shmemDataLength}, {0, rankOffset, expertIndex, 0});
            auto& opr = function.AddOperation(
                Opcode::OP_FFN_BATCHING, {shmemDataTile, shmemFlag, recvTokenCntOut},
                {expandX, validCnt, bufferTensor});
            std::string extraParam = std::to_string(tileIndex) + ", " + groupIndex + ", " +
                                     std::to_string(sharedExpertNum) + ", " + std::to_string(totalTileNum) + ", " +
                                     std::to_string(rankShape) + ", " + axisH + ", " + batchSize + ", " +
                                     std::to_string(expandX->GetShape()[0]);
            MoeDispatchAttr distOpAttr;
            distOpAttr.extraTemplateParam = extraParam;
            opr.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
            tileIndex++;
        }
    }
}

void TiledDispatchFFNValidCnt(
    Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
    const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
    (void)op;
    auto recvTokenCntOut = iOperand[DIST_INDEX_ZERO];
    auto shmemFlag = iOperand[DIST_INDEX_ONE];
    auto validCnt = oOperand[DIST_INDEX_ZERO];

    int32_t flagColSize = shmemFlag->GetShape()[3];
    int32_t rankSize = shmemFlag->GetShape()[0];

    const auto& tileExpert = tileShape.GetDistTileRank();
    int32_t tileExpertShape = tileExpert[0];
    int32_t expertCount = tileExpert[1] + (tileExpert[2] == 0 ? 0 : 1);
    Shape bufferShape{shmemFlag->shape[0] * expertCount};

    for (int32_t expertIndex = 0; expertIndex < expertCount; ++expertIndex) {
        int32_t expertShape =
            ((tileExpert[2] != 0) && (expertIndex == expertCount - 1)) ? tileExpert[2] : tileExpert[0];
        int32_t expertOffset = expertIndex * tileExpertShape;
        auto validCntBuffer = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, bufferShape);
        auto shmemFlagTile =
            shmemFlag->View(function, {1, expertShape, rankSize, flagColSize}, {0, expertOffset, 0, 0});
        auto& tileop = function.AddOperation(
            Opcode::OP_FFN_VALIDCNT, {recvTokenCntOut, shmemFlagTile}, {validCnt, validCntBuffer});
        std::string extraParam = std::to_string(expertShape);
        MoeDispatchAttr distOpAttr;
        distOpAttr.extraTemplateParam = extraParam;
        tileop.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
    }
}

Tensor DispatchFFNValidCnt(const Tensor& recvTokenCntOut, const Tensor& shmemFlag, const MoeConfig& moeConfig)
{
    auto& function = *Program::GetInstance().GetCurrentFunction();
    Shape validCntShape = {moeConfig.expertNumPerRank, 1};
    auto validCntPtr = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, validCntShape);
    auto& oper = function.AddOperation(
        Opcode::OP_FFN_VALIDCNT, {recvTokenCntOut.GetStorage(), shmemFlag.GetStorage()}, {validCntPtr});
    (void)oper;
    return validCntPtr;
}

Tensor DispatchFFNCombineInfo(
    const char* group, const Tensor& tokenTensor, const Tensor& recvTokenCntOut, const Tensor& shmemData,
    const Tensor& shmemFlag, int32_t expandXRow, int32_t ffnTileNum, const MoeConfig& moeConfig)
{
    auto& function = *Program::GetInstance().GetCurrentFunction();
    Shape combineInfoShape = {expandXRow, 3};
    auto combineInfoPtr = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, combineInfoShape);
    auto& oper = function.AddOperation(
        Opcode::OP_FFN_COMBINEINFO, {recvTokenCntOut.GetStorage(), shmemData.GetStorage(), shmemFlag.GetStorage()},
        {combineInfoPtr});
    int tempBufSize = AlignUp(moeConfig.expertNumPerRank * ffnTileNum * 32, 256) + 256 +
                      AlignUp(moeConfig.expertNumPerRank * ffnTileNum * 4, 32) +
                      512; // tempBufSize = recvTokenCnt数 + 存储mask + recvTokenCnt的int数 + 数据搬运
    std::string hcclGroupIndex = std::to_string(CommGroupRecorder::GetInstance().Input(std::string(group)));
    const std::vector<int64_t> bufferShape{tempBufSize};
    oper.SetAttr("hcclGroupIndex", hcclGroupIndex);
    oper.SetAttr("dispatchBufferSize", bufferShape);
    oper.SetAttr("hiddenSize", std::to_string(tokenTensor.GetShape()[1]));
    oper.SetAttr("tokenBatchSize", std::to_string(tokenTensor.GetShape()[0]));
    oper.SetAttr("expertNumPerRank", static_cast<int64_t>(moeConfig.expertNumPerRank));
    return combineInfoPtr;
}

Tensor DispatchFFNBatching(
    const char* group, const Tensor& tokenTensor, const Tensor& recvTokenCntOut, const Tensor& shmemData,
    const Tensor& shmemFlag, int32_t expandXRow, int32_t ffnTileNum, const MoeConfig& moeConfig)
{
    auto& function = *Program::GetInstance().GetCurrentFunction();
    Shape validCntShape = {moeConfig.expertNumPerRank, 1};
    auto validCntPtr = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, validCntShape);
    Shape expandXShape = {expandXRow, tokenTensor.GetShape()[1]};
    auto expandXPtr = std::make_shared<LogicalTensor>(function, tokenTensor.GetDataType(), expandXShape);
    auto& oper = function.AddOperation(
        Opcode::OP_FFN_BATCHING, {recvTokenCntOut.GetStorage(), shmemData.GetStorage(), shmemFlag.GetStorage()},
        {expandXPtr, validCntPtr});
    int cumSumBuffer = AlignUp(moeConfig.expertNumPerRank * ffnTileNum * 32, 256) + 256 +
                       AlignUp(moeConfig.expertNumPerRank * ffnTileNum * 4, 32) +
                       512; // tempBufSize = recvTokenCnt数 + 存储mask + recvTokenCnt的int数 + 数据搬运
    int tokenCopyBuffer = tokenTensor.GetShape(1);
    int tempBufSize = (cumSumBuffer < tokenCopyBuffer) ? tokenCopyBuffer : cumSumBuffer;
    std::string hcclGroupIndex = std::to_string(CommGroupRecorder::GetInstance().Input(std::string(group)));
    const std::vector<int64_t> bufferShape{tempBufSize};
    oper.SetAttr("hcclGroupIndex", hcclGroupIndex);
    oper.SetAttr("dispatchBufferSize", bufferShape);
    oper.SetAttr("hiddenSize", std::to_string(tokenTensor.GetShape()[1]));
    oper.SetAttr("tokenBatchSize", std::to_string(tokenTensor.GetShape()[0]));
    oper.SetAttr("expertNumPerRank", static_cast<int64_t>(moeConfig.expertNumPerRank));
    return expandXPtr;
}

Tensor DispatchFFNSched(
    const char* group, const Tensor& flagDummy, Tensor& shmemFlag, const MoeConfig& moeConfig, int32_t ffnTileCnt)
{
    auto& function = *Program::GetInstance().GetCurrentFunction();
    int32_t totalTileNum = moeConfig.routedExpertNum * ffnTileCnt;
    Shape shape = {totalTileNum, 512}; // 每个flag_count预留512个int存储
    auto recvTokenCntOutPtr = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, shape);
    auto& oper = function.AddOperation(
        Opcode::OP_FFN_SCHED, {flagDummy.GetStorage(), shmemFlag.GetStorage()}, {recvTokenCntOutPtr});
    int32_t moeOpProcessRankSize = ffnTileCnt;
    int32_t maxProcessRankSize = moeOpProcessRankSize;
    int tempBufSize = maxProcessRankSize * 32 + 256 + AlignUp(maxProcessRankSize * 4, 256);
    std::string hcclGroupIndex = std::to_string(CommGroupRecorder::GetInstance().Input(std::string(group)));
    oper.SetAttr("hcclGroupIndex", hcclGroupIndex);
    const std::vector<int64_t> bufferShape{tempBufSize / 8, 8};
    oper.SetAttr("dispatchBufferSize", bufferShape);
    oper.SetAttr("expertNumPerRank", static_cast<int64_t>(moeConfig.expertNumPerRank));
    return recvTokenCntOutPtr;
}

std::vector<int64_t> GetCommBufferSize(const std::shared_ptr<LogicalTensor>& tokenTensor)
{
    const int64_t hOutSize = tokenTensor->shape[1] * BytesOf(tokenTensor->Datatype());
    constexpr int64_t scaleParamPad = 512;
    const int64_t hCommuSize = AlignUp(hOutSize, 512) + scaleParamPad;
    return {1, static_cast<int64_t>(hCommuSize / BytesOf(tokenTensor->Datatype()))};
}

void TiledSendToRoutingExpert(
    Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
    const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
    auto shmemData = iOperand[DIST_INDEX_ZERO];
    auto tokenTensor = iOperand[DIST_INDEX_ONE];
    auto expertTable = iOperand[DIST_INDEX_TWO];
    auto syncTensor = oOperand[DIST_INDEX_ZERO];
    std::string hcclGroupIndex;
    int64_t expertNumPerRank;
    op.GetAttr("expertNumPerRank", expertNumPerRank);
    op.GetAttr("hcclGroupIndex", hcclGroupIndex);
    CreateTileOp(
        tileShape, [&](int32_t tileIndex, int32_t rowOffset, int32_t colOffset, int32_t rowShape, int32_t colShape) {
            (void)tileIndex;
            auto expertTableTile = expertTable->View(function, {rowShape, colShape}, {rowOffset, colOffset});
            auto expertBufferUb = std::make_shared<LogicalTensor>(
                function, expertTable->Datatype(),
                std::vector<int64_t>{1, expertTable->shape[0] * expertTable->shape[1]});
            auto expertBuffer = std::make_shared<LogicalTensor>(
                function, expertTable->Datatype(),
                std::vector<int64_t>{
                    1, expertTable->shape[0] * expertTable->shape[1] * (static_cast<int64_t>(sizeof(int32_t)) + 1)});
            auto tokenBuffer =
                std::make_shared<LogicalTensor>(function, tokenTensor->Datatype(), GetCommBufferSize(tokenTensor));
            auto& tileop = function.AddOperation(
                Opcode::OP_SEND_TO_ROUTING_EXPERT, {tokenTensor, shmemData, expertTableTile},
                {syncTensor, tokenBuffer, expertBufferUb, expertBuffer});
            std::string extraParam = std::to_string(tokenTensor->shape[1]) + ", " + std::to_string(rowOffset) + ", " +
                                     std::to_string(colOffset) + ", " + std::to_string(rowShape) + ", " +
                                     std::to_string(colShape) + ", " + hcclGroupIndex;
            MoeDispatchAttr distOpAttr;
            distOpAttr.extraTemplateParam = extraParam;
            tileop.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
        });
}

void TiledSendToSharedExpert(
    Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
    const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
    auto shmemData = iOperand[DIST_INDEX_ZERO];
    auto tokenTensor = iOperand[DIST_INDEX_ONE];
    auto syncTensor = oOperand[DIST_INDEX_ZERO];
    (void)oOperand;
    std::string hcclGroupIndex;
    op.GetAttr("hcclGroupIndex", hcclGroupIndex);
    CreateTileOp(
        tileShape, [&](int32_t tileIndex, int32_t rowOffset, int32_t colOffset, int32_t rowShape, int32_t colShape) {
            (void)tileIndex;
            Shape shape = {rowShape, colShape};
            auto tokenTensorTile = tokenTensor->View(function, {rowShape, colShape}, {rowOffset, colOffset});
            auto tokenBuffer =
                std::make_shared<LogicalTensor>(function, tokenTensor->Datatype(), GetCommBufferSize(tokenTensor));
            auto& tileop = function.AddOperation(
                Opcode::OP_SEND_TO_SHARED_EXPERT, {tokenTensorTile, shmemData}, {syncTensor, tokenBuffer});
            std::string extraParam = std::to_string(tokenTensor->shape[0]) + ", " +
                                     std::to_string(tokenTensor->shape[1]) + ", " + std::to_string(rowShape) + ", " +
                                     hcclGroupIndex;
            MoeDispatchAttr distOpAttr;
            distOpAttr.extraTemplateParam = extraParam;
            tileop.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
        });
}

void TiledCopyToLocalExpert(
    Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
    const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
    auto tokenTensor = iOperand[DIST_INDEX_ZERO];
    auto expandX = oOperand[DIST_INDEX_ZERO];
    auto syncTensor = oOperand[DIST_INDEX_ONE];
    (void)op;
    CreateTileOp(
        tileShape, [&](int32_t tileIndex, int32_t rowOffset, int32_t colOffset, int32_t rowShape, int32_t colShape) {
            (void)tileIndex;
            auto tokenTensorTile = tokenTensor->View(function, {rowShape, colShape}, {rowOffset, colOffset});
            auto tokenBuffer =
                std::make_shared<LogicalTensor>(function, tokenTensor->Datatype(), GetCommBufferSize(tokenTensor));
            auto& tileop = function.AddOperation(
                Opcode::OP_COPY_TO_LOCAL_EXPERT, {tokenTensorTile}, {expandX, syncTensor, tokenBuffer});
            std::string extraParam = std::to_string(tokenTensor->shape[0]) + ", " +
                                     std::to_string(tokenTensor->shape[1]) + ", " + std::to_string(rowShape);
            MoeDispatchAttr distOpAttr;
            distOpAttr.extraTemplateParam = extraParam;
            tileop.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
        });
}

void TiledDispatchSetFlag(
    Function& function, const TileShape& tileShape, const std::vector<std::shared_ptr<LogicalTensor>>& iOperand,
    const std::vector<std::shared_ptr<LogicalTensor>>& oOperand, const Operation& op)
{
    auto shmemFlag = iOperand[DIST_INDEX_ZERO];
    auto syncTensor = iOperand[DIST_INDEX_ONE];
    auto tokenExpertTable = iOperand[DIST_INDEX_TWO];
    auto syncDummy = oOperand[DIST_INDEX_ZERO];
    int flagColSize = shmemFlag->GetShape()[3];
    std::string hcclGroupIndex;
    op.GetAttr("hcclGroupIndex", hcclGroupIndex);
    int64_t expertNumPerRank;
    op.GetAttr("expertNumPerRank", expertNumPerRank);

    const auto& tileExpert = tileShape.GetDistTileRank();
    const auto& tileRank = tileShape.GetDistTileCol();
    int32_t tileRankShape = tileRank[0];
    int32_t tileExpertShape = tileExpert[0];
    int32_t rankCount = tileRank[1] + (tileRank[2] == 0 ? 0 : 1);
    int32_t expertCount = tileExpert[1] + (tileExpert[2] == 0 ? 0 : 1);

    for (int32_t rankIndex = 0; rankIndex < rankCount; ++rankIndex) {
        int32_t rankShape = ((tileRank[2] != 0) && (rankIndex == rankCount - 1)) ? tileRank[2] : tileRank[0];
        for (int32_t expertIndex = 0; expertIndex < expertCount; ++expertIndex) {
            int32_t expertShape =
                ((tileExpert[2] != 0) && (expertIndex == expertCount - 1)) ? tileExpert[2] : tileExpert[0];
            int32_t rankOffset = rankIndex * tileRankShape;
            int32_t expertOffset = expertIndex * tileExpertShape;
            auto statusTensor = std::make_shared<LogicalTensor>(
                function, tokenExpertTable->Datatype(),
                std::vector<int64_t>{
                    1, expertNumPerRank * 16 + 32}); // 每个expert预留16B缓存flag跟count,最后一个expert后预留32位
            auto expertBufferUb = std::make_shared<LogicalTensor>(
                function, tokenExpertTable->Datatype(),
                std::vector<int64_t>{1, tokenExpertTable->shape[0] * tokenExpertTable->shape[1]});
            auto expertBuffer = std::make_shared<LogicalTensor>(
                function, tokenExpertTable->Datatype(),
                std::vector<int64_t>{
                    1, tokenExpertTable->shape[0] * tokenExpertTable->shape[1] *
                           (static_cast<int64_t>(sizeof(int32_t)) + 1)});
            auto shmemFlagTile =
                shmemFlag->View(function, {rankShape, expertShape, 1, flagColSize}, {rankOffset, expertOffset, 0, 0});
            auto& tileop = function.AddOperation(
                Opcode::OP_DISPATCH_SET_FLAG, {tokenExpertTable, shmemFlagTile, syncTensor},
                {syncDummy, statusTensor, expertBufferUb, expertBuffer});
            std::string extraParam = std::to_string(tokenExpertTable->shape[0]) + ", " +
                                     std::to_string(tokenExpertTable->shape[1]) + ", " + hcclGroupIndex + ", " +
                                     std::to_string(expertShape) + ", " + std::to_string(rankShape);
            MoeDispatchAttr distOpAttr;
            distOpAttr.extraTemplateParam = extraParam;
            tileop.SetAttr(OpAttributeKey::distOpAttr, distOpAttr);
        }
    }
}

Tensor SendToRoutingExpert(
    const Tensor& shmemData, const Tensor& tokenTensor, const Tensor& tokenExpertTable, const char* group,
    const MoeConfig& moeConfig)
{
    Shape shape{1, 1};
    auto& function = *Program::GetInstance().GetCurrentFunction();
    auto syncTensor = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, shape);
    auto& oper = function.AddOperation(
        Opcode::OP_SEND_TO_ROUTING_EXPERT,
        {shmemData.GetStorage(), tokenTensor.GetStorage(), tokenExpertTable.GetStorage()}, {syncTensor});
    std::string hcclGroupIndex = std::to_string(CommGroupRecorder::GetInstance().Input(std::string(group)));
    oper.SetAttr("hcclGroupIndex", hcclGroupIndex);
    oper.SetAttr("expertNumPerRank", static_cast<int64_t>(moeConfig.expertNumPerRank));
    return syncTensor;
}

void SendToSharedExpert(const Tensor& shmemData, const Tensor& tokenTensor, const Tensor& syncTensor, const char* group)
{
    auto& function = *Program::GetInstance().GetCurrentFunction();
    auto& oper = function.AddOperation(
        Opcode::OP_SEND_TO_SHARED_EXPERT, {shmemData.GetStorage(), tokenTensor.GetStorage()},
        {syncTensor.GetStorage()});
    std::string hcclGroupIndex = std::to_string(CommGroupRecorder::GetInstance().Input(std::string(group)));
    oper.SetAttr("hcclGroupIndex", hcclGroupIndex);
}

Tensor DispatchSetFlag(
    Tensor& shmemFlag, const Tensor& tokenExpertTable, const Tensor& syncTensor, const char* group,
    const MoeConfig& moeConfig)
{
    Shape shape = {1, 1};
    auto& function = *Program::GetInstance().GetCurrentFunction();
    auto syncDummy = std::make_shared<LogicalTensor>(function, DataType::DT_INT32, shape);
    auto& oper = function.AddOperation(
        Opcode::OP_DISPATCH_SET_FLAG, {shmemFlag.GetStorage(), syncTensor.GetStorage(), tokenExpertTable.GetStorage()},
        {syncDummy});
    std::string hcclGroupIndex = std::to_string(CommGroupRecorder::GetInstance().Input(std::string(group)));
    oper.SetAttr("hcclGroupIndex", hcclGroupIndex);
    oper.SetAttr("expertNumPerRank", static_cast<int64_t>(moeConfig.expertNumPerRank));
    return syncDummy;
}

Tensor CopyToLocalExpert(const Tensor& tokenTensor, const Tensor& syncTensor, const MoeConfig& moeConfig)
{
    auto& function = *Program::GetInstance().GetCurrentFunction();
    Shape expandXShape = {tokenTensor.GetShape()[0] * moeConfig.routedExpertNum, tokenTensor.GetShape()[1]};
    auto expandXPtr = std::make_shared<LogicalTensor>(function, tokenTensor.GetDataType(), expandXShape);
    auto& oper = function.AddOperation(
        Opcode::OP_COPY_TO_LOCAL_EXPERT, {tokenTensor.GetStorage()}, {expandXPtr, syncTensor.GetStorage()});
    (void)oper;
    return expandXPtr;
}

std::tuple<int32_t, int32_t, int32_t> GetFFNTileParam(const MoeConfig& moeConfig)
{
    int32_t tileRankCnt = moeConfig.rankNum > FFN_TILE_SIZE ? FFN_TILE_SIZE : moeConfig.rankNum;
    int32_t tileNum = tileRankCnt == FFN_TILE_SIZE ? moeConfig.rankNum / FFN_TILE_SIZE : 1;
    int32_t tailNum = tileNum == 1 ? 0 : (moeConfig.rankNum % FFN_TILE_SIZE == 0 ? 0 : 1);
    return {tileRankCnt, tileNum, tailNum};
}

void MoeDispatchValidateV1(
    const Tensor& tokenTensor, const Tensor& tokenExpertTable, Tensor& expandX, Tensor& validCnt, Tensor& combineInfo,
    const char* group, const MoeConfig& moeConfig)
{
    MoeDistributedDispatchValidConfig(moeConfig);
    MoeDistributedDispatchValidateGroup(group);
    MoeDistributedDispatchValidTensor(tokenTensor, MOE_INPUT_DIM, DataType::DT_BF16, MOE_BATCH_SIZE, MOE_HIDDEN_SIZE);
    MoeDistributedDispatchValidTensor(tokenExpertTable, MOE_INPUT_DIM, DataType::DT_INT32, MOE_BATCH_SIZE, MOE_TOPK);
    int32_t expandXRow = std::min(
        static_cast<int32_t>(tokenTensor.GetShape(0)) * static_cast<int32_t>(tokenExpertTable.GetShape(1)) *
            moeConfig.rankNum, static_cast<int32_t>(tokenTensor.GetShape(0)) * moeConfig.routedExpertNum);
    MoeDistributedDispatchValidTensor(validCnt, 1, DataType::DT_INT32, moeConfig.expertNumPerRank, 1);
    MoeDistributedDispatchValidTensor(expandX, MOE_INPUT_DIM, DataType::DT_BF16, expandXRow, MOE_HIDDEN_SIZE);
    MoeDistributedDispatchValidTensor(combineInfo, MOE_INPUT_DIM, DataType::DT_INT32, expandXRow, MOE_ASSIST_INFO_COL);
}

void CreateShmemData(
    const char* group, int64_t worldSize, DataType dataType, const Shape& shape, Tensor& shmemTensor, uint64_t memType)
{
    auto& function = *Program::GetInstance().GetCurrentFunction();
    int32_t hcclGroupIndex = static_cast<int>(CommGroupRecorder::GetInstance().Input(std::string(group)));
    Shape shmemShape{worldSize};
    shmemShape.insert(shmemShape.end(), shape.begin(), shape.end());
    auto shmemTensorInner = std::make_shared<LogicalTensor>(function, dataType, shmemShape);
    shmemTensor = shmemTensorInner;
    Program::GetInstance().GetTensorSlotManager()->TensorWrite(shmemTensor, SlotProperty::SHMEM_TENSOR);
    auto& op = function.AddOperation(Opcode::OP_BIND_TENSOR, {}, {shmemTensorInner});
    op.SetAttribute(
        OpAttributeKey::bindTensor,
        BindTensor(
            hcclGroupIndex, memType,
            AlignUp(
                BytesOf(dataType) * std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()), SHMEM_SIZE_ALIGN)));
}

void CreateShmemDispatchLoop(
    Tensor& shmemData, Tensor& shmemFlag, const char* group, const MoeConfig& moeConfig, int32_t shmemDataCol,
    int32_t flagCol, DataType tokenTendorDtype)
{
    LOOP("CreateShmemTensor", FunctionType::DYNAMIC_LOOP, index, LoopRange(1))
    {
        (void)index;
        Shape shmemDataShape = {moeConfig.rankNum, moeConfig.expertNumPerRank, shmemDataCol};
        Shape shmemFlagShape = {moeConfig.expertNumPerRank, moeConfig.rankNum, flagCol};
        CreateShmemData(group, moeConfig.rankNum, tokenTendorDtype, shmemDataShape, shmemData, 0);
        CreateShmemData(group, moeConfig.rankNum, DT_INT32, shmemFlagShape, shmemFlag, 0);
    }
}

void MoeDistributedDispatch(
    const Tensor& tokenTensor, const Tensor& tokenExpertTable, Tensor& expandX, Tensor& validCnt, Tensor& combineInfo,
    const char* group, const MoeConfig& moeConfig)
{
    MoeDispatchValidateV1(tokenTensor, tokenExpertTable, expandX, validCnt, combineInfo, group, moeConfig);
    SymbolicScalar thisRank = GetHcclRankId(group);
    int batchSize = tokenTensor.GetShape(0);
    int hiddenSize = tokenTensor.GetShape(1);
    int topK = tokenExpertTable.GetShape(1);
    int shmemDataLength = AlignUp(hiddenSize, 512) + 512;
    int flagRow = 1;
    int flagCol = 128;
    int shmemDataCol = shmemDataLength * batchSize;
    Tensor shmemData;
    Tensor shmemFlag;
    CreateShmemDispatchLoop(shmemData, shmemFlag, group, moeConfig, shmemDataCol, flagCol, tokenTensor.GetDataType());
    LOOP("L0", FunctionType::DYNAMIC_LOOP, index, LoopRange(1))
    {
        (void)index;
        TileShape::Current().SetDistTile({1, batchSize, 0}, {topK, 1, 0}, {moeConfig.rankNum, 1, 0});
        Tensor syncTensor = SendToRoutingExpert(shmemData, tokenTensor, tokenExpertTable, group, moeConfig);
        TileShape::Current().SetDistTile(
            {flagRow, 1, 0}, {moeConfig.rankNum, 1, 0}, {1, moeConfig.expertNumPerRank, 0});
        auto localShmemFlag =
            View(shmemFlag, {moeConfig.rankNum, moeConfig.expertNumPerRank, 1, flagCol}, {0, 0, thisRank, 0});
        Tensor flagDummy = DispatchSetFlag(localShmemFlag, tokenExpertTable, syncTensor, group, moeConfig);
        auto [ffnTileCnt, ffnTileNum, ffnTailNum] = GetFFNTileParam(moeConfig);
        TileShape::Current().SetDistTile({batchSize, 1, 0}, {hiddenSize, 1, 0}, {ffnTileCnt, ffnTileNum, ffnTailNum});
        auto shmemFlagSched =
            View(shmemFlag, {1, moeConfig.expertNumPerRank, moeConfig.rankNum, flagCol}, {thisRank, 0, 0, 0});
        auto recvTokenCntOut = DispatchFFNSched(group, flagDummy, shmemFlagSched, moeConfig, ffnTileCnt);
        auto shmemDataBatching =
            View(shmemData, {1, moeConfig.rankNum, moeConfig.expertNumPerRank, shmemDataLength}, {thisRank, 0, 0, 0});
        auto expandXPtr = DispatchFFNBatching(
            group, tokenTensor, recvTokenCntOut, shmemDataBatching, localShmemFlag, expandX.GetShape(0),
            ffnTileNum + ffnTailNum, moeConfig);
        auto combineInfoPtr = DispatchFFNCombineInfo(
            group, tokenTensor, recvTokenCntOut, shmemDataBatching, localShmemFlag, expandX.GetShape(0),
            ffnTileNum + ffnTailNum, moeConfig);
        TileShape::Current().SetDistTileRank({moeConfig.expertNumPerRank / 10, 10, 0});
        auto shmemFlagValidCnt =
            View(shmemFlag, {1, moeConfig.expertNumPerRank, moeConfig.rankNum, flagCol}, {thisRank, 0, 0, 0});
        auto validCntPtr = DispatchFFNValidCnt(recvTokenCntOut, shmemFlagValidCnt, moeConfig);
        expandX = expandXPtr;
        validCnt = validCntPtr;
        combineInfo = combineInfoPtr;
    }
}

void MoeDispatchValidateV2(
    const Tensor& x, const Tensor& expertIds, const char* group, uint32_t epWorldSize, uint32_t moeExpertNum,
    uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, Tensor& expandX, Tensor& expertTokenNums,
    Tensor& assistInfoForCombine, Tensor& recvCounts)
{
    (void) sharedExpertRankNum;
    MoeDistributedDispatchValidateGroup(group);
    ASSERT(DistributedErrorCode::INVALID_WORLD_SIZE, epWorldSize > 0)
        << "MoeDispatch constraint violated: epWorldSize must be > 0, but got " << std::to_string(epWorldSize);
    ASSERT(DistributedErrorCode::INVALID_MOE_EXPERT_NUM, moeExpertNum == ROUTED_EXPET_NUM)
        << "MoeDispatch constraint violated: moeExpertNum must " << ROUTED_EXPET_NUM << ", but got "
        << std::to_string(moeExpertNum);
    int32_t routedExpertNum = moeExpertNum - sharedExpertNum;
    int32_t expertNumPerRank = routedExpertNum / epWorldSize;
    MoeDistributedDispatchValidTensor(x, MOE_INPUT_DIM, DataType::DT_BF16, MOE_BATCH_SIZE, MOE_HIDDEN_SIZE);
    MoeDistributedDispatchValidTensor(expertIds, MOE_INPUT_DIM, DataType::DT_INT32, MOE_BATCH_SIZE, MOE_TOPK);
    MoeDistributedDispatchValidTensor(expertTokenNums, 1, DataType::DT_INT32, expertNumPerRank, 1);
    MoeDistributedDispatchValidTensor(recvCounts, 1, DataType::DT_INT32, 1, 0);
    int batchSize = x.GetShape(0);
    int topK = expertIds.GetShape(1);
    int32_t expandXRow = std::min(
        static_cast<int32_t>(batchSize) * static_cast<int32_t>(topK) * static_cast<int32_t>(epWorldSize),
        static_cast<int32_t>(batchSize) * routedExpertNum);
    MoeDistributedDispatchValidTensor(expandX, MOE_INPUT_DIM, DataType::DT_BF16, expandXRow, MOE_HIDDEN_SIZE);
    MoeDistributedDispatchValidTensor(
        assistInfoForCombine, MOE_INPUT_DIM, DataType::DT_INT32, expandXRow, MOE_ASSIST_INFO_COL);
    uint64_t shmemSize =
        moeExpertNum * x.GetShape(0) * x.GetShape(1) * BytesOf(x.GetDataType()) +
        moeExpertNum * x.GetShape(0) * assistInfoForCombine.GetShape(1) * BytesOf(assistInfoForCombine.GetDataType()) +
        AlignUp(routedExpertNum, 256) * 8 * BytesOf(DataType::DT_INT32) +
        moeExpertNum * 128 * BytesOf(DataType::DT_INT32) + 128 * BytesOf(DataType::DT_INT32);
    const uint64_t winSize = 1024 * 1024 * 200;
    ASSERT(DistributedErrorCode::WIN_SIZE_EXCEED_LIMIT, shmemSize < winSize)
        << "Exceeds winSize limit. Masxmum allowed " << std::to_string(winSize)
        << ", got " << std::to_string(shmemSize);
}

Tensor Nop(const std::vector<Tensor>& inTensors)
{
    auto& function = *Program::GetInstance().GetCurrentFunction();
    auto out = std::make_shared<LogicalTensor>(function, DT_INT32, Shape{1, 1});
    LogicalTensors iOperands;
    for (const Tensor& inTensor : inTensors) {
        iOperands.emplace_back(inTensor.GetStorage());
    }
    function.AddOperation(Opcode::OP_NOP, iOperands, {out});
    return out;
}

Tensor DispatchCalcOccurrences(Tensor& expertIds, SymbolicScalar expertId, int32_t calcIndex)
{
    Tensor expertIdsDup = Full(expertId, DT_INT32, {1, expertIds.GetShape(1)});
    Tensor subResult = Sub(expertIdsDup, expertIds);
    Tensor subResultFp32 = Cast(subResult, DT_FP32, CAST_TRUNC);
    Tensor absSubResult = Abs(subResultFp32);
    Tensor subResultInt32 = Cast(absSubResult, DT_INT32, CAST_TRUNC);
    Tensor countOfEquals = Clip(subResultInt32, Element(DT_INT32, 0), Element(DT_INT32, 1));
    Tensor cumSumOffset = CumSum(countOfEquals, 1);
    Tensor cumSumOffsetInt32 = Cast(cumSumOffset, DT_INT32, CAST_TRUNC);
    Tensor expertOffsetResult = ScalarSubS(cumSumOffsetInt32, Element(DT_INT32, calcIndex));
    Tensor expertOffsetResultFp32 = Cast(expertOffsetResult, DT_FP32, CAST_TRUNC);
    Tensor expertOffsetAbsFp32 = Abs(expertOffsetResultFp32);
    Tensor expertOffset = Cast(expertOffsetAbsFp32, DT_INT32, CAST_TRUNC);
    return expertOffset;
}

void MoeDistributedDispatchV2(
    const Tensor& x, const Tensor& expertIds, const char* group, uint32_t epWorldSize, uint32_t moeExpertNum,
    uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, Tensor& expandX, Tensor& assistInfoForCombine,
    Tensor& expertTokenNums, Tensor& recvCounts)
{
    MoeDispatchValidateV2(
        x, expertIds, group, epWorldSize, moeExpertNum, sharedExpertNum, sharedExpertRankNum, expandX, expertTokenNums,
        assistInfoForCombine, recvCounts);
    int32_t routedExpertNum = moeExpertNum - sharedExpertNum;
    ASSERT(DistributedErrorCode::INVALID_WORLD_SIZE, epWorldSize > 0)
        << "MoeDispatch constraint violated: epWorldSize must be > 0, but got " << std::to_string(epWorldSize);
    int32_t expertNumPerRank = routedExpertNum / epWorldSize;
    int32_t batchSize = x.GetShape(0);
    int32_t hiddenSize = x.GetShape(1);
    int32_t topK = expertIds.GetShape(1);
    ASSERT(DistributedErrorCode::INVALID_MOE_TOP_K, topK > 0)
        << "MoeDispatch constraint violated: topK must be > 0, but got " << std::to_string(topK);
    ASSERT(DistributedErrorCode::INVALID_EXPERT_NUM_PER_RANK, expertNumPerRank > 0)
        << "MoeDispatch constraint violated: expertNumPerRank must be > 0, but got "
        << std::to_string(expertNumPerRank);
    int32_t infoSize = AlignUp(assistInfoForCombine.GetShape(1), 8);
    int32_t countSize = 8;
    int32_t signalCol = 128;
    int32_t cumSumRowShape = AlignUp(routedExpertNum, 256);
    SymbolicScalar thisRank = GetHcclRankId(group);

    Shape shmemDataShape = {expertNumPerRank * epWorldSize * batchSize, hiddenSize};
    auto shmemData = CreateShmemTensor(group, epWorldSize, x.GetDataType(), shmemDataShape);

    Shape shmemInfoShape = {expertNumPerRank * epWorldSize * batchSize, infoSize};
    auto shmemInfo = CreateShmemTensor(group, epWorldSize, DT_INT32, shmemInfoShape);

    Shape shmemCountShape = {cumSumRowShape, countSize};
    auto shmemCount = CreateShmemTensor(group, epWorldSize, DT_INT32, shmemCountShape);

    Shape shmemCountSignalShape = {moeExpertNum, signalCol};
    auto shmemCountSignal = CreateShmemTensor(group, epWorldSize, DT_INT32, shmemCountSignalShape);

    Shape shmemDataSignalgShape = {1, signalCol};
    auto shmemDataSignal = CreateShmemTensor(group, epWorldSize, DT_INT32, shmemDataSignalgShape);

    TileShape::Current().SetVecTile({1, batchSize * topK});
    Tensor expertIdsVec = Reshape(expertIds, {1, batchSize * topK});
    Tensor offsetTable(DataType::DT_INT32, {batchSize, topK}, "offsetTable");
    LOOP("MoeDistributedDispatchPrepare", FunctionType::DYNAMIC_LOOP, i, LoopRange(1))
    {
        (void)i;
        for (int index = 0; index < batchSize * topK; ++index) {
            int32_t rowIndex = index / topK;
            int32_t colIndex = index % topK;
            SymbolicScalar remoteExpertId = GetTensorData(expertIds, {rowIndex, colIndex});
            Tensor tokenOffsetResult = DispatchCalcOccurrences(expertIdsVec, remoteExpertId, index);
            SymbolicScalar tokenOffset = GetTensorData(tokenOffsetResult, {0, index - 1});
            SetTensorData(tokenOffset, {rowIndex, colIndex}, offsetTable);
        }
    }

    LOOP("MoeDistributedDispatchSendData", FunctionType::DYNAMIC_LOOP, index, LoopRange(topK * batchSize))
    {
        Tensor moeInfo(DataType::DT_INT32, {1, infoSize}, "moeInfo");
        SymbolicScalar rowIndex = index / topK;
        SymbolicScalar colIndex = index % topK;
        Tensor tensorTile = View(x, {1, hiddenSize}, {rowIndex, 0});
        SetTensorData(thisRank, {0, 0}, moeInfo);
        SetTensorData(rowIndex, {0, 1}, moeInfo);
        SetTensorData(colIndex, {0, 2}, moeInfo);
        SymbolicScalar remoteExpertId = GetTensorData(expertIds, {rowIndex, colIndex});
        SymbolicScalar remoteExpertOffset = remoteExpertId % expertNumPerRank;
        SymbolicScalar remoteRankId = remoteExpertId / expertNumPerRank;
        SymbolicScalar tokenOffset = GetTensorData(offsetTable, {rowIndex, colIndex});
        auto shmemDataTile = ShmemView(
            shmemData, {1, hiddenSize},
            std::vector<SymbolicScalar>{(remoteExpertOffset * epWorldSize + thisRank) * batchSize + tokenOffset, 0});
        TileShape::Current().SetVecTile({1, hiddenSize});
        Tensor shmemDataPutOut = ShmemPut(tensorTile, shmemDataTile, remoteRankId, AtomicType::SET, offsetTable);
        auto shmemInfoTile = ShmemView(
            shmemInfo, {1, infoSize},
            std::vector<SymbolicScalar>{(remoteExpertOffset * epWorldSize + thisRank) * batchSize + tokenOffset, 0});
        TileShape::Current().SetVecTile({1, infoSize});
        Tensor shmemInfoPutOut = ShmemPut(moeInfo, shmemInfoTile, remoteRankId, AtomicType::SET, offsetTable);
        Tensor sendOut = Nop({shmemDataPutOut, shmemInfoPutOut});
        TileShape::Current().SetVecTile({1, signalCol});
        auto shmemDataSignalTile = ShmemView(shmemDataSignal, {1, signalCol}, {0, 0});
        ShmemSignalAll(shmemDataSignalTile, 0, 1, AtomicType::ADD, sendOut);
    }

    Tensor shmemCountOut(DT_INT32, {1, 1}, "shmemCountOut");
    LOOP("MoeDistributedDispatchSendCount", FunctionType::DYNAMIC_LOOP, expertId, LoopRange(moeExpertNum))
    {
        Tensor expertOffset = DispatchCalcOccurrences(expertIdsVec, expertId, batchSize * topK);
        TileShape::Current().SetVecTile({1, 1});
        SymbolicScalar remoteRankId = expertId / expertNumPerRank;
        SymbolicScalar remoteExpertOffset = expertId % expertNumPerRank;
        auto shmemCountTile = ShmemView(shmemCount, {1, 1}, {remoteExpertOffset * epWorldSize + thisRank + 1, 0});
        Tensor totalOffsetTile = View(expertOffset, {1, 1}, {0, batchSize * topK - 1});
        Tensor shmemPutOut = ShmemPut(totalOffsetTile, shmemCountTile, remoteRankId, AtomicType::SET, totalOffsetTile);
        TileShape::Current().SetVecTile({1, signalCol});
        auto shmemCountSignalTile = ShmemView(shmemCountSignal, {1, signalCol}, {0, 0});
        shmemCountOut = ShmemSignal(shmemCountSignalTile, 0, remoteRankId, 1, AtomicType::ADD, shmemPutOut);
    }

    Tensor cumSumResult(DT_INT32, {cumSumRowShape, countSize}, "cumSumResult");
    Tensor localExpertRecvCount(DT_INT32, {cumSumRowShape, countSize}, "localExpertRecvCount");
    LOOP("MoeDistributedDispatchCumSum", FunctionType::DYNAMIC_LOOP, i, LoopRange(1))
    {
        (void)i;
        TileShape::Current().SetVecTile({1, signalCol});
        auto shmemDataSignalLocalTile = ShmemView(shmemDataSignal, {1, signalCol}, {0, 0});
        Tensor waitUntilOut1 =
            ShmemWaitUntil(shmemDataSignalLocalTile, 0, OpType::EQ, batchSize * topK * epWorldSize, true, x);
        TileShape::Current().SetVecTile({1, signalCol});
        auto shmemCountSignalLocalTile = ShmemView(shmemCountSignal, {1, signalCol}, {0, 0});
        Tensor waitUntilOut =
            ShmemWaitUntil(shmemCountSignalLocalTile, 0, OpType::EQ, moeExpertNum, true, x);
        Tensor waitOut = Nop({waitUntilOut1, waitUntilOut});

        TileShape::Current().SetVecTile({cumSumRowShape, countSize});
        auto shmemReceiveCountTile = ShmemView(shmemCount, {cumSumRowShape, countSize}, {0, 0});
        localExpertRecvCount = ShmemGet(shmemReceiveCountTile, thisRank, waitOut);
        TileShape::Current().SetVecTile({cumSumRowShape, countSize});
        auto shmemCountTile = ShmemView(shmemCount, {cumSumRowShape, countSize}, {0, 0});
        Tensor shmemGetOut = ShmemGet(shmemCountTile, thisRank, waitOut);
        Tensor cumSumCurrent = CumSum(shmemGetOut, 0);
        cumSumResult = Cast(cumSumCurrent, DT_INT32, CAST_TRUNC);

        SymbolicScalar recvCountResult = GetTensorData(cumSumResult, {expertNumPerRank * epWorldSize, 0});
        SetTensorData(recvCountResult, {0}, recvCounts);

        for (int32_t expertId = 0; expertId < expertNumPerRank; ++expertId) {
            Tensor expertValidCnt = View(shmemGetOut, {epWorldSize, countSize}, {expertId * epWorldSize + 1, 0});
            Tensor expertValidCumSum = CumSum(expertValidCnt, 0);
            Tensor expertCumSumResult = Cast(expertValidCumSum, DT_INT32, CAST_TRUNC);
            SymbolicScalar recvValidResult = GetTensorData(expertCumSumResult, {epWorldSize - 1, 0});
            SetTensorData(recvValidResult, {expertId}, expertTokenNums);
        }
    }

    LOOP("MoeDistributedDispatchReceive", FunctionType::DYNAMIC_LOOP, index, LoopRange(expertNumPerRank * epWorldSize))
    {
        SymbolicScalar curCount = GetTensorData(localExpertRecvCount, {index + 1, 0});
        SymbolicScalar offset = GetTensorData(cumSumResult, {index, 0});
        auto curShmemDataTile = ShmemView(
            shmemData, {batchSize, hiddenSize}, std::vector<SymbolicScalar>{curCount, hiddenSize},
            {index * batchSize, 0});
        TileShape::Current().SetVecTile({batchSize, hiddenSize});
        Tensor localDataRecvCount = ShmemGet(curShmemDataTile, thisRank, cumSumResult);
        Assemble(localDataRecvCount, std::vector<SymbolicScalar>{offset, 0}, expandX);
        auto curShmemInfoTile = ShmemView(
            shmemInfo, {batchSize, assistInfoForCombine.GetShape(1)},
            std::vector<SymbolicScalar>{curCount, assistInfoForCombine.GetShape(1)}, {index * batchSize, 0});
        TileShape::Current().SetVecTile({batchSize, assistInfoForCombine.GetShape(1)});
        Tensor localInfoRecvCount = ShmemGet(curShmemInfoTile, thisRank, cumSumResult);
        Assemble(localInfoRecvCount, std::vector<SymbolicScalar>{offset, 0}, assistInfoForCombine);
    }
}
} // namespace Distributed
} // namespace npu::tile_fwk