* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "grouped_matmul_util.h"
using namespace gmm;
namespace gmm {
bool IsTransposeLastTwoDims(const aclTensor *tensor)
{
auto shape = tensor->GetViewShape();
if (shape.GetDimNum() < MIN_DIM_FOR_TRANSPOSE) {
return false;
}
int64_t dim1 = shape.GetDimNum() - 1;
int64_t dim2 = shape.GetDimNum() - 2;
auto strides = tensor->GetViewStrides();
if (strides[dim2] == 1 && strides[dim1] == shape.GetDim(dim2)) {
int64_t tmpNxD = shape.GetDim(dim1) * shape.GetDim(dim2);
if (shape.GetDimNum() == MIN_DIM_FOR_TRANSPOSE) {
return true;
}
for (int64_t batchDim = shape.GetDimNum() - 3; batchDim >= 0; batchDim--) {
if (strides[batchDim] != tmpNxD) {
return false;
}
tmpNxD *= shape.GetDim(batchDim);
}
return true;
}
return false;
}
bool IsTransposeForMxShape(const aclTensor *tensor)
{
auto shape = tensor->GetViewShape();
if (shape.GetDimNum() < MX_SPLIT_K_PER_TOKEN_SCALE_DIM) {
return false;
}
int64_t firstLastDim = shape.GetDimNum() - 1;
int64_t secondLastDim = shape.GetDimNum() - LAST_SECOND_DIM_INDEX;
int64_t thirdLastDim = shape.GetDimNum() - LAST_THIRD_DIM_INDEX;
auto strides = tensor->GetViewStrides();
if (strides[firstLastDim] == 1 && strides[thirdLastDim] == MXFP_MULTI_BASE_SIZE &&
strides[secondLastDim] == shape.GetDim(thirdLastDim) * MXFP_MULTI_BASE_SIZE) {
return true;
}
return false;
}
void CreateContiguousTensorListForPertoken(const aclTensorList *tensorList, std::vector<aclTensor *> &newTensorList,
aclOpExecutor *executor)
{
op::Shape shape;
for (uint64_t idx = 0; idx < (*tensorList).Size(); idx++) {
const aclTensor *inputTensor = (*tensorList)[idx];
op::Shape viewShape = inputTensor->GetViewShape();
shape.SetScalar();
if (viewShape.GetDimNum() <
MX_SPLIT_K_PER_TOKEN_SCALE_DIM) {
continue;
}
shape.AppendDim(viewShape.GetDim(1));
shape.AppendDim(viewShape.GetDim(0));
shape.AppendDim(viewShape.GetDim(2));
aclTensor *tensor =
executor->CreateView(inputTensor, shape, inputTensor->GetViewOffset());
tensor->SetStorageFormat(inputTensor->GetStorageFormat());
newTensorList.emplace_back(tensor);
}
}
void CreateContiguousTensorListForMXTypeMScale(const aclTensorList *tensorList, std::vector<aclTensor *> &newTensorList,
aclOpExecutor *executor)
{
op::Shape shape;
for (uint64_t idx = 0; idx < (*tensorList).Size(); idx++) {
const aclTensor *inputTensor = (*tensorList)[idx];
op::Shape viewShape = inputTensor->GetViewShape();
shape.SetScalar();
if (viewShape.GetDimNum() < MX_SPLIT_M_SCALE_DIM) {
continue;
}
shape.AppendDim(viewShape.GetDim(0));
shape.AppendDim(viewShape.GetDim(viewShape.GetDimNum() - LAST_SECOND_DIM_INDEX));
shape.AppendDim(viewShape.GetDim(viewShape.GetDimNum() - LAST_THIRD_DIM_INDEX));
shape.AppendDim(viewShape.GetDim(viewShape.GetDimNum() - 1));
aclTensor *tensor =
executor->CreateView(inputTensor, shape, inputTensor->GetViewOffset());
tensor->SetStorageFormat(inputTensor->GetStorageFormat());
newTensorList.emplace_back(tensor);
}
}
void CreateContiguousTensorList(const aclTensorList *tensorList, std::vector<aclTensor *> &newTensorList,
aclOpExecutor *executor)
{
op::Shape shape;
for (uint64_t idx = 0; idx < (*tensorList).Size(); idx++) {
const aclTensor *inputTensor = (*tensorList)[idx];
op::Shape viewShape = inputTensor->GetViewShape();
uint32_t viewShapeDimsNum = viewShape.GetDimNum();
shape.SetScalar();
for (uint32_t i = 0; i < viewShapeDimsNum - 2; ++i) {
shape.AppendDim(viewShape.GetDim(i));
}
shape.AppendDim(viewShape.GetDim(viewShapeDimsNum - 1));
shape.AppendDim(viewShape.GetDim(viewShapeDimsNum - 2));
aclTensor *tensor =
executor->CreateView(inputTensor, shape, inputTensor->GetViewOffset());
tensor->SetStorageFormat(inputTensor->GetStorageFormat());
newTensorList.emplace_back(tensor);
}
}
std::string dTypeToString(const ge::DataType &dtype) {
if(DTYPE_STRING.count(dtype) != 0) {
return DTYPE_STRING.at(dtype);
} else {
return std::string(op::ToString(dtype).GetString());
}
}
}