* Copyright (c) Huawei Technologies Co., Ltd. 2026-2026. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
* \file stream_k.h
* \brief
*/
#ifndef LOAD_BALANCE_COMMON_H
#define LOAD_BALANCE_COMMON_H
#include <cstdint>
#include <array>
#include <functional>
#include <unordered_map>
namespace load_balance {
#define LOAD_BALANCE_SUCCESS 0
#define LOAD_BALANCE_ERROR_INVALID_INPUT 1
template <class T> using Range = std::pair<T, T>;
template <class T, size_t ROW, size_t COL> using Table = std::array<std::array<T, COL>, ROW>;
using CostFunc = std::function<uint32_t(uint32_t, uint32_t)>;
enum class SocVersion : uint32_t { BUTT };
struct DeviceInfo {
uint32_t aicCoreMaxNum{0U};
uint32_t aicCoreMinNum{0U};
uint32_t aivCoreMaxNum{0U};
uint32_t aivCoreMinNum{0U};
uint32_t cvRatio;
SocVersion version{SocVersion::BUTT};
};
struct GeneralBalanceParam {
uint32_t mBaseSize{1U};
uint32_t s2BaseSize{1U};
int64_t faToleranceRatio{2U};
bool fdOn{true};
int64_t fdTolerance{0U};
CostFunc costFunc{nullptr};
};
enum class SparseMode : uint8_t {
DEFAULT_MASK = 0,
ALL_MASK,
LEFT_UP_CAUSAL,
RIGHT_DOWN_CAUSAL,
BAND,
BUTT,
};
enum class Layout : uint8_t { BSND = 0, BNSD, BSH, NBSD, TND, NTD, PA_NZ, PA_BBND, PA_BNBD, BUTT };
enum class DataType : uint8_t { FP32 = 0, FP16, INT4, INT8, INT32, BUTT };
template <class T> inline T CeilDiv(T a, T b) {
if (b == 0) {
return 0;
}
return a / b + (a % b != 0);
}
template <class T> inline T FloorDiv(T a, T b) {
if (b == 0) {
return 0;
}
return a / b;
}
template <class T> inline T SafeFloorDiv(T a, T b, T val) {
static_assert(std::is_integral_v<T>, "must be integer type");
if (b == 0) {
return val;
}
return a / b;
}
template <typename T> inline T AddOne(T val) {
static_assert(std::is_integral_v<T>, "must be integer type");
return val + 1;
}
template <typename T> inline T MinusOne(T val) {
static_assert(std::is_integral_v<T>, "must be integer type");
return val - 1;
}
template <typename T> inline T ToOpenInterval(T val) { return AddOne(val); }
template <typename T> inline T ToClosedInterval(T val) { return MinusOne(val); }
template <typename T> inline T IndexToNum(T val) { return AddOne(val); }
template <typename T> inline T NumToIndex(T val) { return MinusOne(val); }
template <typename T> T Clip(T value, T minValue, T maxValue) {
if (value < minValue) {
return minValue;
}
if (value > maxValue) {
return maxValue;
}
return value;
}
template <typename T> inline bool IsWithinTolerance(T limit, T tolerance, T value) {
return limit + tolerance >= value;
}
static Layout ConvertToLayout(const std::string &layoutStr) {
static std::unordered_map<std::string, Layout> layoutTable{{"BSND", Layout::BSND}, {"BNSD", Layout::BNSD},
{"BSH", Layout::BSH}, {"NBSD", Layout::NBSD}, {"TND", Layout::TND}, {"NTD", Layout::NTD},
{"PA_NZ", Layout::PA_NZ}, {"PA_BBND", Layout::PA_BBND}, {"PA_BNBD", Layout::PA_BNBD}};
if (layoutTable.find(layoutStr) != layoutTable.end()) {
return layoutTable[layoutStr];
}
return Layout::BUTT;
}
static uint32_t GetDataTypeByteSize(DataType type) {
switch (type) {
case (DataType::FP32):
case (DataType::INT32):
return 4U;
case (DataType::FP16):
return 2U;
case (DataType::INT8):
return 1U;
default:
return 2U;
}
}
}
#endif