* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "convolution_backward_checker.h"
using namespace op;
using namespace l0op;
using namespace Ops::NN;
#include "op_api/aclnn_util.h"
#ifdef __cplusplus
extern "C" {
#endif
bool CheckDtypeValid(const aclTensor *inputTensor, bool transposed) {
auto curArch = GetCurrentPlatformInfo().GetCurNpuArch();
if (Ops::NN::AclnnUtil::IsRegbase(curArch)) {
auto dtypeSupportList = GetDtypeSupportListBySocVersion4ConvBackward(transposed);
OP_CHECK_DTYPE_NOT_SUPPORT(inputTensor, dtypeSupportList, return false);
} else {
auto dtypeSupportList = GetDtypeSupportListBySocVersion();
OP_CHECK_DTYPE_NOT_SUPPORT(inputTensor, dtypeSupportList, return false);
}
return true;
}
bool CheckParamsValueAllZero(const aclIntArray *params) {
if (params != nullptr) {
for (uint64_t i = 0; i < params->Size(); ++i) {
if ((*params)[i] != 0) {
return false;
}
}
}
return true;
}
bool CheckFormatValid(const aclTensor *inputTensor, const string &tensorName) {
op::Format inputFormat = inputTensor->GetStorageFormat();
std::string inputFormatStr = g_formatToStrTab[inputFormat];
auto inputDim = inputTensor->GetViewShape().GetDimNum();
if (inputDim == CONV1DINPUTDIM) {
OP_CHECK(inputFormat == op::Format::FORMAT_ND || inputFormat == op::Format::FORMAT_NCL,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "In 1D scenes, the %s format only supports ND and NCL, but received %s.",
tensorName.c_str(), inputFormatStr.c_str()), return false);
} else if (inputDim == CONV2DINPUTDIM) {
OP_CHECK(inputFormat == op::Format::FORMAT_NCHW,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "In 2D scenes, the %s format only supports NCHW, but received %s.",
tensorName.c_str(), inputFormatStr.c_str()), return false);
} else if (inputDim == CONV3DINPUTDIM) {
OP_CHECK(inputFormat == op::Format::FORMAT_NCDHW,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "In 3D scenes, the %s format only supports NCDHW, but received %s.",
tensorName.c_str(), inputFormatStr.c_str()), return false);
} else {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID, "The %s tensor dimension of this API only supports 3~5 dimensions.",
tensorName.c_str());
return false;
}
return true;
}
bool CheckParamsValue(const aclIntArray *params, bool isPad) {
int64_t minValue = (isPad) ? 0 : 1;
if (params != nullptr) {
for (uint64_t i = 0; i < params->Size(); ++i) {
if ((*params)[i] < minValue) {
return false;
}
}
}
return true;
}
void GetChannleIndex(const op::Shape &shape, const op::Format &format, int64_t &channelIndex) {
auto inputDim = shape.GetDimNum();
if (format == op::Format::FORMAT_NDHWC || format == op::Format::FORMAT_NHWC) {
channelIndex = inputDim - 1;
} else {
channelIndex = 1;
}
}
bool CheckResolutionGEKernelShape(const op::Shape &inputShape, const op::Shape &weightShape, const ConvolutionBackwardParams ¶ms, int64_t dimIdx) {
int64_t dimOrder = dimIdx - 2;
int64_t filterDimDilation = (weightShape[dimIdx] - 1) * (*params.dilation)[dimOrder] + 1;
int64_t dimInput = inputShape.GetDim(dimIdx) + (*params.padding)[dimOrder] * 2 - filterDimDilation;
bool dimInputExpect = dimInput >= 0;
OP_CHECK(dimInputExpect,
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"(in_dim(%ld) + pad_dim(%ld) * 2) should >= ((weight_shape(%ld) - 1) * dilation(%ld) + 1)",
inputShape.GetDim(dimIdx), (*params.padding)[dimOrder], weightShape[dimIdx], (*params.dilation)[dimOrder]),
return false);
return true;
}
int64_t GetExpectNum(const op::Shape &inputShape, const op::Shape &weightShape, const ConvolutionBackwardParams ¶ms, int64_t dimIdx) {
int64_t dimOrder = dimIdx - 2;
int64_t filterDimDilation = (weightShape[dimIdx] - 1) * (*params.dilation)[dimOrder] + 1;
int64_t dimInput = inputShape.GetDim(dimIdx) + (*params.padding)[dimOrder] * 2 - filterDimDilation;
int64_t dimExpect = dimInput / (*params.stride)[dimOrder] + 1;
return dimExpect;
}
bool GetExpectValueDHW_95(const ConvolutionBackwardInputTensor &inputTensor, const ConvolutionBackwardParams ¶ms, struct ExpectValue &expectValue, const op::Shape &inputShape, const op::Shape &weightShape) {
op::Format weightFormat = inputTensor.weight->GetStorageFormat();
op::Format inputFormat = inputTensor.input->GetStorageFormat();
int64_t inputDVal = 0;
int64_t inputHVal = 0;
int64_t inputWVal = 0;
GetInputShapeSize(inputFormat, inputShape, inputDVal, inputHVal, inputWVal);
int64_t weightDVal = 0;
int64_t weightHVal = 0;
int64_t weightWVal = 0;
GetWeightShapeSize(weightFormat, weightShape, weightDVal, weightHVal, weightWVal);
if (!CheckResolutionGEKernelShape_95(inputDVal, weightDVal, 0, params) ||
!CheckResolutionGEKernelShape_95(inputHVal, weightHVal, 1, params) ||
!CheckResolutionGEKernelShape_95(inputWVal, weightWVal, 2, params)) {
return false;
}
expectValue.doExpect = GetExpectNum_95(inputDVal, weightDVal, 0, params);
expectValue.hoExpect = GetExpectNum_95(inputHVal, weightHVal, 1, params);
expectValue.woExpect = GetExpectNum_95(inputWVal, weightWVal, 2, params);
return true;
}
void GetInputShapeSize(const op::Format &format, const op::Shape &shape, int64_t &shapeDVal, int64_t &shapeHVal, int64_t &shapeWVal) {
if (format == op::Format::FORMAT_NCDHW) {
shapeDVal = shape.GetDim(dDimNCDHWIdx);
shapeHVal = shape.GetDim(hDimNCDHWIdx);
shapeWVal = shape.GetDim(wDimNCDHWIdx);
} else {
shapeDVal = shape.GetDim(1);
shapeHVal = shape.GetDim(kHDimNDHWCIdx);
shapeWVal = shape.GetDim(kWDimNDHWCIdx);
}
}
bool CheckResolutionGEKernelShape_95(int64_t inputVal, int64_t weightVal, int64_t dimOrder, const ConvolutionBackwardParams ¶ms) {
int64_t filterDimDilation = (weightVal - 1) * (*params.dilation)[dimOrder] + 1;
int64_t dimInput = inputVal + (*params.padding)[dimOrder] * 2 - filterDimDilation;
bool dimInputExpect = dimInput >= 0;
OP_CHECK(dimInputExpect,
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"(in_dim(%ld) + pad_dim(%ld) * 2) should >= ((weight_shape(%ld) - 1) * dilation(%ld) + 1)",
inputVal, (*params.padding)[dimOrder], weightVal, (*params.dilation)[dimOrder]),
return false);
return true;
}
int64_t GetExpectNum_95(int64_t inputVal, int64_t weightVal, int64_t dimOrder, const ConvolutionBackwardParams ¶ms) {
int64_t filterDimDilation = (weightVal - 1) * (*params.dilation)[dimOrder] + 1;
int64_t dimInput = inputVal + (*params.padding)[dimOrder] * 2 - filterDimDilation;
int64_t dimExpect = dimInput / (*params.stride)[dimOrder] + 1;
return dimExpect;
}
string AclarrayToString(const aclIntArray *array) {
string str = "";
if (array == nullptr) {
return str;
}
for (uint64_t i = 0; i < array->Size(); ++i) {
str += to_string((*array)[i]);
if (i < array->Size() - 1) {
str += ",";
}
}
return str;
}
void GetWeightShapeSize(const op::Format &weightFormat, const op::Shape &weightShape, int64_t &weightDVal, int64_t &weightHVal, int64_t &weightWVal) {
if (weightFormat == op::Format::FORMAT_NCDHW) {
weightDVal = weightShape[dDimNCDHWIdx];
weightHVal = weightShape[hDimNCDHWIdx];
weightWVal = weightShape[wDimNCDHWIdx];
} else {
weightDVal = weightShape[1];
weightHVal = weightShape[kHDimNDHWCIdx];
weightWVal = weightShape[kWDimNDHWCIdx];
}
}
aclnnStatus CalculateConvolutionBackwardWithEmpty(ConvolutionBackwardInputTensor &inputTensor,
ConvolutionBackwardOutput &outputTensor,
ConvolutionBackwardParams ¶ms, aclOpExecutor *executor) {
if ((*params.outputMask)[1] && outputTensor.gradWeight->Size() != 0) {
auto weightContiguous = l0op::Contiguous(inputTensor.weight, executor);
auto gradWeightZeros = l0op::ZerosLike(weightContiguous, executor);
OP_CHECK(gradWeightZeros != nullptr,
OP_LOGE(ACLNN_ERR_INNER_NULLPTR,
"The calculation with empty tensor failed, weight with ZerosLike return nullptr."),
return ACLNN_ERR_INNER_NULLPTR);
auto result = l0op::ViewCopy(gradWeightZeros, outputTensor.gradWeight, executor);
OP_CHECK(result != nullptr,
OP_LOGE(ACLNN_ERR_INNER_NULLPTR,
"The calculation with empty tensor failed, weight with ViewCopy return nullptr."),
return ACLNN_ERR_INNER_NULLPTR);
}
if ((*params.outputMask)[2] && outputTensor.gradBias->Size() != 0) {
op::Shape biasGradShape = {(*params.biasSizes)[0]};
auto biasTensor = executor->AllocTensor(biasGradShape, inputTensor.weight->GetDataType());
biasTensor->SetStorageFormat(op::Format::FORMAT_ND);
biasTensor->SetViewFormat(op::Format::FORMAT_ND);
biasTensor->SetOriginalFormat(op::Format::FORMAT_ND);
auto gradBiasZeros = l0op::ZerosLike(biasTensor, executor);
OP_CHECK(gradBiasZeros != nullptr,
OP_LOGE(ACLNN_ERR_INNER_NULLPTR,
"The calculation with empty tensor failed, bias with ZerosLike return nullptr."),
return ACLNN_ERR_INNER_NULLPTR);
auto result = l0op::ViewCopy(gradBiasZeros, outputTensor.gradBias, executor);
OP_CHECK(result != nullptr,
OP_LOGE(ACLNN_ERR_INNER_NULLPTR,
"The calculation with empty tensor failed, bias with ViewCopy return nullptr."),
return ACLNN_ERR_INNER_NULLPTR);
}
return ACLNN_SUCCESS;
}
namespace Ops {
namespace NN {
namespace Conv {
bool ConvolutionBackwardChecker::CheckDataTypeValidForGradInput() {
if (!Ops::NN::AclnnUtil::IsRegbase()) {
return true;
}
OP_CHECK(outputTensor_.gradInput->GetDataType() == inputTensor_.input->GetDataType(),
OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "gradInput data type[%s] should be equal to input data type[%s]",
op::ToString(outputTensor_.gradInput->GetDataType()).GetString(),
op::ToString(inputTensor_.input->GetDataType()).GetString()), return false);
return true;
}
bool ConvolutionBackwardChecker::CheckDataTypeValidForGradWeight() {
if (!Ops::NN::AclnnUtil::IsRegbase()) {
return true;
}
if (inputTensor_.weight->GetDataType() == DataType::DT_HIFLOAT8 ||
inputTensor_.weight->GetDataType() == DataType::DT_FLOAT8_E4M3FN) {
return true;
}
OP_CHECK(outputTensor_.gradWeight->GetDataType() == inputTensor_.weight->GetDataType(),
OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "gradWeight data type[%s] should be equal to weight data type[%s]",
op::ToString(outputTensor_.gradWeight->GetDataType()).GetString(),
op::ToString(inputTensor_.weight->GetDataType()).GetString()), return false);
return true;
}
bool ConvolutionBackwardChecker::CheckDataTypeValidForGradBias() {
if (!Ops::NN::AclnnUtil::IsRegbase()) {
return true;
}
OP_CHECK(outputTensor_.gradBias->GetDataType() == inputTensor_.gradOutput->GetDataType(),
OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "gradBias data type[%s] should be equal to gradOutput data type[%s]",
op::ToString(outputTensor_.gradBias->GetDataType()).GetString(),
op::ToString(inputTensor_.gradOutput->GetDataType()).GetString()), return false);
return true;
}
bool ConvolutionBackwardChecker::CheckParamsValidForBpFilter8bit() {
if (inputTensor_.input->GetDataType() != DataType::DT_HIFLOAT8 &&
inputTensor_.input->GetDataType() != DataType::DT_FLOAT8_E4M3FN) {
return true;
}
if (params_.transposed) {
OP_CHECK(CheckParamsValueAllZero(params_.outputPadding),
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When transpose is true and the input data type is %s or %s, "
"the value of outputPadding[%s] must be all 0",
op::ToString(DataType::DT_HIFLOAT8).GetString(),
op::ToString(DataType::DT_FLOAT8_E4M3FN).GetString(),
AclarrayToString(params_.outputPadding).c_str()), return false);
}
OP_CHECK(params_.groups == 1,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When outputMask[1] = True, %s dtype only supports groups = 1, currently is %ld",
op::ToString(inputTensor_.input->GetDataType()).GetString(), params_.groups),
return false);
return true;
}
bool ConvolutionBackwardChecker::InterceptConvFor8bit()
{
if (params_.transposed && (*params_.outputMask)[0]) {
return true;
}
if((outputTensor_.gradInput != nullptr)&& (outputTensor_.gradWeight !=nullptr)){
if (IsConv8bit(inputTensor_.gradOutput->GetDataType()) || IsConv8bit(inputTensor_.input->GetDataType()) ||
IsConv8bit(inputTensor_.weight->GetDataType()) || IsConv8bit(outputTensor_.gradInput->GetDataType()) ||
IsConv8bit(outputTensor_.gradWeight->GetDataType())) {
OP_LOGE(
ACLNN_ERR_PARAM_INVALID,
"The dtype of DT_HIFLOAT8 or DT_FLOAT8_E4M3FN is not supported now, "
"currently gradOutput is %s, input is %s, weight is %s, gradInput is %s, "
"gradWeight is %s. ",
op::ToString(inputTensor_.gradOutput->GetDataType()).GetString(),
op::ToString(inputTensor_.input->GetDataType()).GetString(),
op::ToString(inputTensor_.weight->GetDataType()).GetString(),
op::ToString(outputTensor_.gradInput->GetDataType()).GetString(),
op::ToString(outputTensor_.gradWeight->GetDataType()).GetString());
return false;
}
}
return true;
}
bool ConvolutionBackwardChecker::IsConv8bit(const DataType& dType) const {
return dType == DataType::DT_HIFLOAT8 || dType == DataType::DT_FLOAT8_E4M3FN;
}
bool ConvolutionBackwardChecker::CheckDtypeValidFor8bit(const DataType& dType) {
bool isGradOutput8bit = inputTensor_.gradOutput->GetDataType() == dType;
bool isInput8bit = inputTensor_.input->GetDataType() == dType;
bool isWeight8bit = inputTensor_.weight->GetDataType() == dType;
bool isGradInput8bit = outputTensor_.gradInput->GetDataType() == dType;
bool is8bitFlag = isGradOutput8bit || isInput8bit || isWeight8bit || isGradInput8bit;
bool all8bitFlag = isGradOutput8bit && isInput8bit && isWeight8bit && isGradInput8bit;
if (outputTensor_.gradBias != nullptr) {
bool isGradBias8bit = outputTensor_.gradBias->GetDataType() == dType;
is8bitFlag = is8bitFlag || isGradBias8bit;
all8bitFlag = all8bitFlag && isGradBias8bit;
}
if (outputTensor_.gradBias != nullptr) {
OP_CHECK(!is8bitFlag || all8bitFlag,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When any input or output data types is %s, all of them must be %s, "
"currently gradOutput is %s, input is %s, weight is %s, gradInput is %s, gradBias is %s",
op::ToString(dType).GetString(), op::ToString(dType).GetString(),
op::ToString(inputTensor_.gradOutput->GetDataType()).GetString(),
op::ToString(inputTensor_.input->GetDataType()).GetString(),
op::ToString(inputTensor_.weight->GetDataType()).GetString(),
op::ToString(outputTensor_.gradInput->GetDataType()).GetString(),
op::ToString(outputTensor_.gradBias->GetDataType()).GetString()),
return false);
} else {
OP_CHECK(!is8bitFlag || all8bitFlag,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When any input or output data types is %s, all of them must be %s, "
"currently gradOutput is %s, input is %s, weight is %s, gradInput is %s",
op::ToString(dType).GetString(), op::ToString(dType).GetString(),
op::ToString(inputTensor_.gradOutput->GetDataType()).GetString(),
op::ToString(inputTensor_.input->GetDataType()).GetString(),
op::ToString(inputTensor_.weight->GetDataType()).GetString(),
op::ToString(outputTensor_.gradInput->GetDataType()).GetString()),
return false);
}
return true;
}
bool ConvolutionBackwardChecker::CheckDtypeValidForBpFilter8bit(const DataType& dType) {
bool isGradOutput8bit = inputTensor_.gradOutput->GetDataType() == dType;
bool isInput8bit = inputTensor_.input->GetDataType() == dType;
bool isWeight8bit = inputTensor_.weight->GetDataType() == dType;
bool isGradWeight32bit = outputTensor_.gradWeight->GetDataType() == DataType::DT_FLOAT;
bool is8bitFlag = isGradOutput8bit || isInput8bit || isWeight8bit;
bool all8bitFlag = isGradOutput8bit && isInput8bit && isWeight8bit && isGradWeight32bit;
if (outputTensor_.gradBias != nullptr) {
bool isGradBias8bit = outputTensor_.gradBias->GetDataType() == dType;
is8bitFlag = is8bitFlag || isGradBias8bit;
all8bitFlag = all8bitFlag && isGradBias8bit;
}
if (outputTensor_.gradBias != nullptr) {
OP_CHECK(!is8bitFlag || all8bitFlag,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When outputMask[1] = true, any input or gradBias data types is %s, "
"all input and gradBias data types must be %s, and gradWeight data type must be %s, "
"currently gradOutput is %s, input is %s, weight is %s, gradWeight is %s, gradBias is %s",
op::ToString(dType).GetString(), op::ToString(dType).GetString(), op::ToString(DataType::DT_FLOAT).GetString(),
op::ToString(inputTensor_.gradOutput->GetDataType()).GetString(),
op::ToString(inputTensor_.input->GetDataType()).GetString(),
op::ToString(inputTensor_.weight->GetDataType()).GetString(),
op::ToString(outputTensor_.gradWeight->GetDataType()).GetString(),
op::ToString(outputTensor_.gradBias->GetDataType()).GetString()),
return false);
} else {
OP_CHECK(!is8bitFlag || all8bitFlag,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When outputMask[1] = true, any input data types is %s, "
"all input data types must be %s, and gradWeight data type must be %s, "
"currently gradOutput is %s, input is %s, weight is %s, gradWeight is %s",
op::ToString(dType).GetString(), op::ToString(dType).GetString(), op::ToString(DataType::DT_FLOAT).GetString(),
op::ToString(inputTensor_.gradOutput->GetDataType()).GetString(),
op::ToString(inputTensor_.input->GetDataType()).GetString(),
op::ToString(inputTensor_.weight->GetDataType()).GetString(),
op::ToString(outputTensor_.gradWeight->GetDataType()).GetString()),
return false);
}
bool isGradOutputFp8 = inputTensor_.gradOutput->GetDataType() == DataType::DT_FLOAT8_E4M3FN;
bool isInputFp8 = inputTensor_.input->GetDataType() == DataType::DT_FLOAT8_E4M3FN;
bool isWeightFp8 = inputTensor_.weight->GetDataType() == DataType::DT_FLOAT8_E4M3FN;
bool isGradWeightFp8 = outputTensor_.gradWeight->GetDataType() == DataType::DT_FLOAT8_E4M3FN;
bool isGradBiasFp8 = false;
if (outputTensor_.gradBias != nullptr) {
isGradBiasFp8 = outputTensor_.gradBias->GetDataType() == DataType::DT_FLOAT8_E4M3FN;
}
bool isFp8Flag = isGradOutputFp8 || isInputFp8 || isWeightFp8 || isGradWeightFp8 || isGradBiasFp8;
OP_CHECK(!isFp8Flag,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When outputMask[1] = true, not support dataType of all input and output being DT_FLOAT8_E4M3FN now"), return false);
return true;
}
bool ConvolutionBackwardChecker::CheckConvParams(size_t inputDim) {
OP_CHECK(CheckParamsValue(params_.stride, false),
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The value of stride[%s] must be greater than or equal to 1.",
AclarrayToString(params_.stride).c_str()),
return false);
OP_CHECK(CheckParamsValue(params_.padding, true),
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The value of padding[%s] must be greater than or equal to 0.",
AclarrayToString(params_.padding).c_str()),
return false);
OP_CHECK(CheckParamsValue(params_.dilation, false),
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The value of dilation[%s] must be greater than or equal to 1.",
AclarrayToString(params_.dilation).c_str()),
return false);
if (params_.transposed) {
OP_CHECK(CheckParamsValue(params_.outputPadding, true),
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The value of outputPadding[%s] must be greater than or equal to 0 if transposed",
AclarrayToString(params_.outputPadding).c_str()),
return false);
if (inputDim == CONV3DINPUTDIM) {
for (uint64_t i = 0; i < params_.outputPadding->Size(); ++i) {
OP_CHECK((*params_.outputPadding)[i] < (*params_.stride)[i],
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The value of outputPadding[%s] should be smaller than stride[%s]",
AclarrayToString(params_.outputPadding).c_str(), AclarrayToString(params_.stride).c_str()),
return false);
}
}
} else if (params_.outputPadding != nullptr) {
for (uint64_t i = 0; i < params_.outputPadding->Size(); ++i) {
OP_CHECK((*params_.outputPadding)[i] == 0,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The value of outputPadding[%s] must be 0 if not transposed",
AclarrayToString(params_.outputPadding).c_str()),
return false);
}
}
OP_CHECK(params_.groups >= 1,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The group[%d] must be greater than or equal to 1.", params_.groups),
return false);
return true;
}
inline DataType ConvolutionBackwardChecker::CalcPromoteType() {
auto gradOutputDtype = (inputTensor_.gradOutput)->GetDataType();
auto inputDtype = (inputTensor_.input)->GetDataType();
auto weightDtype = (inputTensor_.weight)->GetDataType();
auto promoteType1 = op::PromoteType(gradOutputDtype, inputDtype);
auto promoteTypeFinal = op::PromoteType(promoteType1, weightDtype);
return promoteTypeFinal;
}
bool ConvolutionBackwardChecker::CheckCubeMathTypeConvBackward() {
auto promoteType = CalcPromoteType();
return CheckCubeMathType(promoteType, params_.cubeMathType);
}
bool ConvolutionBackwardChecker::CheckConvShape() {
auto gradOutputDim = inputTensor_.gradOutput->GetViewShape().GetDimNum();
auto inputDim = inputTensor_.input->GetViewShape().GetDimNum();
auto weightDim = inputTensor_.weight->GetViewShape().GetDimNum();
OP_CHECK(gradOutputDim == inputDim,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The dim of gradOutput and input should be equal."), return false);
OP_CHECK(inputDim == weightDim, OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The dim of input and weight should be equal."),
return false);
OP_CHECK(inputTensor_.gradOutput->Size() != 0 && inputTensor_.weight->Size() != 0,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The gradOutput and weight cannot be empty tensor."),
return false);
if (outputTensor_.gradInput != nullptr) {
OP_CHECK_SHAPE_NOT_EQUAL(inputTensor_.input, outputTensor_.gradInput, return false);
}
if (outputTensor_.gradWeight != nullptr) {
OP_CHECK_SHAPE_NOT_EQUAL(inputTensor_.weight, outputTensor_.gradWeight, return false);
}
op::Format gradOutputFormat = inputTensor_.gradOutput->GetStorageFormat();
int64_t channelOutIdx = 0;
GetChannleIndex(inputTensor_.gradOutput->GetViewShape(), gradOutputFormat, channelOutIdx);
int64_t cOut = inputTensor_.gradOutput->GetViewShape().GetDim(channelOutIdx);
if (outputTensor_.gradBias != nullptr) {
OP_CHECK(outputTensor_.gradBias->GetViewShape().GetDimNum() == 1,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The dimension of gradBias only support 1."), return false);
OP_CHECK(outputTensor_.gradBias->GetViewShape().GetDim(0) == cOut,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The gradBias shape should be equal [%ld].", cOut), return false);
}
int64_t paramsDim = inputDim - 2;
int64_t strideDim = params_.stride->Size();
OP_CHECK(strideDim == paramsDim,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When the input dimension is %ld, the dimension of stride should be %ld.",
inputDim, paramsDim),
return false);
int64_t dilationDim = params_.dilation->Size();
OP_CHECK(dilationDim == paramsDim,
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the input dimension is %ld, the dimension of dilation should be %ld.", inputDim, paramsDim),
return false);
int64_t paddingDim = params_.padding->Size();
int64_t outputPaddingDim = params_.outputPadding->Size();
if (inputDim == CONV2DINPUTDIM) {
OP_CHECK(paddingDim == paramsDim || paddingDim == CONV2DINPUTDIM,
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the input dimension is %ld, the dimension of padding should be %ld or %d.", inputDim,
paramsDim, CONV2DINPUTDIM),
return false);
OP_CHECK(outputPaddingDim == paramsDim || outputPaddingDim == CONV2DINPUTDIM,
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the input dimension is %ld, the dimension of outputPadding should be %ld or %d.",
inputDim, paramsDim, CONV2DINPUTDIM),
return false);
} else {
OP_CHECK(
paddingDim == paramsDim,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When the input dimension is %ld, the dimension of padding should be %ld.",
inputDim, paramsDim),
return false);
OP_CHECK(outputPaddingDim == paramsDim,
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the input dimension is %ld, the dimension of outputPadding should be %ld.", inputDim,
paramsDim),
return false);
}
return true;
}
bool ConvolutionBackwardChecker::CheckConvChannelAndGroup() {
op::Shape inputShape = params_.transposed ? inputTensor_.gradOutput->GetViewShape() :
inputTensor_.input->GetViewShape();
op::Shape weightShape = inputTensor_.weight->GetViewShape();
op::Shape gradOutShape = params_.transposed ? inputTensor_.input->GetViewShape() :
inputTensor_.gradOutput->GetViewShape();
int64_t inputChannelIdx = 1;
int64_t gradOutputChannelIdx = 1;
int64_t weightCoutIdx = 0;
int64_t weightCinIdx = 1;
if (Ops::NN::AclnnUtil::IsRegbase()) {
op::Format inputFormat = inputTensor_.input->GetStorageFormat();
GetChannleIndex(inputShape, inputFormat, inputChannelIdx);
op::Format gradOutputFormat = inputTensor_.gradOutput->GetStorageFormat();
GetChannleIndex(gradOutShape, gradOutputFormat, gradOutputChannelIdx);
op::Format weightFormat = inputTensor_.weight->GetStorageFormat();
GetChannleIndex(weightShape, weightFormat, weightCinIdx);
}
OP_CHECK(gradOutShape.GetDim(gradOutputChannelIdx) == weightShape.GetDim(weightCoutIdx),
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "gradOutput_channel(%ld) != weight_N_dim(%ld)",
gradOutShape.GetDim(gradOutputChannelIdx), weightShape.GetDim(weightCoutIdx)),
return false);
bool channelCheck = weightShape.GetDim(weightCinIdx) == 0 ||
inputShape.GetDim(inputChannelIdx) % weightShape.GetDim(weightCinIdx) != 0;
OP_CHECK(!channelCheck,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "input_channel(%ld) %% weight_channel(%ld) != 0",
inputShape.GetDim(inputChannelIdx), weightShape.GetDim(weightCinIdx)),
return false);
int32_t groups = inputShape.GetDim(inputChannelIdx) / weightShape.GetDim(weightCinIdx);
bool groupCheck = groups == params_.groups;
OP_CHECK(groupCheck,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "input_channel(%ld) / weight_channel(%ld) != groups(%ld)",
inputShape.GetDim(inputChannelIdx), weightShape.GetDim(weightCinIdx), params_.groups),
return false);
if (inputShape.GetDim(inputChannelIdx) == params_.groups && (!Ops::NN::AclnnUtil::IsRegbase())){
auto outChannel = gradOutShape.GetDim(inputChannelIdx);
OP_CHECK(outChannel >= params_.groups,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "when input_channel(%ld) == groups(%ld), output_channel(%ld) need bigger groups",
inputShape.GetDim(inputChannelIdx), params_.groups, outChannel),
return false);
OP_CHECK(gradOutShape.GetDim(inputChannelIdx) % params_.groups == 0,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "when input_channel(%ld) == groups(%ld), output_channel(%ld) need k times input_channel",
inputShape.GetDim(inputChannelIdx), params_.groups, outChannel),
return false);
}
return true;
}
bool ConvolutionBackwardChecker::CheckConvShapePlus() {
op::Shape inputShape = params_.transposed ? inputTensor_.gradOutput->GetViewShape() :
inputTensor_.input->GetViewShape();
op::Shape weightShape = inputTensor_.weight->GetViewShape();
op::Shape gradOutShape = params_.transposed ? inputTensor_.input->GetViewShape() :
inputTensor_.gradOutput->GetViewShape();
auto inputDim = inputShape.GetDimNum();
bool expectCheck = false;
struct ExpectValue expectValue = {};
int64_t gradOutputCout = 0;
if (inputDim == CONV3DINPUTDIM) {
if (!Ops::NN::AclnnUtil::IsRegbase()) {
int64_t depthIdx = 2;
int64_t heightIdx = 3;
int64_t widthIdx = 4;
gradOutputCout = gradOutShape.GetDim(1);
if (!CheckResolutionGEKernelShape(inputShape, weightShape, params_, depthIdx) ||
!CheckResolutionGEKernelShape(inputShape, weightShape, params_, heightIdx) ||
!CheckResolutionGEKernelShape(inputShape, weightShape, params_, widthIdx)) {
return false;
}
expectValue.doExpect = GetExpectNum(inputShape, weightShape, params_, depthIdx);
expectValue.hoExpect = GetExpectNum(inputShape, weightShape, params_, heightIdx);
expectValue.woExpect = GetExpectNum(inputShape, weightShape, params_, widthIdx);
expectCheck = expectValue.doExpect == gradOutShape.GetDim(depthIdx) &&
expectValue.hoExpect == gradOutShape.GetDim(heightIdx) &&
expectValue.woExpect == gradOutShape.GetDim(widthIdx);
} else {
OP_CHECK(GetExpectValueDHW_95(inputTensor_, params_, expectValue, inputShape, weightShape),
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "GetExpectValueDHW_95 failed."), return false);
int64_t gradOutputDVal = 0;
int64_t gradOutputHVal = 0;
int64_t gradOutputWVal = 0;
op::Format gradOutputFormat = inputTensor_.gradOutput->GetStorageFormat();
GetInputShapeSize(gradOutputFormat, gradOutShape, gradOutputDVal, gradOutputHVal, gradOutputWVal);
int64_t coutChannelIdx = 0;
GetChannleIndex(gradOutShape, gradOutputFormat, coutChannelIdx);
gradOutputCout = gradOutShape.GetDim(coutChannelIdx);;
expectCheck = (expectValue.doExpect == gradOutputDVal) && (expectValue.hoExpect == gradOutputHVal) &&
(expectValue.woExpect == gradOutputWVal);
}
OP_CHECK(expectCheck, OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"gradOutput's shape%s is not equal with inferred shape[%ld,%ld,%ld,%ld,%ld]",
op::ToString(gradOutShape).GetString(), gradOutShape.GetDim(0), gradOutputCout,
expectValue.doExpect, expectValue.hoExpect, expectValue.woExpect), return false);
}
return true;
}
inline bool ConvolutionBackwardChecker::CheckNotNull() {
OP_CHECK_NULL(inputTensor_.gradOutput, return false);
OP_CHECK_NULL(inputTensor_.input, return false);
OP_CHECK_NULL(inputTensor_.weight, return false);
OP_CHECK_NULL(params_.stride, return false);
OP_CHECK_NULL(params_.padding, return false);
OP_CHECK_NULL(params_.dilation, return false);
OP_CHECK_NULL(params_.outputPadding, return false);
OP_CHECK_NULL(params_.outputMask, return false);
int64_t outputMaskDim = params_.outputMask->Size();
OP_CHECK(outputMaskDim == 3, OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The dim of outputMask must be equal 3."),
return false);
if ((*params_.outputMask)[0]) {
OP_CHECK_NULL(outputTensor_.gradInput, return false);
}
if ((*params_.outputMask)[1]) {
OP_CHECK_NULL(outputTensor_.gradWeight, return false);
}
if ((*params_.outputMask)[2]) {
OP_CHECK_NULL(outputTensor_.gradBias, return false);
}
return true;
}
aclnnStatus ConvolutionBackwardChecker::CheckParamsFor8Bit()
{
if ((*params_.outputMask)[0] && Ops::NN::AclnnUtil::IsRegbase()) {
CHECK_RET(InterceptConvFor8bit(), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckDtypeValidFor8bit(DataType::DT_HIFLOAT8), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckDtypeValidFor8bit(DataType::DT_FLOAT8_E4M3FN), ACLNN_ERR_PARAM_INVALID);
}
if ((*params_.outputMask)[1] && Ops::NN::AclnnUtil::IsRegbase()) {
CHECK_RET(InterceptConvFor8bit(), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckDtypeValidForBpFilter8bit(DataType::DT_HIFLOAT8), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckParamsValidForBpFilter8bit(), ACLNN_ERR_PARAM_INVALID);
}
return ACLNN_SUCCESS;
}
aclnnStatus ConvolutionBackwardChecker::CheckParams() {
CHECK_RET(CheckNotNull(), ACLNN_ERR_PARAM_NULLPTR);
CHECK_RET(CheckDtypeValid(inputTensor_.gradOutput, params_.transposed), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckDtypeValid(inputTensor_.input, params_.transposed), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckDtypeValid(inputTensor_.weight, params_.transposed), ACLNN_ERR_PARAM_INVALID);
if (outputTensor_.gradInput != nullptr) {
CHECK_RET(CheckDtypeValid(outputTensor_.gradInput, params_.transposed), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckDataTypeValidForGradInput(), ACLNN_ERR_PARAM_INVALID);
}
if (outputTensor_.gradWeight != nullptr) {
CHECK_RET(CheckDtypeValid(outputTensor_.gradWeight, params_.transposed), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckDataTypeValidForGradWeight(), ACLNN_ERR_PARAM_INVALID);
}
if (outputTensor_.gradBias != nullptr) {
CHECK_RET(CheckDtypeValid(outputTensor_.gradBias, params_.transposed), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckDataTypeValidForGradBias(), ACLNN_ERR_PARAM_INVALID);
}
CHECK_RET(CheckParamsFor8Bit() == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckFormatValid(inputTensor_.gradOutput, "gradOutput"), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckFormatValid(inputTensor_.input, "input"), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckFormatValid(inputTensor_.weight, "weight"), ACLNN_ERR_PARAM_INVALID);
if (outputTensor_.gradInput != nullptr) {
CHECK_RET(CheckFormatValid(outputTensor_.gradInput, "gradInput"), ACLNN_ERR_PARAM_INVALID);
}
if (outputTensor_.gradWeight != nullptr) {
CHECK_RET(CheckFormatValid(outputTensor_.gradWeight, "gradWeight"), ACLNN_ERR_PARAM_INVALID);
}
auto curArch = GetCurrentPlatformInfo().GetCurNpuArch();
if (!params_.transposed && outputTensor_.gradBias != nullptr) {
if (curArch == NpuArch::DAV_2201 || Ops::NN::AclnnUtil::IsRegbase(curArch)) {
OP_CHECK(outputTensor_.gradBias->GetStorageFormat() == op::Format::FORMAT_ND,
OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "gradBias format only support ND, but get [%s].",
op::ToString(outputTensor_.gradBias->GetStorageFormat()).GetString()), return ACLNN_ERR_PARAM_INVALID);
}
}
CHECK_RET(CheckConvParams(inputTensor_.input->GetViewShape().GetDimNum()), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckCubeMathTypeConvBackward(), ACLNN_ERR_PARAM_INVALID);
if ((inputTensor_.input->Size() == 0 || inputTensor_.weight->Size() == 0) && (curArch == NpuArch::DAV_2201 || Ops::NN::AclnnUtil::IsRegbase(curArch))) {
return ACLNN_SUCCESS;
}
CHECK_RET(CheckConvShape(), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckConvChannelAndGroup(), ACLNN_ERR_PARAM_INVALID);
CHECK_RET(CheckConvShapePlus(), ACLNN_ERR_PARAM_INVALID);
return ACLNN_SUCCESS;
}
bool ConvolutionBackwardChecker::CheckParamsDim() {
auto inputDim = inputTensor_.input->GetViewShape().GetDimNum();
uint64_t paramsDim = inputDim - 2;
auto validDim = [inputDim, paramsDim](bool condition, const char* paramName) -> bool {
OP_CHECK(condition,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When the input dimension is %ld, the dimension of %s should be %ld.",
inputDim, paramName, paramsDim),
return false);
return true;
};
bool checkRes = validDim(params_.stride->Size() == paramsDim, "stride") &&
validDim(params_.dilation->Size() == paramsDim, "dilation") &&
validDim(params_.padding->Size() == paramsDim ||
(inputDim == CONV2DINPUTDIM && params_.padding->Size() == inputDim), "padding") &&
validDim(params_.outputPadding->Size() == paramsDim ||
(inputDim == CONV2DINPUTDIM && params_.outputPadding->Size() == inputDim), "outputPadding");
return checkRes;
}
bool ConvolutionBackwardChecker::CheckParamsGroup() {
op::Shape inputShape = params_.transposed ? inputTensor_.gradOutput->GetViewShape() : inputTensor_.input->GetViewShape();
op::Shape weightShape = inputTensor_.weight->GetViewShape();
op::Shape gradOutShape = params_.transposed ? inputTensor_.input->GetViewShape() : inputTensor_.gradOutput->GetViewShape();
auto inputChannel = inputShape.GetDim(1);
auto weightCin = weightShape.GetDim(1);
int64_t weightCout = weightShape.GetDim(0);
OP_CHECK((params_.groups != 0 && weightCout % params_.groups == 0),
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "output_channel(%ld) %% groups(%ld) != 0",
weightCout, params_.groups),
return false);
if (params_.transposed && gradOutShape[0] == 0) {
return true;
}
if (inputChannel != 0 || weightCin != 0) {
bool channelCheck = (weightCin == 0 || inputChannel % weightCin != 0 || inputChannel / weightCin != params_.groups);
OP_CHECK(!channelCheck,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "%s(%ld) need groups(%ld) times weight_channel(%ld)",
params_.transposed ? "gradOutput_channel": "inputChannel",
inputChannel, params_.groups, weightCin),
return false);
}
return true;
}
bool ConvolutionBackwardChecker::CheckShape() {
op::Shape inputShape = inputTensor_.input->GetViewShape();
op::Shape weightShape = inputTensor_.weight->GetViewShape();
size_t inputDim = inputShape.GetDimNum();
for (size_t i = 2; i < inputDim; i++) {
auto index = i - 2;
int64_t filterDimDilation = (weightShape[i] - 1) * (*params_.dilation)[index] + 1;
int64_t dimInput = inputShape.GetDim(i) + (*params_.padding)[index] * 2 - filterDimDilation;
OP_CHECK(dimInput >= 0,
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"At dimendion %ld, (in_dim(%ld) + pad_dim(%ld) * 2) should >= ((weight_shape(%ld) - 1) * dilation(%ld) + 1)",
i, inputShape.GetDim(i), (*params_.padding)[index], weightShape[i], (*params_.dilation)[index]),
return false);
}
return true;
}
bool ConvolutionBackwardChecker::CheckShapeTransposed() {
int64_t inputC = inputTensor_.input->GetViewShape().GetDim(1);
int64_t weightCo = inputTensor_.weight->GetViewShape().GetDim(0);
OP_CHECK(weightCo > 0,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "weight Cout should > 0."),
return false);
OP_CHECK(weightCo == inputC,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "weight Cout should be %ld, but get %ld", inputC, weightCo),
return false);
return true;
}
bool ConvolutionBackwardChecker::CheckShapeEmpty() {
if (!CheckParamsDim() || !CheckParamsGroup()) {
return false;
}
if (params_.transposed) {
return CheckShapeTransposed();
}
return CheckShape();
}
bool ConvolutionBackwardChecker::CheckEmptyTensor() {
if ((*params_.outputMask)[2]) {
OP_CHECK(params_.biasSizes != nullptr,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The biasSizes cannot be nullptr with empty tensor calculation."),
return false);
OP_CHECK(params_.biasSizes->Size() == 1,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The biasSizes size must be 1, actually is %ld.",
params_.biasSizes->Size()),
return false);
int64_t channelOutDim = 1;
int64_t Cout = inputTensor_.gradOutput->GetViewShape().GetDim(channelOutDim);
if (!params_.transposed) {
OP_CHECK((*params_.biasSizes)[0] == Cout,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The biasSizes should be equal %ld.", Cout), return false);
} else {
auto size = inputTensor_.weight->GetViewShape().GetDim(1) * params_.groups;
OP_CHECK((*params_.biasSizes)[0] == size,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "When transposed=true, the biasSizes should be equal %ld.", size), return false);
}
}
if ((*params_.outputMask)[0]) {
auto inputShape = inputTensor_.input->GetViewShape();
OP_CHECK(inputShape.GetDimNum() > 2,
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "The input shape must be greater than 2, but now is %ld",
inputShape.GetDimNum()),
return false);
OP_CHECK(inputShape[0] == 0 || inputShape[1] == 0,
OP_LOGE(ACLNN_ERR_PARAM_INVALID,
"When the input tensors contain an empty tensor, aclnnConvolutionBackward only support zero batch or zero channel with input, but got input shape is %s",
op::ToString(inputShape).GetString()),
return false);
}
auto curArch = GetCurrentPlatformInfo().GetCurNpuArch();
if (curArch == NpuArch::DAV_2201 || Ops::NN::AclnnUtil::IsRegbase(curArch)) {
return CheckShapeEmpty();
}
return true;
}
}
}
}
#ifdef __cplusplus
}
#endif