/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file aclnn_convolution.cpp
 * \brief
 */

#include "aclnn_convolution.h"
#include "convolution_util.h"

#include <map>
#include <memory>
#include <vector>
#include <string>

#include "aclnn/aclnn_base.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "op_common/log/log.h"
#include "opdev/platform.h"
#include "opdev/tensor_view_utils.h"

#include "level0/add.h"
#include "level0/broadcast_to.h"
#include "aclnn_kernels/cast.h"
#include "aclnn_kernels/contiguous.h"
#include "convolution.h"
#include "level0/padv3.h"
#include "aclnn_kernels/reshape.h"
#include "level0/squeeze.h"
#include "aclnn_kernels/transdata.h"
#include "aclnn_kernels/transpose.h"
#include "level0/unsqueeze.h"
#include "matmul/common/op_host/op_api/cube_util.h"
#include "../../../../matmul/common/op_host/op_api/matmul_util.h"
#include "../../../../matmul/common/op_host/op_api/batch_matmul_util.h"

using namespace op;
using namespace ge;
using namespace l0op;
using namespace ConvolutionUtil;
using namespace Ops::NN;

namespace op {
static inline ge::AscendString ToString(const std::int64_t value)
{
    return ge::AscendString(std::to_string(value).c_str());
}
} // namespace op

namespace op {
static constexpr int64_t specialStride = 63;
static constexpr int64_t specialChannelIndex = 3;
static constexpr int64_t SMALL_CHANNEL = 4;
static constexpr int64_t CONV2D_SHAPE_SIZE = 4;
static constexpr int64_t CONV3D_SHAPE_SIZE = 5;
static const std::string REFLECTION_MODE = "constant";

// 根据API定义,需要列出所能支持的所有dtype
static constexpr const std::initializer_list<op::DataType> BIAS_SUPPORT_LIST = {
    op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16};

static constexpr const std::initializer_list<op::DataType> BIAS_SUPPORT_LIST_ASCEND310P = {
    op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16};
} // namespace op
/**
 * --------------------------------------L0函数注册机制start------------------------------------------------
 * 以下逻辑支持将L0函数注册到一个map里,在各convXX类中实例化这个map
 * L0FUNCTION类型代表通用的L0函数定义,作为map的value类型, 逻辑上相当于一个占位符
 * XXX_FUNCTION类型代表不同L0类别的函数指针
 * GET_FUNC_ID宏通过输入输出的类别、format确定一个唯一的ID,作为函数map的key
 * REG_L0_FUNCTION宏,可将function注册进map
 * FUNCTION_CALL进行实际的函数调用,此处调用了L0函数的适配器ConvL0Warper
 */
using L0FUNCTION = void (*)();

using CONV_FUNCTION = const aclTensor* (*)(const aclTensor* input, const aclTensor* weight, const aclTensor* bias,
                                           const aclIntArray* stride, const aclIntArray* padding,
                                           const aclIntArray* dilation, int groups, aclOpExecutor* executor);

using CONV_WITHFLAG_FUNCTION = const aclTensor* (*)(const aclTensor* input, const aclTensor* weight,
                                                    const aclTensor* bias, const aclIntArray* stride,
                                                    const aclIntArray* padding, const aclIntArray* dilation, int groups,
                                                    bool useHf32, aclOpExecutor* executor);

using CONVTRANSPOSE_FUNCTION = const aclTensor* (*)(const aclTensor* input, const aclTensor* weight,
                                                    const aclTensor* bias, const aclIntArray* stride,
                                                    const aclIntArray* padding, const aclIntArray* dilation, int groups,
                                                    const aclIntArray* outputPadding, aclOpExecutor* executor);

using CONVTRANSPOSE_WITHFLAG_FUNCTION = const aclTensor* (*)(const aclTensor* input, const aclTensor* weight,
                                                             const aclTensor* bias, const aclIntArray* stride,
                                                             const aclIntArray* padding, const aclIntArray* dilation,
                                                             int groups, const aclIntArray* outputPadding, bool useHf32,
                                                             aclOpExecutor* executor);

using CONV_WITHDTYPE_FUNCTION = const aclTensor* (*)(const aclTensor* input, const aclTensor* weight,
                                                     const aclTensor* bias, op::DataType outputDtype,
                                                     const aclIntArray* stride, const aclIntArray* padding,
                                                     const aclIntArray* dilation, int groups, bool useHf32,
                                                     aclOpExecutor* executor);

namespace op {
std::string CharToString(const char* a)
{
    return std::string(a);
}
} // namespace op

#define GET_FUNC_ID(inputDtype, inputFormat, outputDtype, outputFormat)                                         \
    (CharToString(op::ToString(inputDtype).GetString()) + CharToString(op::ToString(inputFormat).GetString()) + \
     CharToString(op::ToString(outputDtype).GetString()) + CharToString(op::ToString(outputFormat).GetString()))

#define REG_L0_FUNCTION(map, function, inputDtype, inputFormat, outputDtype, outputFormat) \
    ((map).emplace(                                                                        \
        (GET_FUNC_ID((inputDtype), (inputFormat), (outputDtype), (outputFormat))), (L0FUNCTION(&(function)))))

namespace op {

static const aclTensor* ConvL0Warper(
    std::map<std::string, L0FUNCTION> l0Functions, ConvolutionOpInfo& opInfo, const aclTensor* input,
    const aclTensor* weight, const aclTensor* bias, const aclIntArray* stride, const aclIntArray* padding,
    const aclIntArray* dilation, const bool transposed, const aclIntArray* outputPadding, const int64_t groups,
    bool useHf32, aclOpExecutor* executor)
{
    const aclTensor* result = nullptr;

    std::string funcId = GET_FUNC_ID(opInfo.inputDtype, opInfo.inputFormat, opInfo.outputDtype, opInfo.outputFormat);
    if (l0Functions.find(funcId) == l0Functions.end()) {
        OP_LOGE(
            ACLNN_ERR_PARAM_INVALID,
            "Not support the given data type and format combination: "
            "inputDtype: %s, outputDtype: %s, inputFormat:%s, outputFormat:%s",
            op::ToString(opInfo.inputDtype).GetString(), op::ToString(opInfo.outputDtype).GetString(),
            op::ToString(opInfo.inputFormat).GetString(), op::ToString(opInfo.outputFormat).GetString());

        return result;
    }

    L0FUNCTION fn = l0Functions.at(funcId);

    OP_LOGI("The opInfo.inputDtype is %s", op::ToString(opInfo.inputDtype).GetString());
    if (opInfo.inputDtype == op::DataType::DT_FLOAT16 || opInfo.inputDtype == op::DataType::DT_BF16 ||
        opInfo.inputDtype == op::DataType::DT_HIFLOAT8 || opInfo.inputDtype == op::DataType::DT_FLOAT8_E4M3FN) {
        if (!transposed) {
            result =
                (reinterpret_cast<CONV_FUNCTION>(fn))(input, weight, bias, stride, padding, dilation, groups, executor);
        } else {
            result = (reinterpret_cast<CONVTRANSPOSE_FUNCTION>(fn))(
                input, weight, bias, stride, padding, dilation, groups, outputPadding, executor);
        }
    } else {
        if (!transposed) {
            result = (reinterpret_cast<CONV_WITHFLAG_FUNCTION>(fn))(
                input, weight, bias, stride, padding, dilation, groups, useHf32, executor);
        } else {
            result = (reinterpret_cast<CONVTRANSPOSE_WITHFLAG_FUNCTION>(fn))(
                input, weight, bias, stride, padding, dilation, groups, outputPadding, useHf32, executor);
        }
    }
    return result;
}

#define FUNCTION_CALL(                                                                                               \
    l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, useHf32, \
    executor)                                                                                                        \
    ConvL0Warper(                                                                                                    \
        l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,      \
        useHf32, executor)
} // namespace op

#define REG_L0_FUNCTION_BY_OPTYPE(map, function, opType) ((map).emplace(opType, (L0FUNCTION(&(function)))))

namespace op {

static bool IsSupportND()
{
    return GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510;
}

static const aclTensor* L0FuncWarperByOpType(
    std::map<std::string, L0FUNCTION> l0Functions, std::string functionType, const aclTensor* input,
    const aclTensor* weight, const aclTensor* bias, op::DataType outputDtype, const aclIntArray* stride,
    const aclIntArray* padding, const aclIntArray* dilation, const bool transposed, const int64_t groups, bool useHf32,
    aclOpExecutor* executor)
{
    const aclTensor* result = nullptr;
    if (l0Functions.find(functionType) == l0Functions.end()) {
        OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "no matched L0 function");
        return result;
    }
    L0FUNCTION fn = l0Functions.at(functionType);
    if (op::IsSupportND() && input->GetViewShape().GetDimNum() == op::CONV2D_SHAPE_SIZE && !transposed) {
        result = (reinterpret_cast<CONV_WITHDTYPE_FUNCTION>(fn))(
            input, weight, bias, outputDtype, stride, padding, dilation, groups, useHf32, executor);
    }
    return result;
}

#define FUNCTION_CALL_BY_OPTYPE(                                                                                       \
    l0Functions, functionType, input, weight, bias, outputDtype, stride, padding, dilation, transposed, outputPadding, \
    groups, useHf32, executor)                                                                                         \
    L0FuncWarperByOpType(                                                                                              \
        l0Functions, functionType, input, weight, bias, outputDtype, stride, padding, dilation, transposed, groups,    \
        useHf32, executor)

} // namespace op

/* --------------------------------------L0函数注册机制end------------------------------------------------ */

/* --------------------------------------公共check能力start------------------------------------------------ */

namespace {

template <typename T>
static inline bool Equal(T a, T b)
{
    return a == b;
}

template <typename T>
static inline bool LessEqual(T a, T b)
{
    return a <= b;
}

template <typename T, typename Func>
static inline bool Any([[maybe_unused]] T value, [[maybe_unused]] Func f)
{
    return false;
}
// 参数仅需满足任一参数列表判断
template <typename T, typename Func, typename... LIST>
static inline bool Any(T value, Func f, T compare, LIST... list)
{
    bool result = f(value, compare);
    if (!result) {
        return Any(value, f, list...);
    }
    return true;
}

// 参数需要满足所有参数列表判断
template <typename T, typename Func>
static inline bool All([[maybe_unused]] T value, [[maybe_unused]] Func f)
{
    return true;
}
template <typename T, typename Func, typename... LIST>
static inline bool All(T value, Func f, T compare, LIST... list)
{
    bool result = f(value, compare);
    if (result) {
        return All(value, f, list...);
    }
    return false;
}
} // namespace

namespace {
struct ConvParams {
    const aclTensor* input;
    const aclTensor* weight;
    const aclTensor* bias;
    const aclIntArray* stride;
    const aclIntArray* padding;
    const aclIntArray* dilation;
    const bool transposed;
    const aclIntArray* outputPadding;
    const int64_t groups;
    aclTensor* output;
    int8_t cubeMathType;
    uint64_t* workspaceSize;
    aclOpExecutor** executor;
};

// Conv1d, 2d, 3d
constexpr size_t CONV_1D_DIM_SIZE = 3;
constexpr size_t CONV_2D_DIM_SIZE = 4;
constexpr size_t CONV_3D_DIM_SIZE = 5;
constexpr size_t CONST_VALUE_TWO = 2;
constexpr size_t CONST_VALUE_FOUR = 4;
static constexpr uint64_t MAX_UINT16 = 65536;

struct TensorMeta {
public:
    op::Format format;
    op::DataType dataType;
    FVector<int64_t> shape;
    op::Shape tensorShape;
    TensorMeta() = default;
    void SetFromTensor(const aclTensor* tensor)
    {
        if (tensor == nullptr) {
            return;
        }
        dataType = tensor->GetDataType();
        format = tensor->GetViewFormat();
        tensorShape = tensor->GetViewShape();
        shape = op::ToFVector(tensorShape);
        string formatStr = op::ToString(format).GetString();

        // 未定义的shape,默认为1,实际不会被用到,将所哟类型包括一起查询,未找到的为npos
        auto len = shape.size();
        auto npos = formatStr.npos;
        auto index = formatStr.find('N');
        nIdx_ = index;
        n_ = (index == npos || index >= len) ? 1 : shape[index];

        index = formatStr.find('C');
        cIdx_ = index;
        c_ = (index == npos || index >= len) ? 1 : shape[index];

        index = formatStr.find('D');
        dIdx_ = index;
        d_ = (index == npos || index >= len) ? 1 : shape[index];

        index = formatStr.find('H');
        hIdx_ = index;
        h_ = (index == npos || index >= len) ? 1 : shape[index];

        index = formatStr.find('W');
        wIdx_ = index;
        w_ = (index == npos || index >= len) ? 1 : shape[index];

        index = formatStr.find('L');
        lIdx_ = index;
        l_ = (index == npos || index >= len) ? 1 : shape[index];

        // formatStr.endswith('C')
        channelLast_ = formatStr.find('C') == formatStr.length() - 1;
        isZeroTensor_ = tensor->IsEmpty();
    }
    explicit TensorMeta(const aclTensor* tensor)
    {
        this->SetFromTensor(tensor);
    }
    // candy access functions
    int64_t N() const
    {
        return n_;
    }
    int64_t C() const
    {
        return c_;
    }
    int64_t D() const
    {
        return d_;
    }
    int64_t H() const
    {
        return h_;
    }
    int64_t W() const
    {
        return w_;
    }
    int64_t L() const
    {
        return l_;
    }
    size_t NIdx() const { return nIdx_; }
    size_t CIdx() const { return cIdx_; }
    size_t DIdx() const { return dIdx_; }
    size_t HIdx() const { return hIdx_; }
    size_t WIdx() const { return wIdx_; }
    size_t LIdx() const { return lIdx_; }
    bool ChannelLast() const
    {
        return channelLast_;
    }
    bool IsZeroTensor() const
    {
        return isZeroTensor_;
    }

private:
    int64_t n_ = 0;
    int64_t c_ = 0;
    int64_t d_ = 0;
    int64_t h_ = 0;
    int64_t w_ = 0;
    int64_t l_ = 0;
    size_t nIdx_ = 0;
    size_t cIdx_ = 0;
    size_t dIdx_ = 0;
    size_t hIdx_ = 0;
    size_t wIdx_ = 0;
    size_t lIdx_ = 0;
    bool channelLast_ = false;
    bool isZeroTensor_ = false;
};
} // namespace

namespace {

op::DataType GetUpperFloatDataType(op::DataType a, op::DataType b)
{
    OP_LOGD("The input dtype is %s and %s", op::ToString(a).GetString(), op::ToString(b).GetString());
    if (a == op::DataType::DT_DOUBLE || b == op::DataType::DT_DOUBLE) {
        return op::DataType::DT_DOUBLE;
    }
    if (a == op::DataType::DT_FLOAT || b == op::DataType::DT_FLOAT) {
        return op::DataType::DT_FLOAT;
    }
    if (a == op::DataType::DT_BF16 && b == op::DataType::DT_BF16) {
        return op::DataType::DT_BF16;
    }
    if (a == op::DataType::DT_FLOAT16 && b == op::DataType::DT_FLOAT16) {
        return op::DataType::DT_FLOAT16;
    }
    if (a == op::DataType::DT_HIFLOAT8 && b == op::DataType::DT_HIFLOAT8) {
        return op::DataType::DT_HIFLOAT8;
    }
    if (a == op::DataType::DT_FLOAT8_E4M3FN && b == op::DataType::DT_FLOAT8_E4M3FN) {
        return op::DataType::DT_FLOAT8_E4M3FN;
    }

    return op::DataType::DT_FLOAT; // 注意,原则上a,b都是float类型,若不是,则默认用FP32计算
}

struct ConvMeta {
public:
    TensorMeta input;
    TensorMeta weight;
    TensorMeta bias;
    TensorMeta output;
    // stride、dilation 按照空间分布,3维DHW,2维HW,1维L
    FVector<int64_t> stride;
    FVector<int64_t> dilation;
    // padding outputpadding 按照方向维度分布,3维3个值,代表前后、上下、左右,2维度上下、左右,1维度左右
    FVector<int64_t> padding;
    FVector<int64_t> outputPadding;
    op::DataType calculatDataType;
    void FromParams(ConvParams& params)
    {
        input.SetFromTensor(params.input);
        weight.SetFromTensor(params.weight);
        output.SetFromTensor(params.output);
        if (params.bias) {
            bias.format = params.bias->GetViewFormat();
            bias.dataType = params.bias->GetDataType();
            bias.tensorShape = params.bias->GetViewShape();
            bias.shape = op::ToFVector(bias.tensorShape);
        }

        stride = ToVector(params.stride);
        dilation = ToVector(params.dilation);
        padding = ToVector(params.padding);
        if (params.transposed) {
            outputPadding = ToVector(params.outputPadding);
        }
        calculatDataType = GetUpperFloatDataType(input.dataType, weight.dataType);
    }

private:
    FVector<int64_t> ToVector(const aclIntArray* array) const
    {
        FVector<int64_t> v;
        if (array != nullptr) {
            for (uint64_t i = 0; i < array->Size(); ++i) {
                v.push_back((*array)[i]);
            }
        }
        return v;
    }
};

constexpr size_t CONV_1D_PAD_DIM = 1;
constexpr size_t CONV_2D_PAD_DIM = 2;
constexpr size_t CONV_4D_PAD_DIM = 4;
constexpr size_t PAD_TOP_INDEX = 0;
constexpr size_t PAD_BOTTOM_INDEX = 1;
constexpr size_t PAD_LEFT_INDEX = 2;
constexpr size_t PAD_RIGHT_INDEX = 3;
constexpr size_t CONVTBC_L_INDEX = 2;
constexpr size_t CONVTBC_C_INDEX = 1;
namespace {

// 本函数的目的是给conv1d制造1维的pad数组,给conv2d制造2维的pad数组,其他类型的conv保留原数组不变
static FVector<int64_t> ConstructPad(FVector<int64_t>& oldPad, FVector<int64_t>& inputShape)
{
    FVector<int64_t> newPad;
    if (inputShape.size() == CONV_1D_DIM_SIZE) {
        if (oldPad.size() == 1) {
            newPad = {oldPad[0] + oldPad[0]};
        } else if (oldPad.size() == CONV_2D_PAD_DIM) {
            newPad = {oldPad[0] + oldPad[1]};
        } else {
            newPad = {0};
        }
    } else if (inputShape.size() == CONV_2D_DIM_SIZE) {
        if (oldPad.size() == CONV_2D_PAD_DIM) {
            newPad = {(oldPad[0] + oldPad[0]), (oldPad[1] + oldPad[1])};
        } else if (oldPad.size() == CONV_4D_PAD_DIM) {
            newPad = {
                (oldPad[PAD_TOP_INDEX] + oldPad[PAD_BOTTOM_INDEX]), (oldPad[PAD_LEFT_INDEX] + oldPad[PAD_RIGHT_INDEX])};
        } else {
            newPad = {0, 0};
        }
    } else {
        return oldPad;
    }
    return newPad;
}

struct ConvEngine {
public:
    ConvParams params;
    ConvMeta meta;
    std::string entityName;
    explicit ConvEngine(ConvParams& convParams) : params(convParams)
    {
        meta.FromParams(params);
    }
    FVector<int64_t> CalcOutputShape()
    {
        return InferShape();
    }

private:
    FVector<int64_t> InferShape()
    {
        FVector<int64_t> output;
        FVector<int64_t> inputShape = meta.input.shape;
        int64_t inputSpaceDimIndex =
            meta.input.ChannelLast() ? 1 : 2;                   // 空间维度在shape中的起始位置,C维度后置时为1,否则为2
        int64_t inputSpaceDimNum = meta.input.shape.size() - 2; // 空间维度大小,1d卷积时为1,2d为2,3d为3
        FVector<int64_t> weightShape = meta.weight.shape;
        int64_t weightSpaceDimIndex =
            meta.weight.ChannelLast() ? 1 : 2; // 空间维度在shape中的起始位置,C维度后置时为1,否则为2
        // step 1: put nOut in the first place of shape; for conv and transpose mode
        output.push_back(meta.input.N());
        int64_t cOut = meta.weight.N();
        // step 2: calc spaceDim size and push back to shape
        if (!params.transposed) {
            if (inputShape.size() == CONV_1D_DIM_SIZE || inputShape.size() == CONV_2D_DIM_SIZE) {
                auto newPad = ConstructPad(meta.padding, inputShape);
                for (int64_t i = 0; i < inputSpaceDimNum; ++i) {
                    int64_t xOut = (inputShape[i + inputSpaceDimIndex] + newPad[i] -
                                    meta.dilation[i] * (weightShape[i + weightSpaceDimIndex] - 1) - 1) /
                                       meta.stride[i] +
                                   1;
                    output.push_back(xOut);
                }
            } else {
                for (int64_t i = 0; i < inputSpaceDimNum; ++i) {
                    int64_t xOut = (inputShape[i + inputSpaceDimIndex] + CONV_2D_PAD_DIM * meta.padding[i] -
                                    meta.dilation[i] * (weightShape[i + weightSpaceDimIndex] - 1) - 1) /
                                       meta.stride[i] +
                                   1;
                    output.push_back(xOut);
                }
            }
        } else {
            cOut = meta.weight.C() * params.groups;
            if (inputShape.size() == CONV_2D_DIM_SIZE) {
                auto newPad = ConstructPad(meta.padding, inputShape);
                for (int64_t i = 0; i < inputSpaceDimNum; ++i) {
                    int64_t xOut = (inputShape[i + inputSpaceDimIndex] - 1) * meta.stride[i] - newPad[i] +
                                   meta.dilation[i] * (weightShape[i + weightSpaceDimIndex] - 1) +
                                   meta.outputPadding[i] + 1;
                    output.push_back(xOut);
                }
            } else {
                for (int64_t i = 0; i < inputSpaceDimNum; ++i) {
                    int64_t xOut = (inputShape[i + inputSpaceDimIndex] - 1) * meta.stride[i] - 2 * meta.padding[i] +
                                   meta.dilation[i] * (weightShape[i + weightSpaceDimIndex] - 1) +
                                   meta.outputPadding[i] + 1;
                    output.push_back(xOut);
                }
            }
        }
        // last step : put cOut in right place
        if (meta.input.ChannelLast()) {
            output.push_back(cOut);
        } else {
            output.insert(output.begin() + 1, cOut);
        }
        return output;
    }
};

} // namespace

class ConvolutionChecker {
public:
    ConvolutionChecker() = default;
    virtual ~ConvolutionChecker() = default;
    virtual aclnnStatus Check(ConvEngine& engine) = 0;
};

} // namespace

namespace {

static bool CheckConvParamsDtype(const std::string& entityName, const DataType& currDtype, const string& dtypeStr,
    const std::initializer_list<op::DataType>& dtypeSupportList)
{
    for (auto item : dtypeSupportList) {
        if (currDtype == item) {
            return true;
        }
    }
    string dtypeSupportListStr = "[";
    uint32_t cnt = 0;
    for (auto item : dtypeSupportList) {
        dtypeSupportListStr += op::ToString(item).GetString();
        if (cnt != dtypeSupportList.size() - 1) {
            dtypeSupportListStr += ", ";
        } else {
            dtypeSupportListStr += "]";
        }
        cnt++;
    }
    std::string correctDtype = "one of " + dtypeSupportListStr;
    OP_LOGE_FOR_INVALID_DTYPE(entityName, dtypeStr, GeDtypeToString(currDtype), correctDtype);
    return false;
}

} // namespace

/* --------------------------------------公共check能力end------------------------------------------------ */

static const std::initializer_list<op::DataType>& GetBiasDtypeSupportListBySocVersion()
{
    SocVersion socVersion = GetCurrentPlatformInfo().GetSocVersion();
    if (socVersion == SocVersion::ASCEND310P) {
        return op::BIAS_SUPPORT_LIST_ASCEND310P;
    }

    return op::BIAS_SUPPORT_LIST;
}

static bool CheckPointWise(const aclIntArray* array, int64_t value)
{
    for (uint64_t i = 0; i < array->Size(); ++i) {
        if ((*array)[i] != value) {
            return false;
        }
    }
    return true;
}

static bool NeedPointWiseKernel(
    const aclTensor* weight, const aclIntArray* stride, const aclIntArray* padding, const aclIntArray* dilation,
    const int64_t groups)
{
    if (groups != 1) {
        return false;
    }
    if (!CheckPointWise(dilation, 1) || !CheckPointWise(stride, 1) || !CheckPointWise(padding, 0)) {
        return false;
    }

    auto weightShape = weight->GetViewShape();
    size_t dimNum = weightShape.GetDimNum();
    for (size_t idx = CONST_VALUE_TWO; idx < dimNum; ++idx) {
        if (weightShape.GetDim(idx) != 1) {
            return false;
        }
    }
    return true;
}

static bool PointWiseKernelBeyondLimits(const aclTensor* fmap)
{
    auto fmapShape = fmap->GetViewShape();
    uint64_t dihiwi = 1;
    for (size_t idx = CONST_VALUE_TWO; idx < CONV_3D_DIM_SIZE; ++idx) {
        dihiwi = dihiwi * fmapShape.GetDim(idx);
    }
    return dihiwi >= MAX_UINT16;
}

namespace {

class DimChecker : public ConvolutionChecker {
public:
    DimChecker() = default;
    ~DimChecker() override = default;
    aclnnStatus CheckDim(const std::string& entityName, const string& inStr, size_t inDim) const
    {
        if (inDim != CONV_1D_DIM_SIZE && inDim != CONV_2D_DIM_SIZE && inDim != CONV_3D_DIM_SIZE) {
            OP_LOGE_FOR_INVALID_SHAPEDIM(entityName, inStr, std::to_string(inDim), "one of {" +
                std::to_string(CONV_1D_DIM_SIZE) + ", " + std::to_string(CONV_2D_DIM_SIZE) + ", " +
                std::to_string(CONV_3D_DIM_SIZE) + "}");
            return ACLNN_ERR_PARAM_INVALID;
        }
        return ACLNN_SUCCESS;
    }

    aclnnStatus CheckTensorDimensions(const ConvEngine& engine) const
    {
        size_t inputDim = engine.meta.input.shape.size();
        aclnnStatus ret = CheckDim(engine.entityName, "x", inputDim);
        if (ret != ACLNN_SUCCESS) {
            return ret;
        }

        size_t weightDim = engine.meta.weight.shape.size();
        ret = CheckDim(engine.entityName, "filter", weightDim);
        if (ret != ACLNN_SUCCESS) {
            return ret;
        }

        size_t outputDim = engine.meta.output.shape.size();
        ret = CheckDim(engine.entityName, "y", outputDim);
        if (ret != ACLNN_SUCCESS) {
            return ret;
        }

        if (weightDim != inputDim) {
            std::string reason = "the shape dims of x and filter must be the same";
            OP_LOGE_FOR_INVALID_SHAPEDIMS_WITH_REASON(engine.entityName, "x, filter",
                std::to_string(inputDim) + ", " + std::to_string(weightDim), reason);
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (outputDim != inputDim) {
            std::string reason = "the shape dims of x and y must be the same";
            OP_LOGE_FOR_INVALID_SHAPEDIMS_WITH_REASON(engine.entityName, "x, y",
                std::to_string(inputDim) + ", " + std::to_string(outputDim), reason);
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    }

    aclnnStatus CheckBiasDimensions(ConvEngine& engine) const
    {
        if (engine.params.bias == nullptr) {
            return ACLNN_SUCCESS;
        }

        size_t biasDim = engine.meta.bias.shape.size();
        size_t biasSize = 0;
        size_t groupsValue = engine.params.groups;
        size_t weightNValue = engine.meta.weight.N();
        size_t weightCValue = engine.meta.weight.C();
        if (biasDim > static_cast<size_t>(0)) {
            biasSize = engine.meta.bias.shape[0];
        }

        // bias的维度数必须为1维
        if (biasDim != 1) {
            OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "bias", std::to_string(biasDim), "1");
            return ACLNN_ERR_PARAM_INVALID;
        }
        // 如果是transpose场景, bias的维度大小必须为 weight C * groups
        if (engine.params.transposed && biasSize != weightCValue * groupsValue) {
            OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "bias, filter",
                op::FVectorToString(engine.meta.bias.shape) + ", " +
                op::FVectorToString(engine.meta.weight.shape),
                "When transposed is true, shape[0] of bias must be equal to shape[" +
                std::to_string(engine.meta.weight.CIdx()) + "] of filter multiplied by attribute groups("+
                std::to_string(groupsValue) + ")");
            return ACLNN_ERR_PARAM_INVALID;
        }

        // 如果是非transpose场景, bias的维度大小必须为 weight N
        if (!engine.params.transposed && biasSize != weightNValue) {
            OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "bias, filter",
                op::FVectorToString(engine.meta.bias.shape) + ", " +
                op::FVectorToString(engine.meta.weight.shape),
                "When transposed is false, shape[0] of bias must be equal to shape[" +
                std::to_string(engine.meta.weight.NIdx()) + "] of filter");
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    }

