* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_grouped_matmul.h"
#include "aclnn_grouped_matmul_v2.h"
#include "aclnn_grouped_matmul_v3.h"
#include "aclnn_grouped_matmul_v4.h"
#include "aclnn_grouped_matmul_v5.h"
#include "aclnn_grouped_matmul_weight_nz.h"
#include <dlfcn.h>
#include <new>
#include "aclnn_kernels/transdata.h"
#include "grouped_matmul.h"
#include "aclnn_kernels/contiguous.h"
#include "../op_host/grouped_matmul_host_util.h"
#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/platform.h"
#include "opdev/shape_utils.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/make_op_executor.h"
#include "grouped_matmul_util.h"
#include "grouped_matmul_950_checker.h"
#include "grouped_matmul_weight_quant_950_checker.h"
#include "grouped_matmul_no_quant_950_checker.h"
using namespace op;
#define DEPRECATED_API_WARN_ONCE(oldApiName, newApiName) \
do { \
static bool isFirstWarn = true; \
if (isFirstWarn){ \
OP_LOGW("%s is scheduled to be deprecated in a post-December 2026 version update, " \
"and will be replaced by the %s. " \
"We apologize for any inconvenience caused and appreciate your timely migration to the new interface.", (oldApiName), (newApiName)); \
isFirstWarn = false; \
} \
} while(0)
#ifdef __cplusplus
extern "C" {
#endif
namespace {
static constexpr int64_t X_Y_SEPARATED = 0L;
static constexpr int64_t Y_SEPARATED = 1L;
static constexpr int64_t X_SEPARATED = 2L;
static constexpr int64_t NO_SEPARATED = 3L;
static constexpr int64_t MAX_GROUP_LIST_SIZE_ARRAY = 128L;
static constexpr int64_t MAX_GROUP_LIST_SIZE_TENSOR = 1024L;
static constexpr int64_t MAX_INNER_AXIS = 65535L;
static constexpr size_t SEPARATED_WEIGHT_DIM = 2UL;
static constexpr int64_t END_ACT_TYPE_ENUM = 6L;
static constexpr size_t ALIGN_NZ_4BIT_N = 64UL;
static constexpr size_t ALIGN_NZ_4BIT_K = 64UL;
static constexpr size_t ALIGN_NZ_INT8_N = 32UL;
static constexpr size_t ALIGN_NZ_K = 16UL;
static constexpr uint64_t B4_PER_B32 = 8UL;
static constexpr size_t WEIGHT_DIM_A8W4 = 3UL;
static constexpr size_t OFFSET_DIM_A8W4 = 3UL;
static constexpr size_t BIAS_DIM_A8W4 = 2UL;
static constexpr size_t PER_CHANNEL_SCALE_DIM = 2UL;
static constexpr size_t PER_GROUP_SCALE_DIM = 3UL;
static constexpr size_t DIMS_THREE_FOR_GMM = 3UL;
static constexpr size_t GROUP_LIST_SPARSE_DIM_NUM = 2UL;
static constexpr int ANTIQUANT_SCALE_3D_DIMS = 3;
static constexpr int ANTIQUANT_SCALE_4D_DIMS = 4;
static constexpr int SCALE_TENSOR_EXPECTED_DIMS = 2;
static bool IsFormatNZWithC0(const aclTensor* tensor) {
return ge::GetPrimaryFormat(tensor->GetStorageFormat()) == op::Format::FORMAT_FRACTAL_NZ_C0_2 ||
ge::GetPrimaryFormat(tensor->GetStorageFormat()) == op::Format::FORMAT_FRACTAL_NZ_C0_4;
}
static bool IsFormatNZ(const aclTensor* tensor) {
return ge::GetPrimaryFormat(tensor->GetStorageFormat()) == op::Format::FORMAT_FRACTAL_NZ ||
IsFormatNZWithC0(tensor);
}
static aclnnStatus SetSpecialNZTensorToNormalNZFormat(const aclTensorList *&tensorListInput) {
if (tensorListInput->Size() <= 0 || !IsFormatNZWithC0((*tensorListInput)[0])) {
return ACLNN_SUCCESS;
}
OP_LOGD("Set NZ_C0 format to NZ format begin.");
auto tensorList = const_cast<aclTensorList *>(tensorListInput);
for (size_t i = 0; i < tensorList->Size();++i) {
(*tensorList)[i]->SetViewFormat(op::Format::FORMAT_ND);
(*tensorList)[i]->SetOriginalFormat(op::Format::FORMAT_ND);
(*tensorList)[i]->SetStorageFormat(op::Format::FORMAT_FRACTAL_NZ);
}
OP_LOGD("Set NZ_C0 format to NZ format finish.");
return ACLNN_SUCCESS;
}
static void SetStorageShapeForNZ(aclTensor* tensor) {
auto storageShape = tensor->GetStorageShape();
auto storageShapeDim = storageShape.GetDimNum();
storageShape[storageShapeDim - 1] *= B4_PER_B32;
tensor->SetStorageShape(storageShape);
}
static void UnpackB32ToB4(const aclTensorList *&tensorListB32, const std::string& tensorListType, bool skipTranspose = false)
{
if (tensorListB32->Size() <= 0) {
return;
}
DataType b32Dtype = (*tensorListB32)[0]->GetDataType();
DataType b4Dtype = DataType::DT_INT4;
if (b32Dtype == DataType::DT_FLOAT) {
b4Dtype = DataType::DT_FLOAT4_E2M1;
}
OP_LOGD("Unpack %s from %s to %s start.", tensorListType.c_str(), gmm::dTypeToString(b32Dtype).c_str(),
gmm::dTypeToString(b4Dtype).c_str());
auto tensorListB4 = const_cast<aclTensorList *>(tensorListB32);
for (size_t i = 0; i < tensorListB4->Size();++i) {
op::Shape tensorShape = (*tensorListB4)[i]->GetViewShape();
op::Strides newStride = (*tensorListB4)[i]->GetViewStrides();
auto viewShapeDim = tensorShape.GetDimNum();
bool transposeTensor = false;
auto changeDimIdx = viewShapeDim - 1;
if (!skipTranspose && viewShapeDim >= 2 && gmm::IsTransposeLastTwoDims((*tensorListB4)[i])) {
transposeTensor = true;
changeDimIdx = viewShapeDim - 2;
}
tensorShape[changeDimIdx] = tensorShape[changeDimIdx] * B4_PER_B32;
(*tensorListB4)[i]->SetViewShape(tensorShape);
(*tensorListB4)[i]->SetDataType(b4Dtype);
if (IsFormatNZ((*tensorListB4)[i])) {
SetStorageShapeForNZ((*tensorListB4)[i]);
}
if (transposeTensor) {
auto strideSize = newStride.size();
newStride[strideSize - 1] *= B4_PER_B32;
for (int64_t batchDim = strideSize - 3; batchDim >= 0; batchDim--) {
newStride[batchDim] *= B4_PER_B32;
}
(*tensorListB4)[i]->SetViewStrides(newStride);
}
OP_LOGD("Current tensorlist dim : %zu, transpose status: %d.", i, transposeTensor);
}
OP_LOGD("Unpack %s from %s to %s finished.", tensorListType.c_str(), gmm::dTypeToString(b32Dtype).c_str(),
gmm::dTypeToString(b4Dtype).c_str());
}
bool IsQuant(const DataType &xDtype, const DataType &weightDtype)
{
if (xDtype == DataType::DT_FLOAT4_E2M1 || xDtype == DataType::DT_INT4 || xDtype == DataType::DT_FLOAT4_E1M2) {
return true;
}
return ge::GetSizeByDataType(xDtype) == 1 && ge::GetSizeByDataType(weightDtype) == 1;
}
bool IsWeightQuant(const DataType &xDtype, const DataType &weightDtype)
{
return ge::GetSizeByDataType(xDtype) != ge::GetSizeByDataType(weightDtype);
}
}
namespace {
static bool CheckSpecialTranspose(const aclTensorList *tensorList) {
int64_t loopNum = tensorList->Size();
int64_t dimNum = 0;
int64_t checkedAxisNum = 0;
int64_t lastAxisSize = 0;
bool transpose = false;
for (int64_t i = 0; i < loopNum; ++i) {
lastAxisSize = 1;
transpose = gmm::IsTransposeLastTwoDims((*tensorList)[i]);
auto shape = (*tensorList)[i]->GetViewShape();
dimNum = shape.GetDimNum();
checkedAxisNum = dimNum > 1 ? 2 : 1;
for (int64_t j = 1; j <= checkedAxisNum; ++j) {
lastAxisSize *= shape.GetDim(dimNum - j);
}
transpose = transpose && (lastAxisSize != 1);
if (transpose) {
break;
}
}
return transpose;
}
}
namespace {
static aclnnStatus CheckShapeSameLengthTensorList(const aclTensorList *tensorList1, const aclTensorList *tensorList2,
const std::vector<size_t>& dimIds, const int64_t innerAxisDimId,
const std::vector<std::string>& tensorType) {
uint64_t groupNum = tensorList1->Size();
for (uint64_t i = 0; i < groupNum; i++) {
int64_t dimValue1 = (*tensorList1)[i]->GetViewShape().GetDim(dimIds[0]);
if (op::GetCurrentPlatformInfo().GetCurNpuArch() != NpuArch::DAV_3510) {
if (tensorType[2] == "true" && innerAxisDimId > -1) {
int64_t innerAxisValue = (*tensorList1)[i]->GetViewShape().GetDim(innerAxisDimId);
CHECK_COND(innerAxisValue <= MAX_INNER_AXIS, ACLNN_ERR_PARAM_INVALID, "Dim %lu value of %s[%lu] should less or equal to 65535, but now is %ld.",
dimIds[0], tensorType[0].c_str(), i, innerAxisValue);
}
}
int64_t dimValue2 = (*tensorList2)[i]->GetViewShape().GetDim(dimIds[1]);
CHECK_COND(dimValue1 == dimValue2, ACLNN_ERR_PARAM_INVALID,
"Dim %lu value of %s[%lu] should be equal with dim %lu value of %s[%lu], but now is %ld and %ld respectively.",
dimIds[0], tensorType[0].c_str(), i, dimIds[1], tensorType[1].c_str(), i, dimValue1, dimValue2);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckShapeDiffLengthTensorList(const aclTensorList *longTensorList,
const aclTensorList *singleTensorList,
const std::vector<size_t>& dimIds,
const int64_t innerAxisdimId,
const std::vector<std::string>& tensorType) {
int64_t dimValueSingle = (*singleTensorList)[0]->GetViewShape().GetDim(dimIds[1]);
if (op::GetCurrentPlatformInfo().GetCurNpuArch() != NpuArch::DAV_3510) {
if (tensorType[2] == "true" && innerAxisdimId > -1) {
int64_t dimValue = (*singleTensorList)[0]->GetViewShape().GetDim(innerAxisdimId);
CHECK_COND(dimValue <= MAX_INNER_AXIS, ACLNN_ERR_PARAM_INVALID, "Dim %ld value of %s[0] should less or equal to 65535, but now is %ld.",
innerAxisdimId, tensorType[1].c_str(), dimValue);
}
}
uint64_t groupNum = longTensorList->Size();
for (uint64_t i = 0; i < groupNum; i++) {
int64_t dimValueLong = (*longTensorList)[i]->GetViewShape().GetDim(dimIds[0]);
CHECK_COND(dimValueLong == dimValueSingle, ACLNN_ERR_PARAM_INVALID,
"Dim %lu value of %s[%lu] %ld should be equal with dim %lu value of %s[0] %ld.",
dimIds[0], tensorType[0].c_str(), i, dimValueLong,
dimIds[1], tensorType[1].c_str(), dimValueSingle);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckFormat(const aclTensor *tensor, const std::string& tensorType, size_t idx) {
bool isWeightTensor = tensorType == "weight";
op::Format tensorFormat = tensor->GetStorageFormat();
CHECK_COND(tensorFormat < Format::FORMAT_END, ACLNN_ERR_PARAM_INVALID, "Format of %s[%lu] %s is invalid.",
tensorType.c_str(), idx, op::ToString(tensorFormat).GetString());
if (isWeightTensor) {
CHECK_COND(!op::IsPrivateFormat(tensorFormat) || tensorFormat == Format::FORMAT_FRACTAL_NZ ||
tensorFormat == Format::FORMAT_FRACTAL_NZ_C0_16 || tensorFormat == Format::FORMAT_FRACTAL_NZ_C0_32,
ACLNN_ERR_PARAM_INVALID, "Format of %s[%lu] %s is invalid.", tensorType.c_str(), idx,
op::ToString(tensorFormat).GetString());
} else {
CHECK_COND(!op::IsPrivateFormat(tensorFormat), ACLNN_ERR_PARAM_INVALID, "Format of %s[%lu] %s is invalid.",
tensorType.c_str(), idx, op::ToString(tensorFormat).GetString());
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckShapeDiffLengthTensorListSplitAxis(const aclTensorList *longTensorList,
const aclTensorList *singleTensorList,
const size_t dimIdxLongTensorList,
const size_t dimIdxSingleTensorList,
const std::vector<std::string>& tensorType) {
int64_t dimValueSingle = (*singleTensorList)[0]->GetViewShape().GetDim(dimIdxSingleTensorList);
uint64_t groupNum = longTensorList->Size();
int64_t preOffset = 0;
for (uint64_t i = 0; i < groupNum; i++) {
int64_t dimValueLong = (*longTensorList)[i]->GetViewShape().GetDim(dimIdxLongTensorList);
preOffset += dimValueLong;
}
CHECK_COND(preOffset == dimValueSingle, ACLNN_ERR_PARAM_INVALID,
"Sum of dim %lu value of %s %ld should be equal with dim %lu value of %s[0] %ld.",
dimIdxLongTensorList, tensorType[0].c_str(), preOffset,
dimIdxSingleTensorList, tensorType[1].c_str(), dimValueSingle);
return ACLNN_SUCCESS;
}
static aclnnStatus PreCheckGroupType(int64_t splitItem, int64_t groupType) {
CHECK_COND(groupType != gmm::SPLIT_N, ACLNN_ERR_PARAM_INVALID, "Not support split n dim now, groupType can not be 1.");
CHECK_COND(groupType == gmm::SPLIT_M || groupType == gmm::SPLIT_K || groupType == gmm::NO_SPLIT, ACLNN_ERR_PARAM_INVALID,
"groupType only support -1/0/2 now, but given groupType is %ld", groupType);
if (splitItem == X_SEPARATED || splitItem == NO_SEPARATED) {
CHECK_COND(groupType != gmm::NO_SPLIT, ACLNN_ERR_PARAM_INVALID, "When splitItem is 2/3, groupType can not be -1.");
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckDimNumAndFormat(const gmm::GroupedMatmulParams &gmmParams, const aclTensorList *tensorList,
const size_t expectedDimNum, const std::string& tensorType) {
uint64_t tensorListLength = tensorList->Size();
for (size_t i = 0; i < tensorListLength; ++i) {
CHECK_COND((*tensorList)[i] != nullptr, ACLNN_ERR_PARAM_INVALID,
"%s[%lu] is null, which is not supported.", tensorType.c_str(), i);
CHECK_COND(CheckFormat((*tensorList)[i], tensorType, i) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Invalid format.");
size_t dimNum = (*tensorList)[i]->GetViewShape().GetDimNum();
CHECK_COND(dimNum == expectedDimNum, ACLNN_ERR_PARAM_INVALID,
"%s[%lu] dim num should be %lu in this case, but now is %lu.",
tensorType.c_str(), i, expectedDimNum, dimNum);
if (tensorType == "weight") {
bool transpose = gmmParams.groupType == gmm::SPLIT_K ? CheckSpecialTranspose(tensorList) : gmm::IsTransposeLastTwoDims((*gmmParams.weight)[i]);
CHECK_COND(transpose == gmmParams.transposeWeight, ACLNN_ERR_PARAM_INVALID,
"The transpose state must be the same for each tensor in weight.");
}
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckDimNumAndGroupListNoSplitAndFormat(const gmm::GroupedMatmulParams &gmmParams) {
if (gmmParams.apiVersion != gmm::GMMApiVersion::V1) {
CHECK_COND(gmmParams.groupListOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"GroupListOptional should be nullptr when groupType is -1.");
}
size_t tensorListLength = gmmParams.x->Size();
if (gmmParams.groupListOptional != nullptr) {
CHECK_COND(gmmParams.groupListOptional->Size() == tensorListLength, ACLNN_ERR_PARAM_INVALID,
"Size of groupListOptional %lu should be equal to size of x %lu.",
gmmParams.groupListOptional->Size(), tensorListLength);
}
if (gmmParams.groupTensorOptional != nullptr) {
CHECK_COND(gmmParams.groupTensorOptional->GetViewShape().GetDim(0) == static_cast<int64_t>(tensorListLength),
ACLNN_ERR_PARAM_INVALID, "Size of groupListOptional(tensor) %ld should be equal to size of x %zu.",
gmmParams.groupTensorOptional->GetViewShape().GetDim(0), tensorListLength);
}
int64_t preGoupList = 0;
for (size_t i = 0; i < tensorListLength; ++i) {
CHECK_COND((*gmmParams.x)[i] != nullptr, ACLNN_ERR_PARAM_INVALID, "X[%lu] is null, which is not supported.", i);
CHECK_COND((*gmmParams.weight)[i] != nullptr, ACLNN_ERR_PARAM_INVALID, "Weight[%lu] is null, which is not supported.", i);
CHECK_COND(gmm::IsTransposeLastTwoDims((*gmmParams.weight)[i]) == gmmParams.transposeWeight, ACLNN_ERR_PARAM_INVALID,
"The transpose state must be the same for each tensor in weight.");
CHECK_COND((*gmmParams.y)[i] != nullptr, ACLNN_ERR_PARAM_INVALID, "Y[%lu] is null, which is not supported.", i);
CHECK_COND(CheckFormat((*gmmParams.x)[i], "x", i) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Invalid format.");
CHECK_COND(CheckFormat((*gmmParams.weight)[i], "weight", i) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Invalid format.");
CHECK_COND(CheckFormat((*gmmParams.y)[i], "y", i) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Invalid format.");
size_t xDimNum = (*gmmParams.x)[i]->GetViewShape().GetDimNum();
size_t weightDimNum = (*gmmParams.weight)[i]->GetViewShape().GetDimNum();
size_t yDimNum = (*gmmParams.y)[i]->GetViewShape().GetDimNum();
CHECK_COND(xDimNum <= gmm::MAX_FM_DIM && xDimNum >= gmm::MIN_FM_DIM, ACLNN_ERR_PARAM_INVALID,
"X[%lu] dimNum is %lu , but only support 2-6.", i, xDimNum);
CHECK_COND(weightDimNum == SEPARATED_WEIGHT_DIM, ACLNN_ERR_PARAM_INVALID,
"Weight[%lu] dimNum is %lu , but only support 2 when weight separated.", i, weightDimNum);
CHECK_COND(xDimNum == yDimNum, ACLNN_ERR_PARAM_INVALID,
"Y[%lu] dimNum %lu should be equal with x[%lu] DimNum %lu.", i, yDimNum, i, xDimNum);
if (xDimNum > gmm::MIN_FM_DIM) {
CHECK_COND(gmmParams.groupListOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"GroupListOptional should be nullptr when x, y both separated and dim num larger than 2.");
}
if (xDimNum == gmm::MIN_FM_DIM && gmmParams.groupListOptional != nullptr) {
int64_t xMDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(0);
std::string errorMessage = i == 0UL ? "GroupListOptional[0]" :
"GroupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1UL) + "]";
CHECK_COND(xMDimValue == (*gmmParams.groupListOptional)[i] - preGoupList, ACLNN_ERR_PARAM_INVALID,
"X[%lu] dim 0 value %ld should be equal to %s %ld.",
i, xMDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preGoupList);
preGoupList = (*gmmParams.groupListOptional)[i];
}
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckTensorListNotNull(const aclTensorList *tensorList, const std::string &tensorType)
{
uint64_t tensorListLength = tensorList->Size();
for (size_t i = 0; i < tensorListLength; ++i) {
CHECK_COND((*tensorList)[i] != nullptr, ACLNN_ERR_PARAM_NULLPTR, "%s[%lu] is null, which is not supported.",
tensorType.c_str(), i);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckNotNull(const aclTensorList *x, const aclTensorList *weight, const aclTensorList *y)
{
CHECK_COND(x != nullptr, ACLNN_ERR_PARAM_NULLPTR, "X must not be nullptr.");
CHECK_COND(x->Size() != 0, ACLNN_ERR_PARAM_INVALID, "X must not be empty tensorlist.");
CHECK_COND(CheckTensorListNotNull(x, "X") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "X must not be nullptr.");
CHECK_COND(weight != nullptr, ACLNN_ERR_PARAM_NULLPTR, "Weight must not be nullptr.");
CHECK_COND(weight->Size() != 0, ACLNN_ERR_PARAM_INVALID, "Weight must not be empty tensorlist.");
CHECK_COND(CheckTensorListNotNull(weight, "Weight") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Weight must not be nullptr.");
CHECK_COND(y != nullptr, ACLNN_ERR_PARAM_NULLPTR, "Y must not be nullptr.");
CHECK_COND(y->Size() != 0, ACLNN_ERR_PARAM_INVALID, "Y must not be empty tensorlist.");
CHECK_COND(CheckTensorListNotNull(y, "Y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Y must not be nullptr.");
return ACLNN_SUCCESS;
}
static aclnnStatus CheckGroupListCommonIntArray(const gmm::GroupedMatmulParams &gmmParams, const bool isRequiredGroupList,
const size_t groupNum, int64_t &groupListLastValue) {
CHECK_COND(gmmParams.groupListOptional != nullptr || !isRequiredGroupList, ACLNN_ERR_PARAM_NULLPTR,
"GroupListOptional required in this case, but get nullptr.");
if (gmmParams.groupListOptional != nullptr) {
uint64_t groupListSize = gmmParams.groupListOptional->Size();
CHECK_COND(groupListSize <= MAX_GROUP_LIST_SIZE_ARRAY, ACLNN_ERR_PARAM_INVALID,
"When groupList type is int array, size of groupList %lu should be less than or equal to %ld.",
groupListSize, MAX_GROUP_LIST_SIZE_ARRAY);
int64_t preGoupList = 0;
for (size_t i = 0; i < groupListSize; i++) {
CHECK_COND((*gmmParams.groupListOptional)[i] >= preGoupList, ACLNN_ERR_PARAM_INVALID,
"GroupListOptional should be non-negative and incremental.");
preGoupList = (*gmmParams.groupListOptional)[i];
}
CHECK_COND((groupListSize == groupNum && groupNum > 1) || groupNum == 1, ACLNN_ERR_PARAM_INVALID,
"When groupList is not null, size of groupList %lu should be equal to groupNum %lu.",
groupListSize, groupNum);
groupListLastValue = preGoupList;
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckGroupListCommonTensor(const gmm::GroupedMatmulParams &gmmParams, const bool isRequiredGroupList,
const size_t groupNum) {
CHECK_COND(!(gmmParams.groupTensorOptional == nullptr && isRequiredGroupList), ACLNN_ERR_PARAM_INVALID,
"GroupListOptional(tensor) is required in this case, but get nullptr.");
if (gmmParams.groupTensorOptional != nullptr) {
int64_t groupListSize = gmmParams.groupTensorOptional->GetViewShape().GetDim(0);
size_t groupListDimNum = gmmParams.groupTensorOptional->GetViewShape().GetDimNum();
if (gmmParams.groupListType == gmm::GROUP_LIST_SPARSE_M) {
CHECK_COND(groupListDimNum == GROUP_LIST_SPARSE_DIM_NUM,
ACLNN_ERR_PARAM_INVALID,
"When groupList type is 2, dim num of groupList should be 2, but now is %zu.", groupListDimNum);
CHECK_COND(gmmParams.groupTensorOptional->GetViewShape().GetDim(1) == GROUP_LIST_SPARSE_DIM_NUM,
ACLNN_ERR_PARAM_INVALID,
"When groupList type is 2, the 2nd dim of groupList should be 2, but now is %ld.",
gmmParams.groupTensorOptional->GetViewShape().GetDim(1));
}
CHECK_COND(groupListSize <= MAX_GROUP_LIST_SIZE_TENSOR, ACLNN_ERR_PARAM_INVALID,
"When groupList type is tenosr, size of groupList %ld should be less than or equal to %ld.",
groupListSize, MAX_GROUP_LIST_SIZE_TENSOR);
CHECK_COND((groupListSize == static_cast<int64_t>(groupNum) && groupNum > 1) || groupNum == 1, ACLNN_ERR_PARAM_INVALID,
"When groupList is not null, size of groupList(tensor) %ld should be equal to %s %lu with groupType %ld.",
groupListSize, gmmParams.groupType == gmm::SPLIT_M ? "weight dim 0" : "y dim 0", groupNum,
gmmParams.groupType);
CHECK_COND(gmmParams.groupTensorOptional->GetDataType() == DataType::DT_INT64, ACLNN_ERR_PARAM_INVALID,
"Invalid dtype: Only int64 is supported for groupList.");
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckGroupListSplitK(const gmm::GroupedMatmulParams &gmmParams, const bool isRequiredGroupList,
const bool xSeparated, const bool weightSeparated, const size_t groupNum) {
int64_t groupListLastValue = 0;
if (gmmParams.apiVersion == gmm::GMMApiVersion::WeightNz || gmmParams.apiVersion == gmm::GMMApiVersion::V5
|| gmmParams.apiVersion == gmm::GMMApiVersion::V4 || gmmParams.apiVersion == gmm::GMMApiVersion::V3) {
CHECK_COND(CheckGroupListCommonTensor(gmmParams, isRequiredGroupList, groupNum) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "CheckGroupListCommonTensor failed in groupType 2.");
return ACLNN_SUCCESS;
}
CHECK_COND(
CheckGroupListCommonIntArray(gmmParams, isRequiredGroupList, groupNum, groupListLastValue) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "CheckGroupListCommonIntArray failed.");
if (gmmParams.groupListOptional != nullptr) {
if (xSeparated) {
int64_t preOffset = 0;
for (size_t i = 0; i < groupNum; i++) {
int64_t xKDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(1);
std::string errorMessage = i == 0UL ? "GroupListOptional[0]" :
"GroupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1UL) + "]";
CHECK_COND(
xKDimValue == (*gmmParams.groupListOptional)[i] - preOffset, ACLNN_ERR_PARAM_INVALID, "X[%lu] dim 1 value %ld should be equal to %s %ld.",
i, xKDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preOffset);
preOffset = (*gmmParams.groupListOptional)[i];
}
} else if (weightSeparated) {
int64_t preOffset = 0;
for (size_t i = 0; i < groupNum; i++) {
int64_t weightKDimValue = (*gmmParams.weight)[i]->GetViewShape().GetDim(0);
std::string errorMessage = i == 0UL ? "GroupListOptional[0]" :
"GroupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1UL) + "]";
CHECK_COND(
weightKDimValue == (*gmmParams.groupListOptional)[i] - preOffset, ACLNN_ERR_PARAM_INVALID, "Weight[%lu] dim 0 %ld value should be equal to %s %ld.",
i, weightKDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preOffset);
preOffset = (*gmmParams.groupListOptional)[i];
}
} else {
CHECK_COND(
(*gmmParams.x)[0]->GetViewShape().GetDim(1) == groupListLastValue, ACLNN_ERR_PARAM_INVALID, "When splited axis is K, the last value of group list(%ld) must equal with x shape[1] (%ld).",
groupListLastValue, (*gmmParams.x)[0]->GetViewShape().GetDim(1));
}
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckGroupListSplitM(const gmm::GroupedMatmulParams &gmmParams, const bool isRequiredGroupList,
const bool xSeparated, const bool ySeparated, const size_t groupNum) {
int64_t groupListLastValue = 0;
if (gmmParams.apiVersion == gmm::GMMApiVersion::WeightNz || gmmParams.apiVersion == gmm::GMMApiVersion::V5
|| gmmParams.apiVersion == gmm::GMMApiVersion::V4 || gmmParams.apiVersion == gmm::GMMApiVersion::V3) {
CHECK_COND(CheckGroupListCommonTensor(gmmParams, isRequiredGroupList, groupNum) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "CheckGroupListCommonTensor failed in groupType 0.");
return ACLNN_SUCCESS;
}
CHECK_COND(
CheckGroupListCommonIntArray(gmmParams, isRequiredGroupList, groupNum, groupListLastValue) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "CheckGroupListCommonIntArray failed!");
if (gmmParams.groupListOptional != nullptr) {
if (xSeparated) {
int64_t preGoupList = 0;
for (size_t i = 0; i < groupNum; i++) {
int64_t xMDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(0);
std::string errorMessage = i == 0UL ? "GroupListOptional[0]" :
"GroupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1UL) + "]";
CHECK_COND(
xMDimValue == (*gmmParams.groupListOptional)[i] - preGoupList, ACLNN_ERR_PARAM_INVALID, "X[%lu] dim 0 value %ld should be equal to %s %ld.",
i, xMDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preGoupList);
preGoupList = (*gmmParams.groupListOptional)[i];
}
} else if (ySeparated) {
int64_t preGoupList = 0;
for (size_t i = 0; i < groupNum; i++) {
int64_t yMDimValue = (*gmmParams.y)[i]->GetViewShape().GetDim(0);
std::string errorMessage = i == 0UL ? "GroupListOptional[0]" :
"GroupListOptional[" + std::to_string(i) + "] - groupListOptional[" + std::to_string(i - 1UL) + "]";
CHECK_COND(
yMDimValue == (*gmmParams.groupListOptional)[i] - preGoupList, ACLNN_ERR_PARAM_INVALID, "Y[%lu] dim 0 value %ld should be equal to %s %ld.",
i, yMDimValue, errorMessage.c_str(), (*gmmParams.groupListOptional)[i] - preGoupList);
preGoupList = (*gmmParams.groupListOptional)[i];
}
} else {
CHECK_COND(
(*gmmParams.x)[0]->GetViewShape().GetDim(0) == groupListLastValue, ACLNN_ERR_PARAM_INVALID, "When splited axis is M, the last value of group list(%ld) must equal with x shape[0] (%ld).",
groupListLastValue, (*gmmParams.x)[0]->GetViewShape().GetDim(0));
}
}
return ACLNN_SUCCESS;
}
static uint64_t GetGroupSize(const gmm::GroupedMatmulParams &gmmParams) {
if (gmmParams.x->Size() > 1) {
return gmmParams.x->Size();
}
if (gmmParams.weight->Size() > 1) {
return gmmParams.weight->Size();
}
if (gmmParams.y->Size() > 1) {
return gmmParams.y->Size();
}
if (gmmParams.groupListOptional != nullptr) {
return gmmParams.groupListOptional->Size();
}
if (gmmParams.groupTensorOptional != nullptr) {
return gmmParams.groupTensorOptional->GetViewShape().GetDim(0);
}
return 1;
}
static aclnnStatus CheckDimNumAndPerGroupNum(bool isAntiquantInt4, std::tuple<size_t, size_t> tensorDimNums,
const gert::Shape& tensorShape, const std::string& tensorType, int64_t weightKDimValue) {
size_t tensorDimNum = std::get<0>(tensorDimNums);
size_t expectedDimNum = std::get<1>(tensorDimNums);
if (isAntiquantInt4) {
if (tensorDimNum == expectedDimNum) {
int64_t perGroupNum = tensorShape.GetDim(tensorDimNum - 2);
CHECK_COND(perGroupNum > 0 && weightKDimValue % perGroupNum == 0, ACLNN_ERR_PARAM_INVALID,
"PerGroupNum must be larger than 0, and can evenly divided by K[%ld] in A16W4-pergroup case,"
" but now perGroupNum is %ld.", weightKDimValue, perGroupNum);
} else {
CHECK_COND(tensorDimNum == expectedDimNum - 1, ACLNN_ERR_PARAM_INVALID,
"%s Dim must be %zu in A16W4-perchannel case, but now is %zu.",
tensorType.c_str(), expectedDimNum - 1, tensorDimNum);
}
} else {
CHECK_COND(tensorDimNum == expectedDimNum - 1, ACLNN_ERR_PARAM_INVALID,
"%s Dim must be %zu, but now is %zu.",
tensorType.c_str(), expectedDimNum - 1, tensorDimNum);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckOptionalTensorList(const gmm::GroupedMatmulParams &gmmParams, const aclTensorList *tensorList,
const std::string& tensorType) {
uint64_t numTotal = GetGroupSize(gmmParams);
uint64_t tensorSize = tensorList->Size();
uint64_t weightGroupedSize = gmmParams.weight->Size();
auto w0Shape = (*gmmParams.weight)[0]->GetViewShape();
uint64_t weightNDimIdx = w0Shape.GetDimNum() - 1;
int64_t weightKDimValue = w0Shape.GetDim(w0Shape.GetDimNum() - 2);
CHECK_COND(tensorSize == weightGroupedSize, ACLNN_ERR_PARAM_INVALID, "%s size[%lu] must be "
"equal with weight size[%lu].", tensorType.c_str(), tensorSize, weightGroupedSize);
DataType w0Dtype = (*gmmParams.weight)[0]->GetDataType();
bool isAntiquantInt4 = (w0Dtype == DataType::DT_INT4 && tensorType.find("antiquant") != std::string::npos);
if (gmmParams.isSingleWeight) {
CHECK_COND((*tensorList)[0] != nullptr, ACLNN_ERR_PARAM_INVALID,
"%s[0] must not be nullptr, but now is nullptr.", tensorType.c_str());
auto tensor0Shape = (*tensorList)[0]->GetViewShape();
size_t tensorDimNum = tensor0Shape.GetDimNum();
CHECK_COND(CheckDimNumAndPerGroupNum(isAntiquantInt4, {tensorDimNum, 3}, tensor0Shape, tensorType, weightKDimValue)
== ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "CheckDimNumAndPerGroupNum failed.");
uint64_t batchSize = tensor0Shape.GetDim(0);
CHECK_COND(batchSize == numTotal, ACLNN_ERR_PARAM_INVALID, "%s batch size[%lu] should be euqal "
"with groupList length[%lu].", tensorType.c_str(), batchSize, numTotal);
int64_t weightNDimValue = w0Shape.GetDim(weightNDimIdx);
int64_t tensorNDimIdx = tensorDimNum == 4 ? tensorDimNum - 2 : tensorDimNum - 1;
int64_t tensorNDimValue = tensor0Shape.GetDim(tensorNDimIdx);
CHECK_COND(tensorNDimValue == weightNDimValue, ACLNN_ERR_PARAM_INVALID,
"NDim[%ld] of %s should be equal with NDim[%ld] of weight.",
tensorNDimValue, tensorType.c_str(), weightNDimValue);
} else {
for (uint64_t i = 0; i < numTotal; i++) {
CHECK_COND((*tensorList)[i] != nullptr, ACLNN_ERR_PARAM_INVALID,
"%s[%lu] must not be nullptr, but now is nullptr.", tensorType.c_str(), i);
auto tensorShape = (*tensorList)[i]->GetViewShape();
size_t tensorDimNum = tensorShape.GetDimNum();
auto wShape = (*gmmParams.weight)[i]->GetViewShape();
CHECK_COND(CheckDimNumAndPerGroupNum(isAntiquantInt4, {tensorDimNum, 2}, tensorShape, tensorType,
wShape.GetDim(0)) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "CheckDimNumAndPerGroupNum failed.");
int64_t weightNDimValue = wShape.GetDim(weightNDimIdx);
int64_t tensorNDimIdx = tensorDimNum == 4 ? tensorDimNum - 2 : tensorDimNum - 1;
int64_t tensorNDimValue = tensorShape.GetDim(tensorNDimIdx);
CHECK_COND(tensorNDimValue == weightNDimValue, ACLNN_ERR_PARAM_INVALID,
"NDim[%ld] of %s[%lu] should be equal with NDim[%ld] of weight[%lu].",
tensorNDimValue, tensorType.c_str(), i, weightNDimValue, i);
}
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckPerTokenScale(const gmm::GroupedMatmulParams &gmmParams) {
uint64_t perTokenScaleSize = gmmParams.perTokenScaleOptional->Size();
uint64_t xGroupedSize = gmmParams.x->Size();
uint64_t weightGroupedSize = gmmParams.weight->Size();
uint64_t yGroupedSize = gmmParams.y->Size();
uint64_t xMDimIdx = 0;
if (xGroupedSize == 1UL && yGroupedSize == 1UL) {
CHECK_COND(perTokenScaleSize == xGroupedSize && perTokenScaleSize == 1, ACLNN_ERR_PARAM_INVALID,
"PerTokenScaleOptional size[%zu] must be 1 and equal with x size[%zu].",
perTokenScaleSize, xGroupedSize);
CHECK_COND((*gmmParams.perTokenScaleOptional)[0] != nullptr, ACLNN_ERR_PARAM_INVALID,
"PerTokenScaleOptional[0] must not be nullptr, but now is nullptr.");
size_t tensorDimNum = (*gmmParams.perTokenScaleOptional)[0]->GetViewShape().GetDimNum();
CHECK_COND(tensorDimNum == 1, ACLNN_ERR_PARAM_INVALID,
"PerTokenScaleOptional dim num must be 1 when x is single tensor, but now is %zu.", tensorDimNum);
int64_t xMDimValue = (*gmmParams.x)[0]->GetViewShape().GetDim(xMDimIdx);
int64_t tensorMDimValue = (*gmmParams.perTokenScaleOptional)[0]->GetViewShape().GetDim(tensorDimNum - 1UL);
CHECK_COND(tensorMDimValue == xMDimValue, ACLNN_ERR_PARAM_INVALID,
"MDim[%ld] of perTokenScaleOptional should be equal with MDim[%ld] of x.",
tensorMDimValue, xMDimValue);
} else {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "Per-token quant case is only supported "
"when x, weight and y are all single tensor, but now x size is %zu, weight size is %zu, y size is %zu",
xGroupedSize, weightGroupedSize, yGroupedSize);
return ACLNN_ERR_PARAM_INVALID;
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckTensorListDataType(const aclTensorList *tensorList, const DataType dtype) {
for (size_t i = 0; i < tensorList->Size(); i++) {
const aclTensor* tensor = (*tensorList)[i];
OP_CHECK_NULL(tensor, continue);
OP_CHECK_DTYPE_NOT_MATCH(tensor, dtype, return ACLNN_ERR_PARAM_INVALID);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckMatmulDataType(const gmm::GroupedMatmulParams &gmmParams, const DataType xDtype,
const DataType weightDtype, const DataType yDtype, const DataType biasDtype) {
CHECK_COND(CheckTensorListDataType(gmmParams.x, xDtype) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"GMM: x dtype does not match with required dtype[%s].", gmm::dTypeToString(xDtype).c_str());
CHECK_COND(CheckTensorListDataType(gmmParams.weight, weightDtype) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"GMM: weight dtype does not match with required dtype[%s].", gmm::dTypeToString(weightDtype).c_str());
CHECK_COND(CheckTensorListDataType(gmmParams.y, yDtype) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"GMM: y dtype does not match with required dtype[%s].", gmm::dTypeToString(yDtype).c_str());
if (gmmParams.biasOptional != nullptr) {
CHECK_COND(CheckTensorListDataType(gmmParams.biasOptional, biasDtype) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"GMM: bias dtype does not match with required dtype[%s].", gmm::dTypeToString(biasDtype).c_str());
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckNoQuantUnusedParams(const gmm::GroupedMatmulParams &gmmParams) {
CHECK_COND(gmmParams.scaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"ScaleOptional must be nullptr in no quant case.");
CHECK_COND(gmmParams.offsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"OffsetOptional must be nullptr in no quant case.");
CHECK_COND(gmmParams.antiquantScaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"AntiquantScaleOptional must be nullptr in no quant case.");
CHECK_COND(gmmParams.antiquantOffsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"AntiquantOffsetOptional must be nullptr in no quant case.");
CHECK_COND(gmmParams.perTokenScaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"PerTokenScaleOptional must be nullptr in no quant case.");
return ACLNN_SUCCESS;
}
static aclnnStatus CheckNonQuantMatmulDataType(const gmm::GroupedMatmulParams &gmmParams, const DataType weightDtype) {
DataType biasDtype = gmmParams.xDtype == DataType::DT_BF16 ? DataType::DT_FLOAT : gmmParams.xDtype;
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
if (gmmParams.biasOptional != nullptr) {
biasDtype = (*gmmParams.biasOptional)[0]->GetDataType();
CHECK_COND(biasDtype == gmmParams.xDtype || biasDtype == DataType::DT_FLOAT, ACLNN_ERR_PARAM_INVALID,
"Non quant case biasDtype should same as xDtype or float32");
CHECK_COND(CheckNoQuantUnusedParams(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Invalid unused params");
}
}
CHECK_RET(CheckMatmulDataType(gmmParams, gmmParams.xDtype, weightDtype, gmmParams.xDtype, biasDtype) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
static aclnnStatus IsGmmQuantEmpty(const gmm::GroupedMatmulParams &gmmParams) {
CHECK_RET(gmmParams.scaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID);
CHECK_RET(gmmParams.offsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID);
CHECK_RET(gmmParams.perTokenScaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
static aclnnStatus IsGmmAntiQuantEmpty(const gmm::GroupedMatmulParams &gmmParams) {
CHECK_RET(gmmParams.antiquantScaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID);
CHECK_RET(gmmParams.antiquantOffsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
static aclnnStatus CheckNonQuant(const gmm::GroupedMatmulParams &gmmParams) {
CHECK_COND(IsGmmQuantEmpty(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Detected nonquant, but quant inputs are not empty!");
CHECK_COND(IsGmmAntiQuantEmpty(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Detected nonquant, but antiquant inputs are not empty!");
return ACLNN_SUCCESS;
}
static aclnnStatus CheckQuantParamsDtype(const gmm::GroupedMatmulParams &gmmParams, bool isPerTokenQuant) {
DataType yDtype = (*gmmParams.y)[0]->GetDataType();
for (size_t i = 0; i < gmmParams.scaleOptional->Size(); i++) {
DataType scaleDtype = (*gmmParams.scaleOptional)[i]->GetDataType();
if (isPerTokenQuant) {
bool isOutputBF16 = scaleDtype == DataType::DT_BF16 && yDtype == DataType::DT_BF16;
bool isOutputFloat16 = scaleDtype == DataType::DT_FLOAT && yDtype == DataType::DT_FLOAT16;
CHECK_COND(isOutputBF16 || isOutputFloat16, ACLNN_ERR_PARAM_INVALID,
"Per-token quant case only supports scale data type bfloat16 with output data type bfloat16,"
"or scale with data type float32 when output is float16,"
" but now scale[%zu] has data type %s and output has data type %s!",
i, gmm::dTypeToString(scaleDtype).c_str(), gmm::dTypeToString(yDtype).c_str());
} else {
bool isOutputInt8 = (scaleDtype == DataType::DT_INT64 || scaleDtype == DataType::DT_UINT64) &&
yDtype == DataType::DT_INT8;
bool isOutputBF16 = scaleDtype == DataType::DT_BF16 && yDtype == DataType::DT_BF16;
bool isOutputFP16 = scaleDtype == DataType::DT_FLOAT && yDtype == DataType::DT_FLOAT16;
CHECK_COND(isOutputInt8 || isOutputBF16 || isOutputFP16, ACLNN_ERR_PARAM_INVALID,
"Per-channel quant case only supports scale with data type int64/uint64 when output is int8, "
"or data type bfloat16 when output is bfloat16, "
"or data type float32 when output is float16, "
"but scale[%zu] has data type %s and output has data type %s!",
i, gmm::dTypeToString(scaleDtype).c_str(), gmm::dTypeToString(yDtype).c_str());
}
}
if (isPerTokenQuant) {
for (size_t i = 0; i < gmmParams.perTokenScaleOptional->Size(); i++) {
DataType perTokenScaleDtype = (*gmmParams.perTokenScaleOptional)[i]->GetDataType();
CHECK_COND(perTokenScaleDtype == DataType::DT_FLOAT, ACLNN_ERR_PARAM_INVALID,
"Per-token quant case only support perTokenScale with data type float32, "
"but perTokenScale[%zu] has data type %s!", i, gmm::dTypeToString(perTokenScaleDtype).c_str());
}
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckGroupedMatmulQuant(const gmm::GroupedMatmulParams &gmmParams) {
bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P;
CHECK_COND(!is310P, ACLNN_ERR_PARAM_INVALID,
"GMM: quant cases do not support on Ascend310P.");
CHECK_COND(gmmParams.groupType != gmm::SPLIT_K, ACLNN_ERR_PARAM_INVALID,
"GMM: quant cases do not support splited axis is K.");
CHECK_COND(gmmParams.offsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM: offset must be nullptr in quant, but now is not nullptr.");
CHECK_COND(gmmParams.scaleOptional != nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM: scale must not be nullptr in quant, but now is nullptr.");
CHECK_COND(CheckOptionalTensorList(gmmParams, gmmParams.scaleOptional, "scale") == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid scale.");
bool isPerTokenQuant = gmmParams.perTokenScaleOptional != nullptr;
CHECK_COND(CheckQuantParamsDtype(gmmParams, isPerTokenQuant) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Check quant params data type failed!");
if (isPerTokenQuant) {
CHECK_COND(CheckPerTokenScale(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Check perTokenScale failed!");
}
CHECK_COND(IsGmmAntiQuantEmpty(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Detected quant, but antiquant inputs are not empty!");
return ACLNN_SUCCESS;
}
static int64_t GetPergroupSize(const gmm::GroupedMatmulParams &gmmParams, size_t w0DimNum, const gert::Shape& wShape, const gert::Shape& shape) {
int64_t pergroupSize = 0;
size_t shapeDimNum = shape.GetDimNum();
if (gmmParams.isSingleWeight && w0DimNum > SEPARATED_WEIGHT_DIM) {
if (shapeDimNum > SEPARATED_WEIGHT_DIM) {
int64_t k = wShape.GetDim(1);
pergroupSize = k / shape.GetDim(shapeDimNum - 2);
}
} else {
if (shapeDimNum > 1UL) {
int64_t k = wShape.GetDim(0);
pergroupSize = k / shape.GetDim(shapeDimNum - 2);
}
}
return pergroupSize;
}
static aclnnStatus CheckGroupedMatmulAntiQuant(const gmm::GroupedMatmulParams &gmmParams) {
DataType weightDtype = (*gmmParams.weight)[0]->GetDataType();
DataType xDtype = gmmParams.xDtype;
string sXtype = gmm::DTYPE_STRING.at(xDtype);
string sWtype = gmm::DTYPE_STRING.at(weightDtype);
CHECK_COND(GetCurrentPlatformInfo().GetSocVersion() != SocVersion::ASCEND310P, ACLNN_ERR_PARAM_INVALID,
"GMM Xtype:%s Wtype:%s: antiquant cases do not support on Ascend310P.", sXtype.c_str(), sWtype.c_str());
CHECK_COND(gmmParams.groupType != gmm::SPLIT_K, ACLNN_ERR_PARAM_INVALID,
"GMM Xtype:%s Wtype:%s: antiquant cases do not support splited axis is k.", sXtype.c_str(), sWtype.c_str());
CHECK_COND(gmmParams.antiquantScaleOptional != nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM Xtype:%s Wtype:%s: antiquantScale must not be nullptr in antiquant, but now is nullptr.", sXtype.c_str(), sWtype.c_str());
DataType w0Dtype = (*gmmParams.weight)[0]->GetDataType();
bool isAntiquantInt4 = w0Dtype == DataType::DT_INT4;
CHECK_COND(((isAntiquantInt4 && gmmParams.isSingleWeight) || gmmParams.antiquantOffsetOptional != nullptr),
ACLNN_ERR_PARAM_INVALID,
"GMM Xtype:%s Wtype:%s: antiquantOffset must not be nullptr in antiquant, but now is nullptr.",
sXtype.c_str(), sWtype.c_str());
CHECK_COND(CheckOptionalTensorList(gmmParams, gmmParams.antiquantScaleOptional, "antiquantScale") == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid antiquantScale");
if (gmmParams.antiquantOffsetOptional != nullptr) {
CHECK_COND(
CheckOptionalTensorList(gmmParams, gmmParams.antiquantOffsetOptional, "antiquantOffset") == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid antiquantOffset");
}
if (isAntiquantInt4) {
auto antiquantScale0Shape = (*gmmParams.antiquantScaleOptional)[0]->GetViewShape();
size_t antiquantScale0DimNum = antiquantScale0Shape.GetDimNum();
auto w0Shape = (*gmmParams.weight)[0]->GetViewShape();
size_t w0DimNum = w0Shape.GetDimNum();
int64_t pergroupSize = GetPergroupSize(gmmParams, w0DimNum, w0Shape, antiquantScale0Shape);
CHECK_COND(!gmmParams.transposeWeight || pergroupSize % 2 == 0, ACLNN_ERR_PARAM_INVALID,
"PergroupSize should be even when weight is transposed in A16W4-pergroup case, but now is %ld", pergroupSize);
for (size_t i = 0; i < gmmParams.antiquantScaleOptional->Size(); ++i) {
auto antiquantScaleShape = (*gmmParams.antiquantScaleOptional)[i]->GetViewShape();
size_t antiquantScaleDimNum = antiquantScaleShape.GetDimNum();
CHECK_COND(antiquantScaleDimNum == antiquantScale0DimNum,
ACLNN_ERR_PARAM_INVALID, "AntiquantScale[%zu]'s dim num[%zu] is not equal with first tensor's dim"
" num[%zu]",
i, antiquantScaleDimNum, antiquantScale0DimNum);
auto wShape = (*gmmParams.weight)[i]->GetViewShape();
int64_t pergroupSizeOfScale = GetPergroupSize(gmmParams, w0DimNum, wShape, antiquantScaleShape);
CHECK_COND(pergroupSizeOfScale == pergroupSize, ACLNN_ERR_PARAM_INVALID,
"AntiquantScale[%zu]'s pergroup size[%ld] is not the required value[%ld]",
i, pergroupSizeOfScale, pergroupSize);
if (gmmParams.antiquantOffsetOptional != nullptr) {
auto antiquantOffsetShape = (*gmmParams.antiquantOffsetOptional)[i]->GetViewShape();
size_t antiquantOffsetDimNum = antiquantOffsetShape.GetDimNum();
CHECK_COND(antiquantScale0DimNum == antiquantOffsetDimNum,
ACLNN_ERR_PARAM_INVALID,
"AntiquantOffset[%zu]'s dim num[%zu] is not equal with antiquantScale[0]'s dim num[%zu]",
i, antiquantOffsetDimNum, antiquantScale0DimNum);
int64_t pergroupSizeOfOffset = GetPergroupSize(gmmParams, w0DimNum, wShape, antiquantOffsetShape);
CHECK_COND(pergroupSizeOfOffset == pergroupSize, ACLNN_ERR_PARAM_INVALID,
"AntiquantOffset[%zu]'s pergroup size[%ld] is not the required value[%ld]",
i, pergroupSizeOfOffset, pergroupSize);
}
}
}
CHECK_COND(CheckTensorListDataType(gmmParams.antiquantScaleOptional, gmmParams.xDtype) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "GMM: antiquantScale dtype does not match with x dtype[%s].",
gmm::dTypeToString(gmmParams.xDtype).c_str());
if (gmmParams.antiquantOffsetOptional != nullptr) {
CHECK_COND(CheckTensorListDataType(gmmParams.antiquantOffsetOptional, gmmParams.xDtype) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "GMM: antiquantOffset dtype does not match with x dtype[%s].",
gmm::dTypeToString(gmmParams.xDtype).c_str());
}
CHECK_COND(IsGmmQuantEmpty(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Detected antiquant, but quant inputs are not empty!");
return ACLNN_SUCCESS;
}
static aclnnStatus Check310PlatformForFunction(const gmm::GroupedMatmulParams &gmmParams, const DataType &weightDtype, bool isNoActivation) {
if (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P) {
bool isAllInputFP16 = gmmParams.xDtype == DataType::DT_FLOAT16 && weightDtype == DataType::DT_FLOAT16;
if (gmmParams.biasOptional != nullptr) {
isAllInputFP16 = isAllInputFP16 && (*gmmParams.biasOptional)[0]->GetDataType() == DataType::DT_FLOAT16;
}
CHECK_COND(isAllInputFP16, ACLNN_ERR_PARAM_INVALID, "Only float16 is supported on Ascend310P platforms.");
CHECK_COND(isNoActivation, ACLNN_ERR_PARAM_INVALID, "Activation is not supported on Ascend310P platforms.");
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckFunctionQuantParams(const gmm::GroupedMatmulParams &gmmParams) {
DataType yDtypeOrg = (*gmmParams.y)[0]->GetDataType();
for (size_t i = 0; i < gmmParams.y->Size(); i++) {
const aclTensor* yTensor = (*gmmParams.y)[i];
OP_CHECK_NULL(yTensor, continue);
DataType yDtype = yTensor->GetDataType();
CHECK_COND(yDtype == yDtypeOrg, ACLNN_ERR_PARAM_INVALID,
"Output tensorlist has different data type, y[0] data type is %s, and y[%zu] data type id %s.",
gmm::dTypeToString(yDtypeOrg).c_str(), i, gmm::dTypeToString(yDtype).c_str());
if (!(yDtype == DataType::DT_INT8 || yDtype == DataType::DT_BF16 || yDtype == DataType::DT_FLOAT16 || yDtype == DataType::DT_INT32)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "Expect y dtype is int8, int32, float16 or bfloat16 in quant case, "
"but now y[%zu] dtype is %s", i, gmm::dTypeToString(yDtype).c_str());
return ACLNN_ERR_PARAM_INVALID;
}
}
if (gmmParams.biasOptional != nullptr) {
CHECK_COND(CheckTensorListDataType(gmmParams.biasOptional, DataType::DT_INT32) == ACLNN_SUCCESS ||
CheckTensorListDataType(gmmParams.biasOptional, DataType::DT_BF16) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "GMM: bias dtype does not match with required dtype int32 or bfloat16.");
}
if (yDtypeOrg == DataType::DT_INT32) {
return ACLNN_SUCCESS;
}
CHECK_COND(CheckGroupedMatmulQuant(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"CheckGroupedMatmulQuant failed.");
return ACLNN_SUCCESS;
}
static aclnnStatus CheckA8W4AsymQuantParamsRelationship(const gmm::GroupedMatmulParams &gmmParams) {
const op::Shape &xShape = (*gmmParams.x)[0]->GetViewShape();
const op::Shape &weightShape = (*gmmParams.weight)[0]->GetViewShape();
const op::Shape &biasShape = (*gmmParams.biasOptional)[0]->GetViewShape();
const op::Shape &scaleShape = (*gmmParams.scaleOptional)[0]->GetViewShape();
const op::Shape &perTokenScaleShape = (*gmmParams.perTokenScaleOptional)[0]->GetViewShape();
size_t biasDim = biasShape.GetDimNum();
size_t scaleDim = scaleShape.GetDimNum();
size_t perTokenScaleDim = perTokenScaleShape.GetDimNum();
int64_t e = weightShape.GetDim(0);
int64_t m = xShape.GetDim(0);
int64_t n = weightShape.GetDim(WEIGHT_DIM_A8W4 - 1);
CHECK_COND(biasDim == BIAS_DIM_A8W4 && biasShape.GetDim(0) == e && biasShape.GetDim(1) == n,
ACLNN_ERR_PARAM_INVALID, "GMM Asymmetric Quant: bias shape is invalid, "
"must be (e, n), the current shape is (%ld, %ld)", biasShape.GetDim(0), biasShape.GetDim(1));
CHECK_COND(scaleDim == WEIGHT_DIM_A8W4 && scaleShape.GetDim(0) == e
&& scaleShape.GetDim(1) == 1 && scaleShape.GetDim(WEIGHT_DIM_A8W4 - 1) == n, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: scale shape is invalid, must be (e, 1, n), the current shape is (%ld, %ld, %ld)",
scaleShape.GetDim(0), scaleShape.GetDim(1), scaleShape.GetDim(WEIGHT_DIM_A8W4 - 1));
CHECK_COND(perTokenScaleDim == BIAS_DIM_A8W4 && perTokenScaleShape.GetDim(0) == m
&& perTokenScaleShape.GetDim(BIAS_DIM_A8W4 - 1) == 1, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: perTokenScale shape is invalid, must be (m, 1), the current shape is (%ld, %ld)",
perTokenScaleShape.GetDim(0), perTokenScaleShape.GetDim(1));
return ACLNN_SUCCESS;
}
static aclnnStatus CheckA8W4AsymQuantParams(const gmm::GroupedMatmulParams &gmmParams) {
DataType yDtype = (*gmmParams.y)[0]->GetDataType();
CHECK_COND(yDtype == DataType::DT_FLOAT16, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: output y dtype should be float16, current dtype is %s.",
gmm::dTypeToString(yDtype).c_str());
DataType offsetDtype = (*gmmParams.offsetOptional)[0]->GetDataType();
CHECK_COND(offsetDtype == DataType::DT_FLOAT, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: offset dtype does not match with required dtype float32, current dtype is %s.",
gmm::dTypeToString(offsetDtype).c_str());
CHECK_COND(gmmParams.biasOptional != nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: bias must not be null");
CHECK_COND(gmmParams.scaleOptional != nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: scale must not be null");
CHECK_COND(gmmParams.perTokenScaleOptional != nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: perTokenScale must not be null");
DataType biasDtype = (*gmmParams.biasOptional)[0]->GetDataType();
CHECK_COND(biasDtype == DataType::DT_FLOAT, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: bias dtype does not match with required dtype float32, current dtype is %s.",
gmm::dTypeToString(biasDtype).c_str());
DataType scaleDtype = (*gmmParams.scaleOptional)[0]->GetDataType();
CHECK_COND(scaleDtype == DataType::DT_UINT64, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: scale dtype does not match with required dtype uint64, current dtype is %s.",
gmm::dTypeToString(scaleDtype).c_str());
DataType perTokenScaleDtype = (*gmmParams.perTokenScaleOptional)[0]->GetDataType();
CHECK_COND(perTokenScaleDtype == DataType::DT_FLOAT, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: perTokenScale dtype does not match with required dtype float32, current dtype is %s.",
gmm::dTypeToString(perTokenScaleDtype).c_str());
CHECK_COND(gmmParams.antiquantScaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: antiquantScale must be nullptr.");
CHECK_COND(gmmParams.antiquantOffsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM Asymmetric Quant: antiquantOffset must be nullptr.");
CHECK_COND(CheckA8W4AsymQuantParamsRelationship(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "CheckA8W4AsymQuantParamsRelationship failed.");
return ACLNN_SUCCESS;
}
static bool isA8W8AsymmetricQuant(const gmm::GroupedMatmulParams &gmmParams) {
if (gmmParams.offsetOptional == nullptr) {
return false;
}
const op::Shape &offsetShape = (*gmmParams.offsetOptional)[0]->GetViewShape();
int64_t offsetDim = offsetShape.GetDimNum();
if (offsetDim != OFFSET_DIM_A8W4) {
return false;
}
const op::Shape &weightShape = (*gmmParams.weight)[0]->GetViewShape();
if (offsetShape.GetDim(0) == weightShape.GetDim(0) && offsetShape.GetDim(1) == 1
&& offsetShape.GetDim(OFFSET_DIM_A8W4 - 1) == weightShape.GetDim(OFFSET_DIM_A8W4 - 1)) {
return true;
}
return false;
}
static aclnnStatus CheckA8W4SymmQuantParamsRelationship(const gmm::GroupedMatmulParams &gmmParams) {
bool hasBias = (gmmParams.biasOptional != nullptr && (*gmmParams.biasOptional)[0] != nullptr);
const op::Shape &xShape = (*gmmParams.x)[0]->GetViewShape();
const op::Shape &weightShape = (*gmmParams.weight)[0]->GetViewShape();
size_t biasDim = 0;
op::Shape biasShape;
const op::Shape &scaleShape = (*gmmParams.scaleOptional)[0]->GetViewShape();
const op::Shape &perTokenScaleShape = (*gmmParams.perTokenScaleOptional)[0]->GetViewShape();
size_t scaleDim = scaleShape.GetDimNum();
size_t perTokenScaleDim = perTokenScaleShape.GetDimNum();
int64_t e = weightShape.GetDim(0);
int64_t m = xShape.GetDim(0);
int64_t xK = xShape.GetDim(1);
int64_t wK = weightShape.GetDim(WEIGHT_DIM_A8W4 - 2);
int64_t n = weightShape.GetDim(WEIGHT_DIM_A8W4 - 1);
if (!(scaleDim == WEIGHT_DIM_A8W4)) {
OP_LOGW("GMM A8W4 Symmetric Quant: Dim num of scale must be 3, " \
"but current dim num is %ld", scaleDim);
}
if (!(scaleShape.GetDim(0) == e && scaleShape.GetDim(scaleDim - 1) == n)) {
OP_LOGW("GMM A8W4 Symmetric Quant: scale shape is invalid, " \
"the last dim of scale shape is %zu, and last dim of weight shape is %ld", \
scaleShape.GetDim(scaleDim - 1), n);
}
if (!(wK == xK)) {
OP_LOGW("GMM A8W4 Symmetric Quant: dim 1 of x shape must be equal with dim 1 of weight shape" \
"the x shape is (%ld, %ld) and weight shape is (%ld, %ld, %ld)", xShape.GetDim(0), xShape.GetDim(1), e, wK, n);
}
if (hasBias) {
biasShape = (*gmmParams.biasOptional)[0]->GetViewShape();
size_t biasDim = biasShape.GetDimNum();
if (!(biasDim == BIAS_DIM_A8W4 && biasShape.GetDim(0) == e && biasShape.GetDim(1) == n)) {
OP_LOGW("GMM A8W4 Symmetric Quant: bias shape is invalid, " \
"must be (e, n), the current shape is (%ld, %ld)", biasShape.GetDim(0), biasShape.GetDim(1));
}
}
if (!(perTokenScaleShape.GetDim(0) == m)) {
OP_LOGW("GMM A8W4 Symmetric Quant: perTokenScale.shape[0] (%ld) must match x.shape[0] (%ld).", \
perTokenScaleShape.GetDim(0), m);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckA8W4SymmQuantParams(const gmm::GroupedMatmulParams &gmmParams) {
if (!(gmmParams.scaleOptional != nullptr && (*gmmParams.scaleOptional)[0] != nullptr)) {
OP_LOGW("GMM A8W4 Symmetric Quant: scale must not be null");
return ACLNN_SUCCESS;
}
if (!(gmmParams.perTokenScaleOptional != nullptr && (*gmmParams.perTokenScaleOptional)[0] != nullptr)) {
OP_LOGW("GMM A8W4 Symmetric Quant: perTokenScale must not be null");
return ACLNN_SUCCESS;
}
DataType scaleDtype = (*gmmParams.scaleOptional)[0]->GetDataType();
if (!(scaleDtype == DataType::DT_UINT64)) {
OP_LOGW("GMM A8W4 Symmetric Quant: scale dtype does not match with required dtype uint64, current dtype is %s.", \
gmm::dTypeToString(scaleDtype).c_str());
}
DataType perTokenScaleDtype = (*gmmParams.perTokenScaleOptional)[0]->GetDataType();
if (!(perTokenScaleDtype == DataType::DT_FLOAT)) {
OP_LOGW("GMM A8W4 Symmetric Quant: perTokenScale dtype does not match with required dtype float32, current dtype is %s.", \
gmm::dTypeToString(perTokenScaleDtype).c_str());
}
if (!(CheckA8W4SymmQuantParamsRelationship(gmmParams) == ACLNN_SUCCESS)) {
OP_LOGW("CheckA8W4SymmQuantParamsRelationship failed.");
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckA8W4QuantParams(const gmm::GroupedMatmulParams &gmmParams) {
if (gmmParams.groupListType != 1) {
OP_LOGW("GMM A8W4: groupListType only support 1(count), but now is %ld.",
gmmParams.groupListType);
}
if (!isA8W8AsymmetricQuant(gmmParams)) {
CheckA8W4SymmQuantParams(gmmParams);
return ACLNN_SUCCESS;
}
CHECK_COND(CheckA8W4AsymQuantParams(gmmParams) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID,
"CheckA8W4AsymQuantParams failed.");
return ACLNN_SUCCESS;
}
static aclnnStatus CheckA4W4ParamsShape(const gmm::GroupedMatmulParams &gmmParams) {
const op::Shape &scaleShape = (*gmmParams.scaleOptional)[0]->GetViewShape();
int64_t scaleDimNum = scaleShape.GetDimNum();
CHECK_COND(scaleDimNum == PER_CHANNEL_SCALE_DIM || scaleDimNum == PER_GROUP_SCALE_DIM, ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: scale dim only support 2/3 for perchannel/pergroup, but now is %ld.\n",
scaleDimNum);
int64_t n = scaleShape.GetDim(scaleDimNum - 1);
CHECK_COND((n % static_cast<int64_t>(8)) == 0, ACLNN_ERR_PARAM_INVALID,
"A4W4 n axis should align with 16, but now is %ld", n);
if (scaleDimNum == PER_GROUP_SCALE_DIM) {
const op::Shape &inputShape = (*gmmParams.x)[0]->GetViewShape();
int64_t k = inputShape.GetDim(1);
int64_t kGroupNum = scaleShape.GetDim(1);
CHECK_COND(kGroupNum != 0 && (k % kGroupNum) == 0, ACLNN_ERR_PARAM_INVALID,
"Pergroup A4W4: while scale shape is [e, G, n], x shape is [m,k], k needs to be divisible by G,"
"but now k is %ld, G is %ld.\n", k, kGroupNum);
int64_t pergroupNum = static_cast<int64_t>(k / kGroupNum);
CHECK_COND((pergroupNum % 2) == 0, ACLNN_ERR_PARAM_INVALID,
"Pergroup A4W4: while scale shape is [e, G, n], x shape is [m,k],"
"pergroup num(k/G) needs to be divisible by 2, but now pergroupNum is %ld.\n", pergroupNum);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckA4W4QuantParams(const gmm::GroupedMatmulParams &gmmParams) {
CHECK_COND(gmmParams.groupListType == 0 || gmmParams.groupListType == 1 || gmmParams.groupListType == 2,
ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: groupListType only support 0(cumsum) or 1(count) or 2(sparsem), but now is %ld.",
gmmParams.groupListType);
DataType yDtype = (*gmmParams.y)[0]->GetDataType();
CHECK_COND(yDtype == DataType::DT_FLOAT16 || yDtype == DataType::DT_BF16, ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: output y dtype should be float16 or bfloat16, current dtype is %s.",
gmm::dTypeToString(yDtype).c_str());
CHECK_COND(gmmParams.offsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: offset must be null.");
CHECK_COND(gmmParams.biasOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: bias must be null.");
CHECK_COND(gmmParams.scaleOptional != nullptr, ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: scale must not be null.");
DataType scaleDtype = (*gmmParams.scaleOptional)[0]->GetDataType();
CHECK_COND(scaleDtype == DataType::DT_UINT64, ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: scale dtype does not match with required dtype uint64, current dtype is %s.",
gmm::dTypeToString(scaleDtype).c_str());
bool isPerTokenQuant = gmmParams.perTokenScaleOptional != nullptr;
if (isPerTokenQuant) {
DataType perTokenScaleDtype = (*gmmParams.perTokenScaleOptional)[0]->GetDataType();
CHECK_COND(perTokenScaleDtype == DataType::DT_FLOAT, ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: perTokenScale dtype does not match with required dtype float32, current dtype is %s.",
gmm::dTypeToString(perTokenScaleDtype).c_str());
CHECK_COND(CheckPerTokenScale(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: Check perTokenScale failed!");
}
CHECK_COND(CheckA4W4ParamsShape(gmmParams) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID,
"CheckA4W4ParamsShape failed.");
CHECK_COND(IsGmmAntiQuantEmpty(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"GMM A4W4: Detected quant, but antiquant inputs are not empty!");
CHECK_COND(gmmParams.groupType == gmm::SPLIT_M && gmmParams.x->Size() == 1 && gmmParams.weight->Size() == 1
&& gmmParams.y->Size() == 1, ACLNN_ERR_PARAM_INVALID,
"A4W4 only support split m, single x, single weight, single y.");
return ACLNN_SUCCESS;
}
bool isActivationAllowed(int64_t act_type) {
return act_type == GMMActType::GMM_ACT_TYPE_RELU ||
act_type == GMMActType::GMM_ACT_TYPE_GELU_TANH ||
act_type == GMMActType::GMM_ACT_TYPE_FAST_GELU ||
act_type == GMMActType::GMM_ACT_TYPE_SILU;
}
bool CheckScaleForInt8Quant(const gmm::GroupedMatmulParams &gmmParams) {
if (gmmParams.scaleOptional == nullptr || gmmParams.scaleOptional->Size() == 0 ||
(*gmmParams.scaleOptional)[0] == nullptr) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the activation function is enabled, Scale should not be null, but now is null.");
return false;
}
const op::Shape &scaleShape = (*gmmParams.scaleOptional)[0]->GetViewShape();
DataType scaleDtype = (*gmmParams.scaleOptional)[0]->GetDataType();
if (scaleDtype != DataType::DT_FLOAT && scaleDtype != DataType::DT_BF16) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the activation function is enabled, the dtype of Scale should be float32 or bfloat16, but actual is %s.",
gmm::dTypeToString(scaleDtype).c_str());
return false;
}
size_t scaleDimNum = scaleShape.GetDimNum();
if (scaleDimNum != SCALE_TENSOR_EXPECTED_DIMS) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the activation function is enabled, the dim of Scale should be 2 for perchannel, but actual is %zu.\n", scaleDimNum);
return false;
}
const op::Shape &weightShape = (*gmmParams.weight)[0]->GetViewShape();
size_t weightDimNum = weightShape.GetDimNum();
int64_t weightNSize = weightShape.GetDim(weightDimNum - 1);
int64_t weightGroupNum = weightShape.GetDim(0);
int64_t scaleNSize = scaleShape.GetDim(scaleDimNum - 1);
int64_t scaleGroupNum = scaleShape.GetDim(0);
if (scaleNSize != weightNSize || scaleGroupNum != weightGroupNum) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the activation function is enabled, the shape of scale should be (%ld, %ld), but actual"
" is (%ld, %ld).\n",
weightGroupNum, weightNSize, scaleGroupNum, scaleNSize);
return false;
}
return true;
}
bool CheckInt8StaticTCQuant(const gmm::GroupedMatmulParams &gmmParams) {
if (gmmParams.perTokenScaleOptional != nullptr) {
return false;
}
return CheckScaleForInt8Quant(gmmParams);
}
bool CheckInt8DynamicKCQuant(const gmm::GroupedMatmulParams &gmmParams) {
if (gmmParams.perTokenScaleOptional == nullptr || gmmParams.perTokenScaleOptional->Size() == 0 ||
(*gmmParams.perTokenScaleOptional)[0] == nullptr) {
return false;
}
const op::Shape &perTokenScaleShape = (*gmmParams.perTokenScaleOptional)[0]->GetViewShape();
DataType perTokenScaleDtype = (*gmmParams.perTokenScaleOptional)[0]->GetDataType();
if (perTokenScaleDtype != DataType::DT_FLOAT) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the activation function is enabled,"
" the dtype of perTokenScale should be float32, but actual is %s.",
gmm::dTypeToString(perTokenScaleDtype).c_str());
return false;
}
const op::Shape &xShape = (*gmmParams.x)[0]->GetViewShape();
int64_t mSize = xShape.GetDim(xShape.GetDimNum() - 2);
size_t perTokenScaleDim = perTokenScaleShape.GetDimNum();
int64_t perTokenScaleMSize = perTokenScaleShape.GetDim(perTokenScaleDim - 1);
if (perTokenScaleDim != 1) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the activation function is enabled, the dim of perTokenScale should be 1, but actual is %ld.\n", perTokenScaleDim);
return false;
}
if (perTokenScaleMSize != mSize) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the activation function is enabled and the dim of perTokenScale is 1, the shape of perTokenScale should be (%ld,), but actual"
" is (%ld,).\n",
mSize, perTokenScaleMSize);
return false;
}
return CheckScaleForInt8Quant(gmmParams);
}
bool CheckIsEnabledActive(const gmm::GroupedMatmulParams &gmmParams) {
auto xDtype = gmmParams.xDtype;
auto weightDtype = (*gmmParams.weight)[0]->GetDataType();
bool isInt8Input = (xDtype == DataType::DT_INT8 && weightDtype == DataType::DT_INT8);
if (!isInt8Input) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the activation function is enabled, the dtype of x and weight should be DT_INT8,"
" actual is %s and %s.",
op::ToString(xDtype).GetString(),
op::ToString(weightDtype).GetString());
return false;
}
bool isInt8StaticTCQuant = CheckInt8StaticTCQuant(gmmParams);
bool isInt8DynamicKCQuant = CheckInt8DynamicKCQuant(gmmParams);
bool allowActOnDavid = isInt8Input && (isInt8StaticTCQuant || isInt8DynamicKCQuant)
&& isActivationAllowed(gmmParams.activeType);
return allowActOnDavid;
}
static aclnnStatus CheckFunctionParams(const gmm::GroupedMatmulParams &gmmParams) {
DataType weightDtype = (*gmmParams.weight)[0]->GetDataType();
bool isNoActivation = gmmParams.activeType == GMMActType::GMM_ACT_TYPE_NONE;
CHECK_COND(Check310PlatformForFunction(
gmmParams, weightDtype, isNoActivation) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Check310PlatformForFunction failed.");
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
if (IsQuant(gmmParams.xDtype, weightDtype)) {
CHECK_COND(isNoActivation || CheckIsEnabledActive(gmmParams), ACLNN_ERR_PARAM_INVALID, "On this platform, activation is supported only when the input is INT8"
" and the quant mode is either pertoken-perchannel or pertensor-perchannel; "
" activation is not supported in other scenarios;"
" activeType[%ld] is not supported.", gmmParams.activeType);
return gmm::AclnnGroupedMatmulDAV3510Checker<aclTensorList>(gmmParams).CheckGroupedMatmulDAV3510();
} else if (IsWeightQuant(gmmParams.xDtype, weightDtype)) {
CHECK_COND(isNoActivation, ACLNN_ERR_PARAM_INVALID, "Activation is not supported in weight quant mode now."
" activeType[%ld] is not supported.", gmmParams.activeType);
return gmm::AclnnGroupedMatmulWeightQuantDAV3510Checker(gmmParams).CheckGroupedMatmulWeightQuantDAV3510();
} else {
CHECK_COND(isNoActivation, ACLNN_ERR_PARAM_INVALID, "When input is No-Quant, activation is not supported on this platforms."
" activeType[%ld] is not supported.", gmmParams.activeType);
}
}
if (gmmParams.xDtype == DataType::DT_INT8 && weightDtype == DataType::DT_INT4) {
CHECK_COND(CheckA8W4QuantParams(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "CheckA8W4QuantParams failed.");
return ACLNN_SUCCESS;
}
if (gmmParams.xDtype == DataType::DT_INT4 && weightDtype == DataType::DT_INT4) {
CHECK_COND(CheckA4W4QuantParams(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "CheckA4W4QuantParams failed.");
return ACLNN_SUCCESS;
}
if ((gmmParams.xDtype == DataType::DT_BF16 || gmmParams.xDtype == DataType::DT_FLOAT16 ||
gmmParams.xDtype == DataType::DT_FLOAT) && gmmParams.xDtype == weightDtype) {
if (gmmParams.apiVersion == gmm::GMMApiVersion::V1 && (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910B || GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910_93)) {
CHECK_COND(gmmParams.xDtype != DataType::DT_FLOAT, ACLNN_ERR_PARAM_INVALID,
"AclnnGroupedMatmul does not support x or weight dtype float32 on both ASCEND910B and ASCEND910_93 platforms.");
}
CHECK_COND(CheckNonQuantMatmulDataType(gmmParams, weightDtype) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Check no quant case dtype failed.");
CHECK_COND(isNoActivation, ACLNN_ERR_PARAM_INVALID, "Non quant case dose not support activation.");
return CheckNonQuant(gmmParams);
}
if (gmmParams.xDtype == DataType::DT_INT8 && weightDtype == DataType::DT_INT8) {
DataType yDtype = (*gmmParams.y)[0]->GetDataType();
CHECK_COND(isNoActivation || yDtype != DataType::DT_INT8 || yDtype != DataType::DT_INT32,
ACLNN_ERR_PARAM_INVALID, "Quant case with output dtype int8 or int32 dose not support activation.");
CHECK_COND(CheckFunctionQuantParams(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"CheckFunctionQuantParams failed.");
return ACLNN_SUCCESS;
}
if ((gmmParams.xDtype == DataType::DT_BF16 || gmmParams.xDtype == DataType::DT_FLOAT16)
&& (weightDtype == DataType::DT_INT8 || weightDtype == DataType::DT_INT4)) {
DataType biasDtype = gmmParams.xDtype == DataType::DT_BF16 ? DataType::DT_FLOAT: DataType::DT_FLOAT16;
CHECK_RET(
CheckMatmulDataType(gmmParams, gmmParams.xDtype, weightDtype, gmmParams.xDtype, biasDtype) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID);
CHECK_COND(isNoActivation, ACLNN_ERR_PARAM_INVALID, "Antiquant case dose not support activation.");
return CheckGroupedMatmulAntiQuant(gmmParams);
}
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "GMM: there is no matching xDtype and weightDtype pattern. "
"case with x dtype %s and weight dtype %s is not supported.",
gmm::dTypeToString(gmmParams.xDtype).c_str(), gmm::dTypeToString(weightDtype).c_str());
return ACLNN_ERR_PARAM_INVALID;
}
static aclnnStatus CheckWeightShapeInnerAxisEven(const aclTensorList *tensorList, const size_t weightSize,
const int64_t innerAxisDimId) {
if ((*tensorList)[0]->GetDataType() == DataType::DT_INT4) {
for (size_t i = 0; i < weightSize; ++i) {
int64_t n = (*tensorList)[i]->GetViewShape().GetDim(innerAxisDimId);
CHECK_COND(n % 2 == 0, ACLNN_ERR_PARAM_INVALID, "Weight's inner axis size[%ld] is not even!", n);
}
}
return ACLNN_SUCCESS;
}
static aclnnStatus SplitMSingleXSingleWeightSingleY(const gmm::GroupedMatmulParams &gmmParams) {
static const std::vector<std::string> TENSOR_X_WEIGHT{"x", "weight", "true"};
static const std::vector<std::string> TENSOR_X_Y{"x", "y", "false"};
static const std::vector<std::string> TENSOR_WEIGHT_Y{"weight", "y", "true"};
CHECK_COND(gmmParams.splitItem == X_SEPARATED || gmmParams.splitItem == NO_SEPARATED, ACLNN_ERR_PARAM_INVALID,
"When y is not separated, splitItem should be 2/3, but current splitItem is %ld.", gmmParams.splitItem);
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, gmm::MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list x is invalid.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, gmm::SPLIT_M_SINGLE_WEIGHT_DIM, "weight") == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Dim num or format of tensor in tensor list weight is invalid.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, gmm::MIN_FM_DIM, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list y is invalid.");
int64_t innerAxisDimId = 1;
CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.weight, {1, 1}, innerAxisDimId, TENSOR_X_WEIGHT) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "K dim value of x and weight is not matched.");
CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.y, {0, 0}, -1, TENSOR_X_Y) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "M dim value of x and y is not matched.");
innerAxisDimId = !gmmParams.transposeWeight ? 2 : -1;
CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.weight, gmmParams.y, {2, 1}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "N dim value of weight and y is not matched.");
CHECK_COND(CheckWeightShapeInnerAxisEven(gmmParams.weight, gmmParams.weight->Size(),
gmmParams.transposeWeight ? 1 : 2) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "W inner axis size should be even when weight is int4 dtype.");
size_t batchSizeWeight = (*gmmParams.weight)[0]->GetViewShape().GetDim(0);
CHECK_COND(CheckGroupListSplitM(gmmParams, true, false, false, batchSizeWeight) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid groupList.");
return ACLNN_SUCCESS;
}
static aclnnStatus SplitMSingleXSeparatedWeightSingleY(const gmm::GroupedMatmulParams &gmmParams) {
size_t weightSize = gmmParams.weight->Size();
static const std::vector<std::string> TENSOR_WEIGHT_X{"Weight", "x", "true"};
static const std::vector<std::string> TENSOR_X_Y{"x", "y", "false"};
static const std::vector<std::string> TENSOR_WEIGHT_Y{"Weight", "y", "true"};
std::string errorMessage = gmmParams.apiVersion != gmm::GMMApiVersion::V2 ? "When splited axis is M" : "When groupType is 0";
CHECK_COND(gmmParams.splitItem == X_SEPARATED || gmmParams.splitItem == NO_SEPARATED, ACLNN_ERR_PARAM_INVALID,
"When y is not separated, splitItem should be 2/3, but current splitItem is %ld.", gmmParams.splitItem);
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, gmm::MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list x is invalid.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, SEPARATED_WEIGHT_DIM, "weight") == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Dim num or format of tensor in tensor list weight is invalid.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, gmm::MIN_FM_DIM, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list y is invalid.");
int64_t innerAxisDimId = 1;
CHECK_COND(CheckShapeDiffLengthTensorList(gmmParams.weight, gmmParams.x, {0, 1}, innerAxisDimId, TENSOR_WEIGHT_X) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "K dim value of x and weight is not matched.");
CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.y, {0, 0}, -1, TENSOR_X_Y) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "M dim value of x and y is not matched.");
innerAxisDimId = !gmmParams.transposeWeight ? 1 : -1;
CHECK_COND(CheckShapeDiffLengthTensorList(gmmParams.weight, gmmParams.y, {1, 1}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "N dim value of weight and y is not matched.");
CHECK_COND(CheckWeightShapeInnerAxisEven(gmmParams.weight, weightSize, gmmParams.transposeWeight ? 0 : 1) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "W inner axis size should be even when weight is int4 dtype.");
CHECK_COND(CheckGroupListSplitM(gmmParams, true, false, false, weightSize) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid groupList.");
return ACLNN_SUCCESS;
}
static aclnnStatus SplitMSingleXSeparatedWeightSeparatedY(const gmm::GroupedMatmulParams &gmmParams) {
size_t ySize = gmmParams.y->Size();
size_t weightSize = gmmParams.weight->Size();
static const std::vector<std::string> TENSOR_WEIGHT_X{"Weight", "x", "true"};
static const std::vector<std::string> TENSOR_Y_X{"y", "x", "false"};
static const std::vector<std::string> TENSOR_WEIGHT_Y{"Weight", "y", "true"};
std::string errorMessage = gmmParams.apiVersion == gmm::GMMApiVersion::V1 ? "When splited axis is M" : "When groupType is 0";
CHECK_COND(gmmParams.splitItem == X_Y_SEPARATED || gmmParams.splitItem == Y_SEPARATED, ACLNN_ERR_PARAM_INVALID,
"When y is separated, splitItem should be 0/1, but current splitItem is %ld.", gmmParams.splitItem);
CHECK_COND(ySize == weightSize, ACLNN_ERR_PARAM_INVALID,
"When y and weight are separated, size of y %lu should equal to size of weight %lu.",
ySize, weightSize);
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, gmm::MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list x is invalid.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, SEPARATED_WEIGHT_DIM, "weight") == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Dim num or format of tensor in tensor list weight is invalid.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, gmm::MIN_FM_DIM, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list y is invalid.");
int64_t innerAxisDimId = 1;
CHECK_COND(CheckShapeDiffLengthTensorList(gmmParams.weight, gmmParams.x, {0, 1}, innerAxisDimId, TENSOR_WEIGHT_X) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "K dim value of x and weight is not matched.");
CHECK_COND(CheckShapeDiffLengthTensorListSplitAxis(gmmParams.y, gmmParams.x, 0, 0, TENSOR_Y_X) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "M dim value of x and y is not matched.");
innerAxisDimId = !gmmParams.transposeWeight ? 1 : -1;
CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.weight, gmmParams.y, {1, 1}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "N dim value of weight and y is not matched.");
CHECK_COND(CheckWeightShapeInnerAxisEven(gmmParams.weight, weightSize, gmmParams.transposeWeight ? 0 : 1) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "W inner axis size should be even when weight is int4 dtype.");
CHECK_COND(CheckGroupListSplitM(gmmParams, true, false, true, ySize) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid groupList.");
return ACLNN_SUCCESS;
}
static aclnnStatus SplitMSeparatedXSeparatedWeightSingleY(const gmm::GroupedMatmulParams &gmmParams) {
size_t xSize = gmmParams.x->Size();
size_t weightSize = gmmParams.weight->Size();
static const std::vector<std::string> TENSOR_WEIGHT_X{"Weight", "x", "true"};
static const std::vector<std::string> TENSOR_X_Y{"x", "y", "false"};
static const std::vector<std::string> TENSOR_WEIGHT_Y{"Weight", "y", "true"};
std::string errorMessage = gmmParams.apiVersion != gmm::GMMApiVersion::V2 ? "When splited axis is M" : "When groupType is 0";
CHECK_COND(gmmParams.splitItem == X_SEPARATED || gmmParams.splitItem == NO_SEPARATED, ACLNN_ERR_PARAM_INVALID,
"When y is not separated, splitItem should be 2/3, but current splitItem is %ld.", gmmParams.splitItem);
CHECK_COND(xSize == weightSize, ACLNN_ERR_PARAM_INVALID,
"When x and weight are separated, size of x %lu should equal to size of weight %lu.",
xSize, weightSize);
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, gmm::MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list x is invalid.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, SEPARATED_WEIGHT_DIM, "weight") == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Dim num or format of tensor in tensor list weight is invalid.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, gmm::MIN_FM_DIM, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list y is invalid.");
int64_t innerAxisDimId = 0;
CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.weight, gmmParams.x, {0, 1}, innerAxisDimId, TENSOR_WEIGHT_X) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "K dim value of x and weight is not matched.");
CHECK_COND(CheckShapeDiffLengthTensorListSplitAxis(gmmParams.x, gmmParams.y, 0, 0, TENSOR_X_Y) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "M dim value of x and y is not matched.");
innerAxisDimId = !gmmParams.transposeWeight ? 1 : -1;
CHECK_COND(CheckShapeDiffLengthTensorList(gmmParams.weight, gmmParams.y, {1, 1}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "N dim value of weight and y is not matched.");
CHECK_COND(CheckWeightShapeInnerAxisEven(gmmParams.weight, weightSize, gmmParams.transposeWeight ? 0 : 1) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "W inner axis size should be even when weight is int4 dtype.");
CHECK_COND(CheckGroupListSplitM(gmmParams, false, true, false, xSize) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid groupList.");
return ACLNN_SUCCESS;
}
static aclnnStatus CheckCaseSplitM(const gmm::GroupedMatmulParams &gmmParams) {
size_t xSize = gmmParams.x->Size();
size_t ySize = gmmParams.y->Size();
size_t weightSize = gmmParams.weight->Size();
bool apiVersionFlag = gmmParams.apiVersion == gmm::GMMApiVersion::WeightNz || gmmParams.apiVersion == gmm::GMMApiVersion::V5
|| gmmParams.apiVersion == gmm::GMMApiVersion::V4 || gmmParams.apiVersion == gmm::GMMApiVersion::V3;
if (xSize == 1UL && weightSize == 1UL && ySize == 1UL) {
CHECK_COND(SplitMSingleXSingleWeightSingleY(gmmParams) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Split m, single x, single weight, single y case failed.");
return ACLNN_SUCCESS;
}
if (xSize == 1UL && weightSize > 1UL && ySize == 1UL) {
CHECK_COND(SplitMSingleXSeparatedWeightSingleY(gmmParams) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Split m, single x, separated weight, single y case failed.");
return ACLNN_SUCCESS;
}
if (xSize == 1UL && weightSize > 1UL && ySize > 1UL) {
CHECK_COND(!(apiVersionFlag),
ACLNN_ERR_PARAM_INVALID,
"When grouplist is tensor, split m, single x, separated weight, separated y cases do not support.");
CHECK_COND(SplitMSingleXSeparatedWeightSeparatedY(gmmParams) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Split m, single x, separated weight, separated y case failed.");
return ACLNN_SUCCESS;
}
if (xSize > 1UL && weightSize > 1UL && ySize == 1UL) {
CHECK_COND(SplitMSeparatedXSeparatedWeightSingleY(gmmParams) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Split m, separated x, separated weight, single y case failed.");
return ACLNN_SUCCESS;
}
std::string errorMessage = gmmParams.apiVersion != gmm::GMMApiVersion::V2 ? "When splited axis is M" : "When groupType is 0";
if ((apiVersionFlag) && gmmParams.isSingleWeight) {
errorMessage = "When groupType is 0";
}
std::string xStatus = xSize > 1UL ? "separated" : "not separated";
std::string weightStatus = weightSize > 1UL ? "separated" : "not separated";
std::string yStatus = ySize > 1UL ? "separated" : "not separated";
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "%s, current case with x %s, weight %s, y %s is not supported.",
errorMessage.c_str(), xStatus.c_str(), weightStatus.c_str(), yStatus.c_str());
return ACLNN_ERR_PARAM_INVALID;
}
static aclnnStatus CheckCaseSplitK(const gmm::GroupedMatmulParams &gmmParams) {
static const std::vector<std::string> TENSOR_X_WEIGHT{"x", "weight", "true"};
static const std::vector<std::string> TENSOR_X_Y{"x", "y", "false"};
static const std::vector<std::string> TENSOR_WEIGHT_Y{"Weight", "y", "true"};
size_t xSize = gmmParams.x->Size();
size_t ySize = gmmParams.y->Size();
size_t weightSize = gmmParams.weight->Size();
if (xSize == 1UL) {
CHECK_COND(gmmParams.transposeX, ACLNN_ERR_PARAM_INVALID,
"When groupType is 2 and x is not separated, tensor in x should be transposed.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.x, gmm::MIN_FM_DIM, "x") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list x is invalid.");
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.weight, gmm::SPLIT_K_SINGLE_WEIGHT_DIM, "weight") == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list weight is invalid.");
if(weightSize == 1UL && ySize == 1UL) {
CHECK_COND(CheckDimNumAndFormat(gmmParams, gmmParams.y, DIMS_THREE_FOR_GMM, "y") == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor list y is invalid.");
int64_t innerAxisDimId = 0;
CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.weight, {1, 0}, innerAxisDimId, TENSOR_X_WEIGHT) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "K dim value of x and weight is not matched.");
CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.x, gmmParams.y, {0, 1}, -1, TENSOR_X_Y) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "M dim value of x and y is not matched.");
innerAxisDimId = 1;
CHECK_COND(CheckShapeSameLengthTensorList(gmmParams.weight, gmmParams.y, {1, 2}, innerAxisDimId, TENSOR_WEIGHT_Y) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "N dim value of weight and y is not matched.");
size_t batchSizeY = (*gmmParams.y)[0]->GetViewShape().GetDim(0);
CHECK_COND(CheckGroupListSplitK(gmmParams, true, false, false, batchSizeY) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid groupList.");
}
return ACLNN_SUCCESS;
}
OP_LOGE(ACLNN_ERR_PARAM_INVALID,"When groupType is 2, only support case with unseparated x, y, "
"but now x size is %lu, weight size is %lu, y size is %lu.", xSize, weightSize, ySize);
return ACLNN_ERR_PARAM_INVALID;
}
static aclnnStatus CheckCaseNoSplit(const gmm::GroupedMatmulParams &gmmParams) {
CHECK_COND(gmmParams.splitItem == X_Y_SEPARATED || gmmParams.splitItem == Y_SEPARATED, ACLNN_ERR_PARAM_INVALID,
"When y is separated, splitItem should be 0/1, but current splitItem is %ld.", gmmParams.splitItem);
size_t xSize = gmmParams.x->Size();
size_t ySize = gmmParams.y->Size();
size_t weightSize = gmmParams.weight->Size();
CHECK_COND(xSize == ySize, ACLNN_ERR_PARAM_INVALID,
"When y is separated, size of x %lu should equal to size of y %lu.", xSize, ySize);
CHECK_COND(xSize == weightSize, ACLNN_ERR_PARAM_INVALID,
"When x and weight are separated, size of x %lu should equal to size of weight %lu.",
xSize, weightSize);
CHECK_COND(CheckDimNumAndGroupListNoSplitAndFormat(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Dim num or format of tensor in tensor lists or grouplist is invalid.");
for (size_t i = 0; i < xSize; i++) {
size_t xDimNum = (*gmmParams.x)[i]->GetViewShape().GetDimNum();
for (size_t dimIdx = 0UL; dimIdx < xDimNum - 2UL; dimIdx++) {
size_t xDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(dimIdx);
size_t yDimValue = (*gmmParams.y)[i]->GetViewShape().GetDim(dimIdx);
CHECK_COND(xDimValue == yDimValue, ACLNN_ERR_PARAM_INVALID,
"Y[%lu] dim %lu value %lu should equal to x[%lu] dim %lu value %lu.",
i, dimIdx, xDimValue, i, dimIdx, yDimValue);
}
size_t xKDimValue = (*gmmParams.x)[i]->GetViewShape().GetDim(xDimNum - 1UL);
if (op::GetCurrentPlatformInfo().GetCurNpuArch() != NpuArch::DAV_3510) {
CHECK_COND(xKDimValue <= MAX_INNER_AXIS, ACLNN_ERR_PARAM_INVALID,
"X[%lu] dim %lu value %lu should less or equal to 65535.", i, xDimNum - 1, xKDimValue);
}
size_t weightKDimValue = (*gmmParams.weight)[i]->GetViewShape().GetDim(0);
CHECK_COND(xKDimValue == weightKDimValue, ACLNN_ERR_PARAM_INVALID,
"X[%lu] dim %lu value %lu should equal to weight[%lu] dim 0 value %lu.",
i, xDimNum - 1, xKDimValue, i, weightKDimValue);
size_t weightNDimValue = (*gmmParams.weight)[i]->GetViewShape().GetDim(1);
if (op::GetCurrentPlatformInfo().GetCurNpuArch() != NpuArch::DAV_3510 && !gmmParams.transposeWeight) {
CHECK_COND(weightNDimValue <= MAX_INNER_AXIS, ACLNN_ERR_PARAM_INVALID,
"W[%lu] dim %d value %lu should less or equal to 65535.", i, 1, weightNDimValue);
}
if ((*gmmParams.weight)[0]->GetDataType() == DataType::DT_INT4) {
CHECK_COND(weightNDimValue % 2 == 0, ACLNN_ERR_PARAM_INVALID,
"W[%lu] dim %d value %lu should be even when weight is int4 dtype.", i, 1, weightNDimValue);
}
size_t yNDimValue = (*gmmParams.y)[i]->GetViewShape().GetDim(xDimNum - 1UL);
CHECK_COND(yNDimValue == weightNDimValue, ACLNN_ERR_PARAM_INVALID,
"Y[%lu] dim %lu value %lu should equal to weight[%lu] dim 1 value %lu.",
i, xDimNum - 1, yNDimValue, i, weightNDimValue);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckParamDifferentGroupType(const gmm::GroupedMatmulParams &gmmParams) {
CHECK_COND(!(gmmParams.transposeX && gmmParams.transposeWeight), ACLNN_ERR_PARAM_INVALID,
"X and weight can not be transposed at the same time.");
CHECK_COND((gmmParams.groupListOptional == nullptr || gmmParams.groupListOptional->Size() >= 1) &&
(gmmParams.groupTensorOptional == nullptr || gmmParams.groupTensorOptional->GetViewShape().GetDim(0) >= 1),
ACLNN_ERR_PARAM_INVALID, "Size of groupList can not be 0."
"If expected group num is 1, groupList should be nullptr.");
if (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P && gmmParams.transposeWeight) {
CHECK_COND(gmmParams.groupType == gmm::SPLIT_M && gmmParams.x->Size() == 1 && gmmParams.weight->Size() == 1
&& gmmParams.y->Size() == 1, ACLNN_ERR_PARAM_INVALID,
"When transpose weight, ASCEND310P only support split m, single x, single weight, single y.");
}
DataType weightDtype = (*gmmParams.weight)[0]->GetDataType();
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
IsWeightQuant(gmmParams.xDtype, weightDtype)) {
if (gmmParams.groupType == gmm::SPLIT_M) {
size_t batchSizeWeight = (*gmmParams.weight)[0]->GetViewShape().GetDim(0);
CHECK_COND(CheckGroupListSplitM(gmmParams, true, false, false, batchSizeWeight) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid groupList.");
}
return ACLNN_SUCCESS;
}
if (gmmParams.groupType == gmm::NO_SPLIT) {
CHECK_COND(!gmmParams.transposeX, ACLNN_ERR_PARAM_INVALID,
"When x, weight and y are all separated, x can not be transposed.");
CHECK_COND(!(gmmParams.apiVersion == gmm::GMMApiVersion::V1 && gmmParams.transposeWeight) ||
op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510, ACLNN_ERR_PARAM_INVALID,
"In this version, when x, weight and y are all separated, weight can not be transposed.");
CHECK_COND(CheckCaseNoSplit(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Invalid inputs!");
} else if (gmmParams.groupType == gmm::SPLIT_M) {
std::string errorMessage = gmmParams.apiVersion != gmm::GMMApiVersion::V2 && !gmmParams.isSingleWeight
? "When splited axis is M" : "When groupType is 0";
CHECK_COND(!gmmParams.transposeX, ACLNN_ERR_PARAM_INVALID,
"%s, x can not be transposed.", errorMessage.c_str());
CHECK_COND(CheckCaseSplitM(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Invalid inputs!");
} else if (gmmParams.groupType == gmm::SPLIT_K) {
CHECK_COND(gmmParams.biasOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"When groupType is 2, bias must be empty.");
CHECK_COND(CheckCaseSplitK(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Invalid inputs!");
}
if (gmmParams.biasOptional != nullptr) {
CHECK_COND(CheckOptionalTensorList(gmmParams, gmmParams.biasOptional, "bias") == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "Invalid bias!");
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckTuningConfig(const gmm::GroupedMatmulParams &gmmParams) {
size_t tensorLength = gmmParams.x->Size();
int64_t maxM = 0;
for (size_t i = 0; i < tensorLength; ++i) {
maxM = std::max(maxM, (*gmmParams.x)[i]->GetViewShape().GetDim(0));
}
if (gmmParams.tuningConfigOptional != nullptr && gmmParams.tuningConfigOptional->Size() > 0) {
CHECK_COND((*gmmParams.tuningConfigOptional)[0] >= 0 && (*gmmParams.tuningConfigOptional)[0] <= maxM,
ACLNN_ERR_PARAM_INVALID,
"Invalid tuningConfigOptional (%ld)! It should be a non-negative num, and less than or equal to maxM: %ld",
(*gmmParams.tuningConfigOptional)[0], maxM);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckTensorListLength(const aclTensorList *tensorList) {
size_t groupSize = 0;
if (tensorList != nullptr) {
groupSize = tensorList->Size();
}
CHECK_COND(
groupSize <= MAX_GROUP_LIST_SIZE_ARRAY, ACLNN_ERR_PARAM_INVALID, "Length of tensorList should not exceed %ld, but actually got %lu.",
MAX_GROUP_LIST_SIZE_ARRAY, groupSize);
return ACLNN_SUCCESS;
}
static aclnnStatus CheckGroupSize(const gmm::GroupedMatmulParams &gmmParams) {
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
if ((gmmParams.xDtype == DataType::DT_BF16 || gmmParams.xDtype == DataType::DT_FLOAT16 ||
gmmParams.xDtype == DataType::DT_FLOAT) && gmmParams.xDtype == (*gmmParams.weight)[0]->GetDataType()) {
return gmm::AclnnGroupedMatmulNoQuantDAV3510Checker(gmmParams).CheckGroupedMatmulGroupSizeNoQuantDAV3510();
}
}
CHECK_COND(CheckTensorListLength(gmmParams.x) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Invalid length of tensorList x.");
CHECK_COND(CheckTensorListLength(gmmParams.weight) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Invalid length of tensorList weight.");
return ACLNN_SUCCESS;
}
static aclnnStatus CheckUnusedParams(const gmm::GroupedMatmulParams &gmmParams) {
CHECK_COND(gmmParams.activationInputOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"ActivationInputOptional must be nullptr.");
CHECK_COND(gmmParams.activationQuantScaleOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"ActivationQuantScaleOptional must be nullptr.");
CHECK_COND(gmmParams.activationQuantOffsetOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"ActivationQuantOffsetOptional must be nullptr.");
CHECK_COND(gmmParams.activationFeatureOutOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"ActivationFeatureOutOptional must be nullptr.");
CHECK_COND(gmmParams.dynQuantScaleOutOptional == nullptr, ACLNN_ERR_PARAM_INVALID,
"DynQuantScaleOutOptional must be nullptr.");
return ACLNN_SUCCESS;
}
static aclnnStatus CheckParam(const gmm::GroupedMatmulParams &gmmParams) {
CHECK_COND(CheckUnusedParams(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Invalid unused params.");
CHECK_RET(CheckFunctionParams(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckParamDifferentGroupType(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckGroupSize(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckTuningConfig(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
static void CheckOptionalTensorListEmpty(const aclTensorList *&tensorList) {
if (tensorList != nullptr) {
if (tensorList->Size() == 0) {
tensorList = nullptr;
} else if ((*tensorList)[0] == nullptr) {
tensorList = nullptr;
} else if (tensorList->Size() == 1) {
op::Shape shape = (*tensorList)[0]->GetViewShape();
if (shape.GetDimNum() == 1 && shape.GetDim(0) == 0) {
tensorList = nullptr;
}
}
}
}
static void ResetEmptyTensor(gmm::GroupedMatmulParams &gmmParams) {
if (gmmParams.groupListOptional != nullptr && gmmParams.groupListOptional->Size() == 0) {
gmmParams.groupListOptional = nullptr;
}
CheckOptionalTensorListEmpty(gmmParams.biasOptional);
CheckOptionalTensorListEmpty(gmmParams.scaleOptional);
CheckOptionalTensorListEmpty(gmmParams.offsetOptional);
CheckOptionalTensorListEmpty(gmmParams.antiquantScaleOptional);
CheckOptionalTensorListEmpty(gmmParams.antiquantOffsetOptional);
CheckOptionalTensorListEmpty(gmmParams.perTokenScaleOptional);
CheckOptionalTensorListEmpty(gmmParams.activationInputOptional);
CheckOptionalTensorListEmpty(gmmParams.activationQuantScaleOptional);
CheckOptionalTensorListEmpty(gmmParams.activationQuantOffsetOptional);
CheckOptionalTensorListEmpty(gmmParams.activationFeatureOutOptional);
CheckOptionalTensorListEmpty(gmmParams.dynQuantScaleOutOptional);
}
static void CreateEmptyTensor(const aclDataType dataType, const aclTensorList *&gmmTensorList,
aclTensorList *&tensorList, aclOpExecutor *executor) {
if (gmmTensorList == nullptr) {
FVector<aclTensor*> emptyTensors;
aclTensor *emptyTensor = executor->AllocTensor({0}, static_cast<op::DataType>(dataType));
emptyTensors.emplace_back(emptyTensor);
tensorList = executor->AllocTensorList(emptyTensors.data(), emptyTensors.size());
gmmTensorList = tensorList;
}
}
static aclnnStatus DataContiguous(const aclTensorList *&tensors, aclOpExecutor *executor) {
std::vector<const aclTensor *> tensorsVec;
const aclTensor *contiguousTensor = nullptr;
for (size_t i = 0; i < tensors->Size(); ++i) {
const aclTensor *tensor = (*tensors)[i];
contiguousTensor = l0op::Contiguous(tensor, executor);
CHECK_RET(contiguousTensor != nullptr, ACLNN_ERR_INNER_NULLPTR);
tensorsVec.push_back(contiguousTensor);
}
tensors = executor->AllocTensorList(tensorsVec.data(), tensorsVec.size());
return ACLNN_SUCCESS;
}
static aclnnStatus DataContiguousAndTransFormat(const aclTensor *tensor, const aclTensor *&reformatedTensor,
const op::Format requiredFormat, aclOpExecutor *executor) {
CHECK_RET(tensor != nullptr, ACLNN_ERR_INNER_NULLPTR);
if (!op::IsPrivateFormat(tensor->GetStorageFormat()) && tensor->GetStorageFormat() != op::Format::FORMAT_ND) {
tensor = l0op::ReFormat(tensor, op::Format::FORMAT_ND, executor);
}
if (tensor == nullptr || tensor->GetViewShape().GetDimNum() == 1) {
OP_LOGD("No need to do contiguous process");
reformatedTensor = tensor;
} else {
reformatedTensor = l0op::Contiguous(tensor, executor);
}
CHECK_RET(reformatedTensor != nullptr, ACLNN_ERR_INNER_NULLPTR);
reformatedTensor = l0op::TransData(reformatedTensor, requiredFormat, 1, executor);
CHECK_RET(reformatedTensor != nullptr, ACLNN_ERR_INNER_NULLPTR);
return ACLNN_SUCCESS;
}
static aclnnStatus TransWeightToNzCheckAlign(gmm::GroupedMatmulParams &gmmParams, const aclTensor *weight,
const DataType xDtype)
{
size_t viewDimNum = weight->GetViewShape().GetDimNum();
uint64_t k = gmmParams.transposeWeight ? weight->GetViewShape().GetDim(viewDimNum - 1) :
weight->GetViewShape().GetDim(viewDimNum - gmm::LAST_SECOND_DIM_INDEX);
uint64_t n = gmmParams.transposeWeight ? weight->GetViewShape().GetDim(viewDimNum - gmm::LAST_SECOND_DIM_INDEX) :
weight->GetViewShape().GetDim(viewDimNum - 1);
bool k_align = false;
bool n_align = false;
if (weight->GetDataType() == DataType::DT_INT8) {
k_align = gmmParams.transposeWeight ? k % ALIGN_NZ_INT8_N == 0 : k % ALIGN_NZ_K == 0;
n_align = gmmParams.transposeWeight ? n % ALIGN_NZ_K == 0 : n % ALIGN_NZ_INT8_N == 0;
} else if (weight->GetDataType() == DataType::DT_BF16 || weight->GetDataType() == DataType::DT_FLOAT16) {
k_align = k % ALIGN_NZ_K == 0;
n_align = n % ALIGN_NZ_K == 0;
} else if (weight->GetDataType() == DataType::DT_INT4) {
k_align = gmmParams.transposeWeight ? k % ALIGN_NZ_4BIT_N == 0 : k % ALIGN_NZ_K == 0;
n_align = gmmParams.transposeWeight ? n % ALIGN_NZ_K == 0 : n % ALIGN_NZ_4BIT_N == 0;
} else if (weight->GetDataType() == DataType::DT_FLOAT4_E2M1 &&
(xDtype == DataType::DT_FLOAT16 || xDtype == DataType::DT_BF16 || xDtype == DataType::DT_FLOAT8_E4M3FN)) {
k_align = k % ALIGN_NZ_4BIT_K == 0;
n_align = n % ALIGN_NZ_4BIT_N == 0;
}
CHECK_COND(k_align == true && n_align == true, ACLNN_ERR_PARAM_INVALID,
"When weight(%s) format is FRACTAL_NZ, weight'shape(k[%lu], n[%lu]) should be divisible by the "
"following shape: INT8:[16, 32],BF16/FP16[16, 16],INT4[16, 64],FP4[64,64]). If the weight is transposed,"
"the k/n need to be reversed.",
gmm::dTypeToString(weight->GetDataType()).c_str(), k, n);
return ACLNN_SUCCESS;
}
static aclnnStatus TransWeightToNz(gmm::GroupedMatmulParams &gmmParams, aclOpExecutor *executor) {
bool is310p = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P;
if (is310p) {
const aclTensorList *&weights = gmmParams.weight;
size_t wLength = weights->Size();
std::vector<const aclTensor*> reformatedWeightVec;
const aclTensor* reformatedWeight = nullptr;
for (size_t i(0); i < wLength; ++i) {
const aclTensor* weight = (*weights)[i];
op::Shape shape = weight->GetViewShape();
int64_t n = gmmParams.transposeWeight ? shape.GetDim(shape.GetDimNum() - 2) : shape.GetDim(shape.GetDimNum() - 1);
CHECK_COND(n % static_cast<int64_t>(32 / std::max<int>(1, op::TypeSize(weight->GetDataType()))) == 0, ACLNN_ERR_PARAM_INVALID,
"Output n axis should align with 32 Bytes, but now is %ld", n);
aclnnStatus ret = DataContiguousAndTransFormat(weight, reformatedWeight,
Format::FORMAT_FRACTAL_NZ, executor);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
reformatedWeightVec.push_back(reformatedWeight);
}
weights = executor->AllocTensorList(reformatedWeightVec.data(), reformatedWeightVec.size());
} else {
const aclTensorList *&weights = gmmParams.weight;
const aclTensorList *&x = gmmParams.x;
CHECK_COND((*x)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, "The first tensor of x is nullptr!");
auto xDtype = (*x)[0]->GetDataType();
size_t wLength = weights->Size();
for (size_t i(0); i < wLength; ++i) {
const aclTensor* weight = (*weights)[i];
if (weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ &&
weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ_C0_16 &&
weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ_C0_32) {
break;
}
if (!(op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
IsQuant(xDtype, weight->GetDataType()) &&
(*gmmParams.weight)[0]->GetStorageFormat() == op::Format::FORMAT_FRACTAL_NZ)) {
aclnnStatus ret = TransWeightToNzCheckAlign(gmmParams, weight, xDtype);
CHECK_RET(ret == ACLNN_SUCCESS, ret);
}
}
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckZeroShape(gmm::GroupedMatmulParams ¶ms, uint64_t *workspaceSize) {
bool isEmptyX = true;
bool isEmptyWeight =true;
for (size_t i = 0; i < params.x->Size(); ++i) {
if (!((*params.x)[i]->IsEmpty())) {
isEmptyX = false;
break;
}
}
for (size_t i = 0; i < params.weight->Size(); ++i) {
if (!((*params.weight)[i]->IsEmpty())) {
isEmptyWeight = false;
break;
}
}
if (isEmptyX || isEmptyWeight) {
*workspaceSize = 0UL;
return ACLNN_ERR_PARAM_INVALID;
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckZeroShapeSplitK(gmm::GroupedMatmulParams ¶ms, uint64_t *workspaceSize) {
bool zeroM = true;
bool zeroN = true;
for (size_t i = 0; i < params.x->Size(); ++i) {
CHECK_COND((*params.x)[i] != nullptr, ACLNN_SUCCESS, "GroupedMatmul x tensor should not be null");
auto xShape = (*params.x)[i]->GetViewShape();
size_t xDimNum = xShape.GetDimNum();
CHECK_COND(xDimNum == gmm::MIN_FM_DIM, ACLNN_SUCCESS,
"When groupType = 2 GroupedMatmul x dim num should be 2, but actual is %zu.", xDimNum);
zeroM = zeroM && (xShape.GetDim(1) == 0);
}
for (size_t i = 0; i < params.weight->Size(); ++i) {
CHECK_COND((*params.weight)[i] != nullptr, ACLNN_SUCCESS,
"GroupedMatmul weight tensor should not be null");
auto wShape = (*params.weight)[i]->GetViewShape();
CHECK_COND(wShape.GetDimNum() == gmm::MIN_FM_DIM, ACLNN_SUCCESS,
"When groupType = 2 GroupedMatmul weight dim num should be 2, but actual %zu.", wShape.GetDimNum());
zeroN = zeroN && (wShape.GetDim(wShape.GetDimNum() - 1) == 0);
}
if (zeroM || zeroN) {
*workspaceSize = 0UL;
return ACLNN_ERR_PARAM_INVALID;
}
return ACLNN_SUCCESS;
}
static void SetAntiQuantParamsTensorEmptyDAV3510(gmm::GroupedMatmulParams ¶ms, aclOpExecutor *executor)
{
DataType weightDtype = (*params.weight)[0]->GetDataType();
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
IsWeightQuant(params.xDtype, weightDtype)) {
if (params.xDtype == ge::DataType::DT_FLOAT8_E4M3FN) {
DataType yDtype = (*params.y)[0]->GetDataType();
aclTensorList *emptyBiasList = nullptr;
CreateEmptyTensor(ToAclDataType(yDtype), params.biasOptional, emptyBiasList, executor);
aclTensorList *emptyAntiquantOffsetList = nullptr;
CreateEmptyTensor(ToAclDataType(yDtype), params.antiquantOffsetOptional, emptyAntiquantOffsetList,
executor);
} else if (params.xDtype == ge::DataType::DT_INT8) {
aclTensorList *emptyBiasList = nullptr;
CreateEmptyTensor(ToAclDataType(ge::DataType::DT_FLOAT), params.biasOptional, emptyBiasList, executor);
aclTensorList *emptyAntiquantOffsetList = nullptr;
CreateEmptyTensor(ToAclDataType(ge::DataType::DT_FLOAT16), params.antiquantOffsetOptional, emptyAntiquantOffsetList,
executor);
} else {
aclTensorList *emptyBiasList = nullptr;
CreateEmptyTensor(ToAclDataType(params.xDtype), params.biasOptional, emptyBiasList, executor);
aclTensorList *emptyAntiquantOffsetList = nullptr;
CreateEmptyTensor(ToAclDataType(params.xDtype), params.antiquantOffsetOptional, emptyAntiquantOffsetList,
executor);
}
}
}
static void SetParamsTensorEmpty(gmm::GroupedMatmulParams ¶ms, aclOpExecutor *executor) {
aclTensorList *emptyBiasList = nullptr;
DataType weightDtype = (*(params.weight))[0]->GetDataType();
if(params.xDtype == DataType::DT_INT8 && weightDtype == DataType::DT_INT4) {
CreateEmptyTensor(aclDataType::ACL_FLOAT, params.biasOptional, emptyBiasList, executor);
}
else { CreateEmptyTensor(gmm::BIAS_DTYPE.at(params.xDtype), params.biasOptional, emptyBiasList, executor); }
aclTensorList *emptyScaleList = nullptr;
CreateEmptyTensor(aclDataType::ACL_UINT64, params.scaleOptional, emptyScaleList, executor);
aclTensorList *emptyOffsetList = nullptr;
CreateEmptyTensor(aclDataType::ACL_FLOAT, params.offsetOptional, emptyOffsetList, executor);
aclTensorList *emptyAntiquantScaleList = nullptr;
CreateEmptyTensor(aclDataType::ACL_FLOAT16, params.antiquantScaleOptional, emptyAntiquantScaleList, executor);
aclTensorList *emptyAntiquantOffsetList = nullptr;
if(params.xDtype == DataType::DT_BF16 && weightDtype == DataType::DT_INT4) {
CreateEmptyTensor(aclDataType::ACL_BF16, params.antiquantOffsetOptional, emptyAntiquantOffsetList, executor);
} else {
CreateEmptyTensor(aclDataType::ACL_FLOAT16, params.antiquantOffsetOptional, emptyAntiquantOffsetList, executor);
}
aclTensorList *emptyPerTokenScaleList = nullptr;
CreateEmptyTensor(aclDataType::ACL_FLOAT, params.perTokenScaleOptional, emptyPerTokenScaleList, executor);
aclTensorList *emptyActivationInputList = nullptr;
CreateEmptyTensor(aclDataType::ACL_FLOAT, params.activationInputOptional, emptyActivationInputList, executor);
aclTensorList *emptyActivationQuantScaleList = nullptr;
CreateEmptyTensor(aclDataType::ACL_FLOAT, params.activationQuantScaleOptional, emptyActivationQuantScaleList, executor);
aclTensorList *emptyActivationQuantOffsetList = nullptr;
CreateEmptyTensor(aclDataType::ACL_FLOAT, params.activationQuantOffsetOptional, emptyActivationQuantOffsetList, executor);
aclTensorList *emptyActivationFeatureOutList = nullptr;
CreateEmptyTensor(aclDataType::ACL_FLOAT, params.activationFeatureOutOptional, emptyActivationFeatureOutList, executor);
aclTensorList *emptyDynQuantScaleOutList = nullptr;
CreateEmptyTensor(aclDataType::ACL_FLOAT, params.dynQuantScaleOutOptional, emptyDynQuantScaleOutList, executor);
}
static aclnnStatus CheckOutputShape(const aclTensorList* l0Res, const aclTensorList* y) {
CHECK_COND(l0Res->Size() == y->Size(), ACLNN_ERR_PARAM_INVALID, "Output tensor list length is not right.");
for (size_t i = 0; i < y->Size(); ++i) {
auto const &resShape = (*l0Res)[i]->GetViewShape();
auto const &yShape = (*y)[i]->GetViewShape();
if (resShape != yShape) {
if (!(resShape.GetShapeSize() == 1 && yShape.GetShapeSize() == 1)) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "Tensorlist Output %lu tensor's shape[%s] is not equal with infered output's shape[%s].", i,
op::ToString(yShape).GetString(), op::ToString(resShape).GetString());
return ACLNN_ERR_PARAM_INVALID;
}
}
}
return ACLNN_SUCCESS;
}
static bool IsPerTileQuantMode(gmm::GroupedMatmulParams ¶ms)
{
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
IsQuant(params.xDtype, (*params.weight)[0]->GetDataType())) {
gmm::AclnnGroupedMatmulDAV3510Checker<aclTensorList> checker(params);
return checker.IsPerTileQuantMode();
}
return false;
}
static void SetTransposedScaleTensorListContiguous(gmm::GroupedMatmulParams ¶ms, aclOpExecutor *executorPtr,
const bool &isPerTileQuantMode)
{
if (params.scaleOptional != nullptr) {
std::vector<aclTensor *> scaleTensorList;
if ((*params.scaleOptional)[0] != nullptr &&
(*params.scaleOptional)[0]->GetDataType() == DataType::DT_FLOAT8_E8M0) {
gmm::CreateContiguousTensorListForMXTypeMScale(params.scaleOptional, scaleTensorList, executorPtr);
params.scaleOptional = executorPtr->AllocTensorList(scaleTensorList.data(), scaleTensorList.size());
} else if (isPerTileQuantMode) {
gmm::CreateContiguousTensorList(params.scaleOptional, scaleTensorList, executorPtr);
params.scaleOptional = executorPtr->AllocTensorList(scaleTensorList.data(), scaleTensorList.size());
}
}
if ((*params.antiquantScaleOptional)[0] != nullptr &&
((*params.antiquantScaleOptional)[0]->GetViewShape().GetDimNum() == ANTIQUANT_SCALE_3D_DIMS ||
(*params.antiquantScaleOptional)[0]->GetViewShape().GetDimNum() == ANTIQUANT_SCALE_4D_DIMS) &&
op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
params.apiVersion == gmm::GMMApiVersion::WeightNz) {
std::vector<aclTensor *> antiSTensorList;
if ((*params.antiquantScaleOptional)[0]->GetDataType() == DataType::DT_FLOAT8_E8M0) {
gmm::CreateContiguousTensorListForMXTypeMScale(params.antiquantScaleOptional, antiSTensorList, executorPtr);
} else {
gmm::CreateContiguousTensorList(params.antiquantScaleOptional, antiSTensorList, executorPtr);
}
params.antiquantScaleOptional = executorPtr->AllocTensorList(antiSTensorList.data(), antiSTensorList.size());
}
}
static void SetTransposedTensorListContiguous(gmm::GroupedMatmulParams ¶ms, aclOpExecutor *executorPtr)
{
bool isPerTileQuantMode = IsPerTileQuantMode(params);
DataType weightDtype = (*params.weight)[0]->GetDataType();
if (params.transposeX) {
std::vector<aclTensor*> xTensorList;
gmm::CreateContiguousTensorList(params.x, xTensorList, executorPtr);
params.x = executorPtr->AllocTensorList(xTensorList.data(), xTensorList.size());
if (params.perTokenScaleOptional != nullptr &&
((*params.perTokenScaleOptional)[0]->GetDataType() == DataType::DT_FLOAT8_E8M0 || isPerTileQuantMode)) {
std::vector<aclTensor*> perTokenScaleTensorList;
if (isPerTileQuantMode) {
gmm::CreateContiguousTensorList(params.perTokenScaleOptional, perTokenScaleTensorList, executorPtr);
} else {
gmm::CreateContiguousTensorListForPertoken(params.perTokenScaleOptional, perTokenScaleTensorList, executorPtr);}
params.perTokenScaleOptional = executorPtr->AllocTensorList(perTokenScaleTensorList.data(), perTokenScaleTensorList.size());
}
}
if (params.transposeWeight) {
std::vector<aclTensor *> weightTensorList;
auto nZShape = (*params.weight)[0]->GetStorageShape();
gmm::CreateContiguousTensorList(params.weight, weightTensorList, executorPtr);
params.weight = executorPtr->AllocTensorList(weightTensorList.data(), weightTensorList.size());
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
((IsQuant(params.xDtype, weightDtype) &&
(*params.weight)[0]->GetStorageFormat() == op::Format::FORMAT_FRACTAL_NZ) ||
(params.apiVersion == gmm::GMMApiVersion::WeightNz && IsWeightQuant(params.xDtype, weightDtype)))) {
(*params.weight)[0]->SetStorageShape(nZShape);
}
SetTransposedScaleTensorListContiguous(params, executorPtr, isPerTileQuantMode);
}
}
static aclnnStatus ParamsDataContiguous(gmm::GroupedMatmulParams ¶ms, aclOpExecutor *executorPtr) {
CHECK_COND(DataContiguous(params.x, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous x failed.");
DataType xDtype = (*params.x)[0]->GetDataType();
DataType weightDtype = (*params.weight)[0]->GetDataType();
if (!(op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
((IsQuant(xDtype, weightDtype) && (*params.weight)[0]->GetStorageFormat() == op::Format::FORMAT_FRACTAL_NZ) ||
(params.apiVersion == gmm::GMMApiVersion::WeightNz && IsWeightQuant(xDtype, weightDtype))))) {
CHECK_COND(DataContiguous(params.weight, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous weight failed.");
}
CHECK_COND(DataContiguous(params.biasOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous biasOptional failed.");
CHECK_COND(DataContiguous(params.scaleOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous scaleOptional failed.");
CHECK_COND(DataContiguous(params.offsetOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous offsetOptional failed.");
CHECK_COND(DataContiguous(params.antiquantScaleOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous antiquantScaleOptional failed.");
if (params.antiquantOffsetOptional != nullptr) {
CHECK_COND(DataContiguous(params.antiquantOffsetOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous antiquantOffsetOptional failed.");
}
CHECK_COND(DataContiguous(params.perTokenScaleOptional, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous perTokenScaleOptional failed.");
if (params.groupTensorOptional != nullptr) {
params.groupTensorOptional = l0op::Contiguous(params.groupTensorOptional, executorPtr);
CHECK_COND(params.groupTensorOptional != nullptr, ACLNN_ERR_PARAM_INVALID,
"Contiguous groupTensorOptional failed.");
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckWeightQuantGMMWeightNz(DataType x1Dtype, DataType weightDtype, DataType yDtype) {
if (x1Dtype == DataType::DT_INT8 && weightDtype == DataType::DT_INT4) {
CHECK_COND(
yDtype == DataType::DT_FLOAT16 || yDtype == DataType::DT_BF16, ACLNN_ERR_PARAM_INVALID,
"The dtypes of x[%s]-weight[%s]-y[%s] do not match with required dtype.The x-weight-y of the antiquant"
"case[A8W4] only supports the following combinations: INT8-INT4-BF16,INT8-INT4-Fp16",
gmm::dTypeToString(x1Dtype).c_str(), gmm::dTypeToString(weightDtype).c_str(),
gmm::dTypeToString(yDtype).c_str());
return ACLNN_SUCCESS;
} else if ((x1Dtype == DataType::DT_FLOAT16 || x1Dtype == DataType::DT_BF16) &&
weightDtype == DataType::DT_FLOAT4_E2M1) {
CHECK_COND(
yDtype == DataType::DT_FLOAT16 || yDtype == DataType::DT_BF16, ACLNN_ERR_PARAM_INVALID,
"The dtypes of x[%s]-weight[%s]-y[%s] do not match with required dtype.The x-weight-y of the antiquant"
"case[A16mxFp4] only supports the following combinations: Fp16-Fp4_e2m1-Fp16,BF16-Fp4_e2m1-BF16",
gmm::dTypeToString(x1Dtype).c_str(), gmm::dTypeToString(weightDtype).c_str(),
gmm::dTypeToString(yDtype).c_str());
return ACLNN_SUCCESS;
} else if (x1Dtype == DataType::DT_FLOAT8_E4M3FN && weightDtype == DataType::DT_FLOAT4_E2M1) {
CHECK_COND(
yDtype == DataType::DT_BF16 || yDtype == DataType::DT_FLOAT16, ACLNN_ERR_PARAM_INVALID,
"The dtypes of x[%s]-weight[%s]-y[%s] do not match with required dtype.The x-weight-y of the antiquant"
"case[MxA8W4] only supports the following combinations: Fp8_e4m3fn-Fp4_e2m1-BF16/Fp16",
gmm::dTypeToString(x1Dtype).c_str(), gmm::dTypeToString(weightDtype).c_str(),
gmm::dTypeToString(yDtype).c_str());
return ACLNN_SUCCESS;
}
return ACLNN_ERR_PARAM_INVALID;
}
static aclnnStatus CheckQuantGMMWeightNz(DataType x1Dtype, DataType weightDtype, DataType yDtype) {
if (x1Dtype == DataType::DT_INT8 && weightDtype == DataType::DT_INT8) {
CHECK_COND(
yDtype == DataType::DT_FLOAT16 || yDtype == DataType::DT_BF16 || yDtype == DataType::DT_INT32,
ACLNN_ERR_PARAM_INVALID,
"The dtypes of x[%s]-weight[%s]-y[%s] do not match with required dtype.The x-weight-y of the quant case"
"only supports the following combinations: INT8-INT8-BF16,INT8-INT8-Fp16,INT8-INT8-INT32",
gmm::dTypeToString(x1Dtype).c_str(), gmm::dTypeToString(weightDtype).c_str(),
gmm::dTypeToString(yDtype).c_str());
return ACLNN_SUCCESS;
} else if (x1Dtype == DataType::DT_INT4 && weightDtype == DataType::DT_INT4) {
CHECK_COND(
yDtype == DataType::DT_FLOAT16 || yDtype == DataType::DT_BF16, ACLNN_ERR_PARAM_INVALID,
"The dtypes of x[%s]-weight[%s]-y[%s] do not match with required dtype.The x-weight-y of the antiquant"
"case[A4W4] only supports the following combinations: INT4-INT4-BF16,INT4-INT4-Fp16",
gmm::dTypeToString(x1Dtype).c_str(), gmm::dTypeToString(weightDtype).c_str(),
gmm::dTypeToString(yDtype).c_str());
return ACLNN_SUCCESS;
} else if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
x1Dtype == DataType::DT_FLOAT8_E4M3FN && weightDtype == DataType::DT_FLOAT8_E4M3FN) {
return ACLNN_SUCCESS;
} else if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
(x1Dtype == DataType::DT_FLOAT4_E2M1 || x1Dtype == DataType::DT_FLOAT4_E1M2) &&
(weightDtype == DataType::DT_FLOAT4_E2M1 || weightDtype == DataType::DT_FLOAT4_E1M2)) {
return ACLNN_SUCCESS;
}
return ACLNN_ERR_PARAM_INVALID;
}
static aclnnStatus CheckNoQuantGMMWeightNz(DataType x1Dtype, DataType weightDtype, DataType yDtype) {
if (x1Dtype == DataType::DT_BF16 && weightDtype == DataType::DT_BF16) {
CHECK_COND(
yDtype == DataType::DT_BF16, ACLNN_ERR_PARAM_INVALID,
"The dtypes of x[%s]-weight[%s]-y[%s] do not match with required dtype."
"The x-weight-y of the antiquant case[BF16] only supports the following combinations: BF16-BF16-BF16",
gmm::dTypeToString(x1Dtype).c_str(), gmm::dTypeToString(weightDtype).c_str(),
gmm::dTypeToString(yDtype).c_str());
return ACLNN_SUCCESS;
} else if (x1Dtype == DataType::DT_FLOAT16 && weightDtype == DataType::DT_FLOAT16) {
CHECK_COND(
yDtype == DataType::DT_FLOAT16, ACLNN_ERR_PARAM_INVALID,
"The dtypes of x[%s]-weight[%s]-y[%s] do not match with required dtype. The x-weight-y of the antiquant"
"case[FLOAT16] only supports the following combinations: FLOAT16-FLOAT16-FLOAT16",
gmm::dTypeToString(x1Dtype).c_str(), gmm::dTypeToString(weightDtype).c_str(),
gmm::dTypeToString(yDtype).c_str());
return ACLNN_SUCCESS;
}
return ACLNN_ERR_PARAM_INVALID;
}
static aclnnStatus ParamsWeightNzDtype(gmm::GroupedMatmulParams ¶ms) {
DataType x1Dtype = params.xDtype;
DataType weightDtype = (*params.weight)[0]->GetDataType();
DataType yDtype = (*params.y)[0]->GetDataType();
if (CheckWeightQuantGMMWeightNz(x1Dtype, weightDtype, yDtype) == ACLNN_SUCCESS ||
CheckQuantGMMWeightNz(x1Dtype, weightDtype, yDtype) == ACLNN_SUCCESS ||
CheckNoQuantGMMWeightNz(x1Dtype, weightDtype, yDtype) == ACLNN_SUCCESS) {
return ACLNN_SUCCESS;
}
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"The dtypes of x[%s]-weight[%s] do not match with required dtype."
"Only supported x-weight: INT8-INT8, BF16-BF16, FP16-FP16, INT8-INT4, INT4-INT4, FP16/BF16-FP4_E2M1, "
"FP8_E4M3FN-FP4_E2M1, FP8_E4M3FN-FP8_E4M3FN.",
gmm::dTypeToString(x1Dtype).c_str(), gmm::dTypeToString(weightDtype).c_str());
return ACLNN_ERR_PARAM_INVALID;
}
static const aclTensor *SetTensorToNZFormat(const aclTensor *input, op::Shape &shape, aclOpExecutor *executor) {
auto formatTensor = executor->CreateView(input, shape, input->GetViewOffset());
formatTensor->SetStorageFormat(op::Format::FORMAT_FRACTAL_NZ);
formatTensor->SetOriginalFormat(input->GetViewFormat());
formatTensor->SetViewShape(input->GetViewShape());
return formatTensor;
}
static aclnnStatus SetStorageShape(gmm::GroupedMatmulParams ¶ms, op::Shape wqbmmNzShape)
{
DataType xDtype = (*params.x)[0]->GetDataType();
DataType weightDtype = (*params.weight)[0]->GetDataType();
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
((IsQuant(xDtype, weightDtype) && (*params.weight)[0]->GetStorageFormat() == op::Format::FORMAT_FRACTAL_NZ) ||
(params.apiVersion == gmm::GMMApiVersion::WeightNz && IsWeightQuant(xDtype, weightDtype)))) {
(*params.weight)[0]->SetStorageShape(wqbmmNzShape);
}
return ACLNN_SUCCESS;
}
static aclnnStatus GetGMMResultByL0Api(gmm::GroupedMatmulParams ¶ms, uint64_t *workspaceSize, aclOpExecutor **executor) {
auto uniqueExecutor = CREATE_EXECUTOR();
aclOpExecutor *executorPtr = uniqueExecutor.get();
CHECK_RET(executorPtr != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
if (params.xDtype != DataType::DT_INT4) {
CHECK_COND(gmm::BIAS_DTYPE.find(params.xDtype) != gmm::BIAS_DTYPE.cend(), ACLNN_ERR_PARAM_INVALID,
"GMM: Cannot find bias dtype match with xDtype[%s]", gmm::dTypeToString(params.xDtype).c_str());
}
SetAntiQuantParamsTensorEmptyDAV3510(params, executorPtr);
SetParamsTensorEmpty(params, executorPtr);
SetTransposedTensorListContiguous(params, executorPtr);
CHECK_COND(ParamsDataContiguous(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"ParamsDataContiguous failed.");
if (params.groupType == gmm::SPLIT_K) {
if (CheckZeroShapeSplitK(params, workspaceSize) != ACLNN_SUCCESS) {
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;}
} else {
if (CheckZeroShape(params, workspaceSize) != ACLNN_SUCCESS) {
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;}
}
op::Shape nzShape = (*params.weight)[0]->GetStorageShape();
if (params.apiVersion == gmm::GMMApiVersion::WeightNz ||
(*params.weight)[0]->GetStorageFormat() == op::Format::FORMAT_FRACTAL_NZ) {
CHECK_COND(ParamsWeightNzDtype(params) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "ParamsWeightNzDtype failed.");
std::vector<const aclTensor *> tensorsVec;
for (size_t i = 0; i < params.weight->Size(); ++i) {
const aclTensor *tensor = (*params.weight)[i];
op::Shape weightNzShape = tensor->GetViewShape();
tensor = SetTensorToNZFormat(tensor, weightNzShape, executorPtr);
tensorsVec.push_back(tensor);
}
params.weight = executorPtr->AllocTensorList(tensorsVec.data(), tensorsVec.size());
SetStorageShape(params,nzShape);
}
CHECK_COND(TransWeightToNz(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "TransWeightToNz failed.");
if (params.groupListOptional != nullptr) {
params.groupTensorOptional = uniqueExecutor->ConvertToTensor(params.groupListOptional, op::ToOpDataType(ACL_INT64));
}
auto perTokenScaleOptional = (*params.perTokenScaleOptional)[0]->IsEmpty() ? nullptr : (*params.perTokenScaleOptional)[0];
auto result = l0op::GroupedMatmul(params.x, params.weight, params.biasOptional, params.scaleOptional,
params.offsetOptional, params.antiquantScaleOptional, params.antiquantOffsetOptional,
params.groupTensorOptional, perTokenScaleOptional, params.splitItem,
(*params.y)[0]->GetDataType(), params.transposeWeight, params.transposeX, params.groupType,
params.groupListType, params.activeType, params.tuningConfigOptional, params.y->Size(), executorPtr);
CHECK_RET(result != nullptr, ACLNN_ERR_INNER_NULLPTR);
CHECK_COND(CheckOutputShape(result, params.y) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "Check outputShape failed.");
for (size_t i(0); i < params.y->Size(); ++i) {
auto viewCopyResult = l0op::ViewCopy((*result)[i], (*params.y)[i], executorPtr);
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
static int64_t CorrectSplitItem(const aclTensorList *x, const aclTensorList *y, int64_t splitItem) {
int64_t splitItemCorrected = splitItem;
if (splitItem == X_Y_SEPARATED || splitItem == Y_SEPARATED) {
splitItemCorrected = x->Size() == y->Size() ? X_Y_SEPARATED : Y_SEPARATED;
}
if (splitItem == X_SEPARATED || splitItem == NO_SEPARATED) {
splitItemCorrected = x->Size() == 1 ? NO_SEPARATED : X_SEPARATED;
}
return splitItemCorrected;
}
static aclnnStatus CheckTransposeStatus(const aclTensorList *x, const aclTensorList *weight, bool &transposeX,
bool &transposeWeight, int64_t groupType) {
CHECK_COND((*x)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, "x[0] is nullptr!");
CHECK_COND((*weight)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, "weight[0] is nullptr!");
if (groupType == gmm::SPLIT_K) {
transposeX = gmm::IsTransposeLastTwoDims((*x)[0]);
transposeWeight = CheckSpecialTranspose(weight);
} else if (groupType == gmm::SPLIT_M || groupType == gmm::NO_SPLIT) {
transposeX = CheckSpecialTranspose(x);
transposeWeight = gmm::IsTransposeLastTwoDims((*weight)[0]);
}
return ACLNN_SUCCESS;
}
static aclnnStatus CheckEmptyTensor(const aclTensorList *x, const aclTensorList *weight)
{
bool zeroM = true;
bool zeroN = true;
bool zeroK = false;
for (size_t i = 0; i < x->Size(); ++i) {
CHECK_COND((*x)[i] != nullptr, ACLNN_ERR_PARAM_INVALID, "GroupedMatmul x tensor should not be null");
auto xShape = (*x)[i]->GetViewShape();
size_t xDimNum = xShape.GetDimNum();
CHECK_COND(xDimNum >= gmm::MIN_FM_DIM && xDimNum <= gmm::MAX_FM_DIM, ACLNN_ERR_PARAM_INVALID,
"GroupedMatmul x dim num should be in the range [2, 6], but actual is %zu.", xDimNum);
uint64_t m = 1;
for (size_t dimIdx = 0; dimIdx < xDimNum - 1; dimIdx++) {
m *= xShape.GetDim(dimIdx);
}
zeroM = zeroM && m == 0;
zeroK = zeroK || (xShape.GetDim(xShape.GetDimNum() - 1) == 0);
}
for (size_t i = 0; i < weight->Size(); ++i) {
CHECK_COND((*weight)[i] != nullptr, ACLNN_ERR_PARAM_INVALID,
"GroupedMatmul weight tensor should not be null");
auto wShape = (*weight)[i]->GetViewShape();
CHECK_COND(wShape.GetDimNum() >= gmm::MIN_FM_DIM, ACLNN_ERR_PARAM_INVALID,
"GroupedMatmul weight dim num should be 2 or 3, but actual %zu.", wShape.GetDimNum());
zeroN = zeroN && (wShape.GetDim(wShape.GetDimNum() - 1) == 0);
}
CHECK_COND(zeroM || zeroN || !zeroK, ACLNN_ERR_PARAM_INVALID,
"GroupedMatmul does not support input K being 0 unless all M/N is 0");
return ACLNN_SUCCESS;
}
static aclnnStatus aclnnGroupedMatmulGetWorkspaceSizeCommon(const aclTensorList *x, const aclTensorList *weight,
const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional,
const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional,
const aclTensorList *perTokenScaleOptional, const aclIntArray *groupListOptional,
const aclTensor *groupTensorOptional, const aclTensorList *activationInputOptional,
const aclTensorList *activationQuantScaleOptional, const aclTensorList *activationQuantOffsetOptional,
int64_t splitItem, int64_t groupType, int64_t groupListType, int64_t actType, aclIntArray *tuningConfigOptional, gmm::GMMApiVersion apiVersion,
const aclTensorList *y, const aclTensorList *activationFeatureOutOptional,
const aclTensorList *dynQuantScaleOutOptional, uint64_t *workspaceSize, aclOpExecutor **executor) {
DataType xDtype = DataType::DT_UNDEFINED;
for (size_t i = 0; i < x->Size(); ++i) {
if ((*x)[i] != nullptr) {
xDtype = (*x)[i]->GetDataType();
break;
}
}
bool isSingleWeight = (weight->Size() == 1 && groupType != gmm::NO_SPLIT);
bool transposeX = false;
bool transposeWeight = false;
if(groupType != gmm::SPLIT_K) {
CHECK_COND(CheckEmptyTensor(x, weight) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "CheckEmptyTensor failed!");
}
CHECK_COND(CheckTransposeStatus(x, weight, transposeX, transposeWeight, groupType) == ACLNN_SUCCESS,
ACLNN_ERR_PARAM_INVALID, "CheckTransposeStatus failed!");
gmm::GroupedMatmulParams gmmParams{x, weight, biasOptional, groupListOptional, groupTensorOptional, scaleOptional,
offsetOptional, antiquantScaleOptional, antiquantOffsetOptional, perTokenScaleOptional,
activationInputOptional, activationQuantScaleOptional, activationQuantOffsetOptional,
splitItem, groupListType, actType, transposeWeight, transposeX, isSingleWeight,
apiVersion, groupType, tuningConfigOptional, y, activationFeatureOutOptional, dynQuantScaleOutOptional,
xDtype};
if (gmmParams.scaleOptional != nullptr) {
for (size_t i = 0; i < gmmParams.scaleOptional->Size(); i++) {
if ((*gmmParams.scaleOptional)[i]->GetDataType() == DataType::DT_INT64) {
(void)const_cast<aclTensor *>((*gmmParams.scaleOptional)[i])->SetDataType(op::DataType::DT_UINT64);
}
}
}
ResetEmptyTensor(gmmParams);
CHECK_RET(CheckParam(gmmParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID);
gmmParams.splitItem = CorrectSplitItem(x, y, splitItem);
aclnnStatus ret = GetGMMResultByL0Api(gmmParams, workspaceSize, executor);
return ret;
}
}
aclnnStatus CheckCommonParam(const aclTensorList *x , const aclTensorList *weight,
const aclTensor *groupListOptional, int64_t splitItem, int64_t groupType, int64_t groupListType,
int64_t actType, const aclTensorList *out) {
auto npuArch = op::GetCurrentPlatformInfo().GetCurNpuArch();
bool is310P = npuArch == NpuArch::DAV_2002;
bool supportedCaseOn310P = x->Size() == 1 && out->Size() == 1 && weight->Size() == 1 && groupType == 0;
CHECK_COND((is310P && supportedCaseOn310P) || !is310P, ACLNN_ERR_PARAM_INVALID,
"Only surpport x, y, weight not separated case with groupType is 0 on ASCEND310P.");
CHECK_COND(PreCheckGroupType(splitItem, groupType) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"PreCheckGroupType failed, groupType is invalid.");
size_t validGroupTensorDimNum = (groupListType == gmm::GROUP_LIST_SPARSE_M) ? 2UL: 1UL;
CHECK_COND(groupListOptional == nullptr || groupListOptional->GetViewShape().GetDimNum() == validGroupTensorDimNum,
ACLNN_ERR_PARAM_INVALID,
"Invalid groupList tensor dim num: expected %zu when groupListType is %ld, but got %zu.",
validGroupTensorDimNum, groupListType, groupListOptional->GetViewShape().GetDimNum());
CHECK_COND(actType >= 0, ACLNN_ERR_PARAM_INVALID, "actType must be larger or equal to 0");
if (actType != GMMActType::GMM_ACT_TYPE_NONE) {
CHECK_COND(actType != GMMActType::GMM_ACT_TYPE_GELU_ERR_FUNC, ACLNN_ERR_PARAM_INVALID,
"Activation function not support GELU_ERR_FUNC now.");
CHECK_COND(actType < END_ACT_TYPE_ENUM, ACLNN_ERR_PARAM_INVALID,
"Activation function only support RELU/GELU_TANH/FASTGELU/SILU.");
}
if (groupListType == gmm::GROUP_LIST_SPARSE_M) {
CHECK_COND(npuArch == NpuArch::DAV_2201 || npuArch == NpuArch::DAV_3510, ACLNN_ERR_PARAM_INVALID,
"This platform not support groupListType is 2.");
} else {
CHECK_COND(groupListType == 0 || groupListType == 1, ACLNN_ERR_PARAM_INVALID, "groupListType shoule be 0 or 1.");
}
return ACLNN_SUCCESS;
}
aclnnStatus aclnnGroupedMatmulWeightNzGetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional,
const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional,
const aclTensorList *perTokenScaleOptional, const aclTensor *groupListOptional,
const aclTensorList *activationInputOptional, const aclTensorList *activationQuantScaleOptional,
const aclTensorList *activationQuantOffsetOptional, int64_t splitItem, int64_t groupType, int64_t groupListType,
int64_t actType, aclIntArray *tuningConfigOptional, int64_t quantGroupSize, aclTensorList *out, aclTensorList *activationFeatureOutOptional,
aclTensorList *dynQuantScaleOutOptional, uint64_t *workspaceSize, aclOpExecutor **executor) {
CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
"one of required inputs is nullptr.");
L2_DFX_PHASE_1(aclnnGroupedMatmulWeightNz,
DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional, perTokenScaleOptional, activationInputOptional,
activationQuantScaleOptional, activationQuantOffsetOptional,
groupListOptional, splitItem, groupType, groupListType, actType, tuningConfigOptional),
DFX_OUT(out, activationFeatureOutOptional, dynQuantScaleOutOptional));
if ((*weight)[0]->GetDataType() == DataType::DT_INT32) {
UnpackB32ToB4(weight, "weight");
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
IsWeightQuant((*x)[0]->GetDataType(), (*weight)[0]->GetDataType())) {
SetSpecialNZTensorToNormalNZFormat(weight);
}
}
if ((*weight)[0]->GetDataType() == DataType::DT_FLOAT) {
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 &&
IsWeightQuant((*x)[0]->GetDataType(), (*weight)[0]->GetDataType())) {
UnpackB32ToB4(weight, "weight");
SetSpecialNZTensorToNormalNZFormat(weight);
}
}
if ((*x)[0]->GetDataType() == DataType::DT_INT32) {
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
UnpackB32ToB4(x, "x");
} else {
UnpackB32ToB4(x, "x", true);
}
}
CHECK_COND(groupType != gmm::SPLIT_K, ACLNN_ERR_PARAM_INVALID, "Not support split k dim now, groupType can not be 2.");
CHECK_COND(CheckCommonParam(x, weight, groupListOptional, splitItem, groupType, groupListType, actType, out)
== ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "One of required inputs does not meet the requirement.");
(void)quantGroupSize;
return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional,
perTokenScaleOptional, nullptr, groupListOptional,
activationInputOptional, activationQuantScaleOptional,
activationQuantOffsetOptional, splitItem, groupType, groupListType,
actType, tuningConfigOptional, gmm::GMMApiVersion::WeightNz, out, activationFeatureOutOptional,
dynQuantScaleOutOptional, workspaceSize, executor);
}
aclnnStatus aclnnGroupedMatmulV5GetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional,
const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional,
const aclTensorList *perTokenScaleOptional, const aclTensor *groupListOptional,
const aclTensorList *activationInputOptional, const aclTensorList *activationQuantScaleOptional,
const aclTensorList *activationQuantOffsetOptional, int64_t splitItem, int64_t groupType, int64_t groupListType,
int64_t actType, aclIntArray *tuningConfigOptional, aclTensorList *out, aclTensorList *activationFeatureOutOptional,
aclTensorList *dynQuantScaleOutOptional, uint64_t *workspaceSize, aclOpExecutor **executor) {
CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
"One of required inputs is nullptr.");
L2_DFX_PHASE_1(aclnnGroupedMatmulV5,
DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional, perTokenScaleOptional, activationInputOptional,
activationQuantScaleOptional, activationQuantOffsetOptional,
groupListOptional, splitItem, groupType, groupListType, actType, tuningConfigOptional),
DFX_OUT(out, activationFeatureOutOptional, dynQuantScaleOutOptional));
CHECK_COND(weight->Size() != 0, ACLNN_ERR_PARAM_INVALID, "Weight should not be null tensorlist ");
if ((*weight)[0]->GetDataType() == DataType::DT_INT32) {
UnpackB32ToB4(weight, "weight");
}
if ((*x)[0]->GetDataType() == DataType::DT_INT32) {
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
UnpackB32ToB4(x, "x");
} else {
UnpackB32ToB4(x, "x", true);
}
}
CHECK_COND(CheckCommonParam(x, weight, groupListOptional, splitItem, groupType, groupListType, actType, out)
== ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "One of required inputs does not meet the requirement.");
return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional,
perTokenScaleOptional, nullptr, groupListOptional,
activationInputOptional, activationQuantScaleOptional,
activationQuantOffsetOptional, splitItem, groupType, groupListType,
actType, tuningConfigOptional, gmm::GMMApiVersion::V5, out, activationFeatureOutOptional,
dynQuantScaleOutOptional, workspaceSize, executor);
}
aclnnStatus aclnnGroupedMatmulV4GetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional,
const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional,
const aclTensorList *perTokenScaleOptional, const aclTensor *groupListOptional,
const aclTensorList *activationInputOptional, const aclTensorList *activationQuantScaleOptional,
const aclTensorList *activationQuantOffsetOptional, int64_t splitItem, int64_t groupType, int64_t groupListType,
int64_t actType, aclTensorList *out, aclTensorList *activationFeatureOutOptional,
aclTensorList *dynQuantScaleOutOptional, uint64_t *workspaceSize, aclOpExecutor **executor) {
DEPRECATED_API_WARN_ONCE("aclnnGroupedMatmulV4GetWorkspaceSize", "aclnnGroupedMatmulV5GetWorkspaceSize");
CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
"One of required inputs is nullptr.");
L2_DFX_PHASE_1(aclnnGroupedMatmulV4,
DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional, perTokenScaleOptional, activationInputOptional,
activationQuantScaleOptional, activationQuantOffsetOptional,
groupListOptional, splitItem, groupType, groupListType, actType),
DFX_OUT(out, activationFeatureOutOptional, dynQuantScaleOutOptional));
if ((*weight)[0]->GetDataType() == DataType::DT_INT32) {
UnpackB32ToB4(weight, "weight");
}
if ((*x)[0]->GetDataType() == DataType::DT_INT32) {
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
UnpackB32ToB4(x, "x");
} else {
UnpackB32ToB4(x, "x", true);
}
}
CHECK_COND(CheckCommonParam(x, weight, groupListOptional, splitItem, groupType, groupListType, actType, out)
== ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "One of required inputs does not meet the requirement.");
return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional,
perTokenScaleOptional, nullptr, groupListOptional,
activationInputOptional, activationQuantScaleOptional,
activationQuantOffsetOptional, splitItem, groupType, groupListType,
actType, nullptr, gmm::GMMApiVersion::V4, out,
activationFeatureOutOptional, dynQuantScaleOutOptional, workspaceSize,
executor);
}
aclnnStatus aclnnGroupedMatmulV3GetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional,
const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional,
const aclTensor *groupListOptional, int64_t splitItem, int64_t groupType, const aclTensorList *y,
uint64_t *workspaceSize, aclOpExecutor **executor) {
DEPRECATED_API_WARN_ONCE("aclnnGroupedMatmulV3GetWorkspaceSize", "aclnnGroupedMatmulV5GetWorkspaceSize");
CHECK_COND(CheckNotNull(x, weight, y) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
"One of required inputs is nullptr.");
L2_DFX_PHASE_1(aclnnGroupedMatmulV3,
DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional, groupListOptional,
splitItem, groupType),
DFX_OUT(y));
bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P;
bool supportedCaseOn310P = x->Size() == 1 && y->Size() == 1 && weight->Size() == 1 && groupType == 0;
CHECK_COND((is310P && supportedCaseOn310P) || !is310P, ACLNN_ERR_PARAM_INVALID,
"Only surpport x, y, weight not separated case with groupType is 0 on ASCEND310P.");
CHECK_COND(PreCheckGroupType(splitItem, groupType) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"PreCheckGroupType failed, groupType is invalid.");
CHECK_COND(
groupListOptional == nullptr || groupListOptional->GetViewShape().GetDimNum() == 1, ACLNN_ERR_PARAM_INVALID, "When groupList type is tensor, groupList dim only support 1, but now is %lu.",
groupListOptional->GetViewShape().GetDimNum());
return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional, nullptr, nullptr,
groupListOptional, nullptr, nullptr, nullptr, splitItem, groupType,
0, 0, nullptr, gmm::GMMApiVersion::V3, y, nullptr, nullptr,
workspaceSize, executor);
}
aclnnStatus aclnnGroupedMatmulV2GetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional,
const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional,
const aclIntArray *groupListOptional, int64_t splitItem, int64_t groupType, const aclTensorList *y,
uint64_t *workspaceSize, aclOpExecutor **executor) {
DEPRECATED_API_WARN_ONCE("aclnnGroupedMatmulV2GetWorkspaceSize", "aclnnGroupedMatmulV5GetWorkspaceSize");
CHECK_COND(CheckNotNull(x, weight, y) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
"One of required inputs is nullptr.");
bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P;
CHECK_COND(!is310P, ACLNN_ERR_PARAM_INVALID,
"Only aclnnGroupedMatmulV3GetWorkspaceSize is supported on ASCEND310P.");
L2_DFX_PHASE_1(aclnnGroupedMatmulV2,
DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional, groupListOptional,
splitItem, groupType),
DFX_OUT(y));
CHECK_COND(PreCheckGroupType(splitItem, groupType) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"PreCheckGroupType failed, groupType is invalid.");
return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional, nullptr,
groupListOptional, nullptr, nullptr, nullptr, nullptr, splitItem,
groupType, 0, 0, nullptr, gmm::GMMApiVersion::V2, y, nullptr, nullptr,
workspaceSize, executor);
}
aclnnStatus aclnnGroupedMatmulGetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
const aclTensorList *biasOptional, const aclTensorList *scaleOptional, const aclTensorList *offsetOptional,
const aclTensorList *antiquantScaleOptional, const aclTensorList *antiquantOffsetOptional,
const aclIntArray *groupListOptional, int64_t splitItem, const aclTensorList *y, uint64_t *workspaceSize,
aclOpExecutor **executor) {
DEPRECATED_API_WARN_ONCE("aclnnGroupedMatmulGetWorkspaceSize", "aclnnGroupedMatmulV5GetWorkspaceSize");
CHECK_COND(CheckNotNull(x, weight, y) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
"One of required inputs is nullptr.");
bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P;
CHECK_COND(!is310P, ACLNN_ERR_PARAM_INVALID,
"Only aclnnGroupedMatmulV3GetWorkspaceSize is supported on ASCEND310P.");
L2_DFX_PHASE_1(aclnnGroupedMatmul,
DFX_IN(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional, groupListOptional,
splitItem),
DFX_OUT(y));
int64_t groupType = 0;
if (weight->Size() == 1) {
CHECK_COND(x->Size() == 1 && y->Size() == 1, ACLNN_ERR_PARAM_INVALID,
"Only accept separated weight, but input weight is not separated.");
}
bool xYSeparated = (x->Size() > 1 && y->Size() > 1) ||
(x->Size() == 1 && y->Size() == 1 && weight->Size() == 1);
if (xYSeparated) {
groupType = -1L;
}
if (GetCurrentPlatformInfo().GetSocVersion() != SocVersion::ASCEND310P) {
bool isSingleWeight = (weight->Size() == 1) && !(x->Size() == 1 && xYSeparated);
CHECK_COND(!isSingleWeight, ACLNN_ERR_PARAM_INVALID,
"Only accept separated weight, but input weight is not separated.");
}
return aclnnGroupedMatmulGetWorkspaceSizeCommon(x, weight, biasOptional, scaleOptional, offsetOptional,
antiquantScaleOptional, antiquantOffsetOptional, nullptr,
groupListOptional, nullptr, nullptr, nullptr, nullptr, splitItem,
groupType, 0, 0, nullptr, gmm::GMMApiVersion::V1, y, nullptr, nullptr,
workspaceSize, executor);
}
aclnnStatus aclnnGroupedMatmul(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
aclrtStream stream) {
DEPRECATED_API_WARN_ONCE("aclnnGroupedMatmul", "aclnnGroupedMatmulV5");
L2_DFX_PHASE_2(aclnnGroupedMatmul);
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
"This is an error in GMM launch aicore");
return ACLNN_SUCCESS;
}
aclnnStatus aclnnGroupedMatmulV2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
aclrtStream stream) {
DEPRECATED_API_WARN_ONCE("aclnnGroupedMatmulV2", "aclnnGroupedMatmulV5");
L2_DFX_PHASE_2(aclnnGroupedMatmulV2);
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
"This is an error in GMM launch aicore");
return ACLNN_SUCCESS;
}
aclnnStatus aclnnGroupedMatmulV3(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
aclrtStream stream) {
DEPRECATED_API_WARN_ONCE("aclnnGroupedMatmulV3", "aclnnGroupedMatmulV5");
L2_DFX_PHASE_2(aclnnGroupedMatmulV3);
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
"This is an error in GMM launch aicore");
return ACLNN_SUCCESS;
}
aclnnStatus aclnnGroupedMatmulV4(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
aclrtStream stream) {
DEPRECATED_API_WARN_ONCE("aclnnGroupedMatmulV4", "aclnnGroupedMatmulV5");
L2_DFX_PHASE_2(aclnnGroupedMatmulV4);
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
"This is an error in GMM launch aicore");
return ACLNN_SUCCESS;
}
aclnnStatus aclnnGroupedMatmulV5(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
aclrtStream stream) {
L2_DFX_PHASE_2(aclnnGroupedMatmulV5);
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
"This is an error in GMM launch aicore");
return ACLNN_SUCCESS;
}
aclnnStatus aclnnGroupedMatmulWeightNz(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
aclrtStream stream) {
L2_DFX_PHASE_2(aclnnGroupedMatmulWeightNz);
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
"This is an error in GMM launch aicore");
return ACLNN_SUCCESS;
}
#ifdef __cplusplus
}
#endif