* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <algorithm>
#include "fallback/fallback_comm.h"
#include "fallback/fallback.h"
#include "log/log.h"
#ifdef __cplusplus
extern "C" {
#endif
namespace fallback {
using namespace ge;
using namespace gert;
constexpr size_t INDEX_GMM_INPUT_X = 0;
constexpr size_t INDEX_GMM_INPUT_WEIGHT = 1;
constexpr size_t INDEX_GMM_INPUT_BIAS = 2;
constexpr size_t INDEX_GMM_INPUT_SCALE = 3;
constexpr size_t INDEX_GMM_INPUT_OFFSET = 4;
constexpr size_t INDEX_GMM_INPUT_ANTIQUANT_SCALE = 5;
constexpr size_t INDEX_GMM_INPUT_ANTIQUANT_OFFSET = 6;
constexpr size_t INDEX_GMM_INPUT_GROUP_LIST = 7;
constexpr size_t INDEX_GMM_INPUT_PER_TOKEN_SCALE = 8;
constexpr size_t INDEX_GMM_OUTPUT_Y = 0;
constexpr size_t INDEX_GMM_ATTR_SPLIT_ITEM = 0;
constexpr size_t INDEX_GMM_ATTR_TRANSPOSE_WEIGHT = 2;
constexpr size_t INDEX_GMM_ATTR_TRANSPOSE_X = 3;
constexpr size_t INDEX_GMM_ATTR_GROUP_TYPE = 4;
constexpr size_t INDEX_GMM_ATTR_GROUP_LIST_TYPE = 5;
constexpr size_t INDEX_GMM_ATTR_ACT_TYPE = 6;
constexpr size_t INDEX_GMM_ATTR_TUNING_CONFIG = 7;
constexpr size_t MXFP_TPYEM_SCALE_DIM_NUM = 4;
constexpr size_t MXFP_TPYEK_PER_TOKEN_SCALE_DIM_NUM = 3;
constexpr size_t MXFP_MULTI_BASE_SIZE = 2;
constexpr size_t TYPEM_SCALE_K_DIM_INDEX = 1;
constexpr size_t TYPEM_SCALE_N_DIM_INDEX = 2;
constexpr size_t PENULTIMATE_DIM = 2;
constexpr int64_t GMM_SPLIT_K = 2;
constexpr int64_t GMM_SPLIT_M = 0;
constexpr int64_t PERTILE_GROUP_SIZE = 128;
static inline aclDataType ToAclDataTypeForGMM(ge::DataType dtype) {
static const std::vector<DataType> GMM_CONVERT_TO_ACL_DataType_LIST = {
ge::DataType::DT_FLOAT8_E4M3FN, ge::DataType::DT_FLOAT8_E5M2,
ge::DataType::DT_HIFLOAT8, ge::DataType::DT_FLOAT8_E8M0,
ge::DataType::DT_FLOAT4_E2M1};
auto iter = std::find(GMM_CONVERT_TO_ACL_DataType_LIST.begin(), GMM_CONVERT_TO_ACL_DataType_LIST.end(), dtype);
if (iter == GMM_CONVERT_TO_ACL_DataType_LIST.end()) {
return aclDataType::ACL_DT_UNDEFINED;
}
return static_cast<aclDataType>(dtype);
}
static inline aclTensorList* ConvertType(aclTensorList* geTensorList) {
return geTensorList;
}
static bool IsFP8FP4BitsDataType(ge::DataType dataType)
{
return dataType == ge::DataType::DT_FLOAT8_E4M3FN || dataType == ge::DataType::DT_FLOAT8_E5M2 ||
dataType == ge::DataType::DT_HIFLOAT8 || dataType == ge::DataType::DT_FLOAT8_E8M0 ||
dataType == ge::DataType::DT_FLOAT4_E2M1;
}
static bool SetShapeAndStrideForMXFPTypeKForGMM(std::vector<int64_t> &viewShape, std::vector<int64_t> &strides) {
OP_CHECK_IF(viewShape.size() < MXFP_TPYEK_PER_TOKEN_SCALE_DIM_NUM,
OP_LOGE("GroupedMatmul aclnnfallback",
"When split k, the dim num should be greater than 2 in mx quant mode, but the actual is %zu.",
viewShape.size()),
return false);
std::swap(viewShape[0], viewShape[1]);
strides[0] = MXFP_MULTI_BASE_SIZE;
strides[1] = viewShape[0] * MXFP_MULTI_BASE_SIZE;
strides[viewShape.size() - 1] = 1;
return true;
}
static bool SetShapeAndStrideForMXFPTypeMForGMM(std::vector<int64_t> &viewShape, std::vector<int64_t> &strides) {
OP_CHECK_IF(viewShape.size() < MXFP_TPYEM_SCALE_DIM_NUM,
OP_LOGE("GroupedMatmul aclnnfallback",
"When split m, the dim num should be greater than 3 in mx quant mode, but the actual is %zu.",
viewShape.size()),
return false);
std::swap(viewShape[TYPEM_SCALE_K_DIM_INDEX], viewShape[TYPEM_SCALE_N_DIM_INDEX]);
strides[TYPEM_SCALE_K_DIM_INDEX] = MXFP_MULTI_BASE_SIZE;
strides[TYPEM_SCALE_N_DIM_INDEX] = viewShape[TYPEM_SCALE_K_DIM_INDEX] * MXFP_MULTI_BASE_SIZE;
strides[viewShape.size() - 1] = 1;
return true;
}
static bool SetShapeAndStrideForLastTwoDimForGMM(std::vector<int64_t> &viewShape, std::vector<int64_t> &strides)
{
OP_CHECK_IF(viewShape.size() < 2,
OP_LOGE("GroupedMatmul aclnnfallback", "The dim num should be greater than 2, but the actual is %zu.",
viewShape.size()),
return false);
auto dimM = viewShape.size() - 2;
auto dimN = viewShape.size() - 1;
std::swap(strides[dimM], strides[dimN]);
std::swap(viewShape[dimM], viewShape[dimN]);
return true;
}
static void InitShapeAndStrideForGMM(const gert::Shape &originShape, std::vector<int64_t> &viewShape,
std::vector<int64_t> &strides)
{
for (size_t i = 0; i < originShape.GetDimNum(); ++i) {
viewShape.push_back(originShape.GetDim(i));
}
strides.resize(viewShape.size(), 1);
for (int64_t i = viewShape.size() - 2; i >= 0; i--) {
strides[i] = viewShape[i + 1] * strides[i + 1];
}
}
static inline aclTensor *GeTensor2AclTensor(const gert::Tensor *geTensor, bool enableTranspose, size_t index,
bool enableNZ = false)
{
OP_CHECK_IF(geTensor == nullptr, OP_LOGE("GroupedMatmul aclnnfallback", "geTensor nullptr"), return nullptr);
auto storageShape = geTensor->GetStorageShape();
if (storageShape.GetDimNum() <= 1) {
return ConvertType(geTensor);
}
std::vector<int64_t> storageShapeVec;
for (size_t i = 0; i < storageShape.GetDimNum(); ++i) {
storageShapeVec.push_back(storageShape.GetDim(i));
}
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
OP_CHECK_IF(aclCreateTensor == nullptr, OP_LOGE("aclnnfallback", "aclCreateTensor nullptr"), return nullptr);
void* deviceAddr = const_cast<void*>(geTensor->GetAddr());
auto dataTypeGE = geTensor->GetDataType();
aclDataType dataType = ACL_DT_UNDEFINED;
if (IsFP8FP4BitsDataType(dataTypeGE)) {
dataType = ToAclDataTypeForGMM(dataTypeGE);
} else {
dataType = ToAclDataType(dataTypeGE);
}
const gert::Shape &origin_shape = geTensor->GetOriginShape();
std::vector<int64_t> viewShape;
std::vector<int64_t> strides;
InitShapeAndStrideForGMM(origin_shape, viewShape, strides);
if (enableTranspose) {
if (dataTypeGE == ge::DataType::DT_FLOAT8_E8M0) {
if (index == INDEX_GMM_INPUT_PER_TOKEN_SCALE) {
OP_CHECK_IF(!SetShapeAndStrideForMXFPTypeKForGMM(viewShape, strides),
OP_LOGE("GroupedMatmul aclnnfallback", "SetShapeAndStrideForMXFPTypeKForGMM failed"),
return nullptr);
} else if (index == INDEX_GMM_INPUT_SCALE) {
OP_CHECK_IF(!SetShapeAndStrideForMXFPTypeMForGMM(viewShape, strides),
OP_LOGE("GroupedMatmul aclnnfallback", "SetShapeAndStrideForMXFPTypeMForGMM failed"),
return nullptr);
}
} else {
OP_CHECK_IF(!SetShapeAndStrideForLastTwoDimForGMM(viewShape, strides),
OP_LOGE("GroupedMatmul aclnnfallback", "SetShapeAndStrideForLastTwoDimForGMM failed"),
return nullptr);
}
}
auto aclFormat = aclFormat::ACL_FORMAT_ND;
if (enableNZ && GetPrimaryFormat(geTensor->GetStorageFormat()) == ge::Format::FORMAT_FRACTAL_NZ) {
aclFormat = aclFormat::ACL_FORMAT_FRACTAL_NZ;
}
aclTensor* out = aclCreateTensor(viewShape.data(), viewShape.size(), dataType, strides.data(),
0, aclFormat, storageShapeVec.data(), storageShapeVec.size(), deviceAddr);
OP_CHECK_IF(out == nullptr, OP_LOGE("aclnnfallback", "out nullptr"), return nullptr);
return out;
}
static graphStatus PrepareGeTensorVector(const OpExecuteContext *host_api_ctx,
std::vector<const gert::Tensor *> &tensorVector, size_t index)
{
size_t cnt = 0;
while (true) {
auto inputGe = host_api_ctx->GetDynamicInputTensor(index, cnt);
if (inputGe == nullptr) {
break;
}
tensorVector.push_back(inputGe);
cnt++;
}
return GRAPH_SUCCESS;
}
static graphStatus PrepareAclTensorVector(const OpExecuteContext *host_api_ctx, std::vector<const aclTensor *> &tensorVector,
size_t index, bool enableTranspose, bool enableNZ)
{
size_t cnt = 0;
while (true) {
auto inputGe = host_api_ctx->GetDynamicInputTensor(index, cnt);
if (inputGe == nullptr) {
break;
}
auto inputAcl = GeTensor2AclTensor(inputGe, enableTranspose, index, enableNZ);
tensorVector.push_back(inputAcl);
cnt++;
}
return GRAPH_SUCCESS;
}
static graphStatus PrepareOutputTensorVector(const OpExecuteContext *host_api_ctx,
std::vector<const gert::Tensor *> &tensorVector, size_t index,
size_t numGeWeight, int32_t splitItem)
{
size_t numGeY = 0;
if (0 == splitItem || 1 == splitItem) {
numGeY = numGeWeight;
}
else if (2 == splitItem || 3 == splitItem) {
numGeY = 1;
}
else {
OP_LOGE("aclnnfallback", "Invalid value of split_item: %d, which must be one of 0/1/2/3.", splitItem);
return GRAPH_FAILED;
}
for (size_t k = 0; k < numGeY; k++) {
auto outputGe = host_api_ctx->GetOutputTensor(index + k);
if (outputGe == nullptr) {return GRAPH_FAILED;}
tensorVector.push_back(outputGe);
}
return GRAPH_SUCCESS;
}
static bool IsPerTileQuantMode(const OpExecuteContext* host_api_ctx, bool transposeX, bool transposeWeight, int64_t groupType)
{
auto perTokenScaleTensor = host_api_ctx->GetOptionalInputTensor(INDEX_GMM_INPUT_PER_TOKEN_SCALE);
auto x = host_api_ctx->GetDynamicInputTensor(INDEX_GMM_INPUT_X, 0);
auto weight = host_api_ctx->GetDynamicInputTensor(INDEX_GMM_INPUT_WEIGHT, 0);
auto scale = host_api_ctx->GetDynamicInputTensor(INDEX_GMM_INPUT_SCALE, 0);
auto groupListTensor = host_api_ctx->GetOptionalInputTensor(INDEX_GMM_INPUT_GROUP_LIST);
if (x == nullptr || weight == nullptr || scale == nullptr ||
perTokenScaleTensor == nullptr || groupListTensor==nullptr) {
return false;
}
auto &perTokenScaleOriginShape = perTokenScaleTensor->GetOriginShape();
auto perTokenScaleDimNum = perTokenScaleOriginShape.GetDimNum();
auto &xOriginShape = x->GetOriginShape();
auto xDimNum = xOriginShape.GetDimNum();
auto perTokenMDim = transposeX ?
perTokenScaleOriginShape.GetDim(xDimNum - 1) : perTokenScaleOriginShape.GetDim(xDimNum - PENULTIMATE_DIM);
auto xMDim = transposeX ? xOriginShape.GetDim(xDimNum - 1) : xOriginShape.GetDim(xDimNum - PENULTIMATE_DIM);
if ((perTokenScaleDimNum < 2 || perTokenScaleDimNum != xDimNum) || perTokenMDim != xMDim) {
return false;
}
auto &scaleShape = scale->GetOriginShape();
auto scaleDimNum = scaleShape.GetDimNum();
auto &weightShape = weight->GetOriginShape();
auto weightDimNum = weightShape.GetDimNum();
if (scaleDimNum < 2 || scaleDimNum != weightDimNum) {
return false;
}
auto scaleKDim = transposeWeight ?
scaleShape.GetDim(scaleDimNum - 1) : scaleShape.GetDim(scaleDimNum - PENULTIMATE_DIM);
auto scaleNDim = transposeWeight ?
scaleShape.GetDim(scaleDimNum - PENULTIMATE_DIM) : scaleShape.GetDim(scaleDimNum - 1);
auto weightKDim = transposeWeight ?
weightShape.GetDim(weightDimNum - 1) : weightShape.GetDim(weightDimNum - PENULTIMATE_DIM);
auto weightNDim = transposeWeight ?
weightShape.GetDim(weightDimNum - PENULTIMATE_DIM) : weightShape.GetDim(weightDimNum - 1);
auto groupNum = groupListTensor->GetOriginShape().GetDim(0);
bool isKdimValid =
((groupType == GMM_SPLIT_K && scaleKDim == (weightKDim / PERTILE_GROUP_SIZE) + groupNum) ||
(groupType == GMM_SPLIT_M && scaleKDim == (weightKDim + PERTILE_GROUP_SIZE - 1) / PERTILE_GROUP_SIZE));
return (scaleNDim == (weightNDim + PERTILE_GROUP_SIZE - 1) / PERTILE_GROUP_SIZE && isKdimValid);
}
static graphStatus GroupedMatmulExecuteFunc(OpExecuteContext* host_api_ctx)
{
OP_CHECK_IF(host_api_ctx == nullptr, OP_LOGE("aclnnfallback", "host_api_ctx is null"), return GRAPH_FAILED);
auto attrs = host_api_ctx->GetAttrs();
OP_CHECK_IF(attrs == nullptr, OP_LOGE("aclnnfallback", "attrs is null"), return GRAPH_FAILED);
const int64_t* splitItemGe = attrs->GetAttrPointer<int64_t>(INDEX_GMM_ATTR_SPLIT_ITEM);
OP_CHECK_IF(splitItemGe == nullptr, OP_LOGE("aclnnfallback", "splitItemGe is null"), return GRAPH_FAILED);
const bool* isWeightTransposedPtr = attrs->GetAttrPointer<bool>(INDEX_GMM_ATTR_TRANSPOSE_WEIGHT);
OP_CHECK_IF(isWeightTransposedPtr == nullptr, OP_LOGE("aclnnfallback", "isWeightTransposedPtr is null"),
return GRAPH_FAILED);
const bool* isXTransposed = attrs->GetAttrPointer<bool>(INDEX_GMM_ATTR_TRANSPOSE_X);
OP_CHECK_IF(isXTransposed == nullptr, OP_LOGE("aclnnfallback", "isXTransposed is null"), return GRAPH_FAILED);
const int64_t* groupTypeGe = attrs->GetAttrPointer<int64_t>(INDEX_GMM_ATTR_GROUP_TYPE);
OP_CHECK_IF(groupTypeGe == nullptr, OP_LOGE("aclnnfallback", "groupTypeGe is null"), return GRAPH_FAILED);
const int64_t* groupListTypeGe = attrs->GetAttrPointer<int64_t>(INDEX_GMM_ATTR_GROUP_LIST_TYPE);
OP_CHECK_IF(groupListTypeGe == nullptr, OP_LOGE("aclnnfallback", "groupListTypeGe is null"), return GRAPH_FAILED);
const int64_t* actTypeGe = attrs->GetAttrPointer<int64_t>(INDEX_GMM_ATTR_ACT_TYPE);
OP_CHECK_IF(actTypeGe == nullptr, OP_LOGE("aclnnfallback", "actTypeGe is null"), return GRAPH_FAILED);
const auto tuningConfigGe = attrs->GetListInt(INDEX_GMM_ATTR_TUNING_CONFIG);
OP_CHECK_IF(tuningConfigGe == nullptr, OP_LOGE("aclnnfallback", "tuningConfigGe is null"), return GRAPH_FAILED);
static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
OP_CHECK_IF(aclCreateTensorList == nullptr,
OP_LOGE("aclnnfallback", "Get opapi func aclCreateTensorList failed"), return GRAPH_FAILED);
std::vector<const aclTensor*> aclTensorVectorX;
PrepareAclTensorVector(host_api_ctx, aclTensorVectorX, INDEX_GMM_INPUT_X, *isXTransposed, false);
auto aclTensorListX = aclCreateTensorList(aclTensorVectorX.data(), aclTensorVectorX.size());
std::vector<const aclTensor*> aclTensorVectorWeight;
PrepareAclTensorVector(host_api_ctx, aclTensorVectorWeight, INDEX_GMM_INPUT_WEIGHT, *isWeightTransposedPtr, true);
size_t numGeWeight = aclTensorVectorWeight.size();
auto aclTensorListWeight = aclCreateTensorList(aclTensorVectorWeight.data(), aclTensorVectorWeight.size());
std::vector<const gert::Tensor*> geTensorVectorBias;
PrepareGeTensorVector(host_api_ctx, geTensorVectorBias, INDEX_GMM_INPUT_BIAS);
bool isScaleTransposed = false;
const bool* isScaleTransposedPtr = &isScaleTransposed;
auto scaleTensor = host_api_ctx->GetDynamicInputTensor(INDEX_GMM_INPUT_SCALE, 0);
bool isPerTile = IsPerTileQuantMode(host_api_ctx, *isXTransposed, *isWeightTransposedPtr, *groupTypeGe);
if (scaleTensor != nullptr && (scaleTensor->GetDataType() == ge::DataType::DT_FLOAT8_E8M0 || isPerTile)) {
isScaleTransposedPtr = isWeightTransposedPtr;
}
std::vector<const aclTensor*> aclTensorVectorScale;
PrepareAclTensorVector(host_api_ctx, aclTensorVectorScale, INDEX_GMM_INPUT_SCALE, *isScaleTransposedPtr, false);
auto aclTensorListScale = aclCreateTensorList(aclTensorVectorScale.data(), aclTensorVectorScale.size());
std::vector<const gert::Tensor*> geTensorVectorOffset;
PrepareGeTensorVector(host_api_ctx, geTensorVectorOffset, INDEX_GMM_INPUT_OFFSET);
std::vector<const gert::Tensor*> geTensorVectorAntiquantScale;
PrepareGeTensorVector(host_api_ctx, geTensorVectorAntiquantScale, INDEX_GMM_INPUT_ANTIQUANT_SCALE);
std::vector<const gert::Tensor*> geTensorVectorAntiquantOffset;
PrepareGeTensorVector(host_api_ctx, geTensorVectorAntiquantOffset, INDEX_GMM_INPUT_ANTIQUANT_OFFSET);
auto groupListTensor = host_api_ctx->GetOptionalInputTensor(INDEX_GMM_INPUT_GROUP_LIST);
auto perTokenScaleTensor = host_api_ctx->GetOptionalInputTensor(INDEX_GMM_INPUT_PER_TOKEN_SCALE);
std::vector<const aclTensor*> geTensorVectorPerTokenScale;
auto perTokenScale = ConvertType(perTokenScaleTensor);
if (perTokenScale == nullptr) {
std::vector<int64_t> shape{0};
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
OP_CHECK_IF(aclCreateTensor == nullptr, OP_LOGE("aclnnfallback", "aclCreateTensor nullptr"), return GRAPH_FAILED);
perTokenScale = aclCreateTensor(shape.data(), shape.size(), aclDataType::ACL_FLOAT, shape.data(),
0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), nullptr);
OP_CHECK_IF(perTokenScale == nullptr, OP_LOGE("aclnnfallback", "perTokenScale nullptr"), return GRAPH_FAILED);
}
if (perTokenScaleTensor != nullptr && (perTokenScaleTensor->GetDataType() == ge::DataType::DT_FLOAT8_E8M0 ||
isPerTile)) {
PrepareAclTensorVector(host_api_ctx, geTensorVectorPerTokenScale, INDEX_GMM_INPUT_PER_TOKEN_SCALE,
*isXTransposed, false);
} else {
geTensorVectorPerTokenScale.push_back(perTokenScale);
}
auto aclTensorListPerTokenScale = aclCreateTensorList(geTensorVectorPerTokenScale.data(),
geTensorVectorPerTokenScale.size());
std::vector<const gert::Tensor*> geTensorVectorY;
PrepareOutputTensorVector(host_api_ctx, geTensorVectorY, INDEX_GMM_OUTPUT_Y, numGeWeight, *splitItemGe);
aclTensorList* activationInputOptional = nullptr;
aclTensorList* activationQuantScaleOptional = nullptr;
aclTensorList* activationQuantOffsetOptional = nullptr;
aclTensorList* actFeatureOutOptional = nullptr;
aclTensorList* dynQuantScaleOutOptional = nullptr;
std::vector<int64_t> tuningConfig;
tuningConfig.reserve(1);
tuningConfig.push_back(tuningConfigGe->GetData()[0]);
auto api_ret_gmm = EXEC_OPAPI_CMD(aclnnGroupedMatmulV5, aclTensorListX, aclTensorListWeight, geTensorVectorBias,
aclTensorListScale, geTensorVectorOffset, geTensorVectorAntiquantScale,
geTensorVectorAntiquantOffset, aclTensorListPerTokenScale, groupListTensor,
activationInputOptional, activationQuantScaleOptional, activationQuantOffsetOptional,
*splitItemGe, *groupTypeGe, *groupListTypeGe, *actTypeGe, tuningConfig,
geTensorVectorY, actFeatureOutOptional, dynQuantScaleOutOptional);
OP_CHECK_IF(api_ret_gmm != GRAPH_SUCCESS, OP_LOGE("aclnnfallback", "api_ret failed:%u", api_ret_gmm),
return GRAPH_FAILED);
return GRAPH_SUCCESS;
}
IMPL_OP(GroupedMatmul).OpExecuteFunc(GroupedMatmulExecuteFunc);
}
#ifdef __cplusplus
}
#endif