    aclnnStatus CheckParameterSizes(ConvEngine& engine) const
    {
        size_t inputDim = engine.meta.input.shape.size();

        auto strideSize = engine.meta.stride.size();
        if (strideSize != inputDim - CONST_VALUE_TWO) {
            OP_LOGE_FOR_INVALID_LISTSIZE(engine.entityName, "strides", std::to_string(strideSize),
                std::to_string(inputDim - CONST_VALUE_TWO));
            return ACLNN_ERR_PARAM_INVALID;
        }

        auto dilationSize = engine.meta.dilation.size();
        if (dilationSize != inputDim - CONST_VALUE_TWO) {
            OP_LOGE_FOR_INVALID_LISTSIZE(engine.entityName, "dilations", std::to_string(dilationSize),
                std::to_string(inputDim - CONST_VALUE_TWO));
            return ACLNN_ERR_PARAM_INVALID;
        }

        auto paddingSize = engine.meta.padding.size();
        if (((inputDim == CONV_1D_DIM_SIZE || inputDim == CONV_2D_DIM_SIZE) && !engine.params.transposed) ||
            (inputDim == CONV_2D_DIM_SIZE && engine.params.transposed)) {
            if (!Any(paddingSize, Equal<size_t>, inputDim - CONST_VALUE_TWO, inputDim * CONST_VALUE_TWO - CONST_VALUE_FOUR)) {
                OP_LOGE_FOR_INVALID_LISTSIZE(engine.entityName, "pads", std::to_string(paddingSize),
                    std::to_string(inputDim - CONST_VALUE_TWO) + " or " + std::to_string(inputDim * CONST_VALUE_TWO - CONST_VALUE_FOUR));
                return ACLNN_ERR_PARAM_INVALID;
            }
        } else {
            if (paddingSize != inputDim - CONST_VALUE_TWO) {
                OP_LOGE_FOR_INVALID_LISTSIZE(engine.entityName, "pads", std::to_string(paddingSize),
                    std::to_string(inputDim - CONST_VALUE_TWO));
                return ACLNN_ERR_PARAM_INVALID;
            }
        }

        if (engine.params.transposed) {
            auto outputPaddingSize = engine.meta.outputPadding.size();
            if (outputPaddingSize != inputDim - CONST_VALUE_TWO) {
                OP_LOGE_FOR_INVALID_LISTSIZE(engine.entityName, "outputPadding", std::to_string(outputPaddingSize),
                    std::to_string(inputDim - CONST_VALUE_TWO));
                return ACLNN_ERR_PARAM_INVALID;
            }
        }

        return ACLNN_SUCCESS;
    }

    aclnnStatus Check(ConvEngine& engine) override
    {
        aclnnStatus ret = CheckTensorDimensions(engine);
        if (ret != ACLNN_SUCCESS) {
            return ret;
        }

        ret = CheckBiasDimensions(engine);
        if (ret != ACLNN_SUCCESS) {
            return ret;
        }

        return CheckParameterSizes(engine);
    };
};

class DimCheckerTbc : public ConvolutionChecker {
public:
    DimCheckerTbc() = default;
    ~DimCheckerTbc() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        size_t inputDim = engine.meta.input.shape.size();
        if (!Any(inputDim, Equal<size_t>, CONV_1D_DIM_SIZE)) {
            OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "x", std::to_string(inputDim),
                std::to_string(CONV_1D_DIM_SIZE));
            return ACLNN_ERR_PARAM_INVALID;
        }

        size_t weightDim = engine.meta.weight.shape.size();
        if (!Any(weightDim, Equal<size_t>, CONV_1D_DIM_SIZE)) {
            OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "filter", std::to_string(weightDim),
                std::to_string(CONV_1D_DIM_SIZE));
            return ACLNN_ERR_PARAM_INVALID;
        }

        size_t outputDim = engine.meta.output.shape.size();
        if (!Any(outputDim, Equal<size_t>, CONV_1D_DIM_SIZE)) {
            OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "y", std::to_string(outputDim),
                std::to_string(CONV_1D_DIM_SIZE));
            return ACLNN_ERR_PARAM_INVALID;
        }

        constexpr size_t biasDimAllowTbc = 1;
        size_t biasDim = engine.meta.bias.shape.size();
        if (!Any(biasDim, Equal<size_t>, biasDimAllowTbc)) {
            OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "bias", std::to_string(biasDim),
                std::to_string(biasDimAllowTbc));
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    };
};

class DimCheckerDepthwise2d : public ConvolutionChecker {
public:
    DimCheckerDepthwise2d() = default;
    ~DimCheckerDepthwise2d() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        size_t inputDim = engine.meta.input.shape.size();
        if (!Any(inputDim, Equal<size_t>, CONV_2D_DIM_SIZE)) {
            OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "x", std::to_string(inputDim),
                std::to_string(CONV_2D_DIM_SIZE));
            return ACLNN_ERR_PARAM_INVALID;
        }

        size_t weightDim = engine.meta.weight.shape.size();
        if (!Any(weightDim, Equal<size_t>, CONV_2D_DIM_SIZE)) {
            OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "filter", std::to_string(weightDim),
                std::to_string(CONV_2D_DIM_SIZE));
            return ACLNN_ERR_PARAM_INVALID;
        }

        size_t outputDim = engine.meta.output.shape.size();
        if (!Any(outputDim, Equal<size_t>, CONV_2D_DIM_SIZE)) {
            OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "y", std::to_string(outputDim),
                std::to_string(CONV_2D_DIM_SIZE));
            return ACLNN_ERR_PARAM_INVALID;
        }

        if (engine.params.bias != nullptr) {
            size_t biasDim = engine.meta.bias.shape.size();
            constexpr size_t biasDimAllow = 1;
            if (!Any(biasDim, Equal<size_t>, biasDimAllow)) {
                OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "bias", std::to_string(biasDim),
                    std::to_string(biasDimAllow));
                return ACLNN_ERR_PARAM_INVALID;
            }
        }

        auto strideSize = engine.meta.stride.size();
        if (strideSize != inputDim - CONST_VALUE_TWO) {
            OP_LOGE_FOR_INVALID_LISTSIZE(engine.entityName, "strides", std::to_string(strideSize),
                std::to_string(inputDim - CONST_VALUE_TWO));
            return ACLNN_ERR_PARAM_INVALID;
        }

        auto dilationSize = engine.meta.dilation.size();
        if (dilationSize != inputDim - CONST_VALUE_TWO) {
            OP_LOGE_FOR_INVALID_LISTSIZE(engine.entityName, "dilations", std::to_string(dilationSize),
                std::to_string(inputDim - CONST_VALUE_TWO));
            return ACLNN_ERR_PARAM_INVALID;
        }

        auto paddingSize = engine.meta.padding.size();
        if (paddingSize != inputDim - CONST_VALUE_TWO) {
            OP_LOGE_FOR_INVALID_LISTSIZE(engine.entityName, "pads", std::to_string(paddingSize),
                std::to_string(inputDim - CONST_VALUE_TWO));
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    };
};

class DtypeChecker : public ConvolutionChecker {
public:
    DtypeChecker() = default;
    ~DtypeChecker() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        DataType inputDtype = engine.meta.input.dataType;
        DataType weightDtype = engine.meta.weight.dataType;
        DataType biasDtype = inputDtype;
        if (engine.params.bias != nullptr) {
            biasDtype = engine.meta.bias.dataType;
            if (!CheckConvParamsDtype(engine.entityName, biasDtype, "bias", op::BIAS_SUPPORT_LIST)) {
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        DataType outputDtype = engine.meta.output.dataType;
        auto dtypeSupportList = GetDtypeSupportListBySocVersion();
        if (engine.params.transposed) {
            dtypeSupportList = GetDtypeSupportListBySocVersion4ConvBackward(!engine.params.transposed);
        }
        if (CheckConvParamsDtype(engine.entityName, inputDtype, "x", dtypeSupportList) &&
            CheckConvParamsDtype(engine.entityName, weightDtype, "filter", dtypeSupportList) &&
            CheckConvParamsDtype(engine.entityName, outputDtype, "y", dtypeSupportList)) {
            return ACLNN_SUCCESS;
        }
        return ACLNN_ERR_PARAM_INVALID;
    }
};

class DtypeCheckerTbc : public ConvolutionChecker {
public:
    DtypeCheckerTbc() = default;
    ~DtypeCheckerTbc() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        DataType inputDtype = engine.meta.input.dataType;
        DataType weightDtype = engine.meta.weight.dataType;
        DataType outputDtype = engine.meta.output.dataType;
        DataType biasDtype = inputDtype;

        if (engine.params.bias != nullptr) {
            biasDtype = engine.meta.bias.dataType;
            if (!CheckConvParamsDtype(engine.entityName, biasDtype, "bias", GetBiasDtypeSupportListBySocVersion())) {
                return ACLNN_ERR_PARAM_INVALID;
            }
        }

        auto dtypeSupportList = GetDtypeSupportListBySocVersion();
        if (!CheckConvParamsDtype(engine.entityName, inputDtype, "x", dtypeSupportList)) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (inputDtype != weightDtype) {
            OP_LOGE_FOR_INVALID_DTYPE(engine.entityName, "x", GeDtypeToString(inputDtype),
                GeDtypeToString(weightDtype));
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (weightDtype != outputDtype) {
            OP_LOGE_FOR_INVALID_DTYPE(engine.entityName, "filter", GeDtypeToString(weightDtype),
                GeDtypeToString(outputDtype));
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    }
};

class DtypeCheckerDepthwise2d : public ConvolutionChecker {
public:
    DtypeCheckerDepthwise2d() = default;
    ~DtypeCheckerDepthwise2d() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        DataType inputDtype = engine.meta.input.dataType;
        DataType weightDtype = engine.meta.weight.dataType;
        DataType biasDtype = inputDtype;
        if (engine.params.bias != nullptr) {
            biasDtype = engine.meta.bias.dataType;
            if (!CheckConvParamsDtype(engine.entityName, biasDtype, "bias", GetBiasDtypeSupportListBySocVersion())) {
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (inputDtype != DataType::DT_HIFLOAT8) {
                if (inputDtype != biasDtype) {
                    OP_LOGE_FOR_INVALID_DTYPE(engine.entityName, "x", GeDtypeToString(inputDtype),
                        GeDtypeToString(biasDtype));
                    return ACLNN_ERR_PARAM_INVALID;
                }
            }
        }
        DataType outputDtype = engine.meta.output.dataType;

        auto dtypeSupportList = GetDtypeSupportListBySocVersion();
        if (!CheckConvParamsDtype(engine.entityName, inputDtype, "x", dtypeSupportList) ||
            !CheckConvParamsDtype(engine.entityName, weightDtype, "filter", dtypeSupportList) ||
            !CheckConvParamsDtype(engine.entityName, outputDtype, "y", dtypeSupportList)) {
            return ACLNN_ERR_PARAM_INVALID;
        }

        if (inputDtype != weightDtype) {
            OP_LOGE_FOR_INVALID_DTYPE(engine.entityName, "x", GeDtypeToString(inputDtype),
                GeDtypeToString(weightDtype));
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (weightDtype != outputDtype) {
            OP_LOGE_FOR_INVALID_DTYPE(engine.entityName, "filter", GeDtypeToString(weightDtype),
                GeDtypeToString(outputDtype));
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    }
};

class NullptrChecker : public ConvolutionChecker {
public:
    NullptrChecker() = default;
    ~NullptrChecker() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        CHECK_PARAM_NULLPTR(engine.entityName, engine.params.input, "x");
        CHECK_PARAM_NULLPTR(engine.entityName, engine.params.weight, "filter");
        CHECK_PARAM_NULLPTR(engine.entityName, engine.params.stride, "strides");
        CHECK_PARAM_NULLPTR(engine.entityName, engine.params.padding, "pads");
        CHECK_PARAM_NULLPTR(engine.entityName, engine.params.dilation, "dilations");
        CHECK_PARAM_NULLPTR(engine.entityName, engine.params.output, "y");
        CHECK_PARAM_NULLPTR(engine.entityName, engine.params.workspaceSize, "workspaceSize");
        CHECK_PARAM_NULLPTR(engine.entityName, engine.params.executor, "executor");
        if (engine.params.transposed) {
            CHECK_PARAM_NULLPTR(engine.entityName, engine.params.outputPadding, "outputPadding");
        }
        return ACLNN_SUCCESS;
    }
};

class FormatChecker : public ConvolutionChecker {
private:
    aclnnStatus CheckConv1d(ConvEngine& engine)
    {
        auto inputFormat = engine.meta.input.format;
        auto weightFormat = engine.meta.weight.format;
        auto outputFormat = engine.meta.output.format;
        if (!All(Format::FORMAT_NCL, Equal<op::Format>, inputFormat, weightFormat, outputFormat)) {
            OP_LOGE_FOR_INVALID_FORMATS_WITH_REASON(engine.entityName, "x, filter, y",
                GeFormatToString(inputFormat) + ", " + GeFormatToString(weightFormat) + ", " +
                GeFormatToString(outputFormat), "formats of x, filter and y should all be " +
                GeFormatToString(Format::FORMAT_NCL));
            return ACLNN_ERR_PARAM_INVALID;
        }
        return ACLNN_SUCCESS;
    }

    aclnnStatus CheckConv2d(ConvEngine& engine)
    {
        auto inputFormat = engine.meta.input.format;
        auto weightFormat = engine.meta.weight.format;
        auto outputFormat = engine.meta.output.format;
        OP_LOGD("conv2d transpose: [%d]", engine.params.transposed);
        if (!engine.params.transposed) {
            if (op::IsSupportND()) {
                if (!Any(weightFormat, Equal<op::Format>, Format::FORMAT_NCHW)) {
                    OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "filter", GeFormatToString(weightFormat),
                        GeFormatToString(Format::FORMAT_NCHW));
                    return ACLNN_ERR_PARAM_INVALID;
                }
            } else {
                if (!Any(weightFormat, Equal<op::Format>, Format::FORMAT_NCHW, Format::FORMAT_FRACTAL_Z)) {
                    OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "filter", GeFormatToString(weightFormat),
                        GeFormatToString(Format::FORMAT_NCHW) + " or " +
                        GeFormatToString(Format::FORMAT_FRACTAL_Z));
                    return ACLNN_ERR_PARAM_INVALID;
                }
            }
            if (!Any(inputFormat, Equal<op::Format>, Format::FORMAT_NCHW)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "x", GeFormatToString(inputFormat),
                    GeFormatToString(Format::FORMAT_NCHW));
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (!Any(outputFormat, Equal<op::Format>, Format::FORMAT_NCHW)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "y", GeFormatToString(outputFormat),
                    GeFormatToString(Format::FORMAT_NCHW));
                return ACLNN_ERR_PARAM_INVALID;
            }
        } else {
            if (!Any(inputFormat, Equal<op::Format>, Format::FORMAT_NCHW)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "x", GeFormatToString(inputFormat),
                    GeFormatToString(Format::FORMAT_NCHW));
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (!Any(weightFormat, Equal<op::Format>, Format::FORMAT_NCHW)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "filter", GeFormatToString(weightFormat),
                    GeFormatToString(Format::FORMAT_NCHW));
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (!Any(outputFormat, Equal<op::Format>, Format::FORMAT_NCHW)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "y", GeFormatToString(outputFormat),
                    GeFormatToString(Format::FORMAT_NCHW));
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }

    aclnnStatus CheckConv3d(ConvEngine& engine)
    {
        auto inputFormat = engine.meta.input.format;
        auto weightFormat = engine.meta.weight.format;
        auto outputFormat = engine.meta.output.format;
        if (op::IsSupportND()) {
            if (!Any(inputFormat, Equal<op::Format>, Format::FORMAT_NCDHW)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "x", GeFormatToString(inputFormat),
                    GeFormatToString(Format::FORMAT_NCDHW));
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (!Any(weightFormat, Equal<op::Format>, Format::FORMAT_NCDHW)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "filter", GeFormatToString(weightFormat),
                    GeFormatToString(Format::FORMAT_NCDHW));
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (!Any(outputFormat, Equal<op::Format>, Format::FORMAT_NCDHW)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "y", GeFormatToString(outputFormat),
                    GeFormatToString(Format::FORMAT_NCDHW));
                return ACLNN_ERR_PARAM_INVALID;
            }
        } else {
            if (!Any(inputFormat, Equal<op::Format>, Format::FORMAT_NCDHW, Format::FORMAT_NDHWC)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "x", GeFormatToString(inputFormat),
                    GeFormatToString(Format::FORMAT_NCDHW) + " or " +
                    GeFormatToString(Format::FORMAT_NDHWC));
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (!Any(weightFormat, Equal<op::Format>, Format::FORMAT_NCDHW, Format::FORMAT_NDHWC)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "filter", GeFormatToString(weightFormat),
                    GeFormatToString(Format::FORMAT_NCDHW) + " or " +
                    GeFormatToString(Format::FORMAT_NDHWC));
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (!Any(outputFormat, Equal<op::Format>, Format::FORMAT_NCDHW, Format::FORMAT_NDHWC)) {
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "y", GeFormatToString(outputFormat),
                    GeFormatToString(Format::FORMAT_NCDHW) + " or " +
                    GeFormatToString(Format::FORMAT_NDHWC));
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }
public:
    FormatChecker() = default;
    ~FormatChecker() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        auto inputFormat = engine.meta.input.format;
        auto outputFormat = engine.meta.output.format;
        if (inputFormat != outputFormat) {
            OP_LOGE_FOR_INVALID_FORMATS_WITH_REASON(engine.entityName, "x, y",
                GeFormatToString(inputFormat) + ", " + GeFormatToString(outputFormat),
                "the formats of these parameters must be the same");
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (engine.params.bias != nullptr) {
            auto biasFormat = engine.meta.bias.format;
            if (!Any(biasFormat, Equal<op::Format>, Format::FORMAT_NCL, Format::FORMAT_NCHW, Format::FORMAT_NCDHW,
                    Format::FORMAT_ND)) {
                std::string correctFormat = "one of {" + GeFormatToString(Format::FORMAT_NCL) + ", " +
                    GeFormatToString(Format::FORMAT_NCHW) + ", " + GeFormatToString(Format::FORMAT_NCDHW) + ", " +
                    GeFormatToString(Format::FORMAT_ND) + "}";
                OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "bias", GeFormatToString(biasFormat), correctFormat);
                return ACLNN_ERR_PARAM_INVALID;
            }

            if (engine.params.transposed && biasFormat != Format::FORMAT_ND) {
                OP_LOGW(
                    "Please set bias format to %s, other formats may cause precision issues.",
                    GeFormatToString(Format::FORMAT_ND));
            }
        }
        size_t inputDimNum = engine.meta.input.shape.size();
        switch (inputDimNum) {
            case CONV_1D_DIM_SIZE: {
                return CheckConv1d(engine);
            }
            case CONV_2D_DIM_SIZE: {
                return CheckConv2d(engine);
            }
            case CONV_3D_DIM_SIZE: {
                return CheckConv3d(engine);
            }
            default:
                OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "x", std::to_string(inputDimNum),"one of {" +
                    std::to_string(CONV_1D_DIM_SIZE) + ", " + std::to_string(CONV_2D_DIM_SIZE) + ", " +
                    std::to_string(CONV_3D_DIM_SIZE) + "}");
                return ACLNN_ERR_PARAM_INVALID;
        }
    };
};

class FormatCheckerTbc : public ConvolutionChecker {
public:
    FormatCheckerTbc() = default;
    ~FormatCheckerTbc() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        auto inputFormat = engine.meta.input.format;
        auto weightFormat = engine.meta.weight.format;
        auto outputFormat = engine.meta.output.format;

        // conv_tbc,input weight output format都应是ND或者NCL
        if (!Any(inputFormat, Equal<op::Format>, Format::FORMAT_ND, Format::FORMAT_NCL)) {
            OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "x", GeFormatToString(inputFormat),
                GeFormatToString(Format::FORMAT_ND) + " or " + GeFormatToString(Format::FORMAT_NCL));
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (!Any(weightFormat, Equal<op::Format>, Format::FORMAT_ND, Format::FORMAT_NCL)) {
            OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "filter", GeFormatToString(weightFormat),
                GeFormatToString(Format::FORMAT_ND) + " or " + GeFormatToString(Format::FORMAT_NCL));
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (!Any(outputFormat, Equal<op::Format>, Format::FORMAT_ND, Format::FORMAT_NCL)) {
            OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "y", GeFormatToString(outputFormat),
                GeFormatToString(Format::FORMAT_ND) + " or " + GeFormatToString(Format::FORMAT_NCL));
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    };
};

class FormatCheckerDepthwise2d : public ConvolutionChecker {
public:
    FormatCheckerDepthwise2d() = default;
    ~FormatCheckerDepthwise2d() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        auto inputFormat = engine.meta.input.format;
        auto weightFormat = engine.meta.weight.format;
        auto outputFormat = engine.meta.output.format;

        if (!Any(inputFormat, Equal<op::Format>, Format::FORMAT_NCHW)) {
            OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "x", GeFormatToString(inputFormat),
                GeFormatToString(Format::FORMAT_NCHW));
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (!Any(weightFormat, Equal<op::Format>, Format::FORMAT_NCHW)) {
            OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "filter", GeFormatToString(weightFormat),
                GeFormatToString(Format::FORMAT_NCHW));
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (!Any(outputFormat, Equal<op::Format>, Format::FORMAT_NCHW)) {
            OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "y", GeFormatToString(outputFormat),
                GeFormatToString(Format::FORMAT_NCHW));
            return ACLNN_ERR_PARAM_INVALID;
        }

        // 输入和输出format要求必须一致
        if (inputFormat != outputFormat) {
            OP_LOGE_FOR_INVALID_FORMATS_WITH_REASON(engine.entityName, "x, y",
                GeFormatToString(inputFormat) + ", " + GeFormatToString(outputFormat),
                "the formats of these parameters must be the same");
            return ACLNN_ERR_PARAM_INVALID;
        }
        return ACLNN_SUCCESS;
    };
};

inline void GetSpatialDimInfo(
    const TensorMeta& tensor, bool& channelLast, int64_t& spaceDimIndex, size_t& spaceDimNum, FVector<int64_t>& shape)
{
    shape = tensor.shape;
    channelLast = tensor.ChannelLast();
    spaceDimIndex = channelLast ? 1 : 2; // 空间维度在shape中的起始位置,C维度后置时为1,否则为2
    spaceDimNum = shape.size() - 2;      // 空间维度大小,1d卷积时为1,2d为2,3d为3
}

inline aclnnStatus CheckShapeByDim(const std::string& entityName, TensorMeta& input, TensorMeta& weight)
{
    int64_t inputShapeN = input.N();
    int64_t inputShapeC = input.C();
    int64_t weightShapeN = weight.N();
    int64_t weightShapeC = weight.C();
    if (!All(0L, LessEqual<int64_t>, inputShapeN, inputShapeC, weightShapeN, weightShapeC)) {
        OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(entityName, "x, filter",
            op::FVectorToString(input.shape) + ", " + op::FVectorToString(weight.shape),
            "Shape[" + std::to_string(input.NIdx()) + "], Shape[" + std::to_string(input.CIdx()) +
            "] of x and Shape[" + std::to_string(weight.NIdx()) + "], Shape[" + std::to_string(weight.CIdx()) +
            "] of filter must be >= 0");
        return ACLNN_ERR_PARAM_INVALID;
    }
    return ACLNN_SUCCESS;
}

class ValueChecker : public ConvolutionChecker {
public:
    ValueChecker() = default;
    ~ValueChecker() override = default;

