* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file batch_matmul_utils.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
"impl/adv_api/detail/matmul/utils/batch_matmul_utils.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/matmul/matmul.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_UTILS_BATCH_MATMUL_UTILS_H__
#endif
#ifndef IMPL_MATMUL_UTILS_BATCH_MATMUL_UTILS_H
#define IMPL_MATMUL_UTILS_BATCH_MATMUL_UTILS_H
#include "matmul_config_utils.h"
#include "matmul_type_def.h"
namespace AscendC {
template <typename A_TYPE, const auto& MM_CFG>
constexpr bool IsBmmEnableScheduler =
DoMatmulNorm(MM_CFG) &&
((A_TYPE::layout != LayoutMode::NONE && ToMatmulConfig(MM_CFG).batchMode == BatchMode::BATCH_LESS_THAN_L1) ||
(A_TYPE::layout == LayoutMode::NORMAL && ToMatmulConfig(MM_CFG).batchMode == BatchMode::BATCH_LARGE_THAN_L1) ||
(A_TYPE::layout == LayoutMode::NORMAL && ToMatmulConfig(MM_CFG).batchMode == BatchMode::SINGLE_LARGE_THAN_L1));
template <typename A_TYPE, const auto& MM_CFG>
constexpr bool IsBmmBatchScheduler = DoMatmulNorm(MM_CFG) &&
((A_TYPE::layout != LayoutMode::NONE &&
ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1));
template <typename A_TYPE, const auto& MM_CFG>
constexpr bool IsBmmSingleScheduler = DoMatmulNorm(MM_CFG) &&
(A_TYPE::layout == LayoutMode::NORMAL &&
ToMatmulConfig(MM_CFG).batchMode == BatchMode::SINGLE_LARGE_THAN_L1);
struct BatchOffsetInfo {
int32_t modA;
int32_t divisorA;
int32_t alignA;
int32_t modB;
int32_t divisorB;
int32_t alignB;
int32_t modBias;
int32_t divisorBias;
int32_t alignBias;
bool setBiasFlag{false};
};
struct SplitParams {
int16_t axisL1Len;
int16_t kAxisL1Len;
int16_t axisL1Offset;
int16_t kAxisL1Offset;
int16_t axisL0Len;
};
struct BatchSchedulerContext {
int32_t offsetA;
int32_t offsetB;
int32_t offsetBias;
uint32_t reduceGNum;
bool isReduceG;
SplitParams aL0Params;
SplitParams bL0Params;
};
struct BmmOffset {
uint64_t offA = 0;
uint64_t offB = 0;
uint64_t offC = 0;
};
__aicore__ inline uint16_t CeilDiv(uint16_t num1, uint16_t num2)
{
ASSERT(num2 > 0);
return (num1 + num2 - 1) / num2;
}
__aicore__ inline uint16_t CeilAlign(uint16_t num1, uint16_t num2)
{
ASSERT(num2 > 0);
return CeilDiv(num1, num2) * num2;
}
template <class INPUT_TYPE>
__aicore__ inline uint64_t CalcNBatchoffset(
uint32_t batchValue, uint32_t loopIdx, uint32_t layoutInfoN, uint32_t layoutInfoG, uint32_t layoutInfoD,
uint32_t layoutInfoS)
{
uint32_t alignedSingleCoreN = layoutInfoD;
if constexpr (INPUT_TYPE::format == CubeFormat::ND_ALIGN) {
alignedSingleCoreN = CeilAlign(layoutInfoD, AscendCUtils::GetC0Count(sizeof(typename INPUT_TYPE::T)));
}
uint64_t offset;
if constexpr (INPUT_TYPE::layout == LayoutMode::BNGS1S2 || INPUT_TYPE::layout == LayoutMode::NORMAL) {
offset = alignedSingleCoreN * layoutInfoS * batchValue * loopIdx * sizeof(typename INPUT_TYPE::T);
} else if constexpr (INPUT_TYPE::layout == LayoutMode::SBNGD) {
offset = alignedSingleCoreN * batchValue * loopIdx * sizeof(typename INPUT_TYPE::T);
} else if constexpr (INPUT_TYPE::layout == LayoutMode::BSNGD) {
uint64_t layoutBIdx = loopIdx * batchValue / (layoutInfoN * layoutInfoG);
uint64_t layoutNGOff = loopIdx * batchValue % (layoutInfoN * layoutInfoG);
offset = (layoutBIdx * alignedSingleCoreN * layoutInfoS * layoutInfoN * layoutInfoG +
layoutNGOff * alignedSingleCoreN) *
sizeof(typename INPUT_TYPE::T);
}
return offset;
}
__aicore__ inline uint64_t GetBatchCNum(
uint32_t batchA, uint32_t batchB, uint32_t aLayoutInfoG, uint32_t bLayoutInfoG, uint32_t cLayoutInfoG)
{
uint32_t batchC = batchA > batchB ? batchA : batchB;
bool layoutGCondition = cLayoutInfoG == 1 && (aLayoutInfoG != 1 || bLayoutInfoG != 1);
if (layoutGCondition) {
int32_t layoutG = bLayoutInfoG > aLayoutInfoG ? bLayoutInfoG : aLayoutInfoG;
batchC = batchC / layoutG;
}
return batchC;
}
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_UTILS_BATCH_MATMUL_UTILS_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_UTILS_BATCH_MATMUL_UTILS_H__
#endif