* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OP_API_INC_GROUPED_MATMUL_UTIL_H
#define OP_API_INC_GROUPED_MATMUL_UTIL_H
#include "opdev/common_types.h"
#include "opdev/op_executor.h"
namespace gmm {
using namespace op;
constexpr int64_t NO_SPLIT = -1L;
constexpr int64_t SPLIT_M = 0L;
constexpr int64_t SPLIT_K = 2L;
constexpr int64_t SPLIT_N = 1L;
constexpr int64_t GROUP_LIST_SPARSE_M = 2L;
constexpr size_t MAX_FM_DIM = 6UL;
constexpr size_t MIN_FM_DIM = 2UL;
constexpr size_t MIN_DIM_FOR_TRANSPOSE = 2UL;
constexpr size_t SPLIT_M_SINGLE_WEIGHT_DIM = 3UL;
constexpr size_t SPLIT_K_SINGLE_WEIGHT_DIM = 2UL;
constexpr size_t MX_SPLIT_M_SINGLE_X_DIM = 2UL;
constexpr size_t MX_SPLIT_M_SINGLE_WEIGHT_DIM = 3UL;
constexpr size_t MX_SPLIT_M_SCALE_DIM = 4UL;
constexpr size_t MX_BIAS_DIM = 2UL;
constexpr size_t MX_SPLIT_M_PER_TOKEN_SCALE_DIM = 3UL;
constexpr size_t MX_SPLIT_K_SINGLE_X_DIM = 2UL;
constexpr size_t MX_SPLIT_K_SINGLE_WEIGHT_DIM = 2UL;
constexpr size_t MX_SPLIT_K_SCALE_DIM = 3UL;
constexpr size_t MX_SPLIT_K_PER_TOKEN_SCALE_DIM = 3UL;
constexpr size_t MX_TUPLE_X_DIM_INDEX = 0UL;
constexpr size_t MX_TUPLE_WEIGHT_DIM_INDEX = 1UL;
constexpr size_t MX_TUPLE_SCALE_DIM_INDEX = 2UL;
constexpr size_t MX_TUPLE_PERTOKENSCALE_DIM_INDEX = 3UL;
constexpr size_t MX_TUPLE_GROUP_NUMBER_DIM_INDEX = 4UL;
constexpr int64_t CUBE_BLOCK_SIZE_16 = 16L;
constexpr int64_t CUBE_BLOCK_SIZE_32 = 32L;
constexpr int64_t CUBE_BLOCK_SIZE_64 = 64L;
constexpr int64_t PERTILE_GROUP_SIZE = 128L;
constexpr int64_t MXFP_DIVISOR_SIZE = 64L;
constexpr int64_t MXFP_MULTI_BASE_SIZE = 2L;
constexpr int64_t EVEN_FACTOR = 2L;
constexpr int64_t LAST_TWO_DIM_INDEX = 2L;
constexpr int64_t N_K_MAX_VALUE_WEIGHT_QUANT = 65535L;
constexpr int64_t N_K_ALIGN_VALUE_WEIGHT_QUANT_4BIT = 64L;
constexpr size_t LAST_FIRST_DIM_INDEX = 1;
constexpr size_t LAST_SECOND_DIM_INDEX = 2;
constexpr size_t LAST_THIRD_DIM_INDEX = 3;
constexpr size_t LAST_FOURTH_DIM_INDEX = 4;
constexpr size_t QUANT_WEIGHTNZ_STORAGE_DIM = 5;
enum class GMMApiVersion : uint32_t {
V1 = 1U,
V2 = 2U,
V3 = 3U,
V4 = 4U,
V5 = 5U,
WeightNz = 6U
};
template <typename T>
struct GroupedMatmulParamsBase {
const T *x = nullptr;
const T *weight = nullptr;
const aclTensorList *biasOptional = nullptr;
const aclIntArray *groupListOptional = nullptr;
const aclTensor *groupTensorOptional = nullptr;
const T *scaleOptional = nullptr;
const aclTensorList *offsetOptional = nullptr;
const aclTensorList *antiquantScaleOptional = nullptr;
const aclTensorList *antiquantOffsetOptional = nullptr;
const T *perTokenScaleOptional = nullptr;
const aclTensorList *activationInputOptional = nullptr;
const aclTensorList *activationQuantScaleOptional = nullptr;
const aclTensorList *activationQuantOffsetOptional = nullptr;
int64_t splitItem = 0;
int64_t groupListType = 0;
int64_t activeType = 0;
bool transposeWeight = false;
bool transposeX = false;
bool isSingleWeight = false;
GMMApiVersion apiVersion = GMMApiVersion::V1;
int64_t groupType = -1L;
aclIntArray *tuningConfigOptional = nullptr;
const T *y = nullptr;
const aclTensorList *activationFeatureOutOptional = nullptr;
const aclTensorList *dynQuantScaleOutOptional = nullptr;
DataType xDtype = DataType::DT_FLOAT16;
};
using GroupedMatmulParams = GroupedMatmulParamsBase<aclTensorList>;
const std::map<DataType, aclDataType> BIAS_DTYPE {
{DataType::DT_FLOAT16, aclDataType::ACL_FLOAT16},
{DataType::DT_BF16, aclDataType::ACL_FLOAT},
{DataType::DT_INT8, aclDataType::ACL_INT32},
{DataType::DT_FLOAT, aclDataType::ACL_FLOAT},
{DataType::DT_FLOAT8_E4M3FN, aclDataType::ACL_FLOAT},
{DataType::DT_FLOAT8_E5M2, aclDataType::ACL_FLOAT},
{DataType::DT_HIFLOAT8, aclDataType::ACL_FLOAT},
{DataType::DT_INT4, aclDataType::ACL_FLOAT16},
{DataType::DT_FLOAT4_E2M1, aclDataType::ACL_FLOAT},
{DataType::DT_FLOAT4_E1M2, aclDataType::ACL_FLOAT}
};
const std::map<DataType, std::string> DTYPE_STRING{
{DataType::DT_FLOAT16, "FLOAT16"},
{DataType::DT_BF16, "BF16"},
{DataType::DT_INT8, "INT8"},
{DataType::DT_INT4, "INT4"},
{DataType::DT_FLOAT, "FLOAT"},
{DataType::DT_FLOAT8_E4M3FN, "FLOAT8_E4M3FN"},
{DataType::DT_FLOAT8_E5M2, "FLOAT8_E5M2"},
{DataType::DT_HIFLOAT8, "HIFLOAT8"}
};
const std::initializer_list<op::DataType> DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT, DataType::DT_FLOAT16,
DataType::DT_BF16};
bool IsTransposeLastTwoDims(const aclTensor *tensor);
bool IsTransposeForMxShape(const aclTensor *tensor);
void CreateContiguousTensorListForPertoken(const aclTensorList *tensorList, std::vector<aclTensor *> &newTensorList,
aclOpExecutor *executor);
void CreateContiguousTensorListForMXTypeMScale(const aclTensorList *tensorList, std::vector<aclTensor *> &newTensorList,
aclOpExecutor *executor);
void CreateContiguousTensorList(const aclTensorList *tensorList, std::vector<aclTensor *> &newTensorList,
aclOpExecutor *executor);
std::string dTypeToString(const ge::DataType &dtype);
}
#endif