    aclnnStatus Check(ConvEngine& engine) override
    {
        if (CheckShape(engine.entityName, engine.meta.input, engine.meta.weight,
            engine.params.transposed) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // check stride
        if (CheckVectorValueGt0(engine.entityName, "strides", engine.meta.stride) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // check dilation
        if (CheckVectorValueGt0(engine.entityName, "dilations", engine.meta.dilation) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // check pad
        if (CheckPad(engine.entityName, engine.meta.input, engine.meta.weight, engine.meta.dilation,
                engine.meta.padding, engine.params.transposed) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // check channel_value (bias, groups)
        if (engine.params.groups <= 0) {
            OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(engine.entityName, "groups", std::to_string(engine.params.groups),
                "the value should be greater than 0");
            return ACLNN_ERR_PARAM_INVALID;
        }

        // check channel and groups
        int64_t outChannel = -1L;
        if (CheckChannelAndGroups(engine, outChannel) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }

        if (engine.params.bias != nullptr) {
            if (CheckConvBias(engine.entityName, engine.meta.bias, engine.meta.input, outChannel) != ACLNN_SUCCESS) {
                return ACLNN_ERR_PARAM_INVALID;
            }
        }

        if (ACLNN_SUCCESS != CheckEmptyTensor(engine.entityName, engine.meta.input, engine.params.transposed)) {
            return ACLNN_ERR_PARAM_INVALID;
        }

        // transposed=true时,空tensor判断
        if (engine.params.transposed && ACLNN_SUCCESS != CheckEmptyTensorTransposed(engine)) {
            return ACLNN_ERR_PARAM_INVALID;
        }

        // 针对 2d transpose error msg is: backprop pad value invalid 提前拦截
        if (GetCurrentPlatformInfo().GetCurNpuArch() != NpuArch::DAV_3510 && !(padBinaryValid(engine))) {
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    };

private:
    /** 空tensor判断逻辑
     * input:
     * 在ValueChecker时,保证加上pad后,空间维度也大于0
     * 此处校验针对transpose的情况,仅支持输入的n为0,因此仅需要校验C维度是否为0
     * weight: Cout和K不为0,在ValueChecker已完成校验
     */
    static aclnnStatus CheckEmptyTensor(const std::string& entityName, const TensorMeta& input, bool transposed)
    {
        if (transposed && input.C() == 0) {
            OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, "x", op::FVectorToString(input.shape),
                "When transposed is true, shape[" + std::to_string(input.CIdx()) +
                "] of this parameter must be greater than 0");
            return ACLNN_ERR_PARAM_INVALID;
        }
        return ACLNN_SUCCESS;
    }

    static aclnnStatus CheckTransposedFormat(ConvEngine& engine)
    {
        if (GetCurrentPlatformInfo().GetCurNpuArch() != NpuArch::DAV_3510) {
            return ACLNN_SUCCESS;
        }
        auto inputFormat = engine.meta.input.format;
        auto weightFormat = engine.meta.weight.format;
        auto outputFormat = engine.meta.output.format;
        if (!Any(inputFormat, Equal<op::Format>, Format::FORMAT_NCHW, Format::FORMAT_NCL, Format::FORMAT_NCDHW)) {
            OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "x", GeFormatToString(inputFormat),
                GeFormatToString(Format::FORMAT_NCHW) + " or " + GeFormatToString(Format::FORMAT_NCL) +
                " or " + GeFormatToString(Format::FORMAT_NCDHW));
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (!Any(weightFormat, Equal<op::Format>, Format::FORMAT_NCHW, Format::FORMAT_NCL, Format::FORMAT_NCDHW)) {
            OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "filter", GeFormatToString(weightFormat),
                GeFormatToString(Format::FORMAT_NCHW) + " or " + GeFormatToString(Format::FORMAT_NCL) + " or " +
                GeFormatToString(Format::FORMAT_NCDHW));
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (!Any(outputFormat, Equal<op::Format>, Format::FORMAT_NCHW, Format::FORMAT_NCL, Format::FORMAT_NCDHW)) {
            OP_LOGE_FOR_INVALID_FORMAT(engine.entityName, "y", GeFormatToString(outputFormat),
                GeFormatToString(Format::FORMAT_NCHW) + " or " + GeFormatToString(Format::FORMAT_NCL) + " or " +
                GeFormatToString(Format::FORMAT_NCDHW));
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (inputFormat != outputFormat || inputFormat != weightFormat || weightFormat != outputFormat) {
            OP_LOGE_FOR_INVALID_FORMATS_WITH_REASON(engine.entityName, "x, filter, y",
                GeFormatToString(inputFormat) + ", " + GeFormatToString(weightFormat) + ", " +
                GeFormatToString(outputFormat), "the formats of these parameters must be the same");
            return ACLNN_ERR_PARAM_INVALID;
        }
        return ACLNN_SUCCESS;
    }

    static aclnnStatus CheckEmptyTensorTransposed(ConvEngine& engine)
    {
        if (GetCurrentPlatformInfo().GetCurNpuArch() != NpuArch::DAV_3510) {
            return ACLNN_SUCCESS;
        }

        if (CheckTransposedFormat(engine) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }

        FVector<int64_t> inputShape = engine.meta.input.shape;
        FVector<int64_t> weightShape = engine.meta.weight.shape;
        FVector<int64_t> outputShape = engine.meta.output.shape;
        // NCL,NCHW,NCDHW
        if (weightShape[engine.meta.weight.NIdx()] <= 0) {
            OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(engine.entityName, "filter",
                op::FVectorToString(engine.meta.weight.shape), "When transposed is true, shape[" +
                std::to_string(engine.meta.weight.NIdx()) + "] of this parameter must be greater than 0");
            return ACLNN_ERR_PARAM_INVALID;
        }
        for (size_t i = 1; i < inputShape.size(); ++i) {
            // input仅可以N等于0
            if (inputShape[i] <= 0) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(engine.entityName, "x",
                    op::FVectorToString(engine.meta.input.shape), "shape[" + std::to_string(i) +
                    "] of this parameter must be greater than 0");
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (weightShape[i] < 0) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(engine.entityName, "filter",
                    op::FVectorToString(engine.meta.weight.shape), "shape[" + std::to_string(i) +
                    "] of this parameter cannot be a negative value");
                return ACLNN_ERR_PARAM_INVALID;
            }
            // weight: Cin,D,H,W可以为0,仅当output对应维度为0
            if (weightShape[i] == 0 && outputShape[i] != 0) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(engine.entityName, "filter",
                    op::FVectorToString(engine.meta.weight.shape), "shape[" + std::to_string(i) +
                    "] of filter can be 0 only when shape[" + std::to_string(i) + "] of y is 0");
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }

    // 针对 卷积2d transpose: error msg is: backprop pad value invalid 提前拦截
    // 反向pad:  (weightW - 1) * dilationW + 1 - padLeft/Right <=255   (weightH - 1) * dilationH + 1 - padUp/down
    static bool GetPadBinaryValidFlag(ConvEngine& engine, std::stringstream& ss)
    {
        // 255是目前阈值,二进制不支持该值大于255进行计算
        constexpr int64_t padBinValue = 255;
        int64_t dilationH = engine.meta.dilation[0];
        size_t dilationWIdx = (engine.meta.dilation.size() == 1) ? 0 : 1;
        int64_t dilationW = engine.meta.dilation[dilationWIdx];
        int64_t padTop = engine.meta.padding[0];
        size_t padLeftIdx = (engine.meta.padding.size() == 1) ? 0 : 1;
        int64_t padLeft = engine.meta.padding[padLeftIdx];
        auto weightShape = engine.meta.weight.tensorShape;
        int64_t weightW = engine.meta.weight.W();
        int64_t weightH = engine.meta.weight.H();
        int64_t weightL = engine.meta.weight.L();
        bool padValueValid = false;
        if (weightShape.GetDimNum() == CONV_1D_DIM_SIZE) {
            padValueValid = ((weightL - 1) * dilationW - padLeft) <= padBinValue;
            ss << "These parameters must meet the following condition: "
               << "(filter[" << std::to_string(engine.meta.weight.LIdx()) << "] - 1) * dilations["
               << std::to_string(dilationWIdx) << "] - pads[" << std::to_string(padLeftIdx)
               << "] <= " << std::to_string(padBinValue);
        } else if (weightShape.GetDimNum() == CONV_2D_DIM_SIZE) { // 4为weight为NCHW / NHWC场景
            padValueValid = (((weightW - 1) * dilationW - padLeft) <= padBinValue) &&
                            (((weightH - 1) * dilationH - padTop) <= padBinValue);
            ss << "These parameters must meet the following condition: "
               << "(filter[" << std::to_string(engine.meta.weight.WIdx()) << "] - 1) * dilations["
               << std::to_string(dilationWIdx) << "] - pads[" << std::to_string(padLeftIdx)
               << "] <= " << std::to_string(padBinValue) << " and "
               << "(filter[" << std::to_string(engine.meta.weight.HIdx()) << "] - 1) * dilations[0] - pads[0] <= "
               << std::to_string(padBinValue);
        } else if (weightShape.GetDimNum() == CONV_3D_DIM_SIZE) {
            size_t dilationLastIdx = (engine.meta.dilation.size() == 1) ? 0 : 2;
            int64_t dilationLast = engine.meta.dilation[dilationLastIdx];
            size_t padRightIdx = (engine.meta.padding.size() == 1) ? 0 : 2;
            int64_t padRight = engine.meta.padding[padRightIdx];
             // dilationW is dialtionH actually, padLeft is padTop actually 
            padValueValid = (((weightW - 1) * dilationLast - padRight) <= padBinValue) &&
                            (((weightW - 1) * dilationLast - padRight) >= 0) &&
                            (((weightH - 1) * dilationW - padLeft) <= padBinValue) &&
                            (((weightH - 1) * dilationW - padLeft) >= 0);
            ss << "These parameters must meet the following condition: "
               << "0 <= (filter[" << std::to_string(engine.meta.weight.WIdx()) << "] - 1) * dilations["
               << std::to_string(dilationLastIdx) << "] - pads[" << std::to_string(padRightIdx)
               << "] <= " << std::to_string(padBinValue) << " and "
               << "0 <= (filter[" << std::to_string(engine.meta.weight.HIdx()) << "] - 1) * dilations["
               << std::to_string(dilationWIdx) << "] - pads[" << std::to_string(padLeftIdx)
               << "] <= " << std::to_string(padBinValue);
        }
        return padValueValid;
    }

    static bool padBinaryValid(ConvEngine& engine)
    {
        if (!engine.params.transposed) {
            return true;
        }

        // transpose = true, 空tensor场景不校验pad
        if (engine.meta.input.IsZeroTensor() || engine.meta.weight.IsZeroTensor() ||
            engine.meta.output.IsZeroTensor()) {
            return true;
        }
        std::stringstream ss;
        if (!GetPadBinaryValidFlag(engine, ss)) {
            OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "filter, dilations, pads",
                op::FVectorToString(engine.meta.weight.shape) + ", " +
                op::FVectorToString(engine.meta.dilation) + ", " +
                op::FVectorToString(engine.meta.padding),
                ss.str());
            return false;
        }
        return true;
    }

    static inline aclnnStatus CheckVectorValueGt0(const std::string& entityName, const std::string& paramName,
        FVector<int64_t>& param)
    {
        for (size_t i = 0; i < param.size(); ++i) {
            if (param[i] <= 0) {
                std::string reason = "shape[" + std::to_string(i) + "] of this parameter must be greater than 0";
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, paramName, op::FVectorToString(param), reason);
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }

    static aclnnStatus CheckPad(
        const std::string& entityName, TensorMeta& input, TensorMeta& weight, FVector<int64_t>& dilation,
        FVector<int64_t>& padding, bool transposed)
    {
        FVector<int64_t> inputShape, weightShape;
        bool inputChannelLast, weightChannelLast;
        int64_t inputSpaceDimIndex, weightSpaceDimIndex;
        size_t inputSpaceDimNum, weightSpaceDimNum;

        GetSpatialDimInfo(input, inputChannelLast, inputSpaceDimIndex, inputSpaceDimNum, inputShape);
        GetSpatialDimInfo(weight, weightChannelLast, weightSpaceDimIndex, weightSpaceDimNum, weightShape);

        auto newpad = ConstructPad(padding, inputShape);
        for (size_t i = 0; i < inputSpaceDimNum; ++i) {
            auto inputShapeValue = inputShape[i + inputSpaceDimIndex];
            auto weightShapeValue = weightShape[i + weightSpaceDimIndex];
            auto paddingValueFront =
                (input.shape.size() == CONV_1D_DIM_SIZE || input.shape.size() == CONV_2D_DIM_SIZE) ? newpad[i] :
                                                                                                     padding[i];
            auto dilationValue = dilation[i];

            // check input shape after pad only for conv
            if (!transposed && !input.IsZeroTensor()) {
                int64_t inputShapeValueAfterPad = -1;
                std::string formula;
                if (input.shape.size() == CONV_1D_DIM_SIZE || input.shape.size() == CONV_2D_DIM_SIZE) {
                    inputShapeValueAfterPad =
                        (inputShapeValue + paddingValueFront - dilationValue * (weightShapeValue - 1L) - 1L);
                    formula = "x[" + std::to_string(i + inputSpaceDimIndex) +"] + pads[" + std::to_string(i) +
                        "] - dilations[" + std::to_string(i) + "] * (filter[" +
                        std::to_string(i + weightSpaceDimIndex) + "] - 1) - 1 >= 0";
                } else {
                    inputShapeValueAfterPad =
                        (inputShapeValue + paddingValueFront * CONST_VALUE_TWO -
                         dilationValue * (weightShapeValue - 1L) - 1L);
                    formula = "x[" + std::to_string(i + inputSpaceDimIndex) +"] + pads[" + std::to_string(i) +
                        "] * 2 - dilations[" + std::to_string(i) + "] * (filter[" +
                        std::to_string(i + weightSpaceDimIndex) + "] - 1) - 1 >= 0";
                }

                if (inputShapeValueAfterPad < 0) {
                    OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, "x", op::FVectorToString(input.shape),
                        "These parameters must meet the following condition: " + formula);
                    return ACLNN_ERR_PARAM_INVALID;
                }
            }
        }

        return ACLNN_SUCCESS;
    }

    static aclnnStatus CheckConvBias(const std::string& entityName, TensorMeta& bias,
        TensorMeta& input, int64_t outChannel)
    {
        auto biasShape = bias.shape;
        size_t biasDimNum = biasShape.size();

        // the index of C in Bias
        size_t idx_c = 0;
        if (biasDimNum != static_cast<size_t>(1)) {
            std::string str(op::ToString(input.format).GetString());
            idx_c = str.find('C');
        }

        for (size_t i = 0; i < biasDimNum; i++) {
            if (i == idx_c) {
                auto biasCout = biasShape[i];
                if (biasCout != outChannel) {
                    OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(entityName, "bias",
                        op::FVectorToString(bias.shape),
                        "shape[" + std::to_string(idx_c) + "] of bias must be equal to output channel(" +
                        std::to_string(outChannel) + ")");
                    return ACLNN_ERR_PARAM_INVALID;
                }
            } else {
                if (biasShape[i] != 1) {
                    OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, "bias", op::FVectorToString(bias.shape),
                        "shape[" + std::to_string(i) + "] of bias must be equal to 1");
                    return ACLNN_ERR_PARAM_INVALID;
                }
            }
        }

        return ACLNN_SUCCESS;
    }

    static aclnnStatus CheckChannelAndGroups(ConvEngine& engine, int64_t& outChannel)
    {
        int64_t inChannel = engine.meta.input.C();
        if (engine.params.transposed) {
            outChannel = engine.meta.weight.C() * engine.params.groups;
            if (engine.meta.weight.N() != inChannel) {
                OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "x, filter",
                    op::FVectorToString(engine.meta.input.shape) + ", " +
                    op::FVectorToString(engine.meta.weight.shape),
                    "When transposed is true, shape[" + std::to_string(engine.meta.weight.NIdx()) +
                    "] of filter must be equal to shape[" + std::to_string(engine.meta.input.CIdx()) + "] of x");
                return ACLNN_ERR_PARAM_INVALID;
            }
            // output_padding value check  output_padding参数不支持负数
            for (size_t i = 0; i < engine.meta.outputPadding.size(); i++) {
                auto outputPaddingValue = engine.meta.outputPadding[i];
                if (outputPaddingValue >= engine.meta.stride[i] && outputPaddingValue >= engine.meta.dilation[i]) {
                    OP_LOGE_FOR_INVALID_VALUES_WITH_REASON(engine.entityName, "outputPadding, strides, dilations",
                        op::FVectorToString(engine.meta.outputPadding) + ", " + op::FVectorToString(engine.meta.stride) + ", " +
                        op::FVectorToString(engine.meta.dilation), "outputPadding[" + std::to_string(i) +
                        "] must be less than both stride[" + std::to_string(i) + "] and dilation[" +
                        std::to_string(i) + "]");
                    return ACLNN_ERR_PARAM_INVALID;
                }

                OP_CHECK(
                    outputPaddingValue >= 0,
                    OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(engine.entityName, "outputPadding",
                        std::to_string(outputPaddingValue),
                        "outputPadding[" + std::to_string(i) + "] must be greater than or equal to 0"),
                    return ACLNN_ERR_PARAM_INVALID);
            }
        } else {
            outChannel = engine.meta.weight.N();
            if (engine.meta.weight.C() * engine.params.groups != inChannel) {
                OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "x, filter",
                    op::FVectorToString(engine.meta.input.shape) + ", " + op::FVectorToString(engine.meta.weight.shape),
                    "shape[" + std::to_string(engine.meta.input.CIdx()) + "] of x must be equal to shape[" +
                    std::to_string(engine.meta.weight.CIdx()) + "] of filter multiplied by attribute groups(" +
                    std::to_string(engine.params.groups) + ")");
                return ACLNN_ERR_PARAM_INVALID;
            }
        }

        if (engine.meta.weight.N() % engine.params.groups != 0) {
            OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "filter",
                op::FVectorToString(engine.meta.weight.shape),
                "shape[" + std::to_string(engine.meta.weight.NIdx()) +
                "] of filter must be exactly divided by attribute groups(" +
                std::to_string(engine.params.groups) + ")");
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    }

    static aclnnStatus CheckShape(const std::string& entityName, TensorMeta& input, TensorMeta& weight,
        bool transposed = false)
    {
        if (CheckShapeByDim(entityName, input, weight) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }

        // check space(d h w or l)
        FVector<int64_t> inputShape = input.shape;
        bool inputChannelLast = input.ChannelLast();
        int64_t inputSpaceDimIndex = inputChannelLast ? 1 : 2; // 空间维度在shape中的起始位置,C维度后置时为1,否则为2
        size_t inputSpaceDimNum = input.shape.size() - 2;      // 空间维度大小,1d卷积时为1,2d为2,3d为3

        FVector<int64_t> weightShape = weight.shape;
        bool weightChannelLast = weight.ChannelLast();
        int64_t weightSpaceDimIndex = weightChannelLast ? 1 : 2; // 空间维度在shape中的起始位置,C维度后置时为1,否则为2

        // 假设是NCL,判断L的值。假设是NCHW,判断HW的值
        for (size_t i = 0; i < inputSpaceDimNum; ++i) {
            int64_t inputShapeSpace = inputShape[i + inputSpaceDimIndex]; // 空间维度的值
            if (inputShapeSpace < 0) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, "x", op::FVectorToString(input.shape),
                    "shape[" + std::to_string(i + inputSpaceDimIndex) + "] of x must be greater than or equal to 0");
                return ACLNN_ERR_PARAM_INVALID;
            }
            if (GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510 && transposed &&
                weight.IsZeroTensor()) {
                continue;
            }
            int64_t weightShapeSpace = weightShape[i + weightSpaceDimIndex];
            if (weightShapeSpace <= 0) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, "filter", op::FVectorToString(weight.shape),
                    "shape[" + std::to_string(i + weightSpaceDimIndex) + "] of filter must be greater than 0");
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }
};

class ValueCheckerTbc : public ConvolutionChecker {
public:
    ValueCheckerTbc() = default;
    ~ValueCheckerTbc() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        if (CheckShapeTbc(engine.entityName, engine.meta.input, engine.meta.weight,
            engine.meta.output) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        int64_t outChannel = engine.meta.weight.N();
        if (CheckConvBiasTbc(engine.entityName, engine.meta.bias, outChannel) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        int64_t kernelSize = engine.meta.weight.shape[CONVTBC_L_INDEX];
        if (kernelSize <= 0) {
            OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(engine.entityName, "filter",
                op::FVectorToString(engine.meta.weight.shape), "Shape[" +
                std::to_string(CONVTBC_L_INDEX) + "] of filter must be greater than 0");
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (CheckPadTbc(engine.entityName, engine.meta.input, engine.meta.weight,
            engine.meta.padding) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        if (engine.meta.weight.shape[CONVTBC_C_INDEX] != engine.meta.input.shape[CONVTBC_C_INDEX]) {
            OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "x, filter",
                op::FVectorToString(engine.meta.input.shape) + ", " +
                op::FVectorToString(engine.meta.weight.shape),
                "shape[" + std::to_string(CONVTBC_C_INDEX) + "] of x must be equal to shape[" +
                std::to_string(CONVTBC_C_INDEX) + "] of filter");
            return ACLNN_ERR_PARAM_INVALID;
        }
        return ACLNN_SUCCESS;
    };

private:
    /*
    input weight output 的shape均需大于等于0
    bias(一维)的值要等于channel_out
    */
    aclnnStatus CheckShapeTbc(const std::string& entityName, TensorMeta& input, TensorMeta& weight,
        TensorMeta& output) const
    {
        int64_t inputShapeN = input.N();
        int64_t inputShapeC = input.C();
        int64_t inputShapeL = input.L();
        int64_t weightShapeN = weight.N();
        int64_t weightShapeC = weight.C();
        int64_t weightShapeL = weight.L();
        int64_t outputShapeN = output.N();
        int64_t outputShapeC = output.C();
        int64_t outputShapeL = output.L();
        if (!All(0L, LessEqual<int64_t>, inputShapeN, inputShapeC, inputShapeL, weightShapeN, weightShapeC,
                weightShapeL, outputShapeN, outputShapeC, outputShapeL)) {
            OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(entityName, "x, filter, y",
                op::FVectorToString(input.shape) + ", " + op::FVectorToString(weight.shape) + ", " +
                op::FVectorToString(output.shape), "Shape[" + std::to_string(input.NIdx()) +
                "], Shape[" + std::to_string(input.CIdx()) + "], Shape[" + std::to_string(input.LIdx()) +
                "] of x, Shape[" + std::to_string(weight.NIdx()) + "], Shape[" + std::to_string(weight.CIdx()) +
                "], Shape[" + std::to_string(weight.LIdx()) + "] of filter and Shape[" + std::to_string(output.NIdx()) +
                "], Shape[" + std::to_string(output.CIdx()) + "], Shape[" + std::to_string(output.LIdx()) +
                "] of y must be >= 0");
            return ACLNN_ERR_PARAM_INVALID;
        }

        return ACLNN_SUCCESS;
    }

    aclnnStatus CheckConvBiasTbc(const std::string& entityName, TensorMeta& bias, int64_t outChannel) const
    {
        auto biasShape = bias.shape;
        if (biasShape[0] != outChannel) {
            OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(entityName, "bias",
                op::FVectorToString(bias.shape),
                "shape[0] of bias must be equal to output channel(" + std::to_string(outChannel) + ")");
            return ACLNN_ERR_PARAM_INVALID;
        }
        return ACLNN_SUCCESS;
    }

    static aclnnStatus CheckPadTbc(const std::string& entityName, TensorMeta& input, TensorMeta& weight,
        FVector<int64_t>& padding)
    {
        FVector<int64_t> inputShape = input.shape;
        FVector<int64_t> weightShape = weight.shape;
        int64_t inputShapeL = inputShape[CONVTBC_L_INDEX];
        int64_t weightShapeL = weightShape[CONVTBC_L_INDEX];
        constexpr int64_t dilationValue = 1;
        if (!input.IsZeroTensor()) {
            int64_t inputShapeValueAfterPad = (inputShapeL + padding[0] - dilationValue * (weightShapeL - 1) - 1);
            if (inputShapeValueAfterPad < 0) {
                std::string incorrectShapes = op::FVectorToString(input.shape) + ", " +
                    op::FVectorToString(weight.shape) + ", " +
                    op::FVectorToString(padding);
                std::stringstream ss;
                ss << "These parameters must meet the following condition: ";
                ss << "x[" << std::to_string(CONVTBC_L_INDEX) << "] + pads[0] - 1 * (filter["
                   << std::to_string(CONVTBC_L_INDEX) << "] - 1) - 1 >= 0";
                OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(entityName, "x, filter, pads",
                    incorrectShapes, ss.str());
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }
};

class ValueCheckerDepthwise2d : public ConvolutionChecker {
public:
    ValueCheckerDepthwise2d() = default;
    ~ValueCheckerDepthwise2d() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        if (CheckShapeDepthwise2d(engine.entityName, engine.meta.input, engine.meta.weight) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // check stride
        if (CheckVectorValueGt0Depthwise2d(engine.entityName, "strides", engine.meta.stride) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // check dilation
        if (CheckVectorValueGt0Depthwise2d(engine.entityName, "dilations", engine.meta.dilation) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // check pad
        if (CheckPadDepthwise2d(
                engine.entityName, engine.meta.input, engine.meta.weight, engine.meta.dilation, engine.meta.padding,
                engine.params.transposed) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // check channel
        int64_t inChannel = engine.meta.input.C();
        int64_t outChannel = -1L;
        outChannel = engine.meta.weight.N();
        if (engine.meta.weight.C() != 0L && engine.meta.weight.C() != 1L) {
            OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(engine.entityName, "filter",
                op::FVectorToString(engine.meta.weight.shape),
                "shape[" + std::to_string(engine.meta.weight.CIdx()) + "] of this parameter must be 0 or 1");
            return ACLNN_ERR_PARAM_INVALID;
        }

        if (inChannel != 0 && outChannel % inChannel != 0) {
            OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "x, filter",
                op::FVectorToString(engine.meta.input.shape) + ", " +
                op::FVectorToString(engine.meta.weight.shape),
                "shape[" + std::to_string(engine.meta.weight.NIdx()) + "] of filter must be exactly divided by shape[" +
                std::to_string(engine.meta.input.CIdx()) + "] of x");
            return ACLNN_ERR_PARAM_INVALID;
        }

        if (engine.params.bias != nullptr) {
            if (CheckConvBiasDepthwise2d(engine.entityName, engine.meta.bias, outChannel) != ACLNN_SUCCESS) {
                return ACLNN_ERR_PARAM_INVALID;
            }
        }

        return ACLNN_SUCCESS;
    };

private:
    static aclnnStatus CheckShapeDepthwise2d(const std::string& entityName, TensorMeta& input, TensorMeta& weight)
    {
        if (CheckShapeByDim(entityName, input, weight) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }

        // check space(d h w or l)
        FVector<int64_t> inputShape, weightShape;
        bool inputChannelLast, weightChannelLast;
        int64_t inputSpaceDimIndex, weightSpaceDimIndex;
        size_t inputSpaceDimNum, weightSpaceDimNum;

        GetSpatialDimInfo(input, inputChannelLast, inputSpaceDimIndex, inputSpaceDimNum, inputShape);
        GetSpatialDimInfo(weight, weightChannelLast, weightSpaceDimIndex, weightSpaceDimNum, weightShape);
        for (size_t i = 0; i < inputSpaceDimNum; ++i) {
            int64_t inputShapeSpace = inputShape[i + inputSpaceDimIndex]; // 空间维度的值
            if (inputShapeSpace < 0) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, "x", op::FVectorToString(input.shape),
                    "shape[" + std::to_string(i + inputSpaceDimIndex) +
                    "] of this parameter must be greater than or equal to 0");
                return ACLNN_ERR_PARAM_INVALID;
            }
            int64_t weightShapeSpace = weightShape[i + weightSpaceDimIndex];
            if (weightShapeSpace <= 0) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, "filter", op::FVectorToString(weight.shape),
                    "shape[" + std::to_string(i + weightSpaceDimIndex) + "] of this parameter must be greater than 0");
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }

    static inline aclnnStatus CheckVectorValueGt0Depthwise2d(const std::string& entityName, const std::string& paramName, FVector<int64_t>& param)
    {
        for (size_t i = 0; i < param.size(); ++i) {
            if (param[i] <= 0) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, paramName, op::FVectorToString(param),
                    "shape[" + std::to_string(i) + "] of this parameter must be greater than 0");
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }

    static aclnnStatus CheckPadDepthwise2d(
        const std::string& entityName, const TensorMeta& input, const TensorMeta& weight, const FVector<int64_t>& dilation,
        const FVector<int64_t>& padding, bool transposed)
    {
        FVector<int64_t> inputShape, weightShape;
        bool inputChannelLast, weightChannelLast;
        int64_t inputSpaceDimIndex, weightSpaceDimIndex;
        size_t inputSpaceDimNum, weightSpaceDimNum;

        GetSpatialDimInfo(input, inputChannelLast, inputSpaceDimIndex, inputSpaceDimNum, inputShape);
        GetSpatialDimInfo(weight, weightChannelLast, weightSpaceDimIndex, weightSpaceDimNum, weightShape);

        for (size_t i = 0; i < inputSpaceDimNum; ++i) {
            auto inputShapeValue = inputShape[i + inputSpaceDimIndex];
            auto weightShapeValue = weightShape[i + weightSpaceDimIndex];
            auto paddingValueFront = padding[i];
            auto dilationValue = dilation[i];

            // check input shape after pad only for conv
            if (!transposed && !input.IsZeroTensor()) {
                int64_t inputShapeValueAfterPad =
                    (inputShapeValue + paddingValueFront * 2 - dilationValue * (weightShapeValue - 1) - 1);
                if (inputShapeValueAfterPad < 0) {
                    std::string incorrectShapes = op::FVectorToString(input.shape) + ", " +
                        op::FVectorToString(weight.shape) + ", " +
                        op::FVectorToString(padding) + ", " + op::FVectorToString(dilation);
                    std::stringstream ss;
                    ss << "These parameters must meet the following condition: "
                       << "x[" << std::to_string(i + inputSpaceDimIndex) << "] + pads[" << std::to_string(i)
                       << "] * 2 - dilations[" << std::to_string(i) << "] * (filter[" <<
                       std::to_string(i + weightSpaceDimIndex) << "] - 1) - 1 >= 0";
                    OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(entityName, "x, filter, pads, dilations",
                        incorrectShapes, ss.str());
                    return ACLNN_ERR_PARAM_INVALID;
                }
            }
        }

        return ACLNN_SUCCESS;
    }

    static aclnnStatus CheckConvBiasDepthwise2d(const std::string& entityName, TensorMeta& bias, int64_t outChannel)
    {
        auto biasShape = bias.shape;
        if (biasShape[0] != outChannel) {
            OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, "bias", op::FVectorToString(bias.shape),
                "shape[0] of bias must be equal to output channel(" + std::to_string(outChannel) + ")");
            return ACLNN_ERR_PARAM_INVALID;
        }
        return ACLNN_SUCCESS;
    }
};

class ConvXdCheckerTbc : public ConvolutionChecker {
public:
    ConvXdCheckerTbc() = default;
    ~ConvXdCheckerTbc() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        FVector<int64_t> outputShape = engine.CalcOutputShape();
        for (size_t i = 0; i < outputShape.size(); i++) {
            if (outputShape[i] != engine.meta.output.shape[i]) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(engine.entityName, "y",
                    op::FVectorToString(engine.meta.output.shape),
                    "shape[" + std::to_string(i) + "] of y must be equal to its inferred value(" +
                    std::to_string(outputShape[i]) + ")");
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }
};

class ConvXdChecker : public ConvolutionChecker {
public:
    ConvXdChecker() = default;
    ~ConvXdChecker() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        FVector<int64_t> outputShape = engine.CalcOutputShape();
        for (size_t i = 0; i < outputShape.size(); i++) {
            if (outputShape[i] != engine.meta.output.shape[i]) {
                OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(engine.entityName, "y",
                    op::FVectorToString(engine.meta.output.shape),
                    "shape[" + std::to_string(i) + "] of y must be equal to its inferred value(" +
                    std::to_string(outputShape[i]) + ")");
                return ACLNN_ERR_PARAM_INVALID;
            }
        }

        if (engine.meta.input.IsZeroTensor() || engine.meta.weight.IsZeroTensor()) {
            std::stringstream ss;
            ss << "if x or filter is an empty tensor, output must be an empty tensor. ";
            ss << "Empty tensor definition: A dimension in the shape is 0. ";
            if (!engine.meta.output.IsZeroTensor()) {
                OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "x, filter, y",
                    op::FVectorToString(engine.meta.input.shape) + ", " +
                    op::FVectorToString(engine.meta.weight.shape) + ", " +
                    op::FVectorToString(engine.meta.output.shape),
                    ss.str());
                return ACLNN_ERR_PARAM_INVALID;
            }
            for (size_t i = 0; i < outputShape.size(); i++) {
                if (outputShape[i] < 0) {
                    OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(engine.entityName, "y",
                        op::FVectorToString(engine.meta.output.shape),
                        "shape[" + std::to_string(i) + "] of y must be greater than or equal to 0");
                    return ACLNN_ERR_PARAM_INVALID;
                }
            }
        }

        return ACLNN_SUCCESS;
    }
};

class HardwareLimitChecker : public ConvolutionChecker {
public:
    HardwareLimitChecker() = default;
    ~HardwareLimitChecker() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        DataType upperDtype = engine.meta.calculatDataType;
        CHECK_RET(CheckCubeMathType(upperDtype, engine.params.cubeMathType), ACLNN_ERR_PARAM_INVALID);
        return ACLNN_SUCCESS;
    }
};

