* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file convolution_util.cpp
* \brief
*/
#include <memory>
#include <vector>
#include <string>
#include "aclnn_kernels/cast.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "aclnn_kernels/contiguous.h"
#include "level0/squeeze.h"
#include "level0/unsqueeze.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "aclnn_kernels/transdata.h"
#include "aclnn_kernels/reshape.h"
#include "convolution_util.h"
using namespace op;
using namespace ge;
using namespace l0op;
static uint64_t g_ubSize = 0;
static uint64_t g_l1Size = 0;
namespace SplitWInfo {
constexpr int STRIDEH_MAX = 63;
constexpr int DILATION_MAX = 255;
constexpr int PAD_MAX = 255;
constexpr int WEIGHT_MAX = 511;
constexpr int HI_INDEX = 2;
constexpr int WI_INDEX = 3;
constexpr int LEFT_INDEX_ATTR = 2;
constexpr int RIGHT_INDEX_ATTR = 3;
constexpr int W_INDEX_ATTR_CONV3D = 2;
constexpr int CONV3D_ATTR_NUM = 3;
constexpr size_t CONV_2D_DIM_SIZE = 4;
constexpr int EXTRA_NUM = 2;
constexpr int64_t BLK_M = 16;
constexpr int64_t BLK_N = 16;
constexpr int64_t POSK_LIMIT = 65535;
constexpr int64_t BLK_LEN = 32;
constexpr int64_t CONST_DOUBLE = 2;
std::map<op::DataType, uint32_t> gDataTypeSizeTab = {{op::DataType::DT_FLOAT16, 2}, {op::DataType::DT_FLOAT, 4},
{op::DataType::DT_BF16, 2}, {op::DataType::DT_INT8, 1},
{op::DataType::DT_UINT8, 1}, {op::DataType::DT_INT64, 8},
{op::DataType::DT_UINT64, 8}, {op::DataType::DT_INT32, 4}};
}
namespace ConvolutionUtil {
void Conv2DSplitWInfo::InitConv2DSplitWInfo(const aclTensor* input, const aclTensor* weight, const aclIntArray* stride,
const aclIntArray* padding, const aclIntArray* dilation)
{
hi = input->GetViewShape().GetDim(SplitWInfo::HI_INDEX);
wi = input->GetViewShape().GetDim(SplitWInfo::WI_INDEX);
kh = weight->GetViewShape().GetDim(SplitWInfo::HI_INDEX);
kw = weight->GetViewShape().GetDim(SplitWInfo::WI_INDEX);
strideH = (*stride)[0];
strideW = (*stride)[1];
dilationH = (*dilation)[0];
dilationW = (*dilation)[1];
padU = (*padding)[0];
padD = (*padding)[1];
padL = (*padding)[SplitWInfo::LEFT_INDEX_ATTR];
padR = (*padding)[SplitWInfo::RIGHT_INDEX_ATTR];
}
bool Conv2DSplitWInfo::CanSwitchSplitW(const aclTensor* bias, aclTensor* output, int64_t groups,
const ConvolutionOpInfo& opInfo)
{
if (!CheckBasicInfoInSplitW(groups, opInfo)) {
OP_LOGD("Conv2d splitw only support groups is 1, dtypes are [FP16/BF16/FP32] and soc is A2 or A3.\n");
return false;
}
if (CheckConv2DPad()) {
OP_LOGD("Conv2d splitw satisfying padU/D or padL/R not same.\n");
return false;
}
if (CheckConv2DInput()) {
OP_LOGD("Conv2d splitw satisfying strideh that is greater than or equal to 2*kernelh.\n");
return false;
}
if (CheckConv2DTbeOptFlag(opInfo)) {
OP_LOGD("Conv2d splitw satisfying TBE optimization.\n");
return false;
}
if (bias) {
biasTypeSize = SplitWInfo::gDataTypeSizeTab[opInfo.biasDtype];
}
k0 = SplitWInfo::BLK_LEN / SplitWInfo::gDataTypeSizeTab[opInfo.inputDtype];
if (!CheckLoad3dIns()) {
OP_LOGD("Conv2d splitw exceeding load3d ins posk limit %ld.\n", SplitWInfo::POSK_LIMIT);
return false;
}
if (!CheckLoadL1InSplitW(bias, output)) {
return false;
}
return true;
}
bool Conv2DSplitWInfo::CheckConv2DPad() const
{
return (padU != padD) || (padL != padR);
}
bool Conv2DSplitWInfo::CheckConv2DInput() const
{
return strideH >= SplitWInfo::CONST_DOUBLE * kh;
}
bool Conv2DSplitWInfo::CheckConv2DTbeOptFlag(const ConvolutionOpInfo& opInfo)
{
bool load2dFeature = (kh == 1 && kw == 1) && (strideH == 1 && strideW == 1) && hi != 1 &&
(padU == 0 && padD == 0 && padL == 0 && padR == 0);
bool supportDtype = (opInfo.inputDtype == op::DataType::DT_BF16 && opInfo.weightDtype == op::DataType::DT_BF16);
bool canUseLoad2D = (load2dFeature && supportDtype);
if (canUseLoad2D) {
return true;
}
bool canUseConv1D = (hi == 1 && kh == 1 && padU == 0 && padD == 0);
if (canUseConv1D) {
return true;
}
bool padDMAFlag = (padU > SplitWInfo::PAD_MAX || padD > SplitWInfo::PAD_MAX || padL > SplitWInfo::PAD_MAX ||
padR > SplitWInfo::PAD_MAX);
bool strideDMAFlag = (strideH > SplitWInfo::STRIDEH_MAX || strideW > SplitWInfo::STRIDEH_MAX);
bool dilationDMAFlag = (dilationH > SplitWInfo::DILATION_MAX || dilationW > SplitWInfo::DILATION_MAX);
bool kernelDMAFlag = (kh > SplitWInfo::WEIGHT_MAX || kw > SplitWInfo::WEIGHT_MAX);
if (padDMAFlag || strideDMAFlag || dilationDMAFlag || kernelDMAFlag) {
return true;
}
return false;
}
bool Conv2DSplitWInfo::CheckBasicInfoInSplitW(int64_t groups, const ConvolutionOpInfo& opInfo)
{
if (groups != 1) {
return false;
}
if (GetCurrentPlatformInfo().GetSocVersion() != SocVersion::ASCEND910_93 &&
GetCurrentPlatformInfo().GetSocVersion() != SocVersion::ASCEND910B) {
return false;
}
const std::unordered_set<op::DataType> supportedDtypes{op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16,
op::DataType::DT_BF16};
bool supportDtypeFlag = (supportedDtypes.count(opInfo.inputDtype) > 0) &&
(supportedDtypes.count(opInfo.weightDtype) > 0);
if (!supportDtypeFlag) {
return false;
}
return true;
}
bool Conv2DSplitWInfo::CheckLoad3dIns()
{
bool posKFlag = (kh * kw * k0 <= SplitWInfo::POSK_LIMIT);
return posKFlag;
}
bool Conv2DSplitWInfo::CheckLoadL1InSplitW(const aclTensor* bias, aclTensor* output)
{
constexpr int64_t l1BufferSizeForTbe = 524288;
constexpr int64_t l1BufferSizeForAsc = 524032;
int64_t outputW = output->GetViewShape().GetDim(SplitWInfo::WI_INDEX);
if (outputW == 0) {
return false;
}
int64_t hoAL1Min = SplitWInfo::BLK_M / outputW + SplitWInfo::EXTRA_NUM;
int64_t hkDilation = (kh - 1) * dilationH + 1;
int64_t hiAL1Min = std::min(((hoAL1Min - 1) * strideH + hkDilation), hi);
int64_t minL1Size = hiAL1Min * wi * SplitWInfo::BLK_LEN;
if (bias != nullptr) {
minL1Size += SplitWInfo::BLK_N * biasTypeSize;
}
if (minL1Size <= l1BufferSizeForTbe) {
OP_LOGD("Conv2d splitw minL1Size less than L1 Buffer in m mode.\n");
return false;
}
hiAL1Min = std::min(hkDilation, hi);
int64_t woAL1Min = 16;
int64_t wkDilation = (kw - 1) * dilationW + 1;
int64_t wiAL1_min = std::min(((woAL1Min - 1) * strideW + wkDilation), wi);
minL1Size = hiAL1Min * wiAL1_min * SplitWInfo::BLK_LEN;
if (bias != nullptr) {
minL1Size += SplitWInfo::BLK_N * biasTypeSize;
}
if (minL1Size > l1BufferSizeForAsc) {
OP_LOGD("Conv2d splitw minL1Size greater than L1 Buffer in hw mode.\n");
return false;
}
return true;
}
aclIntArray* View2dAs3dForAttr(const aclIntArray* intArray, int64_t expendValue, aclOpExecutor* executor, bool isPad)
{
int64_t data[SplitWInfo::CONV3D_ATTR_NUM];
uint64_t size = intArray->Size();
if (!isPad && (size != static_cast<uint64_t>(SplitWInfo::CONV3D_ATTR_NUM - 1))) {
return nullptr;
}
data[0] = expendValue;
data[1] = (*intArray)[0];
if (isPad) {
data[SplitWInfo::W_INDEX_ATTR_CONV3D] = (*intArray)[SplitWInfo::LEFT_INDEX_ATTR];
} else {
data[SplitWInfo::W_INDEX_ATTR_CONV3D] = (*intArray)[1];
}
aclIntArray* newArray = executor->AllocIntArray(data, SplitWInfo::CONV3D_ATTR_NUM);
return newArray;
}
aclIntArray* View2DSwapHWForAttr(const aclIntArray* intArray, aclOpExecutor* executor)
{
int64_t data[2];
uint64_t size = intArray->Size();
if ((size != static_cast<uint64_t>(SplitWInfo::CONV3D_ATTR_NUM - 1))) {
return nullptr;
}
data[0] = (*intArray)[1];
data[1] = (*intArray)[0];
aclIntArray* newArray = executor->AllocIntArray(data, 2);
return newArray;
}
const aclTensor* View4DSwapHWForTensor(const aclTensor* input, aclOpExecutor* executor)
{
auto dims = input->GetViewShape().GetDimNum();
CHECK_RET(dims == SplitWInfo::CONV_2D_DIM_SIZE, nullptr);
auto shape = op::ToShapeVector(input->GetViewShape());
FVector<int64_t> newShape = {shape[0], shape[1], shape[3], shape[2]};
aclIntArray* shapeArray = executor->AllocIntArray(newShape.data(), newShape.size());
CHECK_RET(shapeArray != nullptr, nullptr);
input = l0op::Reshape(input, shapeArray, executor);
CHECK_RET(input != nullptr, nullptr);
return input;
}
const aclTensor* View4dAs5dForInput(const aclTensor* input, aclOpExecutor* executor)
{
auto contiguousInput = l0op::Contiguous(input, executor);
CHECK_RET(contiguousInput != nullptr, nullptr);
const int64_t appendDim[] = {SplitWInfo::HI_INDEX};
aclIntArray* dim = executor->AllocIntArray(appendDim, 1);
CHECK_RET(dim != nullptr, nullptr);
auto unsqueezedInput = l0op::UnsqueezeNd(contiguousInput, dim, executor);
CHECK_RET(unsqueezedInput != nullptr, nullptr);
auto reformatInput = l0op::ReFormat(unsqueezedInput, op::Format::FORMAT_NCDHW);
CHECK_RET(reformatInput != nullptr, nullptr);
return reformatInput;
}
aclnnStatus ChangeConv2dAttrToConv3d(const aclIntArray* &stride, const aclIntArray* &padding,
const aclIntArray* &dilation, aclOpExecutor* executor)
{
stride = View2dAs3dForAttr(stride, 1, executor, false);
CHECK_RET(stride != nullptr, ACLNN_ERR_INNER_NULLPTR);
dilation = View2dAs3dForAttr(dilation, 1, executor, false);
CHECK_RET(dilation != nullptr, ACLNN_ERR_INNER_NULLPTR);
padding = View2dAs3dForAttr(padding, 0, executor, true);
CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);
return ACLNN_SUCCESS;
}
aclnnStatus ChangeConv2dInputToConv3d(const aclTensor* &input, const aclTensor* &weight, aclOpExecutor* executor)
{
input = View4dAs5dForInput(input, executor);
CHECK_RET(input != nullptr, ACLNN_ERR_INNER_NULLPTR);
weight = View4dAs5dForInput(weight, executor);
CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);
return ACLNN_SUCCESS;
}
const aclTensor* View5dAs4dForOutput(const aclTensor* input, aclOpExecutor* executor)
{
auto contiguousInput = l0op::Contiguous(input, executor);
CHECK_RET(contiguousInput != nullptr, nullptr);
const int64_t appendDim[] = {SplitWInfo::HI_INDEX};
aclIntArray* dim = executor->AllocIntArray(appendDim, 1);
CHECK_RET(dim != nullptr, nullptr);
auto squeezedInput = l0op::SqueezeNd(contiguousInput, dim, executor);
CHECK_RET(squeezedInput != nullptr, nullptr);
auto reformatInput = l0op::ReFormat(squeezedInput, op::Format::FORMAT_NCHW);
CHECK_RET(reformatInput != nullptr, nullptr);
return reformatInput;
}
bool CheckDisContinuousStride(const aclTensor* input, const std::vector<int64_t>& newStrides, uint32_t dims)
{
auto viewStrides = input->GetViewStrides();
uint32_t totalDims = std::min(viewStrides.size(), newStrides.size());
if (dims > totalDims) {
OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "Invalid dims");
}
for (size_t i = 0; i < dims; i++) {
if (viewStrides[i] != newStrides[i]) {
return false;
}
}
return true;
}
void GetUbSize()
{
if (g_ubSize != 0) {
return;
}
uint64_t ubSize = 0;
auto platformInfo = GetCurrentPlatformInfo().GetPlatformInfos();
if (platformInfo == nullptr) {
OP_LOGW("Platform info is null, ubSize fallback to 0.");
return;
}
platformInfo->GetLocalMemSize(fe::LocalMemType::UB, ubSize);
OP_LOGD("GetUbSize returned: %lu", ubSize);
g_ubSize = ubSize - RESERVED_SIZE_8K;
}
void GetL1Size()
{
if (g_l1Size != 0) {
return;
}
auto platformInfo = GetCurrentPlatformInfo().GetPlatformInfos();
if (platformInfo == nullptr) {
OP_LOGW("Platform info is null, ubSize fallback to 0.");
return;
}
platformInfo->GetLocalMemSize(fe::LocalMemType::L1, g_l1Size);
OP_LOGD("GetL1Size returned: %lu", g_l1Size);
}
bool CheckDmaLimits(const struct ConvolutionOpInfo* opInfo, const aclTensor* input, const aclTensor* weight,
const aclIntArray* stride, const aclIntArray* padding, const aclIntArray* dilation, const aclTensor* bias)
{
int64_t orgKh = static_cast<int64_t>(weight->GetViewShape().GetDim(SplitWInfo::HI_INDEX));
int64_t orgKw = static_cast<int64_t>(weight->GetViewShape().GetDim(SplitWInfo::WI_INDEX));
int64_t hin = static_cast<int64_t>(input->GetViewShape().GetDim(SplitWInfo::HI_INDEX));
int64_t win = static_cast<int64_t>(input->GetViewShape().GetDim(SplitWInfo::WI_INDEX));
int64_t strideH = (*stride)[0];
int64_t strideW = (*stride)[1];
int64_t dilationH = (*dilation)[0];
int64_t dilationW = (*dilation)[1];
int64_t padLeft = (*padding)[SplitWInfo::LEFT_INDEX_ATTR];
int64_t padRight = (*padding)[SplitWInfo::RIGHT_INDEX_ATTR];
int64_t m0 = SplitWInfo::BLK_M;
int64_t n0 = SplitWInfo::BLK_N;
int64_t k0 = SplitWInfo::BLK_LEN / SplitWInfo::gDataTypeSizeTab[opInfo->weightDtype];
uint32_t inputDtypeSize = SplitWInfo::gDataTypeSizeTab[opInfo->inputDtype];
uint32_t weightDtypeSize = SplitWInfo::gDataTypeSizeTab[opInfo->weightDtype];
uint32_t biasDtypeSize = SplitWInfo::gDataTypeSizeTab[opInfo->biasDtype];
uint64_t nBL1min = n0;
if (orgKh * orgKw * k0 > SplitWInfo::POSK_LIMIT) {
return true;
}
uint64_t biasL1Size = 0;
if (bias != nullptr) {
biasL1Size = ConvAlignB(nBL1min * biasDtypeSize, SplitWInfo::BLK_LEN);
}
uint64_t kBL1min = k0 * orgKh * orgKw;
uint64_t weightL1Size = ConvAlignB(kBL1min * nBL1min * weightDtypeSize, SplitWInfo::BLK_LEN);
uint64_t inputL1Size = 0;
uint64_t orgWo = (win + padLeft + padRight - (dilationW * (orgKw - 1) + 1)) / strideW + 1;
uint64_t hoAL1min = orgWo < m0 ? (m0 + orgWo - 1) / orgWo : 1;
uint64_t khDilated = (orgKh - 1) * dilationH + 1;
uint64_t hiAL1min = Conv2DInferHiL1(hoAL1min, khDilated, hin, strideH);
uint64_t kAL1min = k0;
uint64_t woAL1min = m0;
uint64_t kwDilated = (orgKw - 1) * dilationW + 1;
uint64_t wiAL1min = Conv2DInferHiL1(woAL1min, kwDilated, win, strideW);
inputL1Size = ConvAlignB(hiAL1min * wiAL1min * kAL1min * inputDtypeSize, SplitWInfo::BLK_LEN);
GetL1Size();
uint64_t minL1LoadSize = biasL1Size + inputL1Size + weightL1Size;
if (minL1LoadSize > g_l1Size) {
return true;
}
return !CheckL1SizeLimitsDma(inputDtypeSize, biasL1Size, weightDtypeSize, k0);
}
bool CheckL1SizeLimitsDma(uint32_t inputDtypeSize, uint64_t biasL1Size, uint32_t weightDtypeSize, int64_t k0)
{
uint64_t hoAL1Min = 1;
GetUbSize();
uint64_t inputUbSizeMin = ConvAlignB(hoAL1Min * SplitWInfo::BLK_M * k0 * inputDtypeSize, SplitWInfo::BLK_LEN);
if (inputUbSizeMin > g_ubSize) {
OP_LOGD("DMA min ub size not enough: ubSize=%lu, inputUbSizeMin=%lu.", g_ubSize, inputUbSizeMin);
return false;
}
uint64_t nBL1min = SplitWInfo::BLK_M;
uint64_t kBL1min = k0;
uint64_t weightUsedL1Size = ConvAlignB(kBL1min * nBL1min * weightDtypeSize, SplitWInfo::BLK_LEN);
uint64_t kAL1min = k0;
uint64_t woAL1min = SplitWInfo::BLK_M;
uint64_t inputUsedL1Size = ConvAlignB(hoAL1Min * woAL1min * kAL1min * inputDtypeSize, SplitWInfo::BLK_LEN);
uint64_t minL1LoadSize = biasL1Size + inputUsedL1Size + weightUsedL1Size;
GetL1Size();
if (minL1LoadSize > g_l1Size) {
return false;
}
return true;
}
uint64_t Conv2DInferHiL1(uint64_t inputHoL1, uint64_t khDilated, uint64_t hi, uint64_t strideH)
{
uint64_t tmpHiL1 = (inputHoL1 - 1) * strideH + khDilated;
if (tmpHiL1 > hi) {
tmpHiL1 = hi;
}
return tmpHiL1;
}
uint64_t ConvAlignB(uint64_t a, uint64_t b)
{
if (b == 0) {
return 0;
}
return ((a + b - 1) / b) * b;
}
}