* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef SELECT_KERNEL_HALF_H
#define SELECT_KERNEL_HALF_H
#include <cstdint>
#include <cmath>
#include <limits>
#include "catlass/detail/alignment.hpp"
#include "platform_info.h"
#include "tiling_params.h"
#include "utils.h"
enum class PaddingTag : uint8_t { PADDING_NONE = 0, PADDING_ND = 1, PADDING_BLOCK_ND = 2, PADDING_NZ = 3};
double GetBandwidth(uint32_t nValue, uint32_t dValue, uint32_t srcDValue) {
double a6 = 0.000000000000020146121020;
double a5 = -0.000000000012456944162142;
double a4 = -0.000000006738536427145036;
double a3 = 0.000007301215580838747961;
double a2 = -0.002146456956750821074703;
double a1 = 0.312849910814454512664184;
double a0 = 0.1;
double unalignBand = a6 * pow(static_cast<double>(dValue), 6) + a5 * pow(static_cast<double>(dValue), 5)
+ a4 * pow(static_cast<double>(dValue), 4) + a3 * pow(static_cast<double>(dValue), 3)
+ a2 * pow(static_cast<double>(dValue), 2) + a1 * static_cast<double>(dValue) + a0;
if (dValue == srcDValue && dValue <= 128) {
if (dValue % 16 == 0) {
unalignBand = 60;
}
}
if (srcDValue >= 65536) {
unalignBand = 1;
}
if (srcDValue % 256 == 0) {
unalignBand = static_cast<double>(100) / 30 * unalignBand;
} else if (srcDValue % 128 == 0) {
unalignBand = static_cast<double>(80) / 30 * unalignBand;
} else if (srcDValue % 64 == 0) {
unalignBand = static_cast<double>(50) / 30 * unalignBand;
} else if (srcDValue % 16 == 0) {
unalignBand = static_cast<double>(40) / 30 * unalignBand;
}
unalignBand = std::min(unalignBand, 80.0);
if (dValue % 256 == 0) {
if (nValue < 16) {
double b2 = -0.003332381309698882569659;
double b1 = 0.113578920178116271610946;
double b0 = 0.016102868630357251855667;
unalignBand = unalignBand * (b2 * pow(nValue, 2) + b1 * nValue + b0);
}
} else if (dValue % 32 == 0) {
if (nValue < 32) {
double b2 = -0.000298086120946179481978;
double b1 = 0.045309519479127147167929;
double b0 = 0.035130178145161221336945;
unalignBand = unalignBand * (b2 * pow(nValue, 2) + b1 * nValue + b0);
}
} else {
if (nValue < 64) {
double b3 = 0.000001809180573350345869;
double b2 = -0.000469676727179688081274;
double b1 = 0.038963259596073690493867;
double b0 = 0.003942641759904389614499;
unalignBand = unalignBand * (b3 * pow(nValue, 3) + b2 * pow(nValue, 2) + b1 * nValue + b0);
}
}
return unalignBand;
}
void GetPaddingTag(TilingParams& tilingParams, PlatformInfo& platformInfo) {
uint32_t m = tilingParams.m;
uint32_t n = tilingParams.n;
uint32_t k = tilingParams.k;
uint32_t m1 = tilingParams.m1;
uint32_t n1 = tilingParams.n1;
uint32_t k1 = tilingParams.k1;
uint32_t splitkFactor = tilingParams.splitkFactor;
uint64_t outterAxisA = m;
uint64_t innerAxisA = k;
uint32_t nValueA = std::min(m, m1);
uint32_t dValueA = std::min(k, k1);
if (static_cast<LayoutTag>(tilingParams.layoutTagA) == LayoutTag::TagColumnMajor) {
outterAxisA = k;
innerAxisA = m;
nValueA = std::min(k, k1);
dValueA = std::min(m, m1);
}
uint64_t outterAxisB = k;
uint64_t innerAxisB = n;
uint32_t nValueB = std::min(k, k1);
uint32_t dValueB = std::min(n, n1);
if (static_cast<LayoutTag>(tilingParams.layoutTagB) == LayoutTag::TagColumnMajor) {
outterAxisB = n;
innerAxisB = k;
nValueB = std::min(n, n1);
dValueB = std::min(k, k1);
}
double aBandwidthAiv = 30;
size_t matrixASize = static_cast<size_t>(m) * k * 2;
if (matrixASize > 192 * 1024 * 1024) {
aBandwidthAiv = 10;
}
double aBandwidthBeforePaddingAic = GetBandwidth(nValueA, dValueA, innerAxisA);
uint32_t tasksAic = CeilDiv(m, m1) * CeilDiv(n, n1) * splitkFactor;
uint32_t blockDimAic = tasksAic > platformInfo.coreNum ? platformInfo.coreNum : tasksAic;
if (CeilDiv(m, m1) < blockDimAic / 2 && k <= k1 && CeilDiv(m, m1) <= 2) {
aBandwidthBeforePaddingAic = aBandwidthBeforePaddingAic / (blockDimAic / CeilDiv(m, m1)) * 1.5;
}
double aBandwidthAfterPaddingAic = 80;
if (nValueA < 16) {
aBandwidthAfterPaddingAic *= (static_cast<double>(nValueA) / 16);
}
double bBandwidthAiv = 30;
size_t matrixBSize = static_cast<size_t>(k) * n * 2;
if (matrixBSize > 192 * 1024 * 1024) {
bBandwidthAiv = 10;
}
double bBandwidthBeforePaddingAic = GetBandwidth(nValueB, dValueB, innerAxisB);
if (CeilDiv(n, n1) < blockDimAic / 2 && k <= k1 && CeilDiv(n, n1) <= 2) {
bBandwidthBeforePaddingAic = bBandwidthBeforePaddingAic / (blockDimAic / CeilDiv(n, n1)) * 1.5;
}
double bBandwidthAfterPaddingAic = 80;
if (nValueB < 16) {
bBandwidthAfterPaddingAic *= (static_cast<double>(nValueB) / 16);
}
uint32_t actualM = std::min(m, m1);
uint32_t actualN = std::min(n, n1);
uint32_t roundMax = CeilDiv(CeilDiv(m, m1) * CeilDiv(n, n1) * splitkFactor, platformInfo.coreNum);
size_t aMaxDataSizeAic = static_cast<size_t>(roundMax) * actualM * CeilDiv(k, splitkFactor) * 2;
size_t bMaxDataSizeAic = static_cast<size_t>(roundMax) * actualN * CeilDiv(k, splitkFactor) * 2;
size_t aMaxDataSizeAiv{0};
uint32_t tasksAivA{0};
{
uint32_t taskRows = 16;
uint32_t taskCols = 48 * 1024 / 2 / taskRows;
if (innerAxisA < taskCols) {
taskCols = innerAxisA;
}
if (outterAxisA < taskRows) {
taskRows = outterAxisA;
}
taskCols = RoundUp(innerAxisA / CeilDiv(innerAxisA, taskCols), 16);
tasksAivA = CeilDiv(outterAxisA, taskRows) * CeilDiv(innerAxisA, taskCols);
uint32_t maxTasksPerCore = CeilDiv(tasksAivA, platformInfo.coreNum * 2);
aMaxDataSizeAiv = maxTasksPerCore * taskCols * taskRows * 2;
}
size_t bMaxDataSizeAiv{0};
uint32_t tasksAivB{0};
{
uint32_t taskRows = 16;
uint32_t taskCols = 48 * 1024 / 2 / taskRows;
if (innerAxisB < taskCols) {
taskCols = innerAxisB;
}
if (outterAxisB < taskRows) {
taskRows = outterAxisB;
}
taskCols = RoundUp(innerAxisB / CeilDiv(innerAxisB, taskCols), 16);
tasksAivB = CeilDiv(outterAxisB, taskRows) * CeilDiv(innerAxisB, taskCols);
uint32_t maxTasksPerCore = CeilDiv(tasksAivB, platformInfo.coreNum * 2);
bMaxDataSizeAiv = maxTasksPerCore * taskCols * taskRows * 2;
}
double headCost = 1 + 7 * static_cast<double>(blockDimAic) / platformInfo.coreNum;
if (splitkFactor > 1) {
headCost = 1;
}
double t00 = static_cast<double>(aMaxDataSizeAic) / aBandwidthBeforePaddingAic / 1000
+ static_cast<double>(bMaxDataSizeAic) / bBandwidthBeforePaddingAic / 1000;
double t01 = static_cast<double>(aMaxDataSizeAic) / aBandwidthBeforePaddingAic / 1000
+ static_cast<double>(bMaxDataSizeAic) / bBandwidthAfterPaddingAic / 1000
+ static_cast<double>(bMaxDataSizeAiv) / bBandwidthAiv / 1000 + headCost;
double t10 = static_cast<double>(aMaxDataSizeAic) / aBandwidthAfterPaddingAic / 1000
+ static_cast<double>(bMaxDataSizeAic) / bBandwidthBeforePaddingAic / 1000
+ static_cast<double>(aMaxDataSizeAiv) / aBandwidthAiv / 1000 + headCost;
double t11 = static_cast<double>(aMaxDataSizeAic) / aBandwidthAfterPaddingAic / 1000
+ static_cast<double>(bMaxDataSizeAic) / bBandwidthAfterPaddingAic / 1000
+ static_cast<double>(aMaxDataSizeAiv) / aBandwidthAiv / 1000
+ static_cast<double>(bMaxDataSizeAiv) / bBandwidthAiv / 1000 + headCost + 2;
double minCost = std::numeric_limits<double>::max();
PaddingTag paddingTagA = PaddingTag::PADDING_NONE;
PaddingTag paddingTagB = PaddingTag::PADDING_NONE;
if (minCost > t00) {
minCost = t00;
}
if (minCost > t01) {
minCost = t01;
paddingTagA = PaddingTag::PADDING_NONE;
paddingTagB = PaddingTag::PADDING_NZ;
}
if (minCost > t10) {
minCost = t10;
paddingTagA = PaddingTag::PADDING_NZ;
paddingTagB = PaddingTag::PADDING_NONE;
}
if (minCost > t11) {
minCost = t11;
paddingTagA = PaddingTag::PADDING_NZ;
paddingTagB = PaddingTag::PADDING_NZ;
}
if ((innerAxisA < 8 || (innerAxisA < 32 && (innerAxisA % 16 != 0))) && outterAxisA > 512) {
paddingTagA = PaddingTag::PADDING_NZ;
}
if ((innerAxisB < 8 || (innerAxisB < 32 && (innerAxisB % 16 != 0))) && outterAxisB > 512) {
paddingTagB = PaddingTag::PADDING_NZ;
}
if (outterAxisA >= 2048 && innerAxisA > 8192 && innerAxisA % 8192 == 0) {
paddingTagA = PaddingTag::PADDING_NZ;
}
if (outterAxisB >= 2048 && innerAxisB > 8192 && innerAxisB % 8192 == 0) {
paddingTagB = PaddingTag::PADDING_NZ;
}
PaddingTag paddingTagC = PaddingTag::PADDING_NONE;
if (static_cast<size_t>(m) * n > 2048 * 2048 && n > 256 && (n % 128 != 0)) {
size_t totalDataSize = static_cast<size_t>(m) * k * CeilDiv(n, n1) * 2
+ static_cast<size_t>(k) * n * CeilDiv(m, m1) * 2 + static_cast<size_t>(m) * n * 2;
if (totalDataSize < 192 * 1024 * 1024) {
paddingTagC = PaddingTag::PADDING_ND;
}
}
tilingParams.paddingTagA = static_cast<uint8_t>(paddingTagA);
tilingParams.paddingTagB = static_cast<uint8_t>(paddingTagB);
tilingParams.paddingTagC = static_cast<uint8_t>(paddingTagC);
uint32_t blockDim = blockDimAic;
uint32_t actualTasksAivA{0};
uint32_t actualTasksAivB{0};
if (tilingParams.paddingTagA && innerAxisA > 192) {
actualTasksAivA = tasksAivA;
}
if (tilingParams.paddingTagB && innerAxisB > 192) {
actualTasksAivB = tasksAivB;
}
uint32_t actualTasksAiv = std::max(actualTasksAivA, actualTasksAivB);
uint32_t blockDimAiv = CeilDiv(actualTasksAiv, 2) > platformInfo.coreNum
? platformInfo.coreNum : CeilDiv(actualTasksAiv, 2);
if (tilingParams.paddingTagA || tilingParams.paddingTagB) {
blockDim = std::max(blockDimAic, blockDimAiv);
}
tilingParams.blockDim = blockDim;
}
bool CommonMatmulB16Handler(TilingParams ¶ms, PlatformInfo& platformInfo)
{
uint8_t kernelSerial = 0;
uint32_t taskBlocks = CeilDiv(params.m, params.m1) * CeilDiv(params.n, params.n1);
params.blockDim = taskBlocks > platformInfo.coreNum ? platformInfo.coreNum : taskBlocks;
params.tilingKey.SetTilingKey(kernelSerial, params.layoutTagA, params.layoutTagB, 0, 0, 0, 0);
return true;
}
bool SmallMatmulB16Handler(TilingParams ¶ms, PlatformInfo& platformInfo)
{
uint8_t kernelSerial = 1;
GetPaddingTag(params, platformInfo);
if (static_cast<PaddingTag>(params.paddingTagA) == PaddingTag::PADDING_NONE
&& static_cast<PaddingTag>(params.paddingTagB) == PaddingTag::PADDING_NONE
&& static_cast<PaddingTag>(params.paddingTagC) == PaddingTag::PADDING_NONE) {
uint32_t taskBlocks = CeilDiv(params.m, params.m1) * CeilDiv(params.n, params.n1);
if (taskBlocks <= platformInfo.coreNum && params.k <= params.k1) {
params.tilingKey.SetTilingKey(kernelSerial, params.layoutTagA, params.layoutTagB, 0, 0, 0, 0);
return true;
}
}
return false;
}
bool LocalPaddingCPaddingCommonMatmulB16Handler(TilingParams ¶ms, PlatformInfo &platformInfo)
{
uint8_t kernelSerial = 10;
uint32_t m = params.m;
uint32_t n = params.n;
uint32_t k = params.k;
uint32_t m1t = 128;
uint32_t n1t = 256;
uint32_t k1t = 256;
if (static_cast<size_t>(m) * n > 2048 * 2048 && n > 256 && (n % 256 != 0)) {
size_t totalDataSize = static_cast<size_t>(m) * k * CeilDiv(n, n1t) * 2
+ static_cast<size_t>(k) * n * CeilDiv(m, m1t) * 2 + static_cast<size_t>(m) * n * 2;
double ratio = static_cast<double>(
static_cast<size_t>(m) * k * CeilDiv(n, n1t) + static_cast<size_t>(k) * n * CeilDiv(m, m1t)
)
/ (static_cast<size_t>(m) * n);
if (totalDataSize >= 192 * 1024 * 1024 && ratio < 16) {
params.m1 = m1t;
params.n1 = n1t;
params.k1 = k1t;
params.n1Factor = 12;
uint32_t taskBlocks = CeilDiv(m, params.m1 * params.m1Factor) * CeilDiv(n, params.n1 * params.n1Factor);
if (taskBlocks < 8 * platformInfo.coreNum && n > 1024) {
params.n1Factor = 1;
}
params.paddingTagC = static_cast<uint8_t>(PaddingTag::PADDING_ND);
params.tilingKey.SetTilingKey(
kernelSerial, params.layoutTagA, params.layoutTagB, 0, params.paddingTagA, params.paddingTagB,
params.paddingTagC
);
return true;
}
}
return false;
}
bool PaddingCommonMatmulB16Handler(TilingParams ¶ms, PlatformInfo& platformInfo)
{
uint8_t kernelSerial = 2;
if (params.paddingTagA || params.paddingTagB || params.paddingTagC) {
params.tilingKey.SetTilingKey(kernelSerial,
params.layoutTagA, params.layoutTagB, 0, params.paddingTagA, params.paddingTagB, params.paddingTagC);
return true;
}
return false;
}
bool PaddingMultiCoreSplitkMatmulB16Handler(TilingParams& params, PlatformInfo& platformInfo)
{
if (!IsCStrideEqualShape(params)) {
return false;
}
uint32_t m = params.m;
uint32_t n = params.n;
uint32_t k = params.k;
uint32_t m1t = 128, n1t = 256, k1t = 256;
LayoutTag layoutTagA = static_cast<LayoutTag>(params.layoutTagA);
LayoutTag layoutTagB = static_cast<LayoutTag>(params.layoutTagB);
bool cond1 = (layoutTagA == LayoutTag::TagColumnMajor && layoutTagB == LayoutTag::TagColumnMajor);
bool cond2 = (layoutTagA == LayoutTag::TagColumnMajor && layoutTagB == LayoutTag::TagRowMajor) && (m > n);
if (cond1 || cond2) {
m1t = 256;
n1t = 128;
}
uint32_t blocks = CeilDiv(m, m1t) * CeilDiv(n, n1t);
uint32_t maxSplitkFactor = 2;
if (k > 1024) {
maxSplitkFactor = 4;
}
if (k > 2048) {
maxSplitkFactor = 8;
}
if (k > 4096) {
maxSplitkFactor = 16;
}
if (k >= 12288) {
maxSplitkFactor = platformInfo.coreNum;
}
if ((blocks <= platformInfo.coreNum / 2 && k > 5120) || (blocks <= 2 && k > 1024)) {
params.m1 = m1t;
params.n1 = n1t;
params.k1 = k1t;
params.splitkFactor = std::min(platformInfo.coreNum / blocks, maxSplitkFactor);
GetPaddingTag(params, platformInfo);
uint8_t kernelSerial = 3;
params.tilingKey.SetTilingKey(kernelSerial,
params.layoutTagA, params.layoutTagB, 0, params.paddingTagA, params.paddingTagB, 0);
return true;
}
return false;
}
bool PaddingStreamkMatmulB16Handler(TilingParams& params, PlatformInfo& platformInfo)
{
uint32_t m = params.m;
uint32_t n = params.n;
uint32_t k = params.k;
uint32_t m1t = 128, n1t = 256, k1t = 256;
LayoutTag layoutTagA = static_cast<LayoutTag>(params.layoutTagA);
LayoutTag layoutTagB = static_cast<LayoutTag>(params.layoutTagB);
bool cond1 = (layoutTagA == LayoutTag::TagColumnMajor && layoutTagB == LayoutTag::TagColumnMajor);
bool cond2 = (layoutTagA == LayoutTag::TagColumnMajor && layoutTagB == LayoutTag::TagRowMajor) && (m > n);
if (cond1 || cond2) {
m1t = 256;
n1t = 128;
}
uint32_t blocks = CeilDiv(m, m1t) * CeilDiv(n, n1t);
uint32_t skBlocks = blocks % platformInfo.coreNum;
if (blocks > platformInfo.coreNum && blocks < 8 * platformInfo.coreNum && skBlocks > 0
&& skBlocks < 0.8 * platformInfo.coreNum && params.k > 3072) {
params.m1 = m1t;
params.n1 = n1t;
params.k1 = k1t;
GetPaddingTag(params, platformInfo);
params.blockDim = platformInfo.coreNum;
uint32_t kernelSerial = 4;
params.tilingKey.SetTilingKey(kernelSerial,
params.layoutTagA, params.layoutTagB, 0, params.paddingTagA, params.paddingTagB, 0);
return true;
}
return false;
}
bool AivMatmulB16Handler(TilingParams& params, PlatformInfo& platformInfo)
{
if (params.k != 1 || !IsAStrideEqualShape(params) || !IsBStrideEqualShape(params)) {
return false;
}
uint32_t aivCoreNums = platformInfo.coreNum * 2;
constexpr uint32_t SCALAR_BUFFER_ELE_NUM = 256;
uint8_t kernelSerial = 9;
uint8_t dispatchPolicyTag = 0;
if (params.m <= params.n || params.n > 64) {
uint32_t nTile = RoundUp(params.n, 16);
if (nTile * 2 > 32768) {
nTile = 32768 / 2;
} else if (nTile * 2 < 128) {
nTile = 128 / 2;
}
uint32_t nCut = CeilDiv(params.n, nTile);
uint32_t mCoreNum = CeilDiv(aivCoreNums, nCut);
uint32_t mTile = CeilDiv(params.m, mCoreNum);
uint32_t ubASize = SCALAR_BUFFER_ELE_NUM * 2;
uint32_t ubBSize = nTile * 2;
uint32_t ubCSize = platformInfo.ubSize - ubASize - ubBSize;
uint32_t mTileLimits = ubCSize / ubBSize > 64 ? 64 : ubCSize / ubBSize;
if (mTile <= mTileLimits && nTile >= params.n) {
dispatchPolicyTag = 1;
}
if (mTile > mTileLimits) {
mTile = nTile >= 8192 ? mTileLimits * 4 : mTileLimits;
}
uint32_t blocks = CeilDiv(params.m, mTile) * CeilDiv(params.n, nTile);
uint32_t loopsTimes = CeilDiv(blocks, aivCoreNums);
uint32_t nextMTile = mTile;
uint32_t nextBlocks = blocks;
uint32_t nextLoopsTimes = loopsTimes;
while(nextLoopsTimes == loopsTimes && nextMTile > 1) {
mTile = nextMTile;
nextMTile = nextMTile - 1;
nextBlocks = CeilDiv(params.m, nextMTile) * CeilDiv(params.n, nTile);
nextLoopsTimes = CeilDiv(nextBlocks, aivCoreNums);
}
uint32_t nextNTile = nTile;
nextLoopsTimes = loopsTimes;
while (nextLoopsTimes == loopsTimes && nextNTile > 64) {
nTile = nextNTile;
nextNTile = nextNTile - 16;
nextBlocks = CeilDiv(params.m, mTile) * CeilDiv(params.n, nextNTile);
nextLoopsTimes = CeilDiv(nextBlocks, aivCoreNums);
}
params.m1 = mTile;
params.n1 = nTile;
params.k1 = 0;
params.tilingKey.SetTilingKey(kernelSerial, dispatchPolicyTag, 0, 0, 0, 0, 0);
} else {
uint32_t nTile = RoundUp(params.n, 16);
uint32_t mTile = params.m;
if (params.m <= 4096) {
mTile = 128;
} else {
mTile = CeilDiv(params.m, aivCoreNums);
}
mTile = RoundUp(mTile, 16);
uint32_t ubScalarSize = SCALAR_BUFFER_ELE_NUM * 2;
uint32_t mTileLimits = (platformInfo.ubSize - ubScalarSize) / 2 / (1 + nTile * 2) / 16 * 16;
if (mTile > mTileLimits) {
mTile = mTileLimits;
}
uint32_t blocks = CeilDiv(params.m, mTile) * CeilDiv(params.n, nTile);
uint32_t loopsTimes = CeilDiv(blocks, aivCoreNums);
uint32_t nextMTile = mTile;
uint32_t nextBlocks = blocks;
uint32_t nextLoopsTimes = loopsTimes;
while (nextLoopsTimes == loopsTimes && nextMTile > 112) {
mTile = nextMTile;
nextMTile = nextMTile - 16;
nextBlocks = CeilDiv(params.m, nextMTile) * CeilDiv(params.n, nTile);
nextLoopsTimes = CeilDiv(nextBlocks, aivCoreNums);
}
params.m1 = mTile;
params.n1 = nTile;
params.k1 = 0;
dispatchPolicyTag = 2;
params.tilingKey.SetTilingKey(kernelSerial, dispatchPolicyTag, 0, 0, 0, 0, 0);
}
uint32_t taskBlocks = CeilDiv(params.m, params.m1) * CeilDiv(params.n, params.n1);
params.blockDim = taskBlocks > (platformInfo.coreNum * 2) ? platformInfo.coreNum * 2 : taskBlocks;
return true;
}
void SetSwizzleParams(TilingParams &tilingParams)
{
if (tilingParams.m > tilingParams.n) {
tilingParams.swizzleOffset = 3;
tilingParams.swizzleDirection = 0;
} else {
tilingParams.swizzleOffset = 3;
tilingParams.swizzleDirection = 1;
}
}
void SelectKernelB16(TilingParams &tilingParams, PlatformInfo& platformInfo)
{
uint8_t layoutTagATmp = tilingParams.layoutTagA;
uint8_t layoutTagBTmp = tilingParams.layoutTagB;
if (static_cast<LayoutTag>(tilingParams.layoutTagA) == LayoutTag::TagColumnMajor &&
tilingParams.m == 1 && tilingParams.strideA == 1) {
tilingParams.layoutTagA = static_cast<uint8_t>(LayoutTag::TagRowMajor);
}
if (static_cast<LayoutTag>(tilingParams.layoutTagB) == LayoutTag::TagRowMajor &&
tilingParams.n == 1 && tilingParams.strideB == 1) {
tilingParams.layoutTagB = static_cast<uint8_t>(LayoutTag::TagColumnMajor);
}
using HandlerPtr = bool (*)(TilingParams& tilingParams, PlatformInfo& platformInfo);
HandlerPtr handlers[] = {
AivMatmulB16Handler,
SmallMatmulB16Handler,
PaddingMultiCoreSplitkMatmulB16Handler,
PaddingStreamkMatmulB16Handler,
LocalPaddingCPaddingCommonMatmulB16Handler,
PaddingCommonMatmulB16Handler,
CommonMatmulB16Handler
};
for (auto handler : handlers) {
if (handler(tilingParams, platformInfo)) {
break;
}
}
tilingParams.layoutTagA = layoutTagATmp;
tilingParams.layoutTagB = layoutTagBTmp;
SetSwizzleParams(tilingParams);
}
#endif