class HardwareLimitCheckerTbc : public ConvolutionChecker {
public:
    HardwareLimitCheckerTbc() = default;
    ~HardwareLimitCheckerTbc() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        DataType inputDtype = engine.meta.input.dataType; // input和weight应该是一个Dtype
        CHECK_RET(CheckCubeMathType(inputDtype, engine.params.cubeMathType), ACLNN_ERR_PARAM_INVALID);
        return ACLNN_SUCCESS;
    }
};

class TemporarySoftwareLimitChecker : public ConvolutionChecker {
public:
    TemporarySoftwareLimitChecker() = default;
    ~TemporarySoftwareLimitChecker() override = default;
    aclnnStatus Check(ConvEngine& engine) override
    {
        size_t inputDim = engine.meta.input.shape.size();
        if (GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
            OP_LOGD("Get Current NpuArch: DAV_3510.");
            return ACLNN_SUCCESS;
        }
        // 除了910A 910B 310P,其余暂不支持
        SocVersion socVersion = GetCurrentPlatformInfo().GetSocVersion();
        switch (socVersion) {
            case SocVersion::ASCEND910:
            case SocVersion::ASCEND910B:
            case SocVersion::ASCEND910_93:
            case SocVersion::ASCEND310P:
                break;
            case SocVersion::ASCEND310B: {
                if (engine.params.transposed || inputDim != CONV_2D_DIM_SIZE) {
                    OP_LOGE_FOR_INVALID_SHAPEDIM(engine.entityName, "x",
                        std::to_string(inputDim),
                        "When the SoC version is " + std::string(op::ToString(socVersion).GetString()) +
                        ", the value of this parameter can only be " + std::to_string(CONV_2D_DIM_SIZE));
                    return ACLNN_ERR_PARAM_INVALID;
                }
                break;
            }
            default: {
                OP_LOGE_FOR_INVALID_VALUE(engine.entityName, "SoC version",
                    std::string(op::ToString(socVersion).GetString()), "one of {" +
                    std::string(op::ToString(SocVersion::ASCEND910).GetString()) + ", " +
                    std::string(op::ToString(SocVersion::ASCEND910B).GetString()) + ", " +
                    std::string(op::ToString(SocVersion::ASCEND910_93).GetString()) + ", " +
                    std::string(op::ToString(SocVersion::ASCEND310P).GetString()) + ", " +
                    std::string(op::ToString(SocVersion::ASCEND310B).GetString()) + ", " +
                    std::string(op::ToString(SocVersion::ASCEND950).GetString()) +
                    "}");
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }
};

class TemporarySoftwareLimitCheckerTbc : public ConvolutionChecker {
public:
    TemporarySoftwareLimitCheckerTbc() = default;
    ~TemporarySoftwareLimitCheckerTbc() override = default;
    aclnnStatus Check([[maybe_unused]] ConvEngine& engine) override
    { // 3D暂不支持
        if (GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
            OP_LOGD("Get Current NpuArch: DAV_3510.");
            return ACLNN_SUCCESS;
        }
        // 除了910A 910B 310P,其余暂不支持
        SocVersion socVersion = GetCurrentPlatformInfo().GetSocVersion();
        switch (socVersion) {
            case SocVersion::ASCEND910:
            case SocVersion::ASCEND910B:
            case SocVersion::ASCEND910_93:
            case SocVersion::ASCEND310P:
                break;
            default: {
                OP_LOGE_FOR_INVALID_VALUE(engine.entityName, "SoC version",
                    std::string(op::ToString(socVersion).GetString()), "one of {" +
                    std::string(op::ToString(SocVersion::ASCEND910).GetString()) + ", " +
                    std::string(op::ToString(SocVersion::ASCEND910B).GetString()) + ", " +
                    std::string(op::ToString(SocVersion::ASCEND910_93).GetString()) + ", " +
                    std::string(op::ToString(SocVersion::ASCEND310P).GetString()) + ", " +
                    std::string(op::ToString(SocVersion::ASCEND950).GetString()) +
                    "}");
                return ACLNN_ERR_PARAM_INVALID;
            }
        }
        return ACLNN_SUCCESS;
    }
};

} // namespace

namespace {
constexpr int64_t DIM_DHW_NUM = 3;
constexpr int64_t CI_DIM_CO_CI_DHW_INDEX = 1;
constexpr int64_t D_DIM_NCDHW_INDEX = 2;
constexpr int64_t H_DIM_NCDHW_INDEX = 3;
constexpr int64_t W_DIM_NCDHW_INDEX = 4;
constexpr int64_t N_DIM_NCHW_INDEX = 0;
constexpr int64_t C_DIM_NCHW_INDEX = 1;
constexpr int64_t H_DIM_NCHW_INDEX = 2;
constexpr int64_t W_DIM_NCHW_INDEX = 3;
constexpr int64_t C_DIM_NCHW_VALUE_TRANSPOSE1D = 768;
constexpr int64_t W_DIM_NCHW_VALUE_TRANSPOSE1D = 4096;

constexpr int64_t N_DIM_NCL_INDEX = 0;
constexpr int64_t C_DIM_NCL_INDEX = 1;
constexpr int64_t L_DIM_NCL_INDEX = 2;

struct BatchMatmulInput {
    const aclTensor* leftData;
    const aclTensor* rightData;
    const aclTensor* biasData;
    const aclTensor* outputData;
    bool isLeftTranspose;
    bool isRightTranspose;
};
} // namespace

namespace {

static aclnnStatus CheckConvParams(ConvEngine& engine)
{
    std::vector<unique_ptr<ConvolutionChecker>> checkList;
    // math level check
    // common checkers: nullptr, dims, format
    checkList.push_back(make_unique<NullptrChecker>());
    checkList.push_back(make_unique<DtypeChecker>());
    checkList.push_back(make_unique<DimChecker>());
    checkList.push_back(make_unique<FormatChecker>());
    checkList.push_back(make_unique<ValueChecker>());
    // different conv checkers: infershape and so on
    checkList.push_back(make_unique<ConvXdChecker>());

    // implement level check
    // hardware limit checkers:double conv, fp32 conv in 1980...
    checkList.push_back(make_unique<HardwareLimitChecker>());
    // temporary software limitation checkers: 3d conv
    checkList.push_back(make_unique<TemporarySoftwareLimitChecker>());

    for (auto& checker : checkList) {
        aclnnStatus ret = checker->Check(engine);
        CHECK_RET(ret == ACLNN_SUCCESS, ret);
    }
    return ACLNN_SUCCESS;
}

static aclnnStatus CheckConvTbcParams(ConvEngine& engine)
{
    std::vector<unique_ptr<ConvolutionChecker>> checkList;
    // math level check
    // common checkers: nullptr, dims, format
    checkList.push_back(make_unique<DtypeCheckerTbc>());
    checkList.push_back(make_unique<DimCheckerTbc>());
    checkList.push_back(make_unique<FormatCheckerTbc>());
    checkList.push_back(make_unique<ValueCheckerTbc>());
    // different conv checkers: infershape and so on
    checkList.push_back(make_unique<ConvXdCheckerTbc>());

    // implement level check
    // hardware limit checkers:double conv, fp32 conv in 1980...
    checkList.push_back(make_unique<HardwareLimitCheckerTbc>());
    // temporary software limitation checkers: 3d conv
    checkList.push_back(make_unique<TemporarySoftwareLimitCheckerTbc>());

    for (auto& checker : checkList) {
        aclnnStatus ret = checker->Check(engine);
        CHECK_RET(ret == ACLNN_SUCCESS, ret);
    }
    return ACLNN_SUCCESS;
}

static aclnnStatus CheckConvDepthwise2dKernelSize(ConvEngine& engine, const aclIntArray* kernelSize)
{
    if (engine.meta.weight.format == op::Format::FORMAT_NCL) {
        return ACLNN_SUCCESS;
    }
    int64_t weightH = engine.meta.weight.H();
    int64_t weightW = engine.meta.weight.W();
    int64_t kernelH = static_cast<int64_t>((*kernelSize)[0]);
    int64_t kernelW = static_cast<int64_t>((*kernelSize)[1]);
    std::string kernelSizeStr = "[" + std::to_string(kernelH) + ", " + std::to_string(kernelW) + "]";
    if (kernelH != weightH) {
        OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "filter, kernelSize",
            op::FVectorToString(engine.meta.weight.shape) + ", " + kernelSizeStr,
            "shape[" + std::to_string(engine.meta.weight.HIdx()) +
            "] of filter must be equal to the shape[0] of kernelSize");
        return ACLNN_ERR_PARAM_INVALID;
    }
    if (kernelW != weightW) {
        OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(engine.entityName, "filter, kernelSize",
            op::FVectorToString(engine.meta.weight.shape) + ", " + kernelSizeStr,
            "shape[" + std::to_string(engine.meta.weight.WIdx()) +
            "] of filter must be equal to the shape[1] of kernelSize");
        return ACLNN_ERR_PARAM_INVALID;
    }
    return ACLNN_SUCCESS;
}

static aclnnStatus CheckConvDepthwise2dParams(ConvEngine& engine)
{
    std::vector<unique_ptr<ConvolutionChecker>> checkList;
    // math level check
    // common checkers: nullptr, dims, format
    checkList.push_back(make_unique<NullptrChecker>());
    checkList.push_back(make_unique<DtypeCheckerDepthwise2d>());
    checkList.push_back(make_unique<DimCheckerDepthwise2d>());
    checkList.push_back(make_unique<FormatCheckerDepthwise2d>());
    checkList.push_back(make_unique<ValueCheckerDepthwise2d>());
    // different conv checkers: infershape and so on
    checkList.push_back(make_unique<ConvXdChecker>());

    // implement level check
    // hardware limit checkers:double conv, fp32 conv in 1980...
    checkList.push_back(make_unique<HardwareLimitChecker>());
    // temporary software limitation checkers: 3d conv
    checkList.push_back(make_unique<TemporarySoftwareLimitChecker>());

    for (auto& checker : checkList) {
        aclnnStatus ret = checker->Check(engine);
        CHECK_RET(ret == ACLNN_SUCCESS, ret);
    }
    return ACLNN_SUCCESS;
}

static inline aclnnStatus CheckParamsNullptrTbc(const string& entityName,
    const aclTensor* self, const aclTensor* weight, const aclTensor* bias, const aclTensor* output)
{
    CHECK_PARAM_NULLPTR(entityName, self, "x");
    CHECK_PARAM_NULLPTR(entityName, weight, "filter");
    CHECK_PARAM_NULLPTR(entityName, bias, "bias");
    CHECK_PARAM_NULLPTR(entityName, output, "y");

    return ACLNN_SUCCESS;
}

static aclnnStatus CheckOutputBiasShape(const string& entityName, const aclTensor* output, const aclTensor* bias)
{
    size_t outputDimNum = output->GetViewShape().GetDimNum();
    OP_CHECK_WRONG_DIMENSION(output, CONV_1D_DIM_SIZE, return false);

    for (size_t i = 0; i < outputDimNum; i++) {
        if (output->GetViewShape()[i] < 0) {
            OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(entityName, "y",
                op::FVectorToString(op::ToFVector(output->GetViewShape())),
                "shape[" + std::to_string(i) + "] of y must be greater than or equal to 0");
            return false;
        }
    }
    OP_CHECK_WRONG_DIMENSION(bias, 1, return false);

    if (bias->GetViewShape()[0] != output->GetViewShape()[L_DIM_NCL_INDEX]) {
        OP_LOGE_FOR_INVALID_SHAPES_WITH_REASON(entityName, "bias, y",
            op::FVectorToString(op::ToFVector(bias->GetViewShape())) + ", " +
            op::FVectorToString(op::ToFVector(output->GetViewShape())),
            "shape[0] of bias must be equal to shape[" + std::to_string(L_DIM_NCL_INDEX) + "] of y");
        return false;
    }
    return true;
}

static aclnnStatus CheckOutputBiasDtype(const aclTensor* output, const aclTensor* bias)
{
    auto dtypeSupportList = GetDtypeSupportListBySocVersion();
    OP_CHECK_DTYPE_NOT_SUPPORT(output, dtypeSupportList, return false);
    OP_CHECK_DTYPE_NOT_SUPPORT(bias, dtypeSupportList, return false);
    return true;
}

static aclnnStatus CheckOutputBiasFormat(const string& entityName, const aclTensor* output, const aclTensor* bias)
{
    if ((output->GetViewFormat() != op::Format::FORMAT_ND && output->GetViewFormat() != op::Format::FORMAT_NCL)) {
        OP_LOGE_FOR_INVALID_FORMAT(entityName, "y", GeFormatToString(output->GetViewFormat()),
            GeFormatToString(op::Format::FORMAT_ND) + " or " + GeFormatToString(op::Format::FORMAT_NCL));
        return false;
    }

    if (bias->GetViewFormat() != op::Format::FORMAT_ND) {
        OP_LOGE_FOR_INVALID_FORMAT(entityName, "bias", GeFormatToString(bias->GetViewFormat()),
            GeFormatToString(op::Format::FORMAT_ND));
        return false;
    }
    return true;
}

static inline aclnnStatus CheckParamsEmpty(const string& entityName, const aclTensor* output, const aclTensor* bias)
{
    CHECK_RET(CheckOutputBiasShape(entityName, output, bias), ACLNN_ERR_PARAM_INVALID);
    CHECK_RET(CheckOutputBiasDtype(output, bias), ACLNN_ERR_PARAM_INVALID);
    CHECK_RET(CheckOutputBiasFormat(entityName, output, bias), ACLNN_ERR_PARAM_INVALID);
    return ACLNN_SUCCESS;
}

static aclnnStatus ProcessBias(
    const aclTensor*& bias, const aclTensor* contiguousBias, const ConvolutionOpInfo& opInfo, bool transposed,
    aclOpExecutor* executor)
{
    if (bias != nullptr) {
        // cast
        auto castBias = l0op::Cast(contiguousBias, opInfo.biasDtype, executor);
        CHECK_RET(castBias != nullptr, ACLNN_ERR_INNER_NULLPTR);
        // transdata
        if (!transposed) {
            bias = castBias;
        } else {
            bias = l0op::ReFormat(castBias, opInfo.biasFormat);
            CHECK_RET(bias != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }
    }
    return ACLNN_SUCCESS;
}

// 实现公共数据预处理,将数据准备为L0可接受的形式
static aclnnStatus CommonPreProcess(const aclTensor*& input, const aclTensor*& weight, const aclTensor*& bias,
    const int64_t groups, const bool transposed, const ConvolutionOpInfo& opInfo, bool changeFormat, bool contiguous,
    aclOpExecutor* executor, bool inputDisContinuous = false)
{
    // 非连续转连续 + cast + transdata
    // input
    auto contiguousInput = input;
    auto contiguousWeight = weight;
    auto contiguousBias = bias;
    if (contiguous) {
        if (!inputDisContinuous) {
            contiguousInput = l0op::Contiguous(input, executor);
            CHECK_RET(contiguousInput != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }

        if (op::GetPrimaryFormat(weight->GetStorageFormat()) != op::Format::FORMAT_FRACTAL_Z) {
            contiguousWeight = l0op::Contiguous(weight, executor);
            CHECK_RET(contiguousWeight != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }
        if (bias != nullptr) {
            contiguousBias = l0op::Contiguous(bias, executor);
            CHECK_RET(contiguousBias != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }
    }
    // cast
    auto castedInput = l0op::Cast(contiguousInput, opInfo.inputDtype, executor);
    CHECK_RET(castedInput != nullptr, ACLNN_ERR_INNER_NULLPTR);

    if (changeFormat) {
        // input format transdata
        input = l0op::TransData(castedInput, opInfo.inputFormat, groups, executor);
        CHECK_RET(input != nullptr, ACLNN_ERR_INNER_NULLPTR);
    } else {
        input = castedInput;
    }
    // weight
    // cast
    auto castedWeight = l0op::Cast(contiguousWeight, opInfo.weightDtype, executor);
    CHECK_RET(castedWeight != nullptr, ACLNN_ERR_INNER_NULLPTR);

    if (changeFormat) {
        // weight format transdata
        weight = l0op::TransData(castedWeight, opInfo.weightFormat, groups, executor);
        CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);
    } else {
        weight = castedWeight;
    }

    // bias
    auto ret = ProcessBias(bias, contiguousBias, opInfo, transposed, executor);
    CHECK_RET(ret == ACLNN_SUCCESS, ret);

    return ACLNN_SUCCESS;
}

// 实现公共数据预处理,将数据准备为L0可接受的形式  C04特殊分支
static aclnnStatus CommonPreProcessC04(
    const aclTensor*& input, const aclTensor*& weight, const aclTensor*& bias, const int64_t groups,
    const bool transposed, const ConvolutionOpInfo& opInfo, bool changeFormat, bool contiguous, aclOpExecutor* executor)
{
    auto contiguousInput = input;
    auto contiguousWeight = weight;
    auto contiguousBias = bias;
    if (contiguous) {
        contiguousInput = l0op::Contiguous(input, executor);
        CHECK_RET(contiguousInput != nullptr, ACLNN_ERR_INNER_NULLPTR);
        contiguousWeight = l0op::Contiguous(weight, executor);
        CHECK_RET(contiguousWeight != nullptr, ACLNN_ERR_INNER_NULLPTR);

        if (bias != nullptr) {
            contiguousBias = l0op::Contiguous(bias, executor);
            CHECK_RET(contiguousBias != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }
    }

    auto castedInput = l0op::Cast(contiguousInput, opInfo.inputDtype, executor);
    CHECK_RET(castedInput != nullptr, ACLNN_ERR_INNER_NULLPTR);
    if (changeFormat) {
        // input format transdata
        input = l0op::TransData(castedInput, opInfo.inputFormat, groups, executor);
        CHECK_RET(input != nullptr, ACLNN_ERR_INNER_NULLPTR);
    } else {
        input = castedInput;
    }

    // weight特殊操作规避:NCHW(1, 1, 3, 3) -> pad -> NCHW(1, 4, 3, 3) -> transpose -> NHWC(1, 3, 3, 4)
    //                   -> ND(1, 36) -> transdata -> FZ_NZ -> setformat(FZC04)
    auto castedWeight = l0op::Cast(contiguousWeight, opInfo.weightDtype, executor);
    CHECK_RET(castedWeight != nullptr, ACLNN_ERR_INNER_NULLPTR);
    OP_LOGD("Before c04 weight is %s", weight->ToString().GetString());
    auto initialWeightViewShape = weight->GetViewShape();
    if (changeFormat) {
        // NCHW (a, b, c, d) -> padV3 -> NCHW(a, 4, c, d)  保证C维度pad到4
        if (weight->GetViewShape().GetDim(1) != 4) {
            const_cast<aclTensor*>(weight)->SetStorageShape(weight->GetViewShape());
            const_cast<aclTensor*>(weight)->SetOriginalShape(weight->GetViewShape());
            OP_LOGD("Before padv3 weight is %s", weight->ToString().GetString());
            // [0, 0, 0, 4 - b, 0, 0, 0, 0] padding  padding总长度为8
            int64_t paddingArray[8] = {}; // Already initialized to zeros
            // CO4需求,因此需要补充维度到4。预计维度数是4,padding中的第3位要pad到4,具体取1指代NCHW中的C
            paddingArray[3] = 4 - weight->GetViewShape().GetDim(1);
            aclIntArray* paddingArrayRes = executor->AllocIntArray(paddingArray, 8);
            CHECK_RET(paddingArrayRes != nullptr, ACLNN_ERR_INNER_NULLPTR);
            auto paddingTensor = executor->ConvertToTensor(paddingArrayRes, DataType::DT_INT32);
            auto constantValues = executor->ConvertToTensor(executor->AllocScalar(0), weight->GetDataType());
            weight = l0op::PadV3(weight, paddingTensor, constantValues, op::REFLECTION_MODE, true, executor);
            CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);
            OP_LOGD("After padv3 weight is %s", weight->ToString().GetString());
        }

        // NCHW(a, b, c, d) -> transpose -> NHWC(a, c, d, b)  因为NCHW,所以长度为4
        int64_t valuePerm[4] = {0, 1, 2, 3}; // Initialize directly
        // 1表示C,3表示W
        std::swap(valuePerm[1], valuePerm[3]); // b, c, d -> d, c, b
        // 1表示W,2表示H
        std::swap(valuePerm[1], valuePerm[2]); // d, c, b -> c, d, b
        auto perm = executor->AllocIntArray(valuePerm, 4);
        CHECK_RET(perm != nullptr, ACLNN_ERR_INNER_NULLPTR);
        weight = l0op::Transpose(weight, perm, executor);
        CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);
        const_cast<aclTensor*>(weight)->SetViewFormat(Format::FORMAT_NHWC);
        const_cast<aclTensor*>(weight)->SetStorageFormat(Format::FORMAT_NHWC);
        const_cast<aclTensor*>(weight)->SetOriginalFormat(Format::FORMAT_NHWC);
        OP_LOGD("After transpose weight is %s", weight->ToString().GetString());

        // NHWC (a, b, c, d) 转成 ND (a, b * c * d)
        auto weightShape = weight->GetViewShape();
        int64_t newFormatRes = weightShape.GetDim(1) * weightShape.GetDim(2) * weightShape.GetDim(3);
        op::Shape ndShape = op::Shape({weightShape.GetDim(0), newFormatRes});
        const_cast<aclTensor*>(weight)->SetStorageFormat(Format::FORMAT_ND);
        const_cast<aclTensor*>(weight)->SetOriginalFormat(Format::FORMAT_ND);
        const_cast<aclTensor*>(weight)->SetStorageShape(ndShape);
        const_cast<aclTensor*>(weight)->SetOriginalShape(ndShape);
        OP_LOGD("After update format weight is %s", weight->ToString().GetString());

        // transdata: NHWC -> FZ_NZ
        weight = l0op::TransData(weight, FORMAT_FRACTAL_NZ, groups, executor);
        CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);
        OP_LOGD("After transdata weight is %s", weight->ToString().GetString());
        const_cast<aclTensor*>(weight)->SetOriginalShape(initialWeightViewShape);
        const_cast<aclTensor*>(weight)->SetOriginalFormat(Format::FORMAT_NCHW);
        auto storageShape = weight->GetStorageShape();

        // reformat
        auto weightFormatC04 = executor->CreateView(weight, weight->GetViewShape(), weight->GetViewOffset());
        weight = weightFormatC04;
        const_cast<aclTensor*>(weight)->SetStorageFormat(Format::FORMAT_FRACTAL_Z_C04);
        const_cast<aclTensor*>(weight)->SetStorageShape(storageShape);
        const_cast<aclTensor*>(weight)->SetOriginalShape(initialWeightViewShape);
        const_cast<aclTensor*>(weight)->SetOriginalFormat(Format::FORMAT_NCHW);
        OP_LOGD("After reformat weight is %s", weight->ToString().GetString());
    } else {
        weight = castedWeight;
    }

    // bias
    auto ret = ProcessBias(bias, contiguousBias, opInfo, transposed, executor);
    CHECK_RET(ret == ACLNN_SUCCESS, ret);

    return ACLNN_SUCCESS;
}

enum class ConvToBmmMode : std::uint8_t
{
    CONV_NO_MM = 0,
    CONV_MM_FEATURE_MAP_EQ_FILTER = 1,
};

enum class ConvTranspose1DToBmmMode : std::uint8_t
{
    CONVTRANSPOSE1D_NO_MM = 0,
    CONVTRANSPOSE1D_MM_FEATURE_MAP_EQ_FILTER = 1,
};

// 实现公共数据后处理,将数据转换为L2输出,但并不做viewcopy
static aclnnStatus CommonPostProcess(
    const int64_t groups, bool changeFormat, const aclTensor* output, const aclTensor*& convOut,
    aclOpExecutor* executor)
{
    // output format transdata
    auto formatOutput = changeFormat ? l0op::TransData(convOut, output->GetStorageFormat(), groups, executor) : convOut;
    CHECK_RET(formatOutput != nullptr, ACLNN_ERR_INNER_NULLPTR);
    // output cast
    auto castedOutput = l0op::Cast(formatOutput, output->GetDataType(), executor);
    CHECK_RET(castedOutput != nullptr, ACLNN_ERR_INNER_NULLPTR);

    convOut = castedOutput;

    return ACLNN_SUCCESS;
}

static void UpdateOutputDtype(
    const aclTensor* output, struct ConvolutionOpInfo& opInfo, int8_t cubeMathType, DataType& upperDtype,
    const bool transposed)
{
    if (!transposed) {
        opInfo.outputDtype = upperDtype; // 目前conv2d算子底层二进制仅支持输入输出相同,暂不支持16进32出的场景
    } else {
        opInfo.outputDtype = (output->GetDataType() == op::DataType::DT_FLOAT) ? output->GetDataType() : upperDtype;
    }
    // ASCEND910 + ASCEND310P 仅支持fp16的卷积,或者USE_FP16场景必走FP16, 因此必须转为fp16实现
    SocVersion socVersion = GetCurrentPlatformInfo().GetSocVersion();
    if (socVersion == SocVersion::ASCEND910 || socVersion == SocVersion::ASCEND310P || cubeMathType == USE_FP16) {
        opInfo.outputDtype = op::DataType::DT_FLOAT16; // 目前底层二进制暂不支持16进32出的场景,故设为FP16运算
    }
    if ((upperDtype == op::DataType::DT_HIFLOAT8) && op::IsSupportND()) {
        opInfo.outputDtype =
            op::DataType::DT_HIFLOAT8; // In conv2d/3d/3dtranspose hif8 case, the output dtype should be hif8.
    }
    if ((upperDtype == op::DataType::DT_FLOAT8_E4M3FN) && op::IsSupportND()) {
        opInfo.outputDtype =
            op::DataType::DT_FLOAT8_E4M3FN; // In conv2d/3d/3dtranspose hif8 case, the output dtype should be hif8.
    }
}

static void UpdateInputDtype(
    const aclTensor* input, const aclTensor* bias, struct ConvolutionOpInfo& opInfo, DataType& upperDtype,
    const bool transposed)
{
    opInfo.inputDtype = upperDtype;
    opInfo.weightDtype = upperDtype;
    if (bias != nullptr) {
        SocVersion socVersion = GetCurrentPlatformInfo().GetSocVersion();
        bool isDAV3510 = GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510;
        if (transposed) {
            if (socVersion == SocVersion::ASCEND910B || socVersion == SocVersion::ASCEND910_93) {
                upperDtype = GetUpperFloatDataType(opInfo.outputDtype, bias->GetDataType());
            } else if (isDAV3510) {
                // Ascend950 force convert bias to float
                // L0C [conv3dBpInput(xDtype) + bias(float)] --> CAST [out(xDtype)]
                upperDtype = op::DataType::DT_FLOAT;
            }
        }
        opInfo.biasDtype = upperDtype;
        // 因为bias二进制不支持为BF16,所以得转成FP32
        if (upperDtype == op::DataType::DT_BF16 &&
            (!(transposed && input->GetViewShape().GetDimNum() == CONV_3D_DIM_SIZE)) && !op::IsSupportND()) {
            OP_LOGD("Since bias does not support BF16, change the dtype of bias to fp32.");
            opInfo.biasDtype = op::DataType::DT_FLOAT;
        }
        if ((upperDtype == op::DataType::DT_HIFLOAT8 || upperDtype == op::DataType::DT_FLOAT8_E4M3FN) &&
            (!transposed) && op::IsSupportND()) {
            OP_LOGD("Bias dtype must be fp32 in hifloat8 or float8_e4m3fn scene for conv forward.");
            opInfo.biasDtype = op::DataType::DT_FLOAT;
        }
    }
}

void GetConvolutionOpDtype(
    const aclTensor* input, const aclTensor* weight, const aclTensor* bias, aclTensor* output,
    struct ConvolutionOpInfo& opInfo, const bool transposed, int8_t cubeMathType)
{
    OP_LOGD("Get into GetConvolutionOpDtype Function.");
    DataType upperDtype = GetUpperFloatDataType(input->GetDataType(), weight->GetDataType());
    if (op::IsSupportND()) {
        upperDtype = CalcPromoteTypeCubeMathTypeNew(upperDtype, cubeMathType);
    } else {
        upperDtype = CalcPromoteTypeCubemathtype(upperDtype, cubeMathType);
    }
    UpdateOutputDtype(output, opInfo, cubeMathType, upperDtype, transposed);
    UpdateInputDtype(input, bias, opInfo, upperDtype, transposed);
}

} // namespace

namespace {

constexpr int STRIDEH_DMA = 63;
constexpr int DILATION_DMA = 255;
constexpr int PAD_DMA = 255;
constexpr int weight_DMA = 511;
constexpr int CONV_2D_DIMS_NUM = 4;
constexpr uint32_t CONV_1D_DIMS = 3;
constexpr uint32_t CONV_2D_DIMS = 4;
const size_t CONV2D_WHITE_LIST_CASE_SIZE = 16;
constexpr int STRIDE_WHITE_LIST_SIZE = 2;

struct Conv2DParams {
  const aclTensor *input;
  const aclTensor *weight;
  const aclIntArray *padding;
  const aclIntArray *dilation;
  int64_t groups;
};

const vector<vector<int64_t>> CONV2D_WHITE_LIST =
{
  {
    DataType::DT_BF16, // input data type
    1, 3, 1600, 3840,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1600, 4000,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1760, 4000,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1600, 4160,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1920, 4160,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2048, 4160,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2144, 4160,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 3520, 4736,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2144, 4672,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2592, 6368,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2912, 5696,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2880, 3840,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1280, 3264,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1760, 2336,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2048, 2048,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2016, 2016,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 960, 4288,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2720, 1504,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2720, 1536,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 832, 4928,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1568, 2624,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1536, 2720,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2336, 1760,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1440, 2880,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 2048, 2016,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 3, 1088, 3744,     // input shape
    64, 3, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 896, 656,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1008, 784,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 640, 1200,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 688, 1056,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 352, 304,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1024, 2496,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1072, 2096,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 512, 1504,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 2304, 1056,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 784, 496,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1472, 944,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1456, 1936,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 256, 112,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 480, 240,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 320, 320,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1632, 1152,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 816, 1888,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 512, 1088,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 288, 80,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 656, 304,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 432, 752,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1072, 736,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1712, 592,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1056, 848,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1200, 1968,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1168, 1936,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 2192, 1232,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 416, 752,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 624, 1488,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 2144, 1264,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 80, 256,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1280, 3248,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 176, 1456,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1456, 1248,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 208, 496,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 272, 240,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 736, 2720,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1248, 560,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 384, 944,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 608, 960,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 960, 656,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 640, 496,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 832, 1056,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1728, 1200,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 512, 688,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1376, 592,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1680, 1280,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1008, 1968,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 768, 1696,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 784, 1136,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1616, 1008,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 976, 944,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 752, 1104,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 992, 3008,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 512, 800,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1760, 432,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 896, 784,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 656, 480,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 2336, 1312,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 1088, 560,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 432, 1104,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 512, 704,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  },
  {
    DataType::DT_BF16, // input data type
    1, 256, 592, 880,     // input shape
    256, 256, 3, 3,        // filter shape
    1, 1, 1, 1,         // padding
    1, 1,               // dilation
    1,                  // groups
  }
};

static bool isNotDMAFromPad(bool isDMASpec, const aclIntArray* padding)
{
    if (padding->Size() == CONV_2D_PAD_DIM) {
        isDMASpec = isDMASpec || ((*padding)[0] > PAD_DMA) || ((*padding)[1] > PAD_DMA);
    } else if (padding->Size() == CONV_4D_PAD_DIM) {
        isDMASpec = isDMASpec || ((*padding)[0] > PAD_DMA) || ((*padding)[1] > PAD_DMA) ||
                    ((*padding)[PAD_LEFT_INDEX] > PAD_DMA) || ((*padding)[PAD_RIGHT_INDEX] > PAD_DMA);
    }
    return isDMASpec;
}
// padding = [pad_top, pad_bottom, pad_left, pad_right]
// 1. 不满足DMA的规格   2. load3d L1最小切分要在L1能够放下
static bool isNotDMA( const aclTensor* input, const aclTensor* weight, const aclTensor* bias, aclTensor* output,
    const aclIntArray* stride, const aclIntArray* padding, const aclIntArray* dilation,
    struct ConvolutionOpInfo* opInfo = nullptr)
{
    int64_t inputHeight = (int64_t)input->GetViewShape().GetDim(2);
    int64_t inputWidth = (int64_t)input->GetViewShape().GetDim(3);
    int64_t weightH = (int64_t)weight->GetViewShape().GetDim(2);
    int64_t weightW = (int64_t)weight->GetViewShape().GetDim(3);
    int64_t outputSize = (int64_t)output->GetViewShape().GetDimNum();
    int64_t outputW = (int64_t)output->GetViewShape().GetDim(2);
    if (outputSize == CONV_2D_DIM_SIZE) {
      outputW = static_cast<int64_t>(output->GetViewShape().GetDim(3)); // NCH(W) -> 012(3)
    }

