* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file distributed_common.h
* \brief
*/
#ifndef DISTRIBUTED_COMMON_H
#define DISTRIBUTED_COMMON_H
#include <cstdint>
#include <array>
#include <string>
#include <vector>
#include <optional>
#include <functional>
#include "interface/configs/config_manager.h"
#include "interface/function/function.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "interface/program/program.h"
#include "interface/operation/opcode.h"
#include "interface/operation/operation.h"
#include "interface/configs/config_manager.h"
#include "distributed_expand.h"
#include "tilefwk/comm_group_recorder.h"
#include "tilefwk/error_code.h"
namespace npu::tile_fwk {
namespace Distributed {
constexpr int32_t DIST_HEAD_SHAPE = 0;
constexpr int32_t DIST_HEAD_COUNT = 1;
constexpr int32_t DIST_TAIL_SHAPE = 2;
constexpr int32_t DIST_INDEX_ZERO = 0;
constexpr int32_t DIST_INDEX_ONE = 1;
constexpr int32_t DIST_INDEX_TWO = 2;
constexpr uint16_t COPY_BLOCK_BYTE_SIZE = 32;
constexpr uint16_t SAME_ADDR_BYTE_SIZE = 512;
constexpr uint64_t SHMEM_SIZE_ALIGN = 512;
constexpr int32_t ROUTED_EXPET_NUM = 160;
constexpr int32_t FFN_TILE_SIZE = 8;
constexpr size_t MIN_TILE_SHAPE_DIM = 2;
constexpr int32_t AIV_NUM = 4;
constexpr int32_t RECEIVE_CNT_OUT_ROW = 1024;
constexpr int32_t RECEIVE_CNT_OUT_COL = 512;
constexpr int32_t SHMEM_SIGNAL_STRIDE = 8;
constexpr int32_t MAX_SHMEM_TILE_DIMS = 4;
constexpr int32_t MAX_GROUP_NAME_LENGTH = 128;
constexpr uint64_t MOE_INPUT_DIM = 2;
constexpr int32_t MOE_HIDDEN_SIZE = 5120;
constexpr int32_t MOE_BATCH_SIZE = 8;
constexpr int32_t MOE_TOPK = 8;
constexpr int32_t MOE_ASSIST_INFO_COL = 3;
enum class TileIndex : size_t { HEAD_SHAPE, HEAD_NUM, TAIL_SHAPE };
enum class AllReduceType {
ONE_SHOT,
TWO_SHOT,
};
inline std::string AtomicTypeToString(AtomicType type)
{
switch (type) {
case AtomicType::SET:
return "TileOp::Distributed::AtomicType::SET";
case AtomicType::ADD:
return "TileOp::Distributed::AtomicType::ADD";
default:
return "";
}
}
inline std::string OpTypeToString(OpType type)
{
switch (type) {
case OpType::EQ:
return "OpType::EQ";
case OpType::NE:
return "OpType::NE";
case OpType::LT:
return "OpType::LT";
case OpType::LE:
return "OpType::LE";
case OpType::GT:
return "OpType::GT";
case OpType::GE:
return "OpType::GE";
default:
return "";
}
}
template <typename T, typename = void>
struct is_iterable : std::false_type {};
template <typename T>
struct is_iterable<T, std::void_t<decltype(std::begin(std::declval<T>())), decltype(std::end(std::declval<T>()))>>
: std::true_type {};
template <typename T>
inline constexpr bool is_iterable_v = is_iterable<T>::value;
template <typename T>
typename std::enable_if<!is_iterable_v<T>, std::string>::type ToString(T value)
{
if constexpr (std::is_same_v<T, std::string>) {
return value;
} else if constexpr (std::is_convertible_v<T, std::string>) {
return std::string(value);
} else if constexpr (std::is_integral_v<T>) {
return std::to_string(value);
} else if constexpr (std::is_same_v<T, AtomicType>) {
return AtomicTypeToString(value);
} else if constexpr (std::is_same_v<T, DataType>) {
return DataType2String(value);
} else if constexpr (std::is_same_v<T, Opcode>) {
return OpcodeManager::Inst().GetOpcodeStr(value);
} else if constexpr (std::is_same_v<T, OpType>) {
return OpTypeToString(value);
} else {
return "";
}
}
template <typename Container>
typename std::enable_if<is_iterable_v<Container>, std::string>::type ToString(const Container& c)
{
std::ostringstream oss;
oss << "[";
bool first = true;
for (const auto& item : c) {
if (!first) {
oss << ", ";
}
oss << ToString(item);
first = false;
}
oss << "]";
return oss.str();
}
struct ShmemPutAttr {
std::string group;
Shape copyBufferShape;
AtomicType atomicType = AtomicType::SET;
SymbolicScalar ownerRank;
};
struct ShmemGetAttr {
Shape copyBufferShape;
AtomicType atomicType = AtomicType::SET;
SymbolicScalar ownerRank;
std::string group;
};
struct ShmemSignalAttr {
int64_t signalValue = 1;
int32_t signalStride = SHMEM_SIGNAL_STRIDE;
std::vector<int64_t> tileShape;
AtomicType atomicType = AtomicType::SET;
bool notifyAll{false};
int64_t worldSize{0};
std::vector<int64_t> viewshapes;
int64_t viewTileNum{0};
int64_t totalTileNum{0};
SymbolicScalar ownerRank;
std::string group;
};
struct ShmemWaitUntilAttr {
int32_t expectedSum = 0;
int32_t signalStride = SHMEM_SIGNAL_STRIDE;
bool resetSignal = false;
std::vector<int64_t> tileShape;
std::vector<int64_t> viewshapes;
std::vector<int64_t> viewTileStrides;
std::vector<int64_t> viewIndexStrides;
int64_t viewTileNum{0};
int64_t totalTileNum{0};
SymbolicScalar ownerRank;
std::string group;
};
struct ShmemSetAttr {
std::string group;
bool isSetData{true};
Shape setBufferShape;
SymbolicScalar ownerRank;
};
struct MoeDispatchAttr {
std::string extraTemplateParam{};
int64_t topK = 0;
SymbolicScalar ownerRank;
};
struct MoeCombineAttr {
int64_t setType = 0;
int64_t topK = 0;
int64_t paddedColShape{0};
int64_t rowOffset{-1};
int64_t rowShape{-1};
SymbolicScalar ownerRank;
};
inline int GetTotalTileNum(const std::array<int, MAX_DIST_DIM_SIZE>& tile)
{
return tile[static_cast<size_t>(TileIndex::HEAD_NUM)] +
static_cast<int>(tile[static_cast<size_t>(TileIndex::TAIL_SHAPE)] != 0);
}
inline std::tuple<int64_t, int64_t, std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> GetTotalTileNum(
const VecTile& tileShape, const ShmemTensor& src)
{
Shape rawShape = ((Operation*)src.signalOp)->GetOOperands()[0]->tensor->rawshape;
Shape dataShape = src.signal.GetShape();
ASSERT(DistributedErrorCode::INVALID_TENSOR_DIM, tileShape.size() >= MIN_TILE_SHAPE_DIM)
<< "Invalid dimensional: "
<< " tileShape dim must >= " << MIN_TILE_SHAPE_DIM << ", but got dimensional=" << tileShape.size();
ASSERT(DistributedErrorCode::INVALID_TENSOR_DIM, dataShape.size() == (tileShape.size() + 1))
<< "Invalid dimensional: "
<< " shape parameter dim must = tileShape dim, but got shape parameter dim=" << (dataShape.size() - 1)
<< ", tileShape dim=" << tileShape.size();
size_t vecTileDim = tileShape.size();
size_t startDim = dataShape.size() - vecTileDim;
for (size_t i = 0; i < vecTileDim; ++i) {
size_t curDim = startDim + i;
ASSERT(DistributedErrorCode::INVALID_TENSOR_DIM, rawShape[curDim] % dataShape[curDim] == 0)
<< "signal of shmem tensor shape[" << i << "]=" << rawShape[curDim]
<< " must be divisible by shape parameter[" << i << "]=" << dataShape[curDim];
}
std::vector<int64_t> viewshapes(vecTileDim);
std::vector<int64_t> dimTileNums(vecTileDim);
std::vector<int64_t> viewTileStrides(vecTileDim);
std::vector<int64_t> viewIndexStrides(vecTileDim);
int64_t viewTileNum = 1;
int64_t crossViewNum = 1;
viewTileStrides[0] = 1;
viewIndexStrides[0] = 1;
for (size_t i = 0; i < vecTileDim; ++i) {
size_t curDim = startDim + i;
viewshapes[i] = dataShape[curDim];
int64_t totalShape = dataShape[curDim];
int64_t tileShapeVal = tileShape[i];
dimTileNums[i] = totalShape / tileShapeVal + (totalShape % tileShapeVal == 0 ? 0 : 1);
viewTileNum *= dimTileNums[i];
crossViewNum *= (rawShape[curDim] / dataShape[curDim]);
if (i > 0) {
viewTileStrides[i] = viewTileStrides[i - 1] * dimTileNums[i - 1];
viewIndexStrides[i] = viewIndexStrides[i - 1] * (rawShape[curDim - 1] / dataShape[curDim - 1]);
}
}
int64_t totalTileNum = viewTileNum * crossViewNum;
return {totalTileNum, viewTileNum, viewshapes, viewTileStrides, viewIndexStrides};
}
}
}
#endif