    // CUBE_FP16的M, K
    constexpr int64_t BLK_M = 16;
    constexpr int64_t BLK_K = 16;
    constexpr int64_t BIT_L12BT_MIN_BIAS = 64;

    // 1. 不满足DMA的规格
    int64_t strideH = (*stride)[0];
    int64_t strideW = (*stride)[1];
    int64_t dilationH = (*dilation)[0];
    int64_t dilationW = (*dilation)[1];
    bool alignResult = ((weightH * weightW * 4 + BLK_K - 1) / BLK_K * BLK_K) <= 65535;
    OP_LOGD("alignResult is %d", alignResult);

    // stride wh <=63, dilation wh<=255, padding <=255, weight wh<=511, align(weightH * weightW * 4, BLK_K) <= 65535
    // 以上条件同时满足表示不满足DMA规格
    bool isDMASpec = (strideH > STRIDEH_DMA) || (strideW > STRIDEH_DMA) || (dilationH > DILATION_DMA) ||
                     (dilationW > DILATION_DMA) || (weightH > weight_DMA) || (weightW > weight_DMA);
    isDMASpec = isNotDMAFromPad(isDMASpec, padding);
    if (isDMASpec) {
        OP_LOGD("Fulfill DMA requirement,return False");
        return false;
    }

    if (IsSupportND()) {
        return !CheckDmaLimits(opInfo, input, weight, stride, padding, dilation, bias);
    }

    if (!alignResult) {
        return false;
    }

    // 2. load3d L1最小切分要在L1能够放下
    int64_t hoNum = BLK_M / outputW + 2;
    int64_t hkDilation = (weightH - 1) * dilationH + 1;
    int64_t hiNum = std::min(((hoNum - 1) * strideH + hkDilation), inputHeight);
    int64_t wiL1 = (int64_t)input->GetViewShape().GetDim(3);
    int64_t hiL1 = hiNum;

    // input_height = 1 & weight_height = 1 & pad_top = 0 & pad bottom = 0
    bool isConv1d = (inputHeight == 1) && (weightH == 1) && ((*padding)[0] == 0) && ((*padding)[1] == 0);
    OP_LOGD("isConv1d is %d", isConv1d);
    if (isConv1d) {
        int64_t woNum = BLK_M;
        int64_t wkDilation = (weightW - 1) * dilationW + 1;
        wiL1 = std::min(((woNum - 1) * strideW + wkDilation), inputWidth);
    }

    // 非Conv1d时width <= 32767才能走C04
    constexpr int64_t WIDTH_THRESSHOLD = 32767;
    if (!isConv1d && inputWidth > WIDTH_THRESSHOLD) {
        OP_LOGD("when not conv1d scene, inputWidth[%ld] > 32767", inputWidth);
        return false;
    }
    int64_t hiwiMul = wiL1 * hiL1;
    constexpr int64_t c0OnL1 = 4;
    uint64_t maxL1Size = static_cast<uint64_t>(hiwiMul * c0OnL1 * 2); // dataTypeToByte(FP16)为2
    maxL1Size = (bias != nullptr) ? maxL1Size + BIT_L12BT_MIN_BIAS : maxL1Size;
    OP_LOGD("maxL1Size is %lu", maxL1Size);

    // hardwareInfo.l1size 910B 为524288
    return maxL1Size <= 524288U;
}

static bool CanSwitchC04InBF16Scene(const struct ConvolutionOpInfo& opInfo)
{
    if (opInfo.weightDtype == op::DataType::DT_BF16 &&
        (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910_93 ||
         GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND910B)) {
        return true;
    }
    return false;
}

static bool CanSwitchC04InF16Scene(const struct ConvolutionOpInfo& opInfo)
{
    // datatype为float16,且socversion为910B和910_93
    if (opInfo.weightDtype == op::DataType::DT_FLOAT16 && IsCubeSupportFp32()) {
        return true;
    }
    return false;
}

// 1. groups为1  2. Cin<=4  3. dtype为FP16 4. 必须是910B芯片 5. 非DMA场景   同时满足才能走c04
static bool CanSwitchC04(
    const aclTensor* input, const aclTensor* weight, const aclTensor* bias, aclTensor* output,
    const aclIntArray* stride, const aclIntArray* padding, const aclIntArray* dilation, int64_t groups, bool transposed)
{
    // 必须为非transpose场景 + format为NCHW才行
    if (transposed || input->GetViewFormat() != Format::FORMAT_NCHW) {
        OP_LOGD("input is not NCHW or is transposed, thus no C04");
        return false;
    }

    int64_t cin = input->GetViewShape().GetDim(1);
    // groups数量必须为1, 并且C04场景必须Cin为4,非DMA可直接切c04
    if ((groups == 1) && (cin <= op::SMALL_CHANNEL && cin > 0)) {
        return isNotDMA(input, weight, bias, output, stride, padding, dilation);
    }

    OP_LOGD("Not fulfill the requirements for C04");
    return false;
}

// 非C04场景的更新 卷积format
void GetConvolutionOpFormat(
    struct ConvolutionOpInfo& opInfo, const aclTensor* input, const aclTensor* weight, const aclTensor* output)
{
    opInfo.weightFormat = Format::FORMAT_FRACTAL_Z;
    opInfo.inputFormat = Format::FORMAT_NC1HWC0;
    opInfo.outputFormat = Format::FORMAT_NC1HWC0;
    opInfo.biasFormat = Format::FORMAT_ND;
    if (op::IsSupportND()) {
        opInfo.weightFormat =
            weight->GetStorageFormat() == Format::FORMAT_NHWC ? Format::FORMAT_NHWC : Format::FORMAT_NCHW;
        opInfo.inputFormat =
            input->GetStorageFormat() == Format::FORMAT_NHWC ? Format::FORMAT_NHWC : Format::FORMAT_NCHW;
        opInfo.outputFormat =
            output->GetStorageFormat() == Format::FORMAT_NHWC ? Format::FORMAT_NHWC : Format::FORMAT_NCHW;
    }
}

void GetConvolution3dOpFormat(
    struct ConvolutionOpInfo& opInfo, const aclTensor* input, const aclTensor* weight, const aclTensor* output)
{
    opInfo.weightFormat = Format::FORMAT_FRACTAL_Z_3D;
    opInfo.inputFormat = Format::FORMAT_NDC1HWC0;
    opInfo.outputFormat = Format::FORMAT_NDC1HWC0;
    opInfo.biasFormat = Format::FORMAT_ND;
    if (op::IsSupportND()) {
        opInfo.weightFormat =
            weight->GetStorageFormat() == Format::FORMAT_NDHWC ? Format::FORMAT_NDHWC : Format::FORMAT_NCDHW;
        opInfo.inputFormat =
            input->GetStorageFormat() == Format::FORMAT_NDHWC ? Format::FORMAT_NDHWC : Format::FORMAT_NCDHW;
        opInfo.outputFormat =
            output->GetStorageFormat() == Format::FORMAT_NDHWC ? Format::FORMAT_NDHWC : Format::FORMAT_NCDHW;
    }
}

void GetConvolutionOpFormatC04(struct ConvolutionOpInfo& opInfo)
{
    opInfo.weightFormat = Format::FORMAT_FRACTAL_Z_C04;
    opInfo.inputFormat = Format::FORMAT_NC1HWC0;
    opInfo.outputFormat = Format::FORMAT_NC1HWC0;
    opInfo.biasFormat = Format::FORMAT_ND;
}

static void AddAclIntArrayToCaseInfo(const aclIntArray &seg, vector<int64_t> &caseInfo)
{
  size_t len = seg.Size();
  for (size_t i = 0; i < len; i++) {
    caseInfo.push_back(seg[i]);
  }
}

static void AddTensorShapeToCaseInfo(const aclTensor &seg, vector<int64_t> &caseInfo)
{
  auto segShape = seg.GetViewShape();
  size_t dimNum = segShape.GetDimNum();
  for (size_t i=0; i < dimNum; i++) {
    caseInfo.push_back(segShape.GetDim(i));
  }
}

static void ConstructCaseInfo(const Conv2DParams &params, vector<int64_t> &caseInfo)
{
  caseInfo.reserve(CONV2D_WHITE_LIST_CASE_SIZE);
  auto inputDataType = params.input->GetDataType();
  caseInfo.push_back(static_cast<int64_t>(inputDataType));
  AddTensorShapeToCaseInfo(*(params.input), caseInfo);
  AddTensorShapeToCaseInfo(*(params.weight), caseInfo);
  AddAclIntArrayToCaseInfo(*(params.padding), caseInfo);
  AddAclIntArrayToCaseInfo(*(params.dilation), caseInfo);
  caseInfo.push_back(params.groups);
}

static bool IsConv2DWhiteListCase(const vector<int64_t> &caseInfo, const vector<vector<int64_t>> &whiteList, const aclIntArray &stride)
{
  if (stride.Size() != STRIDE_WHITE_LIST_SIZE) {
    return false;
  }
  int64_t h = stride[0];
  int64_t w = stride[1];
  bool isStrideRight = (h == 1 && w == 1) || (h == 2 && w == 2);
  for (auto it = whiteList.begin(); it != whiteList.end(); ++it) {
    if (*it == caseInfo) {
      return isStrideRight;
    }
  }
  return false;
}

// 更新convolution所需要的dtype format
void GetConvOpInfo(
    const aclTensor* input, const aclTensor* weight, const aclTensor* bias, aclTensor* output,
    struct ConvolutionOpInfo& opInfo, const bool transposed, int64_t groups, const aclIntArray* stride,
    const aclIntArray* padding, const aclIntArray* dilation, int8_t cubeMathType)
{
    GetConvolutionOpDtype(input, weight, bias, output, opInfo, transposed, cubeMathType);
    // 支持C04 + NCHW + 非transposed的场景
    op::Shape inputSpecialShape = op::Shape({320, 3, 224, 224}); // 客户专用场景
    op::Shape weightSpecialShape = op::Shape({768, 3, 32, 32});  // 客户专用场景
    vector<int64_t> caseInfo2d;
    Conv2DParams params = {input, weight, padding, dilation, groups};
    ConstructCaseInfo(params, caseInfo2d);
    bool isConv2DWhiteListCase = IsConv2DWhiteListCase(caseInfo2d, CONV2D_WHITE_LIST, *stride);
    if (!isConv2DWhiteListCase && (weight->GetViewShape() == weightSpecialShape) &&
        (input->GetViewShape() == inputSpecialShape) &&
        CanSwitchC04InF16Scene(opInfo) &&
        CanSwitchC04(input, weight, bias, output, stride, padding, dilation, groups, transposed)) {
        OP_LOGD("Entering float16 C04 branch");
        GetConvolutionOpFormatC04(opInfo);
    } else if (
        !isConv2DWhiteListCase &&
        CanSwitchC04InBF16Scene(opInfo) &&
        CanSwitchC04(input, weight, bias, output, stride, padding, dilation, groups, transposed)) {
        OP_LOGD("Entering bfloat16 C04 branch");
        GetConvolutionOpFormatC04(opInfo);
    } else {
        Conv2DSplitWInfo conv2dInfo;
        conv2dInfo.InitConv2DSplitWInfo(input, weight, stride, padding, dilation);
        if (isConv2DWhiteListCase || conv2dInfo.CanSwitchSplitW(bias, output, groups, opInfo)) {
            OP_LOGD("Entering splitW branch");
            GetConvolution3dOpFormat(opInfo, input, weight, output);
        } else {
            OP_LOGD("Entering normal C04 branch");
            GetConvolutionOpFormat(opInfo, input, weight, output);
        }
    }
}

void GetConv3dOpInfo(
    const aclTensor* input, const aclTensor* weight, const aclTensor* bias, aclTensor* output,
    struct ConvolutionOpInfo& opInfo, const bool transposed, int8_t cubeMathType)
{
    GetConvolutionOpDtype(input, weight, bias, output, opInfo, transposed, cubeMathType);
    GetConvolution3dOpFormat(opInfo, input, weight, output);
}

static aclIntArray* ViewConv1dPad1dAs4d(const aclIntArray* intArray, aclOpExecutor* executor)
{
    constexpr uint64_t newDimSize = 4;
    int64_t data[newDimSize];
    data[0] = 0;
    data[1] = 0;
    data[PAD_LEFT_INDEX] = (*intArray)[0];
    data[PAD_RIGHT_INDEX] = (*intArray)[0];
    aclIntArray* newArray = executor->AllocIntArray(data, newDimSize);
    return newArray;
}

static aclIntArray* ViewConv1dPad2dAs4d(const aclIntArray* intArray, aclOpExecutor* executor)
{
    constexpr uint64_t newDimSize = 4;
    int64_t data[newDimSize];
    data[0] = 0;
    data[1] = 0;
    data[PAD_LEFT_INDEX] = (*intArray)[0];
    data[PAD_RIGHT_INDEX] = (*intArray)[1];
    aclIntArray* newArray = executor->AllocIntArray(data, newDimSize);
    return newArray;
}

static aclIntArray* ViewConv2dPad2dAs4d(const aclIntArray* intArray, aclOpExecutor* executor)
{
    constexpr uint64_t newDimSize = 4;
    int64_t data[newDimSize];
    data[0] = (*intArray)[0];
    data[1] = (*intArray)[0];
    data[PAD_LEFT_INDEX] = (*intArray)[1];
    data[PAD_RIGHT_INDEX] = (*intArray)[1];
    aclIntArray* newArray = executor->AllocIntArray(data, newDimSize);
    return newArray;
}

aclIntArray* View1dAs2d(const aclIntArray* intArray, int64_t expandValue, aclOpExecutor* executor)
{
    //将1维改成2维
    constexpr uint64_t newDimSize = 2;
    int64_t data[newDimSize];
    uint64_t size = intArray->Size();
    if (size != static_cast<uint64_t>(1)) {
        return nullptr;
    }
    data[0] = expandValue;
    data[1] = (*intArray)[0];
    aclIntArray* newArray = executor->AllocIntArray(data, newDimSize);
    return newArray;
}

aclIntArray* View1dAs2dw(const aclIntArray* intArray, int64_t expandValue, aclOpExecutor* executor)
{
    //将1维改成2维
    constexpr uint64_t newDimSize = 2;
    int64_t data[newDimSize];
    uint64_t size = intArray->Size();
    if (size != static_cast<uint64_t>(1)) {
        return nullptr;
    }
    data[0] = (*intArray)[0];
    data[1] = expandValue;
    aclIntArray* newArray = executor->AllocIntArray(data, newDimSize);
    return newArray;
}

aclIntArray* ViewValueAs1d(const int64_t value, aclOpExecutor* executor)
{
    int64_t data[1];
    data[0] = value;
    aclIntArray* newArray = executor->AllocIntArray(data, 1);
    return newArray;
}

const aclTensor* View1dAs4d(const aclTensor* input, aclOpExecutor* executor)
{
    // input NCL->contiguous->unsqueeze(2)->reformat NCHW
    // 非连续转连续contiguous
    auto contiguousInput = l0op::Contiguous(input, executor);
    CHECK_RET(contiguousInput != nullptr, nullptr);

    // unsqueeze(2)
    constexpr int64_t appendDim[] = {0, 2, 3};
    aclIntArray* dim = executor->AllocIntArray(appendDim, 3);
    auto unsqueezedInput = l0op::UnsqueezeNd(contiguousInput, dim, executor);
    CHECK_RET(unsqueezedInput != nullptr, nullptr);

    // reformat
    auto reformatInput = l0op::ReFormat(unsqueezedInput, op::Format::FORMAT_NCHW);
    CHECK_RET(reformatInput != nullptr, nullptr);

    return reformatInput;
}

static const aclTensor* View3dAs4d(const aclTensor* input, aclOpExecutor* executor)
{
    // input NCL->contiguous->unsqueeze(2)->reformat NCHW
    // 非连续转连续contiguous
    auto contiguousInput = l0op::Contiguous(input, executor);
    CHECK_RET(contiguousInput != nullptr, nullptr);

    // unsqueeze(2)
    constexpr int64_t appendDim[] = {2};
    aclIntArray* dim = executor->AllocIntArray(appendDim, 1);
    auto unsqueezedInput = l0op::UnsqueezeNd(contiguousInput, dim, executor);
    CHECK_RET(unsqueezedInput != nullptr, nullptr);

    // reformat
    auto reformatInput = l0op::ReFormat(unsqueezedInput, op::Format::FORMAT_NCHW);
    CHECK_RET(reformatInput != nullptr, nullptr);

    return reformatInput;
}

static const aclTensor* View3dAs4dw(const aclTensor* input, aclOpExecutor* executor)
{
    // input NCL->contiguous->unsqueeze(2)->reshape->reformat NCHW
    // 非连续转连续contiguous
    auto contiguousInput = l0op::Contiguous(input, executor);
    CHECK_RET(contiguousInput != nullptr, nullptr);

    // unsqueeze(2) 扩w维度
    constexpr int64_t appendDim[] = {2};
    aclIntArray* dim = executor->AllocIntArray(appendDim, 1);
    auto unsqueezedInput = l0op::UnsqueezeNd(contiguousInput, dim, executor);
    CHECK_RET(unsqueezedInput != nullptr, nullptr);

    auto dims = unsqueezedInput->GetViewShape().GetDimNum();
    CHECK_RET(dims == CONV_2D_DIMS_NUM, nullptr);
    auto shape = op::ToShapeVector(unsqueezedInput->GetViewShape());
    FVector<int64_t> newShape = {shape[0], shape[1], shape[3], shape[2]};
    aclIntArray* shapeArray = executor->AllocIntArray(newShape.data(), newShape.size());
    CHECK_RET(shapeArray != nullptr, nullptr);
    unsqueezedInput = l0op::Reshape(unsqueezedInput, shapeArray, executor);
    CHECK_RET(unsqueezedInput != nullptr, nullptr);

    // reformat
    auto reformatInput = l0op::ReFormat(unsqueezedInput, op::Format::FORMAT_NCHW);
    CHECK_RET(reformatInput != nullptr, nullptr);

    return reformatInput;
}

static const aclTensor* View4dAs3d(const aclTensor* input, aclOpExecutor* executor)
{
    // input NCL->contiguous->unsqueeze(2)->reformat NCHW
    // 非连续转连续contiguous
    auto contiguousInput = l0op::Contiguous(input, executor);
    CHECK_RET(contiguousInput != nullptr, nullptr);
    // squeeze(2)
    constexpr int64_t appendDim[] = {2};
    aclIntArray* dim = executor->AllocIntArray(appendDim, 1);
    CHECK_RET(dim != nullptr, nullptr);
    auto squeezedInput = l0op::SqueezeNd(contiguousInput, dim, executor);
    CHECK_RET(squeezedInput != nullptr, nullptr);

    // reformat
    auto reformatInput = l0op::ReFormat(squeezedInput, op::Format::FORMAT_NCL);
    CHECK_RET(reformatInput != nullptr, nullptr);

    return reformatInput;
}

static const aclTensor* View4dAs3dw(const aclTensor* input, aclOpExecutor* executor)
{
    // input NCL->contiguous->Reshape->unsqueeze(2)->reformat NCHW
    // 非连续转连续contiguous
    auto contiguousInput = l0op::Contiguous(input, executor);
    CHECK_RET(contiguousInput != nullptr, nullptr);

    auto dims = input->GetViewShape().GetDimNum();
    CHECK_RET(dims == CONV_2D_DIMS_NUM, nullptr);
    auto shape = op::ToShapeVector(contiguousInput->GetViewShape());
    FVector<int64_t> newShape = {shape[0], shape[1], shape[3], shape[2]};
    aclIntArray* shapeArray = executor->AllocIntArray(newShape.data(), newShape.size());
    CHECK_RET(shapeArray != nullptr, nullptr);
    contiguousInput = l0op::Reshape(contiguousInput, shapeArray, executor);
    CHECK_RET(contiguousInput != nullptr, nullptr);
    // squeeze(3)
    constexpr int64_t appendDim[] = {2};
    aclIntArray* dim = executor->AllocIntArray(appendDim, 1);
    CHECK_RET(dim != nullptr, nullptr);
    auto squeezedInput = l0op::SqueezeNd(contiguousInput, dim, executor);
    CHECK_RET(squeezedInput != nullptr, nullptr);

    // reformat
    auto reformatInput = l0op::ReFormat(squeezedInput, op::Format::FORMAT_NCL);
    CHECK_RET(reformatInput != nullptr, nullptr);

    return reformatInput;
}

static const aclTensor* Permute(const aclTensor* input, FVector<int64_t> dims, aclOpExecutor* executor)
{
    // contiguous
    auto contiguousInput = l0op::Contiguous(input, executor);
    CHECK_RET(contiguousInput != nullptr, nullptr);
    // Transpose
    auto* perm = executor->AllocIntArray(dims.data(), dims.size());
    CHECK_RET(perm != nullptr, nullptr);

    auto* result = l0op::Transpose(contiguousInput, perm, executor);
    CHECK_RET(result != nullptr, nullptr);

    return result;
}

static inline const aclTensor* ViewWithShape(const aclTensor* tensor, const op::Shape& shape, aclOpExecutor* executor)
{
    if (shape == tensor->GetViewShape() && shape == tensor->GetStorageShape()) {
        return tensor;
    }
    return executor->CreateView(tensor, shape, tensor->GetViewOffset());
}

static aclnnStatus CheckConv2dWithWeightFZ(const string& entityName, const aclTensor* input, const aclTensor* weight)
{
    if (weight->GetStorageFormat() != Format::FORMAT_FRACTAL_Z) {
        return ACLNN_SUCCESS;
    }
    if (GetCurrentPlatformInfo().GetSocVersion() != SocVersion::ASCEND310P) {
        OP_LOGE_FOR_INVALID_FORMAT_WITH_REASON(entityName, "filter",
            GeFormatToString(Format::FORMAT_FRACTAL_Z),
            "The value of this parameter can be " + GeFormatToString(Format::FORMAT_FRACTAL_Z) +
            " only when the SoC version is " + std::string(op::ToString(SocVersion::ASCEND310P).GetString()));
        return ACLNN_ERR_PARAM_INVALID;
    }
    if (input->GetDataType() != weight->GetDataType()) {
        OP_LOGE_FOR_INVALID_DTYPES_WITH_REASON(entityName, "x, filter",
            GeDtypeToString(input->GetDataType()) + ", " + GeDtypeToString(weight->GetDataType()),
            "the dtypes of x and filter must be the same");
        return ACLNN_ERR_PARAM_INVALID;
    }
    return ACLNN_SUCCESS;
}

bool isSupportInputHWNC(const aclTensor* input, const ConvolutionOpInfo& opInfo, const int64_t groups)
{
    if (!IsSupportND()) {
        return false;
    }
    if (groups > 1) {
        return false;
    }

    if (opInfo.inputDtype != ge::DataType::DT_FLOAT16 && opInfo.inputDtype != ge::DataType::DT_BF16 &&
        opInfo.inputDtype != ge::DataType::DT_FLOAT) {
        return false;
    }

    int64_t batch = input->GetViewShape().GetDim(N_DIM_NCHW_INDEX);
    int64_t cin = input->GetViewShape().GetDim(C_DIM_NCHW_INDEX);
    if (batch * cin > UINT16_MAX) {
        return false;
    }

    return true;
}

static aclnnStatus CommonPostProcessForBmm(const aclTensor* output, const aclTensor*& convOut, aclOpExecutor* executor)
{
    // ND --> NCDHW
    auto viewOutput = ViewWithShape(convOut, output->GetViewShape(), executor);
    CHECK_RET(viewOutput != nullptr, ACLNN_ERR_INNER_NULLPTR);
    // output reformat
    auto formatOutput = l0op::ReFormat(viewOutput, output->GetViewFormat(), executor);
    CHECK_RET(formatOutput != nullptr, ACLNN_ERR_INNER_NULLPTR);

    auto result = l0op::ViewCopy(formatOutput, output, executor);
    CHECK_RET(result != nullptr, ACLNN_ERR_PARAM_NULLPTR);

    return ACLNN_SUCCESS;
}

static const aclTensor* ViewWithShapeAndReformatND(
    const aclTensor* tensor, const std::initializer_list<int64_t>& shape, aclOpExecutor* executor)
{
    op::Shape shapeBMN = op::Shape(shape);
    auto tensorBMN = ViewWithShape(tensor, shapeBMN, executor);
    CHECK_RET(tensorBMN != nullptr, nullptr);
    return l0op::ReFormat(tensorBMN, op::Format::FORMAT_ND);
}

static int64_t CalcCountByAxisVec(const op::Shape& dataShape, const vector<int64_t>& axisVec)
{
    int64_t count = 1;
    for (auto axis : axisVec) {
        count *= dataShape[axis];
    }
    return count;
}

static aclnnStatus GetAndCastConvolutionOpDtype(ConvEngine& engine, aclOpExecutor* executor)
{
    ConvolutionOpInfo opInfo = {};
    GetConvolutionOpDtype(
        engine.params.input, engine.params.weight, engine.params.bias, engine.params.output, opInfo,
        engine.params.transposed, engine.params.cubeMathType);
    opInfo.biasFormat = Format::FORMAT_ND;
 	OP_LOGD("Reset bias format=%s", op::ToString(opInfo.biasFormat).GetString());
    return CommonPreProcess(
        engine.params.input, engine.params.weight, engine.params.bias, engine.params.groups, engine.params.transposed,
        opInfo, false, true, executor);
}

static aclnnStatus GenInOutByConvToBmm(
    ConvEngine engine, const ConvToBmmMode& convToBmmMode, BatchMatmulInput& bmmInput, aclOpExecutor* executor)
{
    auto ret = GetAndCastConvolutionOpDtype(engine, executor);
    CHECK_RET(ret == ACLNN_SUCCESS, ACLNN_ERR_INNER_NULLPTR);
    if (convToBmmMode == ConvToBmmMode::CONV_MM_FEATURE_MAP_EQ_FILTER) {
        const vector<int64_t> cidhwIdxUnionVec{
            CI_DIM_CO_CI_DHW_INDEX, D_DIM_NCDHW_INDEX, H_DIM_NCDHW_INDEX, W_DIM_NCDHW_INDEX};
        const vector<int64_t> cihwIdxUnionVec{CI_DIM_CO_CI_DHW_INDEX, H_DIM_NCHW_INDEX, W_DIM_NCHW_INDEX};
        const auto& dimIdxUnionVec =
            (engine.meta.input.format == op::Format::FORMAT_NCHW) ? cihwIdxUnionVec : cidhwIdxUnionVec;
        // weight --> [1, Co, CiDHW]
        std::initializer_list<int64_t> weightShapeVec = {
            1, engine.meta.weight.N(), CalcCountByAxisVec(engine.meta.weight.tensorShape, dimIdxUnionVec)};
        auto weightND = ViewWithShapeAndReformatND(engine.params.weight, weightShapeVec, executor);
        CHECK_RET(weightND != nullptr, ACLNN_ERR_INNER_NULLPTR);
        // input --> [1, N, CiDHW]
        std::initializer_list<int64_t> inputShapeVec = {
            1, engine.meta.input.N(), CalcCountByAxisVec(engine.meta.input.tensorShape, dimIdxUnionVec)};
        auto inputND = ViewWithShapeAndReformatND(engine.params.input, inputShapeVec, executor);
        CHECK_RET(inputND != nullptr, ACLNN_ERR_INNER_NULLPTR);
        auto biasND = engine.params.bias;
        if (biasND != nullptr && engine.meta.bias.format != op::Format::FORMAT_ND) {
            biasND = l0op::ReFormat(engine.params.bias, op::Format::FORMAT_ND);
            CHECK_RET(biasND != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }
        // output --> [1, N, Co]
        std::initializer_list<int64_t> outputShapeVec = {1, engine.meta.output.N(), engine.meta.output.C()};
        auto outputND = ViewWithShapeAndReformatND(engine.params.output, outputShapeVec, executor);
        CHECK_RET(outputND != nullptr, ACLNN_ERR_INNER_NULLPTR);

        bmmInput.leftData = inputND;
        bmmInput.isLeftTranspose = false;
        bmmInput.rightData = weightND;
        bmmInput.isRightTranspose = true;
        bmmInput.biasData = biasND;
        bmmInput.outputData = outputND;
    }
    OP_LOGD(
        "convolution to batchmatmul op leftDataType: %s, rightDataType: %s.",
        op::ToString(bmmInput.leftData->GetDataType()).GetString(),
        op::ToString(bmmInput.rightData->GetDataType()).GetString());

    return ACLNN_SUCCESS;
}

static aclnnStatus GenInOutByConvTranspose1DToBmm(
    ConvEngine engine, const ConvTranspose1DToBmmMode& convTranspose1DToBmmMode, BatchMatmulInput& bmmInput,
    aclOpExecutor* executor)
{
    auto ret = GetAndCastConvolutionOpDtype(engine, executor);
    CHECK_RET(ret == ACLNN_SUCCESS, ACLNN_ERR_INNER_NULLPTR);
    if (convTranspose1DToBmmMode == ConvTranspose1DToBmmMode::CONVTRANSPOSE1D_MM_FEATURE_MAP_EQ_FILTER) {
        // The format of 1d is NCL, get n, cin, cout, l
        auto n = engine.meta.input.shape[N_DIM_NCL_INDEX];
        auto inChannels = engine.meta.input.shape[C_DIM_NCL_INDEX];
        auto outChannels = engine.meta.output.shape[C_DIM_NCL_INDEX];
        auto l = engine.meta.weight.shape[L_DIM_NCL_INDEX];
        // input shape [n, cin, 1] reshape to [n, cin]
        auto inputND = ViewWithShapeAndReformatND(engine.params.input, {n, inChannels}, executor);
        CHECK_RET(inputND != nullptr, ACLNN_ERR_INNER_NULLPTR);
        // weight shape [cin, cout, l] reshape to [cin, cout * l]
        auto weightND = ViewWithShapeAndReformatND(engine.params.weight, {inChannels, outChannels * l}, executor);
        CHECK_RET(weightND != nullptr, ACLNN_ERR_INNER_NULLPTR);
        auto biasND = engine.params.bias;
        if (biasND != nullptr && engine.meta.bias.format != op::Format::FORMAT_ND) {
            biasND = l0op::ReFormat(engine.params.bias, op::Format::FORMAT_ND);
            CHECK_RET(biasND != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }
        // output shape [n, cout, l] reshape to [n, cout * l]
        auto outputND = ViewWithShapeAndReformatND(engine.params.output, {n, outChannels * l}, executor);
        CHECK_RET(outputND != nullptr, ACLNN_ERR_INNER_NULLPTR);

        bmmInput.leftData = inputND;
        bmmInput.isLeftTranspose = false;
        bmmInput.rightData = weightND;
        bmmInput.isRightTranspose = false;
        bmmInput.biasData = biasND;
        bmmInput.outputData = outputND;
    }
    return ACLNN_SUCCESS;
}

static bool IsSupportConvTranspose1DToBmm(ConvEngine engine)
{
    if (GetCurrentPlatformInfo().GetCurNpuArch() != NpuArch::DAV_3510) {
        return false;
    }
    if (engine.meta.input.format != op::Format::FORMAT_NCL) {
        return false;
    }
    if (IsSupportND() && (engine.meta.input.dataType == op::DataType::DT_HIFLOAT8 ||
                          engine.meta.input.dataType == op::DataType::DT_FLOAT8_E4M3FN)) {
        return false;
    }
    if (!engine.params.transposed || engine.params.groups != 1) {
        return false;
    }
    if (engine.meta.padding[0] != 0 || engine.meta.outputPadding[0] != 0 || engine.meta.dilation[0] != 1) {
        return false;
    }
    if (engine.meta.weight.shape[L_DIM_NCL_INDEX] != 1 && engine.params.bias != nullptr) {
        OP_LOGD("L_DIM of weight is not 1");
        return false;
    }
    return true;
}

static ConvTranspose1DToBmmMode GetConvTranspose1DToBmmMode(ConvEngine engine)
{
    if (!IsSupportConvTranspose1DToBmm(engine)) {
        return ConvTranspose1DToBmmMode::CONVTRANSPOSE1D_NO_MM;
    }
    if (engine.meta.input.shape[L_DIM_NCL_INDEX] == 1 &&
        engine.meta.output.shape[L_DIM_NCL_INDEX] == engine.meta.weight.shape[L_DIM_NCL_INDEX]) {
        return ConvTranspose1DToBmmMode::CONVTRANSPOSE1D_MM_FEATURE_MAP_EQ_FILTER;
    }
    return ConvTranspose1DToBmmMode::CONVTRANSPOSE1D_NO_MM;
}

static bool IsSupportConvToBmm(ConvEngine engine)
{
    SocVersion socVersion = GetCurrentPlatformInfo().GetSocVersion();
    bool isNotDAV3510 = GetCurrentPlatformInfo().GetCurNpuArch() != NpuArch::DAV_3510;
    if (socVersion != SocVersion::ASCEND910B && socVersion != SocVersion::ASCEND910_93 &&
        isNotDAV3510) {
        return false;
    }

    if ((socVersion == SocVersion::ASCEND910B || socVersion == SocVersion::ASCEND910_93) &&
        engine.meta.input.format != op::Format::FORMAT_NCDHW && engine.meta.input.format != op::Format::FORMAT_NCHW) {
        return false;
    }

    if (op::IsSupportND()) {
        if (engine.meta.input.format != op::Format::FORMAT_NCDHW &&
            engine.meta.input.format != op::Format::FORMAT_NCHW) {
            return false;
        }
        if (engine.meta.input.dataType == op::DataType::DT_HIFLOAT8) {
            return false;
        }
    }
    // padding
    for (uint64_t paddingIdx = 0; paddingIdx < engine.meta.padding.size(); ++paddingIdx) {
        if (engine.meta.padding[paddingIdx] != 0) {
            return false;
        }
    }
    // dilation
    for (uint64_t dilationIdx = 0; dilationIdx < engine.meta.dilation.size(); ++dilationIdx) {
        if (engine.meta.dilation[dilationIdx] != 1) {
            return false;
        }
    }
    // other attribute
    if (engine.params.transposed || engine.params.groups != 1) {
        return false;
    }
    return true;
}

static ConvToBmmMode GetConvToBmmMode(ConvEngine engine)
{
    if (!IsSupportConvToBmm(engine)) {
        return ConvToBmmMode::CONV_NO_MM;
    }

    bool isFmapEqFilter = true;
    const std::vector<int64_t> dimIdxVecNcdhw{D_DIM_NCDHW_INDEX, H_DIM_NCDHW_INDEX, W_DIM_NCDHW_INDEX};
    const std::vector<int64_t> dimIdxVecNchw{H_DIM_NCHW_INDEX, W_DIM_NCHW_INDEX};
    const auto& dimIdxVec = (engine.meta.input.format == op::Format::FORMAT_NCHW) ? dimIdxVecNchw : dimIdxVecNcdhw;
    for (int64_t dimIdx : dimIdxVec) {
        if (engine.meta.input.shape[dimIdx] != engine.meta.weight.shape[dimIdx]) {
            isFmapEqFilter = false;
            break;
        }
    }
    if (isFmapEqFilter) {
        return ConvToBmmMode::CONV_MM_FEATURE_MAP_EQ_FILTER;
    }
    return ConvToBmmMode::CONV_NO_MM;
}

} // namespace

namespace AclnnConvolution {

static inline void RegisterConv2dL0Functions(std::map<std::string, L0FUNCTION>& l0Functions)
{
    REG_L0_FUNCTION(
        l0Functions, Conv2d5HdFp16, op::DataType::DT_FLOAT16, op::Format::FORMAT_NC1HWC0, op::DataType::DT_FLOAT16,
        op::Format::FORMAT_NC1HWC0);
    REG_L0_FUNCTION(
        l0Functions, Conv2d5HdFp32, op::DataType::DT_FLOAT, op::Format::FORMAT_NC1HWC0, op::DataType::DT_FLOAT,
        op::Format::FORMAT_NC1HWC0);
    REG_L0_FUNCTION(
        l0Functions, Conv2d5HdFp1625HdFp32, op::DataType::DT_FLOAT16, op::Format::FORMAT_NC1HWC0,
        op::DataType::DT_FLOAT, op::Format::FORMAT_NC1HWC0);
    REG_L0_FUNCTION(
        l0Functions, Conv2d5HdBf16, op::DataType::DT_BF16, op::Format::FORMAT_NC1HWC0, op::DataType::DT_BF16,
        op::Format::FORMAT_NC1HWC0);
    REG_L0_FUNCTION_BY_OPTYPE(l0Functions, Conv2dV2NCHW, "Conv2DV2");
}

static inline aclnnStatus CommonConvImpl(
    std::map<std::string, L0FUNCTION>& l0Functions, ConvolutionOpInfo& opInfo, const aclTensor* input,
    const aclTensor* weight, const aclTensor* bias, const aclIntArray* stride, const aclIntArray* padding,
    const aclIntArray* dilation, bool transposed, const aclIntArray* outputPadding, int64_t groups, bool useHf32,
    aclOpExecutor* executor, const aclTensor*& convOut, const char* errorMsg)
{
    if (op::IsSupportND()) {
        convOut = FUNCTION_CALL_BY_OPTYPE(
            l0Functions, "Conv2DV2", input, weight, bias, opInfo.outputDtype, stride, padding, dilation, transposed,
            outputPadding, groups, useHf32, executor);
    } else {
        convOut = FUNCTION_CALL(
            l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,
            useHf32, executor);
    }

    if (convOut == nullptr) {
        OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "%s", errorMsg);
        return ACLNN_ERR_RUNTIME_ERROR;
    }
    return ACLNN_SUCCESS;
}

class ConvolutionImpl {
public:
    virtual aclnnStatus PreProcess() = 0;
    virtual aclnnStatus Impl() = 0;
    virtual aclnnStatus PostProcess() = 0;
    ConvolutionImpl(
        const aclTensor* inputParam, const aclTensor* weightParam, const aclTensor* biasParam,
        const aclIntArray* strideParam, const aclIntArray* paddingParam, const aclIntArray* dilationParam,
        const bool transposedParam, const aclIntArray* outputPaddingParam, const int64_t groupsParam,
        aclTensor* outputParam, bool useHf32Param, int8_t cubeMathTypeParam, aclOpExecutor* executorParam,
        const std::string& entityNameParam)
        : input(inputParam),
          weight(weightParam),
          bias(biasParam),
          stride(strideParam),
          padding(paddingParam),
          dilation(dilationParam),
          transposed(transposedParam),
          outputPadding(outputPaddingParam),
          groups(groupsParam),
          output(outputParam),
          useHf32(useHf32Param),
          cubeMathType(cubeMathTypeParam),
          executor(executorParam),
          entityName(entityNameParam) {};
    virtual ~ConvolutionImpl()
    {
        input = nullptr;
        weight = nullptr;
        bias = nullptr;
        stride = nullptr;
        padding = nullptr;
        dilation = nullptr;
        outputPadding = nullptr;
        output = nullptr;
        executor = nullptr;
    };

protected:
    const aclTensor* input;
    const aclTensor* weight;
    const aclTensor* bias;
    const aclIntArray* stride;
    const aclIntArray* padding;
    const aclIntArray* dilation;
    const bool transposed;
    const aclIntArray* outputPadding;
    const int64_t groups;
    aclTensor* output;
    const bool useHf32;
    int8_t cubeMathType;
    uint64_t* workspaceSize = nullptr;
    aclOpExecutor* executor;
    const std::string entityName;
    const aclTensor* convOut = nullptr;
    ConvolutionOpInfo opInfo; // 用于提前计算所有前后处理相关的format、dtype等信息
    std::map<std::string, L0FUNCTION> l0Functions;
};

static inline aclnnStatus ExecuteConvImpl(const std::shared_ptr<AclnnConvolution::ConvolutionImpl>& convImpl)
{
    if (convImpl == nullptr) {
        return ACLNN_ERR_INNER;
    }
    aclnnStatus ret = convImpl->PreProcess();
    if (ret != ACLNN_SUCCESS) {
        return ret;
    }

    ret = convImpl->Impl();
    if (ret != ACLNN_SUCCESS) {
        return ret;
    }

    ret = convImpl->PostProcess();
    if (ret != ACLNN_SUCCESS) {
        return ret;
    }

    return ACLNN_SUCCESS;
}

static inline void RegisterTransposedConvL0Functions(
    std::map<std::string, L0FUNCTION>& l0Functions, op::Format dataFormat)
{
    REG_L0_FUNCTION(
        l0Functions, ConvTranspose2d5HdFp16, op::DataType::DT_FLOAT16, dataFormat, op::DataType::DT_FLOAT16,
        dataFormat);
    REG_L0_FUNCTION(
        l0Functions, ConvTranspose2d5HdFp32, op::DataType::DT_FLOAT, dataFormat, op::DataType::DT_FLOAT, dataFormat);
    REG_L0_FUNCTION(
        l0Functions, ConvTranspose2d5HdBf16, op::DataType::DT_BF16, dataFormat, op::DataType::DT_BF16, dataFormat);
    REG_L0_FUNCTION(
        l0Functions, ConvTranspose2d5HdHif8, op::DataType::DT_HIFLOAT8, dataFormat, op::DataType::DT_HIFLOAT8,
        dataFormat);
    REG_L0_FUNCTION(
        l0Functions, ConvTranspose2d5HdF8e4m3fn, op::DataType::DT_FLOAT8_E4M3FN, dataFormat,
        op::DataType::DT_FLOAT8_E4M3FN, dataFormat);
}

#define CONV_CONSTRUCTOR(type)                                                                               \
    CONCAT(Conv##type, Impl)(                                                                                \
        const aclTensor* inputParam, const aclTensor* weightParam, const aclTensor* biasParam,               \
        const aclIntArray* strideParam, const aclIntArray* paddingParam, const aclIntArray* dilationParam,   \
        const bool transposedParam, const aclIntArray* outputPaddingParam, const int64_t groupsParam,        \
        aclTensor* outputParam, bool useHf32Param, int8_t cubeMathTypeParam, aclOpExecutor* executorParam, const std::string& entityNameParam)   \
        : ConvolutionImpl(                                                                                   \
              inputParam, weightParam, biasParam, strideParam, paddingParam, dilationParam, transposedParam, \
              outputPaddingParam, groupsParam, outputParam, useHf32Param, cubeMathTypeParam, executorParam, entityNameParam)  \
    {}

#define CONCAT(a, b) a##b

class Conv2dImpl : public ConvolutionImpl {
public:
    CONV_CONSTRUCTOR(2d)

    aclnnStatus PreProcess() override
    {
        RegisterConv2dL0Functions(l0Functions);
        if (padding->Size() != CONV_2D_PAD_DIM && padding->Size() != CONV_4D_PAD_DIM) {
            OP_LOGE_FOR_INVALID_LISTSIZE(entityName, "pads", std::to_string(padding->Size()),
                std::to_string(CONV_2D_PAD_DIM) + " or " + std::to_string(CONV_4D_PAD_DIM));
            return ACLNN_ERR_RUNTIME_ERROR;
        }
        if (padding->Size() == CONV_2D_PAD_DIM) {
            padding = ViewConv2dPad2dAs4d(padding, executor);
            CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }
        auto ret = CheckConv2dWithWeightFZ(entityName, input, weight);
        CHECK_RET(ret == ACLNN_SUCCESS, ret);
        CHECK_RET(!CheckUnSupportDtype(input, weight), ACLNN_ERR_INNER_NULLPTR);
        GetConvOpInfo(input, weight, bias, output, opInfo, transposed, groups, stride, padding, dilation, cubeMathType);
        OP_LOGD(
            "convolution aclnn op inputDtype: %s, outputDtype: %s, biasDtype: %s, useHf32: %d.",
            op::ToString(opInfo.inputDtype).GetString(), op::ToString(opInfo.outputDtype).GetString(),
            op::ToString(opInfo.biasDtype).GetString(), useHf32);
        // 需要切C04分支卷积
        if (opInfo.weightFormat == Format::FORMAT_FRACTAL_Z_C04 && weight->GetDataType() == op::DataType::DT_FLOAT16) {
            OP_LOGD("Conv2d entering float16 C04 branch");
            return CommonPreProcessC04(input, weight, bias, groups, transposed, opInfo, true, true, executor);
        }
        if (opInfo.weightFormat == Format::FORMAT_FRACTAL_Z_C04 && weight->GetDataType() == op::DataType::DT_BF16) {
            OP_LOGD("Conv2d entering bfloat16 C04 branch");
        } else if (opInfo.inputFormat == Format::FORMAT_NDC1HWC0) {
            OP_LOGD("Conv2d entering splitW branch");
            auto changeRes = ChangeConv2dAttrToConv3d(stride, padding, dilation, executor);
            CHECK_RET(changeRes == ACLNN_SUCCESS, ACLNN_ERR_INNER_NULLPTR);
            changeRes = ChangeConv2dInputToConv3d(input, weight, executor);
            CHECK_RET(changeRes == ACLNN_SUCCESS, ACLNN_ERR_INNER_NULLPTR);
            REG_L0_FUNCTION(
                l0Functions, Conv3dv26HdFp16, op::DataType::DT_FLOAT16, op::Format::FORMAT_NDC1HWC0,
                op::DataType::DT_FLOAT16, op::Format::FORMAT_NDC1HWC0);
            REG_L0_FUNCTION(
                l0Functions, Conv3dv26HdBf16, op::DataType::DT_BF16, op::Format::FORMAT_NDC1HWC0, op::DataType::DT_BF16,
                op::Format::FORMAT_NDC1HWC0);
            REG_L0_FUNCTION(
                l0Functions, Conv3dv26HdFp32, op::DataType::DT_FLOAT, op::Format::FORMAT_NDC1HWC0,
                op::DataType::DT_FLOAT, op::Format::FORMAT_NDC1HWC0);
            OP_LOGD(
                "convolution aclnn op inputDtype: %s, outputDtype: %s, biasDtype: %s, useHf32: %d.",
                op::ToString(opInfo.inputDtype).GetString(), op::ToString(opInfo.outputDtype).GetString(),
                op::ToString(opInfo.biasDtype).GetString(), useHf32);
        } else {
            OP_LOGD("Conv2d entering normal branch");
        }

        bool inputDisContinuous = false;
        auto viewShape = input->GetViewShape();
        std::vector<int64_t> newStrides = {viewShape[C_DIM_NCHW_INDEX], 1,
            viewShape[N_DIM_NCHW_INDEX] * viewShape[C_DIM_NCHW_INDEX] * viewShape[W_DIM_NCHW_INDEX],
            viewShape[N_DIM_NCHW_INDEX] * viewShape[C_DIM_NCHW_INDEX]};
        bool strideFlag = CheckDisContinuousStride(input, newStrides, CONV_2D_DIMS);
        if (strideFlag && input->GetViewOffset() == 0 && isSupportInputHWNC(input, opInfo, groups) &&
            isNotDMA(input, weight, bias, output, stride, padding, dilation, &opInfo)) {
            OP_LOGD("Conv2d entering disContinuous branch");
            op::Shape newStorageShapeOp = op::Shape({viewShape[H_DIM_NCHW_INDEX], viewShape[W_DIM_NCHW_INDEX],
                                                    viewShape[N_DIM_NCHW_INDEX], viewShape[C_DIM_NCHW_INDEX]});
            input = executor->CreateView(input, input->GetViewShape(), newStorageShapeOp, input->GetViewStrides(), 0);
            const_cast<aclTensor*>(input)->SetStorageShape(newStorageShapeOp);
            const_cast<aclTensor*>(input)->SetOriginalShape(viewShape);
            const_cast<aclTensor*>(input)->SetStorageFormat(Format::FORMAT_NCHW);
            const_cast<aclTensor*>(input)->SetOriginalFormat(Format::FORMAT_NCHW);
            OP_LOGD("Conv2DV2: Skip discontiguous to contiguous conversion for HWNC input.");
            inputDisContinuous = true;
        }
        // 调用静态函数PreProcess
        bool needChangeFormat = !op::IsSupportND();
        return CommonPreProcess(input, weight, bias, groups, transposed, opInfo, needChangeFormat, true, executor,
                                inputDisContinuous);
    };

    aclnnStatus Impl() override
    {
        if (op::IsSupportND()) {
            convOut = FUNCTION_CALL_BY_OPTYPE(
                l0Functions, "Conv2DV2", input, weight, bias, opInfo.outputDtype, stride, padding, dilation, transposed,
                outputPadding, groups, useHf32, executor);
        } else {
            convOut = FUNCTION_CALL(
                l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,
                useHf32, executor);
        }
        if (convOut == nullptr) {
            OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "conv2d raise an unknown error");
            return ACLNN_ERR_RUNTIME_ERROR;
        }
        return ACLNN_SUCCESS;
    };

    aclnnStatus PostProcess() override
    {
        if (opInfo.inputFormat != Format::FORMAT_NDC1HWC0) {
            bool needChangeOutFormat = !op::IsSupportND();
            auto res = CommonPostProcess(groups, needChangeOutFormat, output, convOut, executor);
            CHECK_RET(res == ACLNN_SUCCESS, res);
        } else {
            // splitw模式,会使得conv2d转为conv3d做,所以后处理先按照conv3d的处理方式输出
            auto fakeOutput3d =
                executor->AllocTensor(output->GetDataType(), op::Format::FORMAT_NCDHW, op::Format::FORMAT_NCDHW);
            CHECK_RET(fakeOutput3d != nullptr, ACLNN_ERR_INNER_NULLPTR);
            auto res = CommonPostProcess(groups, true, fakeOutput3d, convOut, executor);
            CHECK_RET(res == ACLNN_SUCCESS, res);
            convOut = View5dAs4dForOutput(convOut, executor);
            CHECK_RET(convOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }
        auto result = l0op::ViewCopy(convOut, output, executor);
        CHECK_RET(result != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };
    ~Conv2dImpl() override = default;
};

class ConvTbcImpl : public ConvolutionImpl {
public:
    CONV_CONSTRUCTOR(Tbc)

    aclnnStatus PreProcess() override
    {
        RegisterConv2dL0Functions(l0Functions);

        // conv1d is implemented by 2d, so first change view of input, weight, bias
        stride = View1dAs2d(stride, 1, executor);
        CHECK_RET(stride != nullptr, ACLNN_ERR_INNER_NULLPTR);

        padding = View1dAs2d(padding, 0, executor);
        CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);

        dilation = View1dAs2d(dilation, 1, executor);
        CHECK_RET(dilation != nullptr, ACLNN_ERR_INNER_NULLPTR);

        input = View3dAs4d(input, executor);
        CHECK_RET(input != nullptr, ACLNN_ERR_INNER_NULLPTR);

        weight = View3dAs4d(weight, executor);
        CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);

        bias = View1dAs4d(bias, executor);
        CHECK_RET(bias != nullptr, ACLNN_ERR_INNER_NULLPTR);
        CHECK_RET(!CheckUnSupportDtype(input, weight), ACLNN_ERR_INNER_NULLPTR);
        GetConvOpInfo(input, weight, bias, output, opInfo, transposed, groups, stride, padding, dilation, cubeMathType);
        OP_LOGD(
            "convolution aclnn op inputDtype: %s, outputDtype: %s, biasDtype: %s, useHf32: %d.",
            op::ToString(opInfo.inputDtype).GetString(), op::ToString(opInfo.outputDtype).GetString(),
            op::ToString(opInfo.biasDtype).GetString(), useHf32);
        // 调用静态函数PreProcess
        return CommonPreProcess(input, weight, bias, groups, transposed, opInfo, true, false, executor);
    };

    aclnnStatus Impl() override
    {
        // conv1d is implement by conv2d
        return CommonConvImpl(
            l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,
            useHf32, executor, convOut, "convTbc raise an unknown error");
    };

    aclnnStatus PostProcess() override
    {
        // 因仅支持NCL格式的conv1d,所以转为conv2d的format默认为HCHW
        auto fakeOutput2d =
            executor->AllocTensor(output->GetDataType(), op::Format::FORMAT_NCHW, op::Format::FORMAT_NCHW);

        // 调用静态函数PostProcess
        auto res = CommonPostProcess(groups, true, fakeOutput2d, convOut, executor);
        CHECK_RET(res == ACLNN_SUCCESS, res);
        // 现在Conv1d转为conv2d来做,所以需要转换输出
        convOut = View4dAs3d(convOut, executor);
        CHECK_RET(convOut != nullptr, ACLNN_ERR_PARAM_NULLPTR);
        // permute to [T, B, C]
        FVector<int64_t> permuteDim{2, 0, 1};
        auto permuteConvTbc = Permute(convOut, permuteDim, executor);
        CHECK_RET(permuteConvTbc != nullptr, ACLNN_ERR_PARAM_NULLPTR);
        // view copy
        auto castConvTbc = l0op::ReFormat(permuteConvTbc, Format::FORMAT_ND);
        auto result = l0op::ViewCopy(castConvTbc, output, executor);
        CHECK_RET(result != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };
    ~ConvTbcImpl() override = default;
};

class Conv1dImpl : public ConvolutionImpl {
public:
    CONV_CONSTRUCTOR(1d)

    aclnnStatus PreProcess() override
    {
        RegisterConv2dL0Functions(l0Functions);

        // conv1d is implemented by 2d, so first change view of input, weight, bias
        stride = View1dAs2d(stride, 1, executor);
        CHECK_RET(stride != nullptr, ACLNN_ERR_INNER_NULLPTR);

        if (padding->Size() == 1) {
            padding = ViewConv1dPad1dAs4d(padding, executor);
            CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);
        } else if (padding->Size() == CONV_2D_PAD_DIM) {
            padding = ViewConv1dPad2dAs4d(padding, executor);
            CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);
        } else {
            OP_LOGE_FOR_INVALID_LISTSIZE(entityName, "pads", std::to_string(padding->Size()), "1 or 2");
            return ACLNN_ERR_INNER_NULLPTR;
        }

        dilation = View1dAs2d(dilation, 1, executor);
        CHECK_RET(dilation != nullptr, ACLNN_ERR_INNER_NULLPTR);

        weight = View3dAs4d(weight, executor);
        CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);

        if (bias != nullptr) {
            if (bias->GetViewShape().GetDimNum() == 3) { // 输入维度为3
                bias = View3dAs4d(bias, executor);
            } else {
                // bias dim = 1, 其他dim在check时候返回
                bias = View1dAs4d(bias, executor);
            }
            CHECK_RET(bias != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }

        GetConvolutionOpDtype(input, weight, bias, output, opInfo, transposed, cubeMathType);
        auto viewShape = input->GetViewShape();
        bool inputDisContinuous = false;
        std::vector<int64_t> newStrides =
            {viewShape[C_DIM_NCL_INDEX], 1, viewShape[N_DIM_NCL_INDEX] * viewShape[C_DIM_NCL_INDEX]};
        bool strideFlag = CheckDisContinuousStride(input, newStrides, CONV_1D_DIMS);
        if (strideFlag && input->GetViewOffset() == 0 && isSupportInputHWNC(input, opInfo, groups)) {
            op::Shape viewShape2d =
                op::Shape({viewShape[N_DIM_NCL_INDEX], viewShape[C_DIM_NCL_INDEX], 1, viewShape[L_DIM_NCL_INDEX]});
            op::Shape storageShape2d = 
                op::Shape({1, viewShape[L_DIM_NCL_INDEX], viewShape[N_DIM_NCL_INDEX], viewShape[C_DIM_NCL_INDEX]});
            op::Strides newStridesOp({viewShape[C_DIM_NCL_INDEX], 1,
                viewShape[N_DIM_NCL_INDEX] * viewShape[C_DIM_NCL_INDEX] * viewShape[L_DIM_NCL_INDEX],
                viewShape[N_DIM_NCL_INDEX] * viewShape[C_DIM_NCL_INDEX]});
            
            auto inputView = executor->CreateView(input, viewShape2d, input->GetViewShape(), newStridesOp, 0);
            if (isNotDMA(inputView, weight, bias, output, stride, padding, dilation, &opInfo)) {
                const_cast<aclTensor*>(inputView)->SetStorageShape(storageShape2d);
                const_cast<aclTensor*>(inputView)->SetOriginalShape(viewShape2d);
                const_cast<aclTensor*>(inputView)->SetViewFormat(Format::FORMAT_NCHW);
                const_cast<aclTensor*>(inputView)->SetStorageFormat(Format::FORMAT_NCHW);
                const_cast<aclTensor*>(inputView)->SetOriginalFormat(Format::FORMAT_NCHW);
                input = inputView;
                OP_LOGD("Conv1D: Skip discontiguous to contiguous conversion for HWNC input.");
                inputDisContinuous = true;
            }
        }

        if (!inputDisContinuous) {
            input = View3dAs4d(input, executor);
            CHECK_RET(input != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }

        CHECK_RET(!CheckUnSupportDtype(input, weight), ACLNN_ERR_INNER_NULLPTR);
        GetConvOpInfo(input, weight, bias, output, opInfo, transposed, groups, stride, padding, dilation, cubeMathType);
        OP_LOGD(
            "convolution aclnn op inputDtype: %s, outputDtype: %s, biasDtype: %s, useHf32: %d.",
            op::ToString(opInfo.inputDtype).GetString(), op::ToString(opInfo.outputDtype).GetString(),
            op::ToString(opInfo.biasDtype).GetString(), useHf32);
        specialConv1d = isSpecialConv1d(input, weight, stride, padding, dilation) && (groups == 1);
        // 调用静态函数PreProcess
        bool needChangeFormat = op::IsSupportND() ? false : !specialConv1d;
        return CommonPreProcess(input, weight, bias, groups, transposed, opInfo, needChangeFormat, true, executor,
                                inputDisContinuous);
    };

    aclnnStatus Impl() override
    {
        if (specialConv1d) {
            // assert x and weight format = NCHW, C=H=1
            // x view to shape n*w/s s, the batch dim is fold to the dim 1,
            op::Shape inputShape2d =
                op::Shape({input->GetViewShape()[0] * input->GetViewShape()[3] / (*stride)[1], (*stride)[1]});
            auto input2d = ViewWithShape(input, inputShape2d, executor);
            CHECK_RET(input2d != nullptr, ACLNN_ERR_INNER_NULLPTR);

            // weight reshape to shape Cout s
            op::Shape weightShape2d = op::Shape({weight->GetViewShape()[0], (*stride)[1]});
            auto weight2d = ViewWithShape(weight, weightShape2d, executor);
            CHECK_RET(weight2d != nullptr, ACLNN_ERR_INNER_NULLPTR);

            // weight perpute to shape s Cout
            FVector<int64_t> dims{1, 0};
            auto permutWeight = Permute(weight2d, dims, executor);
            CHECK_RET(permutWeight != nullptr, ACLNN_ERR_INNER_NULLPTR);

            auto input2dND = l0op::ReFormat(input2d, op::Format::FORMAT_ND);
            auto permutWeightND = l0op::ReFormat(permutWeight, op::Format::FORMAT_ND);
            // matmul (x,weight) to shape n*w/s Cout
            auto mmOut = ExecMmOp(input2dND, permutWeightND, nullptr, 0, executor);
            CHECK_RET(mmOut != nullptr, ACLNN_ERR_INNER_NULLPTR);

            // matmul output reshape to shape n w/s Cout
            op::Shape mmOut3dShape = op::Shape(
                {input->GetViewShape()[0], input->GetViewShape()[3] / (*stride)[1], weight->GetViewShape()[0]});
            auto mmOut3d = ViewWithShape(mmOut, mmOut3dShape, executor);
            CHECK_RET(mmOut3d != nullptr, ACLNN_ERR_INNER_NULLPTR);
            auto mmOut3dNCL = l0op::ReFormat(mmOut3d, op::Format::FORMAT_NCL);

            // matmul output contiguous
            auto contiguousMmOut3d = l0op::Contiguous(mmOut3dNCL, executor);
            CHECK_RET(contiguousMmOut3d != nullptr, ACLNN_ERR_INNER_NULLPTR);

            // matmul output permut to shape n Cout w/s
            dims = {0, 2, 1};
            auto permutMmOut3d = Permute(contiguousMmOut3d, dims, executor);
            CHECK_RET(permutMmOut3d != nullptr, ACLNN_ERR_INNER_NULLPTR);

            auto output3dNCL = l0op::ReFormat(permutMmOut3d, op::Format::FORMAT_NCL);
            convOut = output3dNCL;

            return ACLNN_SUCCESS;
        }
        // conv1d is implement by conv2d
        return CommonConvImpl(
            l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,
            useHf32, executor, convOut, "conv1d raise an unknown error");
    };

    aclnnStatus PostProcess() override
    {
        // conv1d 转换为conv2d做,所以后处理先按照conv2d的处理方式处理输出
        // 因仅支持NCL格式的conv1d,所以转为conv2d的format默认为HCHW
        auto fakeOutput2d =
            executor->AllocTensor(output->GetDataType(), op::Format::FORMAT_NCHW, op::Format::FORMAT_NCHW);

        // 调用静态函数PostProcess
        auto res = CommonPostProcess(groups, !specialConv1d, fakeOutput2d, convOut, executor);
        CHECK_RET(res == ACLNN_SUCCESS, res);
        // 现在Conv1d转为conv2d来做,所以需要转换输出
        if (!specialConv1d) {
            convOut = View4dAs3d(convOut, executor);
            CHECK_RET(convOut != nullptr, ACLNN_ERR_PARAM_NULLPTR);
        }

        auto result = l0op::ViewCopy(convOut, output, executor);
        CHECK_RET(result != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };

    ~Conv1dImpl() override = default;

private:
    bool isSpecialConv1d(
        const aclTensor* inputParam, const aclTensor* weightParam, const aclIntArray* strideParam,
        const aclIntArray* paddingParam, const aclIntArray* dilationParam) const
    {
        if ((*strideParam)[1] > op::specialStride &&
            (*strideParam)[1] == weightParam->GetViewShape()[op::specialChannelIndex] &&
            (*paddingParam)[PAD_LEFT_INDEX] == 0 && (*paddingParam)[PAD_RIGHT_INDEX] == 0 && (*dilationParam)[1] == 1 &&
            inputParam->GetViewShape()[1] == 1) {
            return true;
        } else {
            return false;
        }
    }
    bool specialConv1d = false;
};

class Conv3dTo2dImpl : public ConvolutionImpl {
public:
    CONV_CONSTRUCTOR(3dTo2d)

    aclnnStatus PreProcess() override
    {
        OP_LOGD("Conv3d on 310P entering Conv2d branch (D==1, Kd==1, padD==0).");

        RegisterConv2dL0Functions(l0Functions);

        constexpr uint64_t conv3dAttrDim = 3;
        constexpr uint64_t conv2dAttrDim = 2;
        constexpr uint64_t conv2dPadDim = 4;

        CHECK_RET(stride != nullptr, ACLNN_ERR_INNER_NULLPTR);
        CHECK_RET(dilation != nullptr, ACLNN_ERR_INNER_NULLPTR);
        CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);
        CHECK_RET(stride->Size() == conv3dAttrDim, ACLNN_ERR_PARAM_INVALID);
        CHECK_RET(dilation->Size() == conv3dAttrDim, ACLNN_ERR_PARAM_INVALID);
        CHECK_RET(padding->Size() == conv3dAttrDim, ACLNN_ERR_PARAM_INVALID);

        int64_t stride2d[conv2dAttrDim] = {(*stride)[1], (*stride)[2]};
        stride = executor->AllocIntArray(stride2d, conv2dAttrDim);
        CHECK_RET(stride != nullptr, ACLNN_ERR_INNER_NULLPTR);

        int64_t dilation2d[conv2dAttrDim] = {(*dilation)[1], (*dilation)[2]};
        dilation = executor->AllocIntArray(dilation2d, conv2dAttrDim);
        CHECK_RET(dilation != nullptr, ACLNN_ERR_INNER_NULLPTR);

        // conv3d padding: [padD, padH, padW] (symmetric); conv2d padding expects [padTop, padBottom, padLeft, padRight]
        int64_t padding4d[conv2dPadDim] = {(*padding)[1], (*padding)[1], (*padding)[2], (*padding)[2]};
        padding = executor->AllocIntArray(padding4d, conv2dPadDim);
        CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);

        // NCDHW (D==1) -> NCHW for both input and weight.
        input = View5dAs4dForOutput(input, executor);
        CHECK_RET(input != nullptr, ACLNN_ERR_INNER_NULLPTR);
        weight = View5dAs4dForOutput(weight, executor);
        CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);

        output2d = executor->AllocTensor(output->GetDataType(), op::Format::FORMAT_NCHW, op::Format::FORMAT_NCHW);
        CHECK_RET(output2d != nullptr, ACLNN_ERR_INNER_NULLPTR);

        CHECK_RET(!CheckUnSupportDtype(input, weight), ACLNN_ERR_INNER_NULLPTR);
        GetConvOpInfo(
            input, weight, bias, output2d, opInfo, transposed, groups, stride, padding, dilation, cubeMathType);
        OP_LOGD(
            "convolution aclnn op (conv3d->conv2d) inputDtype: %s, outputDtype: %s, biasDtype: %s, useHf32: %d.",
            op::ToString(opInfo.inputDtype).GetString(), op::ToString(opInfo.outputDtype).GetString(),
            op::ToString(opInfo.biasDtype).GetString(), useHf32);

        bool needChangeFormat = !op::IsSupportND();
        return CommonPreProcess(input, weight, bias, groups, transposed, opInfo, needChangeFormat, true, executor);
    };

    aclnnStatus Impl() override
    {
        return CommonConvImpl(
            l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,
            useHf32, executor, convOut, "conv3d->conv2d raise an unknown error");
    };

    aclnnStatus PostProcess() override
    {
        bool needChangeOutFormat = !op::IsSupportND();
        auto res = CommonPostProcess(groups, needChangeOutFormat, output2d, convOut, executor);
        CHECK_RET(res == ACLNN_SUCCESS, res);

        // NCHW -> NCDHW (D==1).
        convOut = View4dAs5dForInput(convOut, executor);
        CHECK_RET(convOut != nullptr, ACLNN_ERR_INNER_NULLPTR);

        auto result = l0op::ViewCopy(convOut, output, executor);
        CHECK_RET(result != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };

    ~Conv3dTo2dImpl() override = default;

private:
    aclTensor* output2d = nullptr;
    std::map<std::string, L0FUNCTION> l0Functions;
};

class Conv3dImpl : public ConvolutionImpl {
public:
    CONV_CONSTRUCTOR(3d)

    aclnnStatus PreProcess() override
    {
        REG_L0_FUNCTION(
            l0Functions, Conv3d6HdFp16, op::DataType::DT_FLOAT16, op::Format::FORMAT_NDC1HWC0, op::DataType::DT_FLOAT16,
            op::Format::FORMAT_NDC1HWC0);
        REG_L0_FUNCTION(
            l0Functions, Conv3dv26HdFp32, op::DataType::DT_FLOAT, op::Format::FORMAT_NDC1HWC0, op::DataType::DT_FLOAT,
            op::Format::FORMAT_NDC1HWC0);
        REG_L0_FUNCTION(
            l0Functions, Conv3dv26HdBf16, op::DataType::DT_BF16, op::Format::FORMAT_NDC1HWC0, op::DataType::DT_BF16,
            op::Format::FORMAT_NDC1HWC0);
        REG_L0_FUNCTION(
            l0Functions, Conv3dv2NCDHWFp16, op::DataType::DT_FLOAT16, op::Format::FORMAT_NCDHW,
            op::DataType::DT_FLOAT16, op::Format::FORMAT_NCDHW);
        REG_L0_FUNCTION(
            l0Functions, Conv3dv2NCDHWFp32, op::DataType::DT_FLOAT, op::Format::FORMAT_NCDHW, op::DataType::DT_FLOAT,
            op::Format::FORMAT_NCDHW);
        REG_L0_FUNCTION(
            l0Functions, Conv3dv2NCDHWBf16, op::DataType::DT_BF16, op::Format::FORMAT_NCDHW, op::DataType::DT_BF16,
            op::Format::FORMAT_NCDHW);
        REG_L0_FUNCTION(
            l0Functions, Conv3dv2NCDHWHif8, op::DataType::DT_HIFLOAT8, op::Format::FORMAT_NCDHW,
            op::DataType::DT_HIFLOAT8, op::Format::FORMAT_NCDHW);
        CHECK_RET(!CheckUnSupportDtype(input, weight), ACLNN_ERR_INNER_NULLPTR);
        GetConv3dOpInfo(input, weight, bias, output, opInfo, transposed, cubeMathType);

        // 判断是否是PointWise卷积
        isPointWiseKernelFlag = !op::IsSupportND() && IsSupportConv3DToConv3DV2() &&
                                NeedPointWiseKernel(weight, stride, padding, dilation, groups) &&
                                !PointWiseKernelBeyondLimits(input);
        // PointWise卷积,biasDtype只能为FLOAT32
        if (isPointWiseKernelFlag) {
            opInfo.biasDtype = op::DataType::DT_FLOAT;
            opInfo.weightFormat = Format::FORMAT_NCDHW;
            opInfo.inputFormat = Format::FORMAT_NCDHW;
            opInfo.outputFormat = Format::FORMAT_NCDHW;
            OP_LOGD("Entering PointWise branch.");
        }
        OP_LOGD(
            "convolution aclnn op inputDtype: %s, outputDtype: %s, biasDtype: %s, useHf32: %d.",
            op::ToString(opInfo.inputDtype).GetString(), op::ToString(opInfo.outputDtype).GetString(),
            op::ToString(opInfo.biasDtype).GetString(), useHf32);
        // 调用静态函数PreProcess
        bool needChangeFormat = op::IsSupportND() ? false : !isPointWiseKernelFlag;
        return CommonPreProcess(input, weight, bias, groups, transposed, opInfo, needChangeFormat, true, executor);
    };

    aclnnStatus Impl() override
    {
        convOut = FUNCTION_CALL(
            l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,
            useHf32, executor);
        if (convOut == nullptr) {
            OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "conv3d raise an unknown error");
            return ACLNN_ERR_RUNTIME_ERROR;
        }
        return ACLNN_SUCCESS;
    };

    aclnnStatus PostProcess() override
    {
        auto res = CommonPostProcess(groups, !isPointWiseKernelFlag, output, convOut, executor);
        CHECK_RET(res == ACLNN_SUCCESS, res);

        auto result = l0op::ViewCopy(convOut, output, executor);
        CHECK_RET(result != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };
    ~Conv3dImpl() override = default;

private:
    bool isPointWiseKernelFlag = false;
};

inline aclnnStatus PreProcessCheckOutputDtype(const aclTensor* bias, const ConvolutionOpInfo& opInfo,
    const std::string& entityName)
{
    if (bias != nullptr &&
        (opInfo.outputDtype == op::DataType::DT_HIFLOAT8 || opInfo.outputDtype == op::DataType::DT_FLOAT8_E4M3FN)) {
        std::stringstream reason;
        reason << "When transposed is true and bias is not nullptr, "
                << "the dtype of this parameter cannot be "
                << GeDtypeToString(op::DataType::DT_HIFLOAT8) << " or " 
                << GeDtypeToString(op::DataType::DT_FLOAT8_E4M3FN);
        OP_LOGE_FOR_INVALID_DTYPE_WITH_REASON(entityName, "y", GeDtypeToString(opInfo.outputDtype), reason.str());
        return ACLNN_ERR_PARAM_INVALID;
    }
    return ACLNN_SUCCESS;
};

class ConvTransposed1dImpl : public ConvolutionImpl {
public:
    CONV_CONSTRUCTOR(Transposed1d)
    aclnnStatus PreProcess() override
    {
        op::Format dataFormat = op::IsSupportND() ? op::Format::FORMAT_NCHW : op::Format::FORMAT_NC1HWC0;
        RegisterTransposedConvL0Functions(l0Functions, dataFormat);

        ConvTranspose1dSwapHW = isConvTransposed1dSwitchHW();
        if (ConvTranspose1dSwapHW) {
            stride = View1dAs2dw(stride, 1, executor);
            CHECK_RET(stride != nullptr, ACLNN_ERR_INNER_NULLPTR);

            padding = View1dAs2dw(padding, 0, executor);
            CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);

            dilation = View1dAs2dw(dilation, 1, executor);
            CHECK_RET(dilation != nullptr, ACLNN_ERR_INNER_NULLPTR);

            outputPadding = View1dAs2dw(outputPadding, 0, executor);
            CHECK_RET(outputPadding != nullptr, ACLNN_ERR_INNER_NULLPTR);

            input = View3dAs4dw(input, executor);
            CHECK_RET(input != nullptr, ACLNN_ERR_INNER_NULLPTR);

            weight = View3dAs4dw(weight, executor);
            CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);
        } else {
            stride = View1dAs2d(stride, 1, executor);
            CHECK_RET(stride != nullptr, ACLNN_ERR_INNER_NULLPTR);

            padding = View1dAs2d(padding, 0, executor);
            CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);

            dilation = View1dAs2d(dilation, 1, executor);
            CHECK_RET(dilation != nullptr, ACLNN_ERR_INNER_NULLPTR);

            outputPadding = View1dAs2d(outputPadding, 0, executor);
            CHECK_RET(outputPadding != nullptr, ACLNN_ERR_INNER_NULLPTR);

            input = View3dAs4d(input, executor);
            CHECK_RET(input != nullptr, ACLNN_ERR_INNER_NULLPTR);

            weight = View3dAs4d(weight, executor);
            CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }

        GetConvOpInfo(input, weight, bias, output, opInfo, transposed, groups, stride, padding, dilation, cubeMathType);
        if (PreProcessCheckOutputDtype(bias, opInfo, entityName) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // 调用静态函数PreProcess
        bool needChangeFormat = op::IsSupportND() ? false : true;
        return CommonPreProcess(input, weight, bias, groups, transposed, opInfo, needChangeFormat, false, executor);
    };

    aclnnStatus Impl() override
    {
        // conv1d is implement by conv2d
        convOut = FUNCTION_CALL(
            l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,
            useHf32, executor);
        if (convOut == nullptr) {
            OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "conv1d raise an unknown error");
            return ACLNN_ERR_RUNTIME_ERROR;
        }
        return ACLNN_SUCCESS;
    };

    aclnnStatus PostProcess() override
    {
        // 按照conv2d 的方式处理bias
        if (bias != nullptr && op::IsSupportND()) {
            int64_t biasLength = bias->GetViewShape().GetDim(0);
            bias = l0op::Reshape(bias, {1, biasLength, 1, 1}, executor);
            CHECK_RET(bias != nullptr, ACLNN_ERR_INNER_NULLPTR);
            convOut = l0op::Add(convOut, bias, executor);
            CHECK_RET(convOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }
        // conv1d 转换为conv2d做,所以后处理先按照conv2d的处理方式处理输出
        // 因仅支持NCL格式的conv1d,所以转为conv2d的format默认为NCHW
        auto fakeOutput2d =
            executor->AllocTensor(output->GetDataType(), op::Format::FORMAT_NCHW, op::Format::FORMAT_NCHW);
        // 调用静态函数PostProcess
        bool needChangeFormat = op::IsSupportND() ? false : true;
        auto res = CommonPostProcess(groups, needChangeFormat, fakeOutput2d, convOut, executor);
        CHECK_RET(res == ACLNN_SUCCESS, res);
        // 现在Conv1d转为conv2d来做,所以需要转换输出
        if(ConvTranspose1dSwapHW) {
            convOut = View4dAs3dw(convOut, executor);
        } else {
            convOut = View4dAs3d(convOut, executor);
        }
        CHECK_RET(convOut != nullptr, ACLNN_ERR_PARAM_NULLPTR);
        auto result = l0op::ViewCopy(convOut, output, executor);
        CHECK_RET(result != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };
    ~ConvTransposed1dImpl() override = default;
private:
    bool ConvTranspose1dSwapHW = false;
    bool isConvTransposed1dSwitchHW() const
    {
        //针对特定场景进行优化 outW>4096 N=1 inC<=768
        if(!op::IsSupportND()
        && output->GetViewShape().GetDim(L_DIM_NCL_INDEX) > W_DIM_NCHW_VALUE_TRANSPOSE1D
        && input->GetViewShape().GetDim(N_DIM_NCL_INDEX) == 1
        && input->GetViewShape().GetDim(C_DIM_NCL_INDEX) <= C_DIM_NCHW_VALUE_TRANSPOSE1D) {
            return true;
        }
        return false;
    }
};
class ConvTransposed2dImpl : public ConvolutionImpl {
public:
    CONV_CONSTRUCTOR(Transposed2d)
    aclnnStatus PreProcess() override
    {
        op::Format dataFormat = op::IsSupportND() ? op::Format::FORMAT_NCHW : op::Format::FORMAT_NC1HWC0;
        RegisterTransposedConvL0Functions(l0Functions, dataFormat);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose2d5HdFp16, op::DataType::DT_FLOAT16, op::Format::FORMAT_NHWC,
            op::DataType::DT_FLOAT16, op::Format::FORMAT_NHWC);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose2d5HdFp32, op::DataType::DT_FLOAT, op::Format::FORMAT_NHWC,
            op::DataType::DT_FLOAT, op::Format::FORMAT_NHWC);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose2d5HdBf16, op::DataType::DT_BF16, op::Format::FORMAT_NHWC, op::DataType::DT_BF16,
            op::Format::FORMAT_NHWC);
        ConvTransposed2dSwitchHW = isConvTransposed2dSwitchHW();
        if (ConvTransposed2dSwitchHW)
        {
            input = View4DSwapHWForTensor(input, executor);
            weight = View4DSwapHWForTensor(weight, executor);
            stride = View2DSwapHWForAttr(stride, executor);
            padding = View2DSwapHWForAttr(padding, executor);
            dilation = View2DSwapHWForAttr(dilation, executor);
            outputPadding = View2DSwapHWForAttr(outputPadding, executor);
        }

        GetConvOpInfo(input, weight, bias, output, opInfo, transposed, groups, stride, padding, dilation, cubeMathType);
        if (PreProcessCheckOutputDtype(bias, opInfo, entityName) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        // 调用静态函数PreProcess
        bool needChangeFormat = op::IsSupportND() ? false : true;
        return CommonPreProcess(input, weight, bias, groups, transposed, opInfo, needChangeFormat, true, executor);
    };

    aclnnStatus Impl() override
    {
        convOut = FUNCTION_CALL(
            l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,
            useHf32, executor);
        if (convOut == nullptr) {
            OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "convTranspose2d raise an unknown error");
            return ACLNN_ERR_RUNTIME_ERROR;
        }
        return ACLNN_SUCCESS;
    };

    aclnnStatus PostProcess() override
    {
        if (bias && op::IsSupportND()) {
            op::Shape biasShape = bias->GetViewShape();
            int64_t biasLength = biasShape.GetDim(0);
            if (output->GetStorageFormat() == op::Format::FORMAT_NHWC) {
                bias = l0op::Reshape(bias, {1, 1, 1, biasLength}, executor);
            } else {
                bias = l0op::Reshape(bias, {1, biasLength, 1, 1}, executor);
            }
            CHECK_RET(bias != nullptr, ACLNN_ERR_INNER_NULLPTR);

            convOut = l0op::Add(convOut, bias, executor);
            CHECK_RET(convOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }

        bool needChangeFormat = op::IsSupportND() ? false : true;
        auto res = CommonPostProcess(groups, needChangeFormat, output, convOut, executor);
        CHECK_RET(res == ACLNN_SUCCESS, res);

        if(ConvTransposed2dSwitchHW){
            convOut = View4DSwapHWForTensor(convOut, executor);
        }

        auto result = l0op::ViewCopy(convOut, output, executor);
        CHECK_RET(result != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };
    ~ConvTransposed2dImpl() override = default;
private:
    bool ConvTransposed2dSwitchHW = false;
    bool isConvTransposed2dSwitchHW() const
    {
        //针对特定场景进行优化 pad=0 dilation=1 outputPadding=0 outW>4096 N=1 inC<=768
        if (!op::IsSupportND() && (*stride)[0] == 1 && (*padding)[0] == 0 && (*dilation)[0] == 1 && (*outputPadding)[0] == 0
        && output->GetViewShape().GetDim(W_DIM_NCHW_INDEX) > W_DIM_NCHW_VALUE_TRANSPOSE1D
        && input->GetViewShape().GetDim(N_DIM_NCHW_INDEX) == 1
        && input->GetViewShape().GetDim(H_DIM_NCHW_INDEX) == 1
        && weight->GetViewShape().GetDim(H_DIM_NCHW_INDEX) == 1
        && input->GetViewShape().GetDim(C_DIM_NCHW_INDEX) <= C_DIM_NCHW_VALUE_TRANSPOSE1D)
        {
            return true;
        }
        return false;
    }
};
class ConvTransposed3dImpl : public ConvolutionImpl {
public:
    CONV_CONSTRUCTOR(Transposed3d)

    aclnnStatus PreProcess() override
    {
        op::Format dataFormat;
        if (op::IsSupportND()) {
            dataFormat = op::Format::FORMAT_NCDHW;
        } else {
            dataFormat = op::Format::FORMAT_NDC1HWC0;
        }
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose3d6HdFp16, op::DataType::DT_FLOAT16, dataFormat, op::DataType::DT_FLOAT16,
            dataFormat);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose3d6HdFp32, op::DataType::DT_FLOAT, dataFormat, op::DataType::DT_FLOAT,
            dataFormat);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose3d6HdBf16, op::DataType::DT_BF16, dataFormat, op::DataType::DT_BF16, dataFormat);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose3d6HdHif8, op::DataType::DT_HIFLOAT8, dataFormat, op::DataType::DT_HIFLOAT8,
            dataFormat);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose3d6HdF8e4m3fn, op::DataType::DT_FLOAT8_E4M3FN, dataFormat,
            op::DataType::DT_FLOAT8_E4M3FN, dataFormat);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose3d6HdFp16, op::DataType::DT_FLOAT16, op::Format::FORMAT_NDHWC,
            op::DataType::DT_FLOAT16, op::Format::FORMAT_NDHWC);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose3d6HdFp32, op::DataType::DT_FLOAT, op::Format::FORMAT_NDHWC,
            op::DataType::DT_FLOAT, op::Format::FORMAT_NDHWC);
        REG_L0_FUNCTION(
            l0Functions, ConvTranspose3d6HdBf16, op::DataType::DT_BF16, op::Format::FORMAT_NDHWC, op::DataType::DT_BF16,
            op::Format::FORMAT_NDHWC);
        GetConv3dOpInfo(input, weight, bias, output, opInfo, transposed, cubeMathType);
        bool needChangeFormat = op::IsSupportND() ? false : true;
        if (PreProcessCheckOutputDtype(bias, opInfo, entityName) != ACLNN_SUCCESS) {
            return ACLNN_ERR_PARAM_INVALID;
        }
        return CommonPreProcess(input, weight, bias, groups, transposed, opInfo, needChangeFormat, true, executor);
    };

    aclnnStatus Impl() override
    {
        convOut = FUNCTION_CALL(
            l0Functions, opInfo, input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups,
            useHf32, executor);
        if (convOut == nullptr) {
            OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "convTranspose3d raise an unknown error");
            return ACLNN_ERR_RUNTIME_ERROR;
        }
        return ACLNN_SUCCESS;
    };

    aclnnStatus ProcessBias(op::Format dstFormat)
    {
        if (!IsSupportND()) {
            // output format transdata
            convOut = l0op::TransData(convOut, dstFormat, groups, executor);
            CHECK_RET(convOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }

        bool isDAV3510 = GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510;
        if (bias && !isDAV3510) { 
            op::Shape biasShape = bias->GetViewShape(); 
            int64_t biasLength = biasShape.GetDim(0); 
            if (dstFormat == op::Format::FORMAT_NDHWC) { 
                bias = l0op::Reshape(bias, {1, 1, 1, 1, biasLength}, executor); 
            } else { 
                bias = l0op::Reshape(bias, {1, biasLength, 1, 1, 1}, executor); 
            } 
            CHECK_RET(bias != nullptr, ACLNN_ERR_INNER_NULLPTR); 

            convOut = l0op::Add(convOut, bias, executor); 
            CHECK_RET(convOut != nullptr, ACLNN_ERR_INNER_NULLPTR); 
        }

        if (!op::IsSupportND()) {
            // output cast
            convOut = l0op::Cast(convOut, output->GetDataType(), executor);
            CHECK_RET(convOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
        }

        if (op::IsSupportND()) {
            bool needChangeFormat = op::IsSupportND() ? false : true;
            auto res = CommonPostProcess(groups, needChangeFormat, output, convOut, executor);
            CHECK_RET(res == ACLNN_SUCCESS, res);
        }

        return ACLNN_SUCCESS;
    }

    aclnnStatus ProcessResult() {
        OP_LOGD("Check N2H available.");
        auto storageFormat = convOut->GetStorageFormat();
        auto viewFormat = convOut->GetViewFormat();
        if (storageFormat == op::Format::FORMAT_NDHWC && viewFormat == op::Format::FORMAT_NCDHW) {
            OP_LOGD("Start N2H optimize.");
            FVector<int64_t> shapeDims = {2, 4, 1, 0, 3};
            auto perm = executor->AllocIntArray(shapeDims.data(), shapeDims.size());
            CHECK_RET(perm != nullptr, ACLNN_ERR_INNER_NULLPTR);
            convOut = l0op::Transpose(convOut, perm, executor);
            // change output format
            CHECK_RET(convOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
            const_cast<aclTensor *>(convOut)->SetOriginalFormat(Format::FORMAT_NCDHW);
            const_cast<aclTensor *>(convOut)->SetStorageFormat(Format::FORMAT_NCDHW);
            const_cast<aclTensor *>(convOut)->SetViewFormat(Format::FORMAT_NCDHW);
        }
        return ACLNN_SUCCESS;
    }

    virtual aclnnStatus ProcessConvOut()
    {
        return ProcessBias(output->GetStorageFormat());
    }

    aclnnStatus PostProcess() override
    {
        auto n2h = ProcessResult();
        CHECK_RET(n2h == ACLNN_SUCCESS, n2h);

        auto res = ProcessConvOut();
        CHECK_RET(res == ACLNN_SUCCESS, res);

        auto result = l0op::ViewCopy(convOut, output, executor);
        CHECK_RET(result != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };

    ~ConvTransposed3dImpl() override = default;
};

class ConvTranspose2dTo3dImpl : public ConvTransposed3dImpl {
public:
    ConvTranspose2dTo3dImpl(
        const aclTensor* inputParam, const aclTensor* weightParam, const aclTensor* biasParam,
        const aclIntArray* strideParam, const aclIntArray* paddingParam, const aclIntArray* dilationParam,
        const bool transposedParam, const aclIntArray* outputPaddingParam, const int64_t groupsParam,
        aclTensor* outputParam, bool useHf32Param, int8_t cubeMathTypeParam, aclOpExecutor* executorParam, const std::string& entityNameParam)
        : ConvTransposed3dImpl(
              inputParam, weightParam, biasParam, strideParam, paddingParam, dilationParam, transposedParam,
              outputPaddingParam, groupsParam, outputParam, useHf32Param, cubeMathTypeParam, executorParam, entityNameParam)
    {}

    aclnnStatus PreProcess() override
    {
        constexpr int paddingDim = 6; // 3D padding Dim
        std::vector<int64_t> data = {0, 0, (*padding)[0], (*padding)[0], (*padding)[1], (*padding)[1]};
        if (padding->Size() == CONV_4D_PAD_DIM) {
            data = {0, 0, (*padding)[0], (*padding)[1], (*padding)[2], (*padding)[3]};
        }
        padding = executor->AllocIntArray(data.data(), paddingDim);
        CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);
        outputPadding = View2dAs3dForAttr(outputPadding, 0, executor, false);
        CHECK_RET(outputPadding != nullptr, ACLNN_ERR_INNER_NULLPTR);
        stride = View2dAs3dForAttr(stride, 1, executor, false);
        CHECK_RET(stride != nullptr, ACLNN_ERR_INNER_NULLPTR);
        dilation = View2dAs3dForAttr(dilation, 1, executor, false);
        CHECK_RET(dilation != nullptr, ACLNN_ERR_INNER_NULLPTR);
        auto changeRes = ChangeConv2dInputToConv3d(input, weight, executor);
        CHECK_RET(changeRes == ACLNN_SUCCESS, ACLNN_ERR_INNER_NULLPTR);
        return ConvTransposed3dImpl::PreProcess();
    };

    aclnnStatus ProcessConvOut() override
    {
        auto res = ConvTransposed3dImpl::ProcessBias(op::Format::FORMAT_NCDHW);
        CHECK_RET(res == ACLNN_SUCCESS, res);
        convOut = View5dAs4dForOutput(convOut, executor);
        CHECK_RET(convOut != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };

    ~ConvTranspose2dTo3dImpl() override = default;
};

class ConvTransposed1dTo3dImpl : public ConvTranspose2dTo3dImpl {
public:
    ConvTransposed1dTo3dImpl(
        const aclTensor* inputParam, const aclTensor* weightParam, const aclTensor* biasParam,
        const aclIntArray* strideParam, const aclIntArray* paddingParam, const aclIntArray* dilationParam,
        const bool transposedParam, const aclIntArray* outputPaddingParam, const int64_t groupsParam,
        aclTensor* outputParam, bool useHf32Param, int8_t cubeMathTypeParam, aclOpExecutor* executorParam, const std::string& entityNameParam)
        : ConvTranspose2dTo3dImpl(
              inputParam, weightParam, biasParam, strideParam, paddingParam, dilationParam, transposedParam,
              outputPaddingParam, groupsParam, outputParam, useHf32Param, cubeMathTypeParam, executorParam, entityNameParam)
    {}

    aclnnStatus PreProcess() override
    {
        // 1d transpose: pad仅支持1维度
        outputPadding = View1dAs2d(outputPadding, 0, executor);
        CHECK_RET(outputPadding != nullptr, ACLNN_ERR_INNER_NULLPTR);
        stride = View1dAs2d(stride, 1, executor);
        CHECK_RET(stride != nullptr, ACLNN_ERR_INNER_NULLPTR);
        padding = View1dAs2d(padding, 0, executor);
        CHECK_RET(padding != nullptr, ACLNN_ERR_INNER_NULLPTR);
        dilation = View1dAs2d(dilation, 1, executor);
        CHECK_RET(dilation != nullptr, ACLNN_ERR_INNER_NULLPTR);
        input = View3dAs4d(input, executor);
        CHECK_RET(input != nullptr, ACLNN_ERR_INNER_NULLPTR);
        weight = View3dAs4d(weight, executor);
        CHECK_RET(weight != nullptr, ACLNN_ERR_INNER_NULLPTR);
        return ConvTranspose2dTo3dImpl::PreProcess();
    };

    aclnnStatus ProcessConvOut() override
    {
        auto res = ConvTranspose2dTo3dImpl::ProcessConvOut(); // 先降为2D
        CHECK_RET(res == ACLNN_SUCCESS, res);
        convOut = View4dAs3d(convOut, executor);
        CHECK_RET(convOut != nullptr, ACLNN_ERR_PARAM_NULLPTR);

        return ACLNN_SUCCESS;
    };

    ~ConvTransposed1dTo3dImpl() override = default;
};

static bool CheckTensorFormatNCDHW(const aclTensor* tensor)
{
    return tensor != nullptr &&
           tensor->GetViewFormat() == op::Format::FORMAT_NCDHW &&
           tensor->GetStorageFormat() == op::Format::FORMAT_NCDHW;
}

static bool CheckConv3dTensorShape(const aclTensor* tensor)
{
    return tensor != nullptr &&
           tensor->GetViewShape().GetDimNum() == CONV_3D_DIM_SIZE &&
           tensor->GetViewShape().GetDim(D_DIM_NCDHW_INDEX) == 1;
}

static bool CanConv3dToConv2dOn310P(
    const aclTensor* input, const aclTensor* weight, const aclIntArray* padding, const bool transposed,
    const aclTensor* output)
{
    if (transposed || GetCurrentPlatformInfo().GetSocVersion() != SocVersion::ASCEND310P) {
        return false;
    }

    if (input == nullptr || weight == nullptr || padding == nullptr || output == nullptr) {
        return false;
    }

    // The conversion relies on "squeeze/unsqueeze D==1" being layout-preserving for NCDHW<->NCHW,
    // so require storage format to be NCDHW as well.
    if (!CheckTensorFormatNCDHW(input) || !CheckTensorFormatNCDHW(weight) || !CheckTensorFormatNCDHW(output)) {
        return false;
    }

    // Strong constraints for correctness:
    // - input D == 1
    // - weight Kd == 1 (NCDHW index 2)
    // - paddingD == 0 (conv3d padding is symmetric [padD, padH, padW])
    if (!CheckConv3dTensorShape(input) || !CheckConv3dTensorShape(weight) || !CheckConv3dTensorShape(output)) {
        return false;
    }

    if (padding->Size() != DIM_DHW_NUM || (*padding)[0] != 0) {
        return false;
    }

    return true;
}

std::shared_ptr<ConvolutionImpl> CreateConvolutionImpl(
    const aclTensor* input, const aclTensor* weight, const aclTensor* bias, const aclIntArray* stride,
    const aclIntArray* padding, const aclIntArray* dilation, const bool transposed, const bool tbc,
    const aclIntArray* outputPadding, const int64_t groups, int8_t cubeMathType, aclTensor* output,
    aclOpExecutor* executor, const std::string& entityName)
{
    // 存疑:是否按照原来的只看input dtype
    auto promoteType = op::PromoteType(input->GetDataType(), weight->GetDataType());
    // In conv hif8 case, do not consider bias dtype for useHf32
    if ((bias != nullptr) && (promoteType != DataType::DT_HIFLOAT8 && promoteType != DataType::DT_FLOAT8_E4M3FN)) {
        promoteType = op::PromoteType(promoteType, bias->GetDataType());
    }
    bool useHf32 = NeedCubeGoHF32(promoteType, cubeMathType);

    size_t inputDim = input->GetViewShape().GetDimNum();
    if (tbc) {
        return std::make_shared<ConvTbcImpl>(
            input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output, useHf32,
            cubeMathType, executor, entityName);
    }
    if (!transposed) {
        switch (inputDim) {
            case CONV_1D_DIM_SIZE: {
                return std::make_shared<Conv1dImpl>(
                    input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output, useHf32,
                    cubeMathType, executor, entityName);
            }
            case CONV_2D_DIM_SIZE: {
                return std::make_shared<Conv2dImpl>(
                    input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output, useHf32,
                    cubeMathType, executor, entityName);
            }
            case CONV_3D_DIM_SIZE: {
                if (CanConv3dToConv2dOn310P(input, weight, padding, transposed, output)) {
                    return std::make_shared<Conv3dTo2dImpl>(
                        input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output,
                        useHf32, cubeMathType, executor, entityName);
                }
                return std::make_shared<Conv3dImpl>(
                    input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output, useHf32,
                    cubeMathType, executor, entityName);
            }
            default:
                return nullptr;
        }
    }
    switch (inputDim) {
        case CONV_1D_DIM_SIZE: {
            if (IsSupportConv1DTransposeTo3D()) {
                return std::make_shared<ConvTransposed1dTo3dImpl>(
                    input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output, useHf32,
                    cubeMathType, executor, entityName);
            }
            return std::make_shared<ConvTransposed1dImpl>(
                input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output, useHf32,
                cubeMathType, executor, entityName);
        }
        case CONV_2D_DIM_SIZE: {
            if (IsSupportConv2DTransposeTo3D(
                    input, weight, bias, stride, padding, dilation, outputPadding, groups, output)) {
                return std::make_shared<ConvTranspose2dTo3dImpl>(
                    input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output, useHf32,
                    cubeMathType, executor, entityName);
            }
            return std::make_shared<ConvTransposed2dImpl>(
                input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output, useHf32,
                cubeMathType, executor, entityName);
        }
        case CONV_3D_DIM_SIZE: {
            return std::make_shared<ConvTransposed3dImpl>(
                input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, output, useHf32,
                cubeMathType, executor, entityName);
        }
        default:
            return nullptr;
    }
}

} // namespace AclnnConvolution

#ifdef __cplusplus
extern "C" {
#endif

aclnnStatus aclnnConvolutionGetWorkspaceSize(
    const aclTensor* input, const aclTensor* weight, const aclTensor* bias, const aclIntArray* stride,
    const aclIntArray* padding, const aclIntArray* dilation, bool transposed, const aclIntArray* outputPadding,
    const int64_t groups, aclTensor* output, int8_t cubeMathType, uint64_t* workspaceSize, aclOpExecutor** executor)
{
    L2_DFX_PHASE_1(
        aclnnConvolution,
        DFX_IN(input, weight, bias, stride, padding, dilation, transposed, outputPadding, groups, cubeMathType),
        DFX_OUT(output));
    // construct param and convolution engine
    ConvParams params = {input,         weight, bias,   stride,       padding,       dilation, transposed,
                         outputPadding, groups, output, cubeMathType, workspaceSize, executor};
    ConvEngine convEngine(params);
    convEngine.entityName = "aclnnConvolutionGetWorkspaceSize";
    // check param
    auto ret = CheckConvParams(convEngine);
    CHECK_RET_CODE(ret, "Check Param failed");

    auto uniqueExecutor = CREATE_EXECUTOR();
    // 创建OpExecutor
    CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
    if (!(input->IsEmpty() || weight->IsEmpty() || output->IsEmpty())) {
        auto convToBmmMode = GetConvToBmmMode(convEngine);
        auto convTranspose1DToBmmMode = GetConvTranspose1DToBmmMode(convEngine);
        if (convToBmmMode != ConvToBmmMode::CONV_NO_MM) {
            OP_LOGD("Aclnn convolution entering batch matmul branch.");
            BatchMatmulInput bmmInput = {nullptr, nullptr, nullptr, nullptr, false, false};
            ret = GenInOutByConvToBmm(convEngine, convToBmmMode, bmmInput, uniqueExecutor.get());
            CHECK_RET(ret == ACLNN_SUCCESS, ret);

            auto convOut = ExecBatchMatmulOpWithBiasAndAttrs(
                bmmInput.leftData, bmmInput.rightData, bmmInput.biasData, bmmInput.outputData, bmmInput.isLeftTranspose,
                bmmInput.isRightTranspose, cubeMathType, uniqueExecutor.get());
            OP_CHECK(
                convOut != nullptr, OP_LOGE(ACLNN_ERR_INNER, "The BatchMatmul in Conv Return Nullptr."),
                return ACLNN_ERR_INNER_NULLPTR);
            auto resForBmm = CommonPostProcessForBmm(output, convOut, uniqueExecutor.get());
            CHECK_RET(resForBmm == ACLNN_SUCCESS, resForBmm);
        } else if (convTranspose1DToBmmMode != ConvTranspose1DToBmmMode::CONVTRANSPOSE1D_NO_MM) {
            BatchMatmulInput bmmInput = {nullptr, nullptr, nullptr, nullptr, false, false};
            ret = GenInOutByConvTranspose1DToBmm(convEngine, convTranspose1DToBmmMode, bmmInput, uniqueExecutor.get());
            CHECK_RET(ret == ACLNN_SUCCESS, ret);
            auto convTranspose1DOut = ExecBatchMatmulOpWithBiasAndAttrs(
                bmmInput.leftData, bmmInput.rightData, bmmInput.biasData, bmmInput.outputData, bmmInput.isLeftTranspose,
                bmmInput.isRightTranspose, cubeMathType, uniqueExecutor.get());
            OP_CHECK(
                convTranspose1DOut != nullptr,
                OP_LOGE(ACLNN_ERR_INNER, "The BatchMatmul in ConvTranspose1D Return Nullptr."),
                return ACLNN_ERR_INNER_NULLPTR);
            auto resForBmm = CommonPostProcessForBmm(output, convTranspose1DOut, uniqueExecutor.get());
            CHECK_RET(resForBmm == ACLNN_SUCCESS, resForBmm);
        } else {
            std::shared_ptr<AclnnConvolution::ConvolutionImpl> convImpl = AclnnConvolution::CreateConvolutionImpl(
                input, weight, bias, stride, padding, dilation, transposed, false, outputPadding, groups, cubeMathType,
                output, uniqueExecutor.get(), convEngine.entityName);
            if (convImpl == nullptr) {
                return ACLNN_ERR_INNER;
            }

            ret = convImpl->PreProcess();
            if (ret != ACLNN_SUCCESS) {
                return ret;
            }

            ret = convImpl->Impl();
            if (ret != ACLNN_SUCCESS) {
                return ret;
            }

            ret = convImpl->PostProcess();
            if (ret != ACLNN_SUCCESS) {
                return ret;
            }
        }
    } else {
        OP_LOGD("Input is zero tensor.");
    }

    *workspaceSize = (uniqueExecutor.get())->GetWorkspaceSize();
    uniqueExecutor.ReleaseTo(executor);
    return ACLNN_SUCCESS;
}

aclnnStatus aclnnConvTbcGetWorkspaceSize(
    const aclTensor* self, const aclTensor* weight, const aclTensor* bias, const int64_t pad, aclTensor* output,
    int8_t cubeMathType, uint64_t* workspaceSize, aclOpExecutor** executor)
{
    L2_DFX_PHASE_1(aclnnConvTbc, DFX_IN(self, weight, bias, pad, cubeMathType), DFX_OUT(output));
    std::string entityName = "aclnnConvTbcGetWorkspaceSize";
    // 创建OpExecutor
    auto uniqueExecutor = CREATE_EXECUTOR();
    CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);

    auto ret = CheckParamsNullptrTbc(entityName, self, weight, bias, output);
    CHECK_RET(ret == ACLNN_SUCCESS, ret);

    if (GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
        ret = CheckParamsEmpty(entityName, output, bias);
        CHECK_RET(ret == ACLNN_SUCCESS, ret);
    }
    // 创建对应TBC的padding, stride, dilation
    aclIntArray* padding = ViewValueAs1d(pad, uniqueExecutor.get()); // [pad]
    // 1: 单位步长值
    FVector<int64_t> unitValue{1};
    const aclIntArray* strideArray = (uniqueExecutor.get())->AllocIntArray(unitValue.data(), 1);   // [1]
    const aclIntArray* dilationArray = (uniqueExecutor.get())->AllocIntArray(unitValue.data(), 1); // [1]
    // permute & unsqueeze input, weight, output
    FVector<int64_t> permuteDimsAll{1, 2, 0};
    FVector<int64_t> permuteDimsWeight{2, 1, 0};
    op::Shape inputShape = op::Shape({self->GetViewShape()[1], self->GetViewShape()[2], self->GetViewShape()[0]});
    op::Shape outputShape =
        op::Shape({output->GetViewShape()[1], output->GetViewShape()[2], output->GetViewShape()[0]});
    op::Shape weightShape =
        op::Shape({weight->GetViewShape()[2], weight->GetViewShape()[1], weight->GetViewShape()[0]});

    auto* permuteInput = self->IsEmpty() ? ViewWithShape(self, inputShape, uniqueExecutor.get()) :
                                           Permute(self, permuteDimsAll, uniqueExecutor.get());
    auto* permuteOutput = output->IsEmpty() ? ViewWithShape(output, outputShape, uniqueExecutor.get()) :
                                              Permute(output, permuteDimsAll, uniqueExecutor.get());
    auto* permuteWeight = weight->IsEmpty() ? ViewWithShape(weight, weightShape, uniqueExecutor.get()) :
                                              Permute(weight, permuteDimsWeight, uniqueExecutor.get());
    auto permuteOutputT = const_cast<aclTensor*>(permuteOutput);
    // construct param and convolution engine
    ConvParams params = {permuteInput,  permuteWeight, bias,    strideArray, padding,
                         dilationArray, false,         nullptr, 1,           permuteOutputT,
                         cubeMathType,  workspaceSize, executor};
    ConvEngine convEngine(params);
    convEngine.entityName = entityName;
    // conv_tbc param check
    ret = CheckConvTbcParams(convEngine);
    CHECK_RET(ret == ACLNN_SUCCESS, ret);
    if (!(self->IsEmpty() || weight->IsEmpty() || output->IsEmpty())) {
        // convTbcImplement
        std::shared_ptr<AclnnConvolution::ConvolutionImpl> convImpl = AclnnConvolution::CreateConvolutionImpl(
            permuteInput, permuteWeight, bias, strideArray, padding, dilationArray, false, true, nullptr, 1,
            cubeMathType, output, uniqueExecutor.get(), convEngine.entityName);
        ret = ExecuteConvImpl(convImpl);
        if (ret != ACLNN_SUCCESS) {
            return ret;
        }
    } else if ((self->IsEmpty() || weight->IsEmpty()) && !output->IsEmpty()) {
        OP_LOGD("Input is zero tensor, and output is non-zero tensor.");
        auto biasContiguous = l0op::Contiguous(bias, uniqueExecutor.get());
        op::FVector<int64_t, op::MAX_DIM_NUM> broadcastDims = op::ToShapeVector(output->GetViewShape());
        auto shapes = (uniqueExecutor.get())->AllocIntArray(broadcastDims.data(), output->GetViewShape().GetDimNum());
        auto out = l0op::BroadcastTo(biasContiguous, shapes, uniqueExecutor.get());
        auto outCast = l0op::Cast(out, output->GetDataType(), uniqueExecutor.get());
        auto viewCopyOut = l0op::ViewCopy(outCast, output, uniqueExecutor.get());
        CHECK_RET(viewCopyOut != nullptr, ACLNN_ERR_INNER_NULLPTR);
    } else {
        OP_LOGD("Output is zero tensor.");
    }

    *workspaceSize = (uniqueExecutor.get())->GetWorkspaceSize();
    uniqueExecutor.ReleaseTo(executor);
    return ACLNN_SUCCESS;
}

aclnnStatus aclnnConvolution(void* workspace, const uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
    L2_DFX_PHASE_2(aclnnConvolution);
    return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}

aclnnStatus aclnnConvTbc(void* workspace, const uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
    L2_DFX_PHASE_2(aclnnConvTbc);
    return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}

aclnnStatus aclnnConvDepthwise2dGetWorkspaceSize(
    const aclTensor* self, const aclTensor* weight, const aclIntArray* kernelSize, const aclTensor* bias,
    const aclIntArray* stride, const aclIntArray* padding, const aclIntArray* dilation, aclTensor* out,
    int8_t cubeMathType, uint64_t* workspaceSize, aclOpExecutor** executor)
{
    L2_DFX_PHASE_1(
        aclnnConvDepthwise2d, DFX_IN(self, weight, kernelSize, bias, stride, padding, dilation, cubeMathType),
        DFX_OUT(out));
    int64_t groups = 1;
    // construct param and convolution engine
    ConvParams params = {self,    weight, bias, stride,       padding,       dilation, false,
                         nullptr, groups, out,  cubeMathType, workspaceSize, executor};
    ConvEngine convEngine(params);
    convEngine.entityName = "aclnnConvDepthwise2dGetWorkspaceSize";
    // check param
    auto ret = CheckConvDepthwise2dParams(convEngine);
    CHECK_RET_CODE(ret, "Check Param failed");
    ret = CheckConvDepthwise2dKernelSize(convEngine, kernelSize);
    CHECK_RET_CODE(ret, "Check kernelSize failed");

    auto uniqueExecutor = CREATE_EXECUTOR();
    // 创建OpExecutor
    CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);

    // 空tensor的情况,由于外部已经将output的shape,type,format设置好,故不需要做任何操作,直接返回
    if (!(self->IsEmpty() || weight->IsEmpty() || out->IsEmpty())) {
        groups = self->GetViewShape().GetDim(1);
        if (self->GetViewFormat() == op::Format::FORMAT_NHWC) {
            groups = self->GetViewShape().GetDim(op::specialChannelIndex);
        }
        std::shared_ptr<AclnnConvolution::ConvolutionImpl> convImpl = AclnnConvolution::CreateConvolutionImpl(
            self, weight, bias, stride, padding, dilation, false, false, nullptr, groups, cubeMathType, out,
            uniqueExecutor.get(), convEngine.entityName);
        if (convImpl == nullptr) {
            return ACLNN_ERR_INNER;
        }

        ret = convImpl->PreProcess();
        if (ret != ACLNN_SUCCESS) {
            return ret;
        }

        ret = convImpl->Impl();
        if (ret != ACLNN_SUCCESS) {
            return ret;
        }

        ret = convImpl->PostProcess();
        if (ret != ACLNN_SUCCESS) {
            return ret;
        }
    } else {
        OP_LOGD("Input is zero tensor.");
    }
    *workspaceSize = (uniqueExecutor.get())->GetWorkspaceSize();
    uniqueExecutor.ReleaseTo(executor);
    return ACLNN_SUCCESS;
}

aclnnStatus aclnnConvDepthwise2d(
    void* workspace, const uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
{
    L2_DFX_PHASE_2(aclnnConvDepthwise2d);
    return CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
}

#ifdef __cplusplus
}
#endif