* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file conv_operation_impl.cpp
* \brief
*/
#include "interface/configs/config_manager.h"
#include "interface/inner/pre_def.h"
#include "interface/operation/operation.h"
#include "interface/operation/operation_common.h"
#include "interface/program/program.h"
#include "interface/tensor/logical_tensor.h"
#include "interface/utils/common.h"
#include "interface/utils/operator_tracer.h"
#include "tilefwk/error_code.h"
#include "operation_impl.h"
#include "tilefwk/data_type.h"
#include "tilefwk/tile_shape.h"
#include "tilefwk/platform.h"
namespace npu {
namespace tile_fwk {
namespace Conv {
const std::string LoadStoreConvOpAttributeKey::cutW = "CUT_W";
const std::string LoadStoreConvOpAttributeKey::copyInMode = "COPY_IN_MODE";
const std::string LoadStoreConvOpAttributeKey::copyOutMode = "COPY_OUT_MODE";
const std::string LoadStoreConvOpAttributeKey::isFmap = "IS_FMAP";
const std::string LoadStoreConvOpAttributeKey::isConv3D = "IS_CONV3D";
bool IsArch32Platform()
{
return Platform::Instance().GetSoc().GetNPUArch() == NPUArch::DAV_2201;
}
std::vector<int64_t> rotateVector(const std::vector<int64_t>& input, size_t shift)
{
std::vector<int64_t> result = input;
std::rotate(result.begin(), result.begin() + shift, result.end());
return result;
}
void CheckValueRange(int64_t value, const std::string& name, int64_t min, int64_t max, const std::string& formula = "")
{
std::ostringstream oss;
oss << "Invalid " << name << ":" << value << ", expected range [" << min << "," << max << "].";
if (!formula.empty()) {
oss << "Formula: " << formula;
}
ASSERT(ConvOperationError::INPUT_INVALID, value >= min && value <= max) << oss.str();
}
int64_t ConvComputeHo(const Tensor& inputTensor, const Tensor& weightTensor, const ConvAttrParam& attrParam)
{
if (attrParam.isConv1D) {
return 1;
}
uint32_t indexH = attrParam.isConv3D ? NCDHW_H_IDX : NCHW_H_IDX;
std::vector<int64_t> strides = attrParam.strides;
int64_t strideH = strides[PAD_STRIDE_H];
if (strideH == 0) {
return 1;
}
std::vector<int64_t> paddings = attrParam.paddings;
std::vector<int64_t> dilations = attrParam.dilations;
int64_t padTop = paddings[PAD_TOP_INDEX];
int64_t padBottom = paddings[PAD_BOTTOM_INDEX];
int64_t dilationH = dilations[PAD_STRIDE_H];
int64_t hin = inputTensor.GetShape()[indexH];
int64_t kh = weightTensor.GetShape()[indexH];
int64_t cmpHo = (hin + padTop + padBottom - dilationH * (kh - 1) - 1) / strideH + 1;
return cmpHo;
}
int64_t ConvComputeWo(const Tensor& inputTensor, const Tensor& weightTensor, const ConvAttrParam& attrParam)
{
uint32_t indexW = attrParam.isConv3D ? NCDHW_W_IDX : (attrParam.isConv1D ? NCHW_H_IDX : NCHW_W_IDX);
uint32_t indexAttr = attrParam.isConv1D ? PAD_STRIDE_H : PAD_STRIDE_W;
std::vector<int64_t> strides = attrParam.strides;
int64_t strideW = strides[indexAttr];
if (strideW == 0) {
return 1;
}
std::vector<int64_t> paddings = attrParam.paddings;
std::vector<int64_t> dilations = attrParam.dilations;
int64_t dilationW = dilations[indexAttr];
int64_t padLeft = paddings[2 * indexAttr];
int64_t padRight = paddings[2 * indexAttr + 1];
int64_t win = inputTensor.GetShape()[indexW];
int64_t kw = weightTensor.GetShape()[indexW];
int64_t cmpWo = (win + padLeft + padRight - dilationW * (kw - 1) - 1) / strideW + 1;
return cmpWo;
}
int64_t ConvComputeDo(const Tensor& inputTensor, const Tensor& weightTensor, const ConvAttrParam& attrParam)
{
std::vector<int64_t> strides = attrParam.strides;
int64_t strideD = strides[PAD_STRIDE_D];
if (strideD == 0) {
return 1;
}
std::vector<int64_t> paddings = attrParam.paddings;
std::vector<int64_t> dilations = attrParam.dilations;
int64_t padHead = paddings[PAD_HEAD_INDEX];
int64_t padTail = paddings[PAD_TAIL_INDEX];
int64_t dilationD = dilations[PAD_STRIDE_D];
int64_t din = inputTensor.GetShape()[NCDHW_D_IDX];
int64_t kd = weightTensor.GetShape()[NCDHW_D_IDX];
int64_t cmpDo = (din + padHead + padTail - dilationD * (kd - 1) - 1) / strideD + 1;
return cmpDo;
}
void CheckOutputShape(const Tensor& inputTensor, const Tensor& weightTensor, const ConvAttrParam& attrParam)
{
int64_t hOut = ConvComputeHo(inputTensor, weightTensor, attrParam);
std::string hOutFormula = "hOut = (hin + 2 * pad_h - (kh - 1) * dilation_h - 1) / stride_h + 1";
CheckValueRange(hOut, "hOut", NUM1, MAX_SIZE, hOutFormula);
int64_t wOut = ConvComputeWo(inputTensor, weightTensor, attrParam);
std::string wOutFormula = "wOut = (win + 2 * pad_w - (kw - 1) * dilation_w - 1) / stride_w + 1";
CheckValueRange(wOut, "wOut", NUM1, MAX_SIZE, wOutFormula);
if (attrParam.isConv3D) {
int64_t dOut = ConvComputeDo(inputTensor, weightTensor, attrParam);
std::string dOutFormula = "dOut = (din + 2 * pad_d - (kd - 1) * dilation_d - 1) / stride_d + 1";
CheckValueRange(dOut, "dOut", NUM1, MAX_SIZE, dOutFormula);
}
}
void CheckAlignment(int64_t value, int64_t alignment, const std::string& valueName)
{
ASSERT(ConvOperationError::INPUT_INVALID, alignment != 0) << "Error in alignment check for " << valueName << ".";
ASSERT(ConvOperationError::INPUT_INVALID, value % alignment == 0)
<< "Invalid " << valueName << ":" << value << ", requires " << alignment << "-element alignment.";
}
int64_t ConvAlignB(int64_t a, int64_t b)
{
if (b == 0) {
return 0;
}
return ((a + b - 1) / b) * b;
}
void CheckHowoTile(const Tensor& inputTensor, const Tensor& weightTensor, const ConvAttrParam& attrParam)
{
auto& convTile = TileShape::Current().GetConvTile();
int64_t tileHout = convTile.tileL1Info.tileHout;
int64_t tileWout = convTile.tileL1Info.tileWout;
int64_t hOut = ConvComputeHo(inputTensor, weightTensor, attrParam);
int64_t wOut = ConvComputeWo(inputTensor, weightTensor, attrParam);
if (wOut % NUM16 != 0) {
ASSERT(ConvOperationError::INPUT_INVALID, tileHout == 1)
<< "When wOut is not a multiple of 16, tileHout should be 1.";
}
CheckValueRange(tileHout, "tileHout", NUM1, hOut);
CheckValueRange(tileWout, "tileWout", NUM1, ConvAlignB(wOut, NUM16));
CheckAlignment(tileWout, NUM16, "tileWout");
}
void ValidateL0Constraint(
int64_t tile1, int64_t tile2, int64_t tile3, size_t dtypeSize, size_t cacheSize, const std::string& cacheName,
const std::string& dim1Name, const std::string& dim2Name, const std::string& dim3Name)
{
ASSERT(ConvOperationError::OVER_BUFFER_LIMIT, tile1 * tile2 * tile3 * dtypeSize <= cacheSize)
<< "Shape does not satisfy " << cacheName << " load constraints, " << dim1Name << ":" << tile1 << ", "
<< dim2Name << ":" << tile2 << ", " << dim3Name << ":" << tile3 << ", which must satisfy " << dim1Name << " × "
<< dim2Name << " × " << dim3Name << " × dtypesize ≤ " << cacheName << "Size(" << cacheSize << ").";
}
void CheckL0TileTiling(
DataType outType, const ConvAttrParam& attrParam, const Tensor& weightTensor, const Tensor& inputTensor)
{
auto& convTile = TileShape::Current().GetConvTile();
int64_t tileH = convTile.tileL0Info.tileH, tileW = convTile.tileL0Info.tileW;
int64_t tileN = convTile.tileL0Info.tileN, tileK = convTile.tileL0Info.tileK;
int64_t tileHout = convTile.tileL1Info.tileHout, tileWout = convTile.tileL1Info.tileWout;
int64_t tileCout = convTile.tileL1Info.tileN, k0 = ALIGN_SIZE_32 / BytesOf(outType);
int64_t tileCinFmap = convTile.tileL1Info.tileCinFmap;
int64_t tileCinWeight = convTile.tileL1Info.tileCinWeight;
uint32_t indexH = attrParam.isConv3D ? NCDHW_H_IDX : NCHW_H_IDX;
uint32_t indexW = attrParam.isConv3D ? NCDHW_W_IDX : (attrParam.isConv1D ? NCHW_H_IDX : NCHW_W_IDX);
int64_t kh = attrParam.isConv1D ? 1 : weightTensor.GetShape()[indexH];
int64_t kw = weightTensor.GetShape()[indexW];
int64_t cin = inputTensor.GetShape()[NCHW_C_IDX];
int64_t cout = weightTensor.GetShape()[NCHW_N_IDX];
int64_t kAL1 = ConvAlignB(tileCinFmap, k0) * kh * kw, oriK = ConvAlignB(cin, k0) * kh * kw;
int64_t kBL1 = ConvAlignB(tileCinWeight, k0) * kh * kw;
int64_t batch = inputTensor.GetShape()[NCHW_N_IDX], groups = attrParam.groups;
int numTileL0 =
batch * groups * CeilDiv(cout / groups, tileN) * CeilDiv(tileHout, tileH) * CeilDiv(tileWout, tileW);
if (attrParam.isConv3D) {
int64_t kd = weightTensor.GetShape()[NCDHW_D_IDX];
int64_t dout = ConvComputeDo(inputTensor, weightTensor, attrParam);
numTileL0 *= dout;
kAL1 *= kd;
kBL1 *= kd;
oriK *= kd;
}
if (numTileL0 > MAX_LOOP || CeilDiv(oriK, tileK) > MAX_LOOP) {
CONV_LOGW("Suggestion: Consider increasing tile size to reduce compilation time.");
}
int64_t minKL1 = std::min(kAL1, kBL1);
CheckAlignment(tileK, k0, "tileK");
CheckValueRange(tileH, "tileH", NUM1, tileHout);
CheckValueRange(tileW, "tileW", NUM1, tileWout);
CheckValueRange(tileK, "tileK", NUM1, minKL1);
CheckAlignment(tileN, MKN_N_VALUE, "tileL0Info.tileN");
CheckAlignment(tileW, MKN_N_VALUE, "tileW");
CheckValueRange(tileN, "tileL0Info.tileN", NUM1, ConvAlignB(tileCout, MKN_N_VALUE));
ASSERT(ConvOperationError::INPUT_INVALID, kAL1 % tileK == 0 && kBL1 % tileK == 0)
<< "Invalid tileK: " << tileK << ", must be a factor of both kAL1:" << kAL1 << " and kBL1:" << kBL1;
Platform& platform = Platform::Instance();
size_t l0aSize = platform.GetAICCore().GetMemorySize(MemoryType::MEM_L0A);
size_t l0bSize = platform.GetAICCore().GetMemorySize(MemoryType::MEM_L0B);
size_t l0cSize = platform.GetAICCore().GetMemorySize(MemoryType::MEM_L0C);
ValidateL0Constraint(tileH, tileW, tileK, BytesOf(outType), l0aSize, "L0A", "tileH", "tileW", "tileK");
ValidateL0Constraint(tileK, tileN, 1, BytesOf(outType), l0bSize, "L0B", "tileK", "tileN", "");
ValidateL0Constraint(tileH, tileW, tileN, BytesOf(DataType::DT_FP32), l0cSize, "L0C", "tileH", "tileW", "tileN");
}
void CheckDivisible(int64_t value, int64_t divisor, const std::string& valueName, const std::string& divisorName)
{
ASSERT(ConvOperationError::INPUT_INVALID, divisor != 0) << divisorName << " cannot be zero.";
ASSERT(ConvOperationError::INPUT_INVALID, value % divisor == 0)
<< "The value of " << divisorName << " (" << divisor << ") does not divide " << valueName << "(" << value
<< "). Adjusting " << divisorName << " to the nearest value such that " << valueName << " % " << divisorName
<< " == 0.";
}
void CheckTileTiling(
DataType outType, const Tensor& inputTensor, const Tensor& weightTensor, const ConvAttrParam& attrParam)
{
auto convTile = TileShape::Current().GetConvTile();
int64_t tileHin = convTile.tileL1Info.tileHin;
int64_t tileWin = convTile.tileL1Info.tileWin;
int64_t tileCinFmap = convTile.tileL1Info.tileCinFmap;
int64_t tileCinWeight = convTile.tileL1Info.tileCinWeight;
int64_t tileN = convTile.tileL1Info.tileN;
int64_t tileBatch = convTile.tileL1Info.tileBatch;
int64_t groups = attrParam.groups;
uint32_t indexH = attrParam.isConv3D ? NCDHW_H_IDX : NCHW_H_IDX;
uint32_t indexW = attrParam.isConv3D ? NCDHW_W_IDX : (attrParam.isConv1D ? NCHW_H_IDX : NCHW_W_IDX);
int64_t cOut = weightTensor.GetShape()[NCHW_N_IDX];
int64_t hin = attrParam.isConv1D ? 1 : inputTensor.GetShape()[indexH];
int64_t win = inputTensor.GetShape()[indexW];
CheckValueRange(tileHin, "tileHin", NUM1, hin);
CheckValueRange(tileBatch, "tileBatch", NUM1, NUM1);
CheckValueRange(tileWin, "tileWin", NUM1, win);
CheckValueRange(tileN, "tileL1Info.tileN", NUM1, ConvAlignB(cOut / groups, MKN_N_VALUE));
CheckAlignment(tileN, MKN_N_VALUE, "tileL1Info.tileN");
CheckHowoTile(inputTensor, weightTensor, attrParam);
int64_t k0 = ALIGN_SIZE_32 / BytesOf(outType);
CheckAlignment(tileCinFmap, k0, "tileCinFmap");
CheckAlignment(tileCinWeight, k0, "tileCinWeight");
if (convTile.setL0Tile) {
CheckL0TileTiling(outType, attrParam, weightTensor, inputTensor);
}
}
uint64_t Conv2DInferHiL1(uint64_t inputHoL1, uint64_t khDilated, uint64_t hi, uint64_t strideH)
{
uint64_t tmpHiL1 = (inputHoL1 - 1) * strideH + khDilated;
if (tmpHiL1 > hi) {
tmpHiL1 = hi;
}
return tmpHiL1;
}
void CheckL1SizeTiling(
DataType outType, const Tensor& inputTensor, const Tensor& weightTensor, const Tensor& biasTensor,
const ConvAttrParam& attrParam)
{
auto convTile = TileShape::Current().GetConvTile();
Platform& platform = Platform::Instance();
uint64_t l1Size = platform.GetAIVCore().GetMemorySize(MemoryType::MEM_L1);
uint32_t indexH = attrParam.isConv3D ? NCDHW_H_IDX : NCHW_H_IDX;
uint32_t indexW = attrParam.isConv3D ? NCDHW_W_IDX : (attrParam.isConv1D ? NCHW_H_IDX : NCHW_W_IDX);
uint64_t kh = attrParam.isConv1D ? 1 : weightTensor.GetShape()[indexH];
uint64_t hin = attrParam.isConv1D ? 1 : inputTensor.GetShape()[indexH];
uint64_t kw = weightTensor.GetShape()[indexW];
uint64_t win = inputTensor.GetShape()[indexW];
uint64_t k0 = ALIGN_SIZE_32 / BytesOf(outType);
std::vector<int64_t> strides = attrParam.strides;
std::vector<int64_t> dilations = attrParam.dilations;
uint32_t indexAttrW = attrParam.isConv1D ? PAD_STRIDE_H : PAD_STRIDE_W;
uint64_t strideH = attrParam.isConv1D ? 1 : strides[PAD_STRIDE_H];
uint64_t strideW = strides[indexAttrW];
uint64_t dilationH = attrParam.isConv1D ? 1 : dilations[PAD_STRIDE_H];
uint64_t dilationW = dilations[indexAttrW];
uint64_t biasL1Size = 0;
uint64_t tileN = convTile.tileL1Info.tileN;
if (!biasTensor.IsEmpty()) {
biasL1Size = ConvAlignB(tileN * BytesOf(outType), ALIGN_SIZE_32);
}
uint64_t tileCinFmap = convTile.tileL1Info.tileCinFmap;
uint64_t tileCinWeight = convTile.tileL1Info.tileCinWeight;
uint64_t kBL1 = ConvAlignB(tileCinWeight * kh * kw, k0);
uint64_t weightL1Size = ConvAlignB(kBL1 * tileN * BytesOf(outType), ALIGN_SIZE_32);
uint64_t inputL1Size = 0;
uint64_t tileWout = convTile.tileL1Info.tileWout;
uint64_t tileHout = convTile.tileL1Info.tileHout;
uint64_t khDilated = (kh - 1) * dilationH + 1;
uint64_t hiAL1 = std::min((tileHout - 1) * strideH + khDilated, hin);
uint64_t kwDilated = (kw - 1) * dilationW + 1;
uint64_t wiAL1 = std::min((tileWout - 1) * strideW + kwDilated, win);
;
inputL1Size = ConvAlignB(hiAL1 * wiAL1 * tileCinFmap * BytesOf(outType), ALIGN_SIZE_32);
uint64_t minL1LoadSize = biasL1Size + inputL1Size + weightL1Size;
ASSERT(ConvOperationError::OVER_BUFFER_LIMIT, minL1LoadSize <= l1Size)
<< "MinL1LoadSize > L1size, current MinL1LoadSize: " << minL1LoadSize << ", L1size: " << l1Size << ".";
}
void CheckGroupsShape(const int64_t cinFmap, const int64_t cinWeight, const int64_t cOut, const int64_t groups)
{
CheckValueRange(groups, "groups", NUM1, SHAPE_INNER_AXIS_MAX_SIZE);
CheckDivisible(cinFmap, groups, "Cin", "groups");
CheckDivisible(cOut, groups, "Cout", "groups");
ASSERT(ConvOperationError::INPUT_INVALID, cinFmap == cinWeight * groups)
<< "Fmap Cin (" << cinFmap << ") != weight Cin (" << cinWeight << ") * groups (" << groups << ").";
}
void CheckDimParam(const std::vector<int64_t>& vec, const std::string& name, int expectedDim)
{
ASSERT(ConvOperationError::INPUT_INVALID, vec.size() == static_cast<size_t>(expectedDim))
<< "Input attr " << name << " dim: " << vec.size() << " != " << expectedDim << ".";
}
void CheckDimensionRange(const std::vector<int64_t>& vec, const std::string& name, int minVal, int maxVal)
{
for (size_t i = 0; i < vec.size(); ++i) {
ASSERT(ConvOperationError::INPUT_INVALID, vec[i] >= minVal && vec[i] <= maxVal)
<< "The value of the " << i << "-th dimension of " << name << " must be in the range [" << minVal << ","
<< maxVal << "].Current value:" << vec[i] << ".";
}
}
void CheckLoad3dShape(DataType outType, const Tensor& weightTensor, const ConvAttrParam& attrParam)
{
std::vector<int64_t> paddings = attrParam.paddings;
std::vector<int64_t> dilations = attrParam.dilations;
std::vector<int64_t> strides = attrParam.strides;
if (attrParam.isConv3D) {
paddings = rotateVector(paddings, NUM4);
dilations = rotateVector(dilations, NUM2);
strides = rotateVector(strides, NUM2);
}
CheckDimensionRange(paddings, "paddings", 0, MAX_PAD_KERNEL);
CheckDimensionRange(dilations, "dilations", NUM1, MAX_DILATION_STRIDE);
CheckDimensionRange(strides, "strides", NUM1, MAX_DILATION_STRIDE);
uint32_t indexH = attrParam.isConv3D ? NCDHW_H_IDX : NCHW_H_IDX;
uint32_t indexW = attrParam.isConv3D ? NCDHW_W_IDX : (attrParam.isConv1D ? NCHW_H_IDX : NCHW_W_IDX);
int64_t kw = weightTensor.GetShape()[indexW];
int64_t kh = attrParam.isConv1D ? 1 : weightTensor.GetShape()[indexH];
ASSERT(ConvOperationError::INPUT_INVALID, kh <= MAX_PAD_KERNEL && kw <= MAX_PAD_KERNEL)
<< "Weight shapes do not satisfy Load3D's" << (attrParam.isConv1D ? " limit: kw=" : " limits: kh=")
<< (attrParam.isConv1D ? kw : kh) << (attrParam.isConv1D ? "" : ", kw=" + std::to_string(kw))
<< ", which must <= " << MAX_PAD_KERNEL << ".";
int64_t k0 = ALIGN_SIZE_32 / BytesOf(outType);
ASSERT(ConvOperationError::INPUT_INVALID, kh * kw * k0 <= SHAPE_INNER_AXIS_MAX_SIZE)
<< "Weight shapes do not satisfy Load3D's limits: kh*kw*k0=" << kh * kw * k0
<< "(k0 = 32 bytes / dtypesize), which must <=" << SHAPE_INNER_AXIS_MAX_SIZE << ".";
}
void CheckAttrShape(
DataType outType, const Tensor& inputTensor, const Tensor& weightTensor, const ConvAttrParam& attrParam)
{
std::vector<int64_t> paddings = attrParam.paddings;
uint32_t index = attrParam.isConv3D ? SHAPE_DIM3 : (attrParam.isConv1D ? SHAPE_DIM1 : SHAPE_DIM2);
CheckDimParam(attrParam.paddings, "paddings", index * NUM2);
CheckDimParam(attrParam.dilations, "dilations", index);
CheckDimParam(attrParam.strides, "strides", index);
int64_t groups = attrParam.groups;
int64_t cinFmap = inputTensor.GetShape()[NCHW_C_IDX];
int64_t cinWeight = weightTensor.GetShape()[NCHW_C_IDX];
int64_t cOut = weightTensor.GetShape()[NCHW_N_IDX];
if (attrParam.isConv3D) {
paddings = rotateVector(paddings, NUM4);
}
const std::vector<std::string> dimNames = attrParam.isConv1D ? std::vector<std::string>{"L"} :
attrParam.isConv3D ? std::vector<std::string>{"D", "H", "W"} :
std::vector<std::string>{"H", "W"};
for (size_t i = 0; i < paddings.size() / NUM2; ++i) {
int weightVal = weightTensor.GetShape()[i + NUM2];
int paddingLeft = paddings[i * NUM2];
int paddingRight = paddings[i * NUM2 + 1];
ASSERT(ConvOperationError::INPUT_INVALID, paddingLeft < weightVal && paddingRight < weightVal)
<< "The value of the " << dimNames[i]
<< " dimension of weight must be > padding.Current weight value:" << weightVal
<< ",padding value:" << paddingLeft << " and " << paddingRight << ".";
}
CheckGroupsShape(cinFmap, cinWeight, cOut, groups);
CheckLoad3dShape(outType, weightTensor, attrParam);
if (IsArch32Platform() && groups > 1) {
int64_t c0 = ALIGN_SIZE_32 / BytesOf(weightTensor.GetStorage()->Datatype());
ASSERT(ConvOperationError::INPUT_INVALID, weightTensor.GetShape()[NCHW_N_IDX] % c0 == 0)
<< "Input illegal weight shape N:" << weightTensor.GetShape()[NCHW_N_IDX]
<< ", which must be a multiple of C0:" << c0 << ".";
}
}
void CheckOriginShape(const Tensor& inputTensor, const Tensor& weightTensor, const Tensor& biasTensor)
{
CheckDimensionRange(inputTensor.GetShape(), "fmap", NUM1, MAX_SIZE);
CheckDimensionRange(weightTensor.GetShape(), "weight", NUM1, MAX_SIZE);
if (biasTensor.IsEmpty()) {
return;
}
int64_t cOut = weightTensor.GetShape()[NCHW_N_IDX];
ASSERT(ConvOperationError::INPUT_INVALID, biasTensor.GetShape()[0] == cOut)
<< "Input illegal bias shape:" << biasTensor.GetShape()[0] << ", which must equal to Cout:" << cOut << ".";
}
void CheckConvOperands(
DataType outType, const Tensor& inputTensor, const Tensor& weightTensor, const Tensor& biasTensor,
ConvAttrParam& attrParam)
{
ASSERT(
ConvOperationError::INPUT_INVALID,
outType == DataType::DT_FP32 || outType == DataType::DT_FP16 || outType == DataType::DT_BF16)
<< "Unsupported output data type. Only DT_FP32, DT_FP16, DT_BF16 are supported.";
if (inputTensor.Dim() == CONV1D_INPUT_DIM && weightTensor.Dim() == CONV1D_INPUT_DIM) {
attrParam.isConv1D = true;
} else if (inputTensor.Dim() == CONV3D_INPUT_DIM && weightTensor.Dim() == CONV3D_INPUT_DIM) {
attrParam.isConv3D = true;
}
CheckOriginShape(inputTensor, weightTensor, biasTensor);
CheckOutputShape(inputTensor, weightTensor, attrParam);
CheckAttrShape(outType, inputTensor, weightTensor, attrParam);
CheckTileTiling(outType, inputTensor, weightTensor, attrParam);
CheckL1SizeTiling(outType, inputTensor, weightTensor, biasTensor, attrParam);
}
void SetTensorOpAttr(
Operation& op, const LogicalTensorPtr& inputTensor, const LogicalTensorPtr& weightTensor,
const LogicalTensorPtr& resTensor, const ConvAttrParam& convAttrParam)
{
op.SetAttribute(CONV_BIAS_ATTR, convAttrParam.hasBias);
op.SetAttribute(CONV_GROUPS_ATTR, convAttrParam.groups);
op.SetAttribute(CONV_PADDINGS_ATTR, convAttrParam.paddings);
op.SetAttribute(CONV_STRIDES_ATTR, convAttrParam.strides);
op.SetAttribute(CONV_DILATIONS_ATTR, convAttrParam.dilations);
op.SetAttribute(CONV_3D_FLAG, convAttrParam.isConv3D);
op.SetAttribute(CONV_ORI_FMAP_SHAPE_ATTR, inputTensor->GetShape());
op.SetAttribute(CONV_ORI_WEIGHT_SHAPE_ATTR, weightTensor->GetShape());
op.SetAttribute(CONV_ORI_RES_SHAPE_ATTR, resTensor->GetShape());
}
std::vector<LogicalTensorPtr> GetOperandVecIn(
std::vector<LogicalTensorPtr> operandVecIn, const ConvAttrParam& convAttrParam)
{
int64_t cin0 = ALIGN_SIZE_32 / BytesOf(operandVecIn[INPUT_FMAP_IDX]->Datatype());
int64_t batch = operandVecIn[INPUT_FMAP_IDX]->GetShape()[NCHW_N_IDX];
int64_t hi = convAttrParam.isConv3D ? operandVecIn[INPUT_FMAP_IDX]->GetShape()[NCDHW_H_IDX] :
operandVecIn[INPUT_FMAP_IDX]->GetShape()[NCHW_H_IDX];
int64_t wi = convAttrParam.isConv3D ? operandVecIn[INPUT_FMAP_IDX]->GetShape()[NCDHW_W_IDX] :
operandVecIn[INPUT_FMAP_IDX]->GetShape()[NCHW_W_IDX];
int64_t cout = operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCHW_N_IDX];
int64_t cinPerGroup = operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCHW_C_IDX];
int64_t kh = convAttrParam.isConv3D ? operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCDHW_H_IDX] :
operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCHW_H_IDX];
int64_t kw = convAttrParam.isConv3D ? operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCDHW_W_IDX] :
operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCHW_W_IDX];
int64_t cin1PerGroup = CeilDiv(cinPerGroup, cin0);
int64_t cout1PerGroup = CeilDiv(cout / convAttrParam.groups, MKN_N_VALUE);
std::vector<int64_t> inputNzShape = {batch, convAttrParam.groups * cin1PerGroup, hi, wi, cin0};
std::vector<int64_t> weightFzShape =
{convAttrParam.groups * cin1PerGroup * kh * kw, cout1PerGroup, MKN_N_VALUE, cin0};
TileOpFormat inputNzFormat = TileOpFormat::TILEOP_NC1HWC0;
TileOpFormat weightFzFormat = TileOpFormat::TILEOP_FRACTAL_Z;
if (convAttrParam.isConv3D) {
inputNzFormat = TileOpFormat::TILEOP_NDC1HWC0;
weightFzFormat = TileOpFormat::TILEOP_FRACTAL_Z_3D;
int64_t din = operandVecIn[INPUT_FMAP_IDX]->GetShape()[NCDHW_D_IDX];
int64_t kd = operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCDHW_D_IDX];
inputNzShape = {batch, din, convAttrParam.groups * cin1PerGroup, hi, wi, cin0};
weightFzShape = {convAttrParam.groups * kd * cin1PerGroup * kh * kw, cout1PerGroup, MKN_N_VALUE, cin0};
}
Tensor inputNzTensor(operandVecIn[INPUT_FMAP_IDX]->Datatype(), inputNzShape, "TensorInputNz", inputNzFormat);
Tensor weightFzTensor(operandVecIn[INPUT_FMAP_IDX]->Datatype(), weightFzShape, "TensorWeightFz", weightFzFormat);
return {inputNzTensor.GetStorage(), weightFzTensor.GetStorage()};
}
Tensor ConstructTensorGraphNZ2NZ(
Function* functionPtr, std::vector<LogicalTensorPtr> operandVecIn, const Tensor& resTensor,
const ConvAttrParam& convAttrParam)
{
std::vector<LogicalTensorPtr> operandVecOut = {resTensor.GetStorage()};
std::vector<LogicalTensorPtr> operandVecInNZ = GetOperandVecIn(operandVecIn, convAttrParam);
auto& inputTransOp = functionPtr->AddOperation(Opcode::OP_FAKE_TRANS, {operandVecIn[INPUT_FMAP_IDX]},
{operandVecInNZ[INPUT_FMAP_IDX]});
inputTransOp.SetAttribute(FAKE_TRANS_IN_FORMAT_ATTR, static_cast<int64_t>(operandVecIn[INPUT_FMAP_IDX]->Format()));
inputTransOp.SetAttribute(
FAKE_TRANS_OUT_FORMAT_ATTR, static_cast<int64_t>(operandVecInNZ[INPUT_FMAP_IDX]->Format()));
auto& weightTransOp = functionPtr->AddOperation(Opcode::OP_FAKE_TRANS, {operandVecIn[INPUT_WEIGHT_IDX]},
{operandVecInNZ[INPUT_WEIGHT_IDX]});
weightTransOp.SetAttribute(
FAKE_TRANS_IN_FORMAT_ATTR, static_cast<int64_t>(operandVecIn[INPUT_WEIGHT_IDX]->Format()));
weightTransOp.SetAttribute(
FAKE_TRANS_OUT_FORMAT_ATTR, static_cast<int64_t>(operandVecInNZ[INPUT_WEIGHT_IDX]->Format()));
if (convAttrParam.hasBias) {
operandVecInNZ.push_back(operandVecIn[INPUT_BIAS_IDX]);
}
Opcode convOpCode = convAttrParam.isConv3D ? Opcode::OP_CONV3D : Opcode::OP_CONV2D;
auto& op = functionPtr->AddOperation(convOpCode, operandVecInNZ, operandVecOut);
std::vector<int64_t> orgOutShape = {resTensor.GetShape()[NC1HWC0_N_IDX],
operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCHW_N_IDX], resTensor.GetShape()[NC1HWC0_H_IDX],
resTensor.GetShape()[NC1HWC0_W_IDX]};
if (convAttrParam.isConv3D) {
orgOutShape = {resTensor.GetShape()[NDC1HWC0_N_IDX], operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCHW_N_IDX],
resTensor.GetShape()[NDC1HWC0_D_IDX], resTensor.GetShape()[NDC1HWC0_H_IDX],
resTensor.GetShape()[NDC1HWC0_W_IDX]};
}
Tensor finalResTensor(resTensor.GetStorage()->Datatype(), orgOutShape, "TensorOut");
auto& orgResOp = functionPtr->AddOperation(Opcode::OP_FAKE_TRANS, operandVecOut, {finalResTensor.GetStorage()});
orgResOp.SetAttribute(FAKE_TRANS_IN_FORMAT_ATTR, static_cast<int64_t>(resTensor.Format()));
orgResOp.SetAttribute(FAKE_TRANS_OUT_FORMAT_ATTR, static_cast<int64_t>(finalResTensor.Format()));
SetTensorOpAttr(op, operandVecIn[INPUT_FMAP_IDX], operandVecIn[INPUT_WEIGHT_IDX], finalResTensor.GetStorage(), convAttrParam);
if (convAttrParam.isConv1D) {
orgOutShape = {resTensor.GetShape()[NC1HWC0_N_IDX], operandVecIn[INPUT_WEIGHT_IDX]->GetShape()[NCHW_N_IDX],
resTensor.GetShape()[NC1HWC0_W_IDX]};
Tensor finalRes3DimTensor(resTensor.GetStorage()->Datatype(), orgOutShape, "TensorOut3Dim");
auto& reshapeResOp = functionPtr->AddOperation(Opcode::OP_RESHAPE, {finalResTensor.GetStorage()},
{finalRes3DimTensor.GetStorage()});
reshapeResOp.SetAttribute("isConv", true);
return finalRes3DimTensor;
}
return finalResTensor;
}
Tensor ConstructTensorGraph(
const Tensor& inputTensor, const Tensor& weightTensor, const Tensor& biasTensor, const Tensor& resTensor,
ConvAttrParam& convAttrParam)
{
Function* functionPtr = Program::GetInstance().GetCurrentFunction();
ASSERT(ConvExpandFuncError::EXPANDFUNC_TILE_OP_NULLPTR, functionPtr != nullptr) << "functionPtr is nullptr.";
std::vector<LogicalTensorPtr> operandVecIn = {inputTensor.GetStorage(), weightTensor.GetStorage()};
std::vector<LogicalTensorPtr> operandVecOut = {resTensor.GetStorage()};
if (convAttrParam.isConv1D) {
std::vector<int64_t> fmap4DimShape{inputTensor.GetShape()[NCHW_N_IDX], inputTensor.GetShape()[NCHW_C_IDX], 1,
inputTensor.GetShape()[NCHW_H_IDX]};
Tensor fmap4DimTensor(inputTensor.GetStorage()->Datatype(), fmap4DimShape, "", inputTensor.Format());
std::vector<int64_t> weight4DimShape{weightTensor.GetShape()[NCHW_N_IDX], weightTensor.GetShape()[NCHW_C_IDX],
1, weightTensor.GetShape()[NCHW_H_IDX]};
Tensor weigth4DimTensor(weightTensor.GetStorage()->Datatype(), weight4DimShape, "", weightTensor.Format());
auto& reshapeFmapOp =
functionPtr->AddOperation(Opcode::OP_RESHAPE, {inputTensor.GetStorage()}, {fmap4DimTensor.GetStorage()});
auto& reshapeWeightOp =
functionPtr->AddOperation(Opcode::OP_RESHAPE, {weightTensor.GetStorage()}, {weigth4DimTensor.GetStorage()});
reshapeFmapOp.SetAttribute("isConv", true);
reshapeWeightOp.SetAttribute("isConv", true);
operandVecIn = {fmap4DimTensor.GetStorage(), weigth4DimTensor.GetStorage()};
std::vector<int64_t> res4DimShape{inputTensor.GetShape()[NCHW_N_IDX], weightTensor.GetShape()[NCHW_N_IDX], 1,
resTensor.GetShape()[NCHW_H_IDX]};
Tensor res4DimTensor(resTensor.GetStorage()->Datatype(), res4DimShape, "", resTensor.Format());
operandVecOut = {res4DimTensor.GetStorage()};
}
if (!biasTensor.IsEmpty()) {
convAttrParam.hasBias = true;
std::vector<int64_t> bias2DimShape{1, biasTensor.GetShape()[0]};
Tensor bias2DimTensor(biasTensor.GetStorage()->Datatype(), bias2DimShape, "", biasTensor.Format());
auto& reshapeBiasOp =
functionPtr->AddOperation(Opcode::OP_RESHAPE, {biasTensor.GetStorage()}, {bias2DimTensor.GetStorage()});
reshapeBiasOp.SetAttribute("isConv", true);
operandVecIn.push_back(bias2DimTensor.GetStorage());
}
if (IsArch32Platform()) {
return ConstructTensorGraphNZ2NZ(functionPtr, operandVecIn, resTensor, convAttrParam);
}
Opcode convOpCode = convAttrParam.isConv3D ? Opcode::OP_CONV3D : Opcode::OP_CONV2D;
auto& op = functionPtr->AddOperation(convOpCode, operandVecIn, operandVecOut);
SetTensorOpAttr(op, operandVecIn[INPUT_FMAP_IDX], operandVecIn[INPUT_WEIGHT_IDX], operandVecOut[0], convAttrParam);
if (convAttrParam.isConv1D) {
auto& reshapeResOp = functionPtr->AddOperation(Opcode::OP_RESHAPE, operandVecOut, {resTensor.GetStorage()});
reshapeResOp.SetAttribute("isConv", true);
}
return resTensor;
}
void SetConvAttrParam(const Operation& op, ConvAttrParam& convAttrParam)
{
convAttrParam.isConv3D = (op.HasAttr(CONV_3D_FLAG)) ? op.GetBoolAttribute(CONV_3D_FLAG) : false;
convAttrParam.paddings = (op.HasAttr(CONV_PADDINGS_ATTR)) ? op.GetVectorIntAttribute(CONV_PADDINGS_ATTR) :
convAttrParam.isConv3D ? CONV3D_ATTR_DEFAULT_LIST :
CONV2D_PAD_ATTR_DEFAULT_LIST;
convAttrParam.strides = (op.HasAttr(CONV_STRIDES_ATTR)) ? op.GetVectorIntAttribute(CONV_STRIDES_ATTR) :
convAttrParam.isConv3D ? CONV3D_ATTR_DEFAULT_LIST :
CONV2D_ATTR_DEFAULT_LIST;
convAttrParam.dilations = (op.HasAttr(CONV_DILATIONS_ATTR)) ? op.GetVectorIntAttribute(CONV_DILATIONS_ATTR) :
convAttrParam.isConv3D ? CONV3D_ATTR_DEFAULT_LIST :
CONV2D_ATTR_DEFAULT_LIST;
convAttrParam.groups = (op.HasAttr(CONV_GROUPS_ATTR)) ? op.GetIntAttribute(CONV_GROUPS_ATTR) : 1;
convAttrParam.hasBias = (op.HasAttr(CONV_BIAS_ATTR)) ? op.GetBoolAttribute(CONV_BIAS_ATTR) : false;
convAttrParam.isInOutTensorNZ = false;
ASSERT(ConvExpandFuncError::EXPANDFUNC_TENSOR_ATTR_GET_FAILED, op.HasAttr(CONV_ORI_FMAP_SHAPE_ATTR))
<< "Conv ori fmapshape should be set when InOut Tensor NZ mode.";
ASSERT(ConvExpandFuncError::EXPANDFUNC_TENSOR_ATTR_GET_FAILED, op.HasAttr(CONV_ORI_WEIGHT_SHAPE_ATTR))
<< "Conv ori weightshape should be set when InOut Tensor NZ mode.";
convAttrParam.oriFmapShape = op.GetVectorIntAttribute(CONV_ORI_FMAP_SHAPE_ATTR);
convAttrParam.oriWeightShape = op.GetVectorIntAttribute(CONV_ORI_WEIGHT_SHAPE_ATTR);
convAttrParam.oriResShape = op.GetVectorIntAttribute(CONV_ORI_RES_SHAPE_ATTR);
}
void SetTensorGraphNodes(
const std::vector<LogicalTensorPtr>& operandVec, const LogicalTensorPtr& cTensorPtr,
const ConvAttrParam& convAttrParam, ConvGraphNodes& tensorGraphNodes)
{
size_t operandVecSize = SHAPE_DIM2 + static_cast<size_t>(convAttrParam.hasBias);
ASSERT(ConvExpandFuncError::EXPANDFUNC_PARAMS_INVALID, operandVec.size() == operandVecSize)
<< "Operand vector size mismatch: "
<< "Expected size: " << operandVecSize << ", actual size: " << operandVec.size()
<< ", Conv Common Input: " << SHAPE_DIM2 << ", hasBias: " << convAttrParam.hasBias;
tensorGraphNodes.fmapTensorPtr = operandVec[INPUT_FMAP_IDX];
tensorGraphNodes.weightTensorPtr = operandVec[INPUT_WEIGHT_IDX];
if (convAttrParam.hasBias) {
tensorGraphNodes.biasTensorPtr = operandVec[INPUT_BIAS_IDX];
}
ASSERT(
ConvExpandFuncError::EXPANDFUNC_TILE_OP_NULLPTR,
tensorGraphNodes.fmapTensorPtr != nullptr && tensorGraphNodes.weightTensorPtr != nullptr)
<< "Expected aTensorPtr and bTensorPtr to be non-nullptr.";
ASSERT(ConvExpandFuncError::EXPANDFUNC_TILE_OP_NULLPTR, cTensorPtr != nullptr) << "cTensorPtr is nullptr.";
tensorGraphNodes.resTensorPtr = cTensorPtr;
}
void SetConvShapeInfo(
const TileShape& tileShape, const ConvGraphNodes& tensorGraphNodes, const ConvAttrParam& convAttrParam,
ConvTileInfo& convTileInfo)
{
convTileInfo.orgBatch =
convAttrParam.isConv3D ? convAttrParam.oriFmapShape[NCDHW_N_IDX] : convAttrParam.oriFmapShape[NCHW_N_IDX];
convTileInfo.orgHin =
convAttrParam.isConv3D ? convAttrParam.oriFmapShape[NCDHW_H_IDX] : convAttrParam.oriFmapShape[NCHW_H_IDX];
convTileInfo.orgWin =
convAttrParam.isConv3D ? convAttrParam.oriFmapShape[NCDHW_W_IDX] : convAttrParam.oriFmapShape[NCHW_W_IDX];
convTileInfo.orgCin =
convAttrParam.isConv3D ? convAttrParam.oriFmapShape[NCDHW_C_IDX] : convAttrParam.oriFmapShape[NCHW_C_IDX];
convTileInfo.orgHout =
convAttrParam.isConv3D ? convAttrParam.oriResShape[NCDHW_H_IDX] : convAttrParam.oriResShape[NCHW_H_IDX];
convTileInfo.orgWout =
convAttrParam.isConv3D ? convAttrParam.oriResShape[NCDHW_W_IDX] : convAttrParam.oriResShape[NCHW_W_IDX];
convTileInfo.orgDin = convAttrParam.isConv3D ? convAttrParam.oriFmapShape[NCDHW_D_IDX] : 1;
convTileInfo.orgDout = convAttrParam.isConv3D ? convAttrParam.oriResShape[NCDHW_D_IDX] : 1;
convTileInfo.cin0 = ALIGN_SIZE_32 / BytesOf(tensorGraphNodes.fmapTensorPtr->Datatype());
convTileInfo.orgCout =
convAttrParam.isConv3D ? convAttrParam.oriWeightShape[NCDHW_N_IDX] : convAttrParam.oriWeightShape[NCHW_N_IDX];
convTileInfo.orgKh =
convAttrParam.isConv3D ? convAttrParam.oriWeightShape[NCDHW_H_IDX] : convAttrParam.oriWeightShape[NCHW_H_IDX];
convTileInfo.orgKw =
convAttrParam.isConv3D ? convAttrParam.oriWeightShape[NCDHW_W_IDX] : convAttrParam.oriWeightShape[NCHW_W_IDX];
convTileInfo.orgKd = convAttrParam.isConv3D ? convAttrParam.oriWeightShape[NCDHW_D_IDX] : 1;
int64_t cinPerGroup = convTileInfo.orgCin / convAttrParam.groups;
convTileInfo.orgHoutWout = convTileInfo.orgHout * convTileInfo.orgWout;
convTileInfo.kPerGroup = ConvAlignB(cinPerGroup, convTileInfo.cin0) * convTileInfo.orgKh * convTileInfo.orgKw;
convTileInfo.coutPerGroup = convTileInfo.orgCout / convAttrParam.groups;
auto& convTile = tileShape.GetConvTile();
convTileInfo.kAL1 = convTile.tileL1Info.tileCinFmap * convTileInfo.orgKh * convTileInfo.orgKw;
convTileInfo.kBL1 = convTile.tileL1Info.tileCinWeight * convTileInfo.orgKh * convTileInfo.orgKw;
convTileInfo.nBL1 = convTile.tileL1Info.tileN;
convTileInfo.hAL1In = convTile.tileL1Info.tileHin;
convTileInfo.wAL1In = convTile.tileL1Info.tileWin;
convTileInfo.hAL1Out = convTile.tileL1Info.tileHout;
convTileInfo.wAL1Out = convTile.tileL1Info.tileWout;
convTileInfo.kL0 = convTile.tileL0Info.tileK;
convTileInfo.hL0 = convTile.tileL0Info.tileH;
convTileInfo.wL0 = convTile.tileL0Info.tileW;
convTileInfo.nL0 = convTile.tileL0Info.tileN;
}
LogicalTensorPtr ConstructBiasTile(
Function& function, const ConvGraphNodes& tensorGraphNodes, ConvIterInfo& iterInfo, ConvTileInfo& convTileInfo)
{
std::vector<int64_t> dstBiasL1Shape = std::vector<int64_t>{1, iterInfo.nL0Size};
std::vector<int64_t> dstBiasL1Offset = std::vector<int64_t>{
0, iterInfo.groupOffset * convTileInfo.coutPerGroup + iterInfo.nL1Offset + iterInfo.nL0Offset};
LogicalTensorPtr dstBiasl1TensorPtr = std::make_shared<LogicalTensor>(
function, tensorGraphNodes.biasTensorPtr->Datatype(), dstBiasL1Shape,
SymbolicScalar::FromConcrete(dstBiasL1Shape), tensorGraphNodes.biasTensorPtr->Format(), "biasL1Tensor");
dstBiasl1TensorPtr->UpdateDynValidShape(SymbolicScalar::FromConcrete(dstBiasL1Shape));
auto& viewOpBiasL1 = function.AddOperation(Opcode::OP_VIEW, {tensorGraphNodes.biasTensorPtr}, {dstBiasl1TensorPtr});
auto viewAttributeBiasL1 = std::make_shared<ViewOpAttribute>(
dstBiasL1Offset, MemoryType::MEM_L1, SymbolicScalar::FromConcrete(dstBiasL1Offset),
dstBiasl1TensorPtr->GetDynValidShape());
viewOpBiasL1.SetOpAttribute(viewAttributeBiasL1);
viewOpBiasL1.SetAttribute(Matrix::A_MUL_B_COPY_IN_MODE, static_cast<int64_t>(Matrix::CopyInMode::ND2ND));
std::vector<int64_t> dstBiasBtShape = std::vector<int64_t>{1, iterInfo.nL0Size};
std::vector<int64_t> dstBiasBtOffset = std::vector<int64_t>{0, iterInfo.nL0Offset};
LogicalTensorPtr dstBiasBtTensorPtr = std::make_shared<LogicalTensor>(
function, DataType::DT_FP32, dstBiasBtShape, SymbolicScalar::FromConcrete(dstBiasBtShape),
tensorGraphNodes.biasTensorPtr->Format(), "biasBtTensor");
dstBiasBtTensorPtr->UpdateDynValidShape(SymbolicScalar::FromConcrete(dstBiasBtShape));
auto& viewOpBiasBt = function.AddOperation(Opcode::OP_VIEW, {dstBiasl1TensorPtr}, {dstBiasBtTensorPtr});
auto viewAttributeBiasBt = std::make_shared<ViewOpAttribute>(
dstBiasBtOffset, MemoryType::MEM_BT, SymbolicScalar::FromConcrete(dstBiasBtOffset),
dstBiasBtTensorPtr->GetDynValidShape());
viewOpBiasBt.SetOpAttribute(viewAttributeBiasBt);
return dstBiasBtTensorPtr;
}
void SetImg2ColAttr(
Operation& load3dOpAl0, const ConvAttrParam& convAttrParam, ConvIterInfo& iterInfo,
const ConvTileInfo& convTileInfo)
{
int64_t strideH = convAttrParam.strides[0];
int64_t strideW = convAttrParam.strides[1];
int64_t dilationH = convAttrParam.dilations[0];
int64_t dilationW = convAttrParam.dilations[1];
int64_t dilatedKernelH = (convTileInfo.orgKh - 1) * dilationH + 1;
int64_t dilatedKernelW = (convTileInfo.orgKw - 1) * dilationW + 1;
load3dOpAl0.SetAttribute(OpAttributeKey::strideH, strideH);
load3dOpAl0.SetAttribute(OpAttributeKey::strideW, strideW);
load3dOpAl0.SetAttribute(OpAttributeKey::dilationH, dilationH);
load3dOpAl0.SetAttribute(OpAttributeKey::dilationW, dilationW);
load3dOpAl0.SetAttribute(OpAttributeKey::filterH, convTileInfo.orgKh);
load3dOpAl0.SetAttribute(OpAttributeKey::filterW, convTileInfo.orgKw);
if (iterInfo.hL1InOffset >= 0) {
load3dOpAl0.SetAttribute(OpAttributeKey::paddingTop, 0);
} else {
load3dOpAl0.SetAttribute(OpAttributeKey::paddingTop, 0 - iterInfo.hL1InOffset);
}
int64_t hinAL1Used = (iterInfo.houtL1Size - 1) * strideH + dilatedKernelH;
int64_t hinBottomPadOffset = iterInfo.hL1InOffset + hinAL1Used;
if (hinBottomPadOffset > convTileInfo.orgHin) {
load3dOpAl0.SetAttribute(OpAttributeKey::paddingBottom, hinBottomPadOffset - convTileInfo.orgHin);
} else {
load3dOpAl0.SetAttribute(OpAttributeKey::paddingBottom, 0);
}
if (iterInfo.wL1InOffset >= 0) {
load3dOpAl0.SetAttribute(OpAttributeKey::paddingLeft, 0);
} else {
load3dOpAl0.SetAttribute(OpAttributeKey::paddingLeft, 0 - iterInfo.wL1InOffset);
}
int64_t winAL1Used = (iterInfo.woutL1Size - 1) * strideW + dilatedKernelW;
int64_t winRightPadOffset = iterInfo.wL1InOffset + winAL1Used;
if (winRightPadOffset > convTileInfo.orgWin) {
load3dOpAl0.SetAttribute(OpAttributeKey::paddingRight, winRightPadOffset - convTileInfo.orgWin);
} else {
load3dOpAl0.SetAttribute(OpAttributeKey::paddingRight, 0);
}
int64_t mStartPt = iterInfo.hL0Offset * iterInfo.woutL1Size + iterInfo.wL0Offset;
int64_t kStartPt = iterInfo.kL0Offset % convTileInfo.kAL1;
load3dOpAl0.SetAttribute(OpAttributeKey::postM, mStartPt);
load3dOpAl0.SetAttribute(OpAttributeKey::postK, kStartPt);
load3dOpAl0.SetAttribute(OpAttributeKey::padValue, 0);
load3dOpAl0.SetAttribute(OpAttributeKey::repeatStride, iterInfo.repeatStride);
load3dOpAl0.SetAttribute(OpAttributeKey::repeatTime, iterInfo.repeatTime);
load3dOpAl0.SetAttribute(OpAttributeKey::wStride, iterInfo.wStride);
load3dOpAl0.SetAttribute("isConv", true);
load3dOpAl0.SetAttribute(Conv::LoadStoreConvOpAttributeKey::isConv3D, convAttrParam.isConv3D);
}
void SetCopyInAL1Op(
Operation& copyInOpAl1, const ConvGraphNodes& tensorGraphNodes, const ConvTileInfo& convTileInfo,
ConvIterInfo& iterInfo, const ConvAttrParam& convAttrParam, const std::vector<int64_t>& dstAL1Shape,
const std::vector<int64_t>& srcGmValidShape, const int64_t& srcCinOffset)
{
copyInOpAl1.SetAttribute("isConv", true);
copyInOpAl1.SetAttribute(LoadStoreConvOpAttributeKey::isFmap, true);
copyInOpAl1.SetAttribute(LoadStoreConvOpAttributeKey::isConv3D, convAttrParam.isConv3D);
copyInOpAl1.SetAttribute("src_d_stride", convAttrParam.isConv3D ? convAttrParam.dilations[NUM2] : 1);
int64_t src_n_offset = iterInfo.batchOffset;
int64_t src_c_offset = iterInfo.groupOffset * (convTileInfo.orgCin / convAttrParam.groups) + srcCinOffset;
int64_t src_d_offset = convAttrParam.isConv3D ?
(iterInfo.dinL1Offset + (iterInfo.kL0Offset / convTileInfo.kPerGroup) * convAttrParam.dilations[NUM2]) : 0;
int64_t src_h_offset = iterInfo.hL1InOffset > 0 ? iterInfo.hL1InOffset : 0;
int64_t src_w_offset = iterInfo.wL1InOffset > 0 ? iterInfo.wL1InOffset : 0;
std::vector<int64_t> srcGmOffset;
std::vector<int64_t> srcGmShape;
if (IsArch32Platform()) {
copyInOpAl1.SetAttribute(
LoadStoreConvOpAttributeKey::copyInMode, static_cast<int64_t>(CopyInMode::COPY_MOD_NZ2NZ));
int64_t cin1PerGroup = CeilDiv(convTileInfo.orgCin / convAttrParam.groups, convTileInfo.cin0);
int64_t cin1Offset = iterInfo.groupOffset * cin1PerGroup + srcCinOffset / convTileInfo.cin0;
if (convAttrParam.isConv3D) {
srcGmOffset = {src_n_offset, src_d_offset, cin1Offset, src_h_offset, src_w_offset, 0};
srcGmShape = {1, iterInfo.dkAL1Size, CeilDiv(srcGmValidShape[1], convTileInfo.cin0), iterInfo.hinL1Size,
iterInfo.winL1Size, convTileInfo.cin0};
} else {
srcGmOffset = {src_n_offset, cin1Offset, src_h_offset, src_w_offset, 0};
srcGmShape = {1, CeilDiv(srcGmValidShape[1], convTileInfo.cin0), iterInfo.hinL1Size, iterInfo.winL1Size,
convTileInfo.cin0};
}
} else {
copyInOpAl1.SetAttribute(
LoadStoreConvOpAttributeKey::copyInMode, static_cast<int64_t>(CopyInMode::COPY_MOD_DN2NZ));
srcGmOffset = {src_n_offset, src_c_offset, src_h_offset, src_w_offset};
if (convAttrParam.isConv3D) {
srcGmOffset = {src_n_offset, src_c_offset, src_d_offset, src_h_offset, src_w_offset};
}
srcGmShape = srcGmValidShape;
}
auto copyAttr = std::make_shared<CopyOpAttribute>(
OpImmediate::Specified(srcGmOffset), MemoryType::MEM_L1, OpImmediate::Specified(srcGmShape),
OpImmediate::Specified(tensorGraphNodes.fmapTensorPtr->tensor->GetDynRawShape()),
OpImmediate::Specified(dstAL1Shape));
copyInOpAl1.SetOpAttribute(copyAttr);
copyInOpAl1.SetAttribute("l1_tile_shape", SymbolicScalar::FromConcrete(dstAL1Shape));
copyInOpAl1.SetAttribute(OpAttributeKey::srcGmConvValidShape, SymbolicScalar::FromConcrete(srcGmShape));
iterInfo.aL1UpadateFlag = false;
}
static void ConstructFmapL1Tile(
Function& function, const ConvGraphNodes& tensorGraphNodes, const ConvTileInfo& convTileInfo,
ConvIterInfo& iterInfo, LogicalTensorPtr& dstAL1TensorPtr, const ConvAttrParam& convAttrParam)
{
iterInfo.kAL1Size = std::min((convTileInfo.kPerGroup * iterInfo.dkL1Size - iterInfo.kL0Offset), convTileInfo.kAL1);
int64_t cin1AL1Size = (iterInfo.kAL1Size / convTileInfo.cin0) / (convTileInfo.orgKh * convTileInfo.orgKw);
std::vector<int64_t> dstAL1Shape =
std::vector<int64_t>{1, cin1AL1Size, iterInfo.hinL1Size, iterInfo.winL1Size, convTileInfo.cin0};
int64_t srcCinOffset = (iterInfo.kL0Offset % convTileInfo.kPerGroup) / (convTileInfo.orgKh * convTileInfo.orgKw);
int64_t srcGmCin = std::min(
convTileInfo.orgCin / convAttrParam.groups - srcCinOffset,
convTileInfo.kAL1 / (convTileInfo.orgKh * convTileInfo.orgKw));
std::vector<int64_t> srcGmValidShape = std::vector<int64_t>{1, srcGmCin, iterInfo.hinL1Size, iterInfo.winL1Size};
if (convAttrParam.isConv3D) {
iterInfo.dkAL1Size = 1;
if (iterInfo.kAL1Size > convTileInfo.kPerGroup) {
srcCinOffset = 0;
iterInfo.dkAL1Size = iterInfo.kAL1Size / convTileInfo.kPerGroup;
cin1AL1Size = (iterInfo.kAL1Size / (iterInfo.dkAL1Size * convTileInfo.cin0)) /
(convTileInfo.orgKh * convTileInfo.orgKw);
}
dstAL1Shape = std::vector<int64_t>{
1, iterInfo.dkAL1Size, cin1AL1Size, iterInfo.hinL1Size, iterInfo.winL1Size, convTileInfo.cin0};
srcGmValidShape = std::vector<int64_t>{1, srcGmCin, iterInfo.dkAL1Size, iterInfo.hinL1Size, iterInfo.winL1Size};
}
dstAL1TensorPtr = std::make_shared<LogicalTensor>(
function, tensorGraphNodes.fmapTensorPtr->Datatype(), dstAL1Shape, SymbolicScalar::FromConcrete(dstAL1Shape),
tensorGraphNodes.fmapTensorPtr->Format(), "aL1Tensor");
dstAL1TensorPtr->UpdateDynValidShape(SymbolicScalar::FromConcrete(dstAL1Shape));
auto& copyInOpAl1 =
function.AddOperation(Opcode::OP_L1_COPY_IN_CONV, {tensorGraphNodes.fmapTensorPtr}, {dstAL1TensorPtr});
copyInOpAl1.SetAttribute("isConv", true);
SetCopyInAL1Op(
copyInOpAl1, tensorGraphNodes, convTileInfo, iterInfo, convAttrParam, dstAL1Shape, srcGmValidShape,
srcCinOffset);
}
LogicalTensorPtr ConstructFmapTile(
Function& function, const ConvGraphNodes& tensorGraphNodes, const ConvTileInfo& convTileInfo,
ConvIterInfo& iterInfo, LogicalTensorPtr& dstAL1TensorPtr, const ConvAttrParam& convAttrParam)
{
if (iterInfo.kL0Offset % convTileInfo.kAL1 == 0) {
iterInfo.aL1UpadateFlag = true;
}
if (iterInfo.aL1UpadateFlag) {
ConstructFmapL1Tile(function, tensorGraphNodes, convTileInfo, iterInfo, dstAL1TensorPtr, convAttrParam);
}
std::vector<int64_t> dstAL0Shape =
std::vector<int64_t>{ConvAlignB(iterInfo.mL0Size, MKN_M_VALUE), iterInfo.kL0Size};
LogicalTensorPtr dstAL0TensorPtr = std::make_shared<LogicalTensor>(
function, tensorGraphNodes.fmapTensorPtr->Datatype(), dstAL0Shape,
SymbolicScalar::FromConcrete({iterInfo.mL0Size, iterInfo.kL0Size}), tensorGraphNodes.fmapTensorPtr->Format(),
"aL0Tensor");
dstAL0TensorPtr->UpdateDynValidShape(SymbolicScalar::FromConcrete(dstAL0Shape));
auto& load3dOpAl0 = function.AddOperation(Opcode::OP_LOAD3D_CONV, {dstAL1TensorPtr}, {dstAL0TensorPtr});
load3dOpAl0.SetAttribute("l0_tile_shape", SymbolicScalar::FromConcrete(dstAL0Shape));
SetImg2ColAttr(load3dOpAl0, convAttrParam, iterInfo, convTileInfo);
return dstAL0TensorPtr;
}
void SetCopyInBL1Op(
Operation& copyInOpBl1, const ConvGraphNodes& tensorGraphNodes, const ConvTileInfo& convTileInfo,
ConvIterInfo& iterInfo, const ConvAttrParam& convAttrParam, const std::vector<int64_t>& dstBL1Shape,
const std::vector<int64_t>& srcGmValidShape, const int64_t& srcCinOffset)
{
copyInOpBl1.SetAttribute("isConv", true);
copyInOpBl1.SetAttribute(LoadStoreConvOpAttributeKey::isFmap, false);
copyInOpBl1.SetAttribute(LoadStoreConvOpAttributeKey::isConv3D, convAttrParam.isConv3D);
int64_t src_n_offset = iterInfo.groupOffset * convTileInfo.coutPerGroup + iterInfo.nL1Offset;
int64_t src_c_offset = srcCinOffset;
int64_t src_d_offset = 0;
if (convAttrParam.isConv3D) {
src_d_offset = (iterInfo.doL1Offset * convAttrParam.strides[NUM2] - convAttrParam.paddings[NUM4]) < 0 ?
(convTileInfo.orgKd - iterInfo.dkBL1SrcOffset + (iterInfo.kL0Offset / convTileInfo.kPerGroup)) :
(iterInfo.kL0Offset / convTileInfo.kPerGroup);
}
std::vector<int64_t> srcGmOffset;
std::vector<int64_t> srcGmShape;
if (IsArch32Platform()) {
copyInOpBl1.SetAttribute(
LoadStoreConvOpAttributeKey::copyInMode, static_cast<int64_t>(CopyInMode::COPY_MOD_NZ2NZ));
int64_t cout1Offset = iterInfo.nL1Offset / MKN_N_VALUE;
int64_t cin1Offset = src_c_offset / convTileInfo.cin0;
int64_t khxkw = convTileInfo.orgKh * convTileInfo.orgKw;
int64_t cin1 = CeilDiv(convTileInfo.orgCin / convAttrParam.groups, convTileInfo.cin0);
if (convAttrParam.isConv3D) {
srcGmOffset = {((iterInfo.groupOffset * convTileInfo.orgKd + src_d_offset) * cin1 + cin1Offset) * khxkw,
cout1Offset, 0, 0};
srcGmShape = {CeilDiv(srcGmValidShape[1], convTileInfo.cin0) * iterInfo.dkBL1Size * khxkw,
CeilDiv(iterInfo.nL1Size, MKN_N_VALUE), MKN_N_VALUE, convTileInfo.cin0};
} else {
srcGmOffset = {(iterInfo.groupOffset * cin1 + cin1Offset) * khxkw, cout1Offset, 0, 0};
srcGmShape = {CeilDiv(srcGmValidShape[1], convTileInfo.cin0) * khxkw,
CeilDiv(iterInfo.nL1Size, MKN_N_VALUE), MKN_N_VALUE, convTileInfo.cin0};
}
} else {
copyInOpBl1.SetAttribute(
LoadStoreConvOpAttributeKey::copyInMode, static_cast<int64_t>(CopyInMode::COPY_MOD_DN2NZ));
srcGmOffset = {src_n_offset, src_c_offset, 0, 0};
if (convAttrParam.isConv3D) {
srcGmOffset = {src_n_offset, src_c_offset, src_d_offset, 0, 0};
}
srcGmShape = srcGmValidShape;
}
auto copyAttr = std::make_shared<CopyOpAttribute>(
OpImmediate::Specified(srcGmOffset), MemoryType::MEM_L1, OpImmediate::Specified(srcGmShape),
OpImmediate::Specified(tensorGraphNodes.weightTensorPtr->tensor->GetDynRawShape()),
OpImmediate::Specified(dstBL1Shape));
copyInOpBl1.SetOpAttribute(copyAttr);
copyInOpBl1.SetAttribute("l1_tile_shape", SymbolicScalar::FromConcrete(dstBL1Shape));
copyInOpBl1.SetAttribute(OpAttributeKey::srcGmConvValidShape, SymbolicScalar::FromConcrete(srcGmShape));
iterInfo.bL1UpadateFlag = false;
}
static void ConstructWeightL1Tile(
Function& function, const ConvGraphNodes& tensorGraphNodes, const ConvTileInfo& convTileInfo,
ConvIterInfo& iterInfo, LogicalTensorPtr& dstBL1TensorPtr, const ConvAttrParam& convAttrParam)
{
iterInfo.kBL1Size = std::min(convTileInfo.kPerGroup * iterInfo.dkL1Size - iterInfo.kL0Offset, convTileInfo.kBL1);
std::vector<int64_t> dstBL1Shape = std::vector<int64_t>{
iterInfo.kBL1Size / convTileInfo.cin0, CeilDiv(iterInfo.nL1Size, MKN_N_VALUE), MKN_N_VALUE, convTileInfo.cin0};
int64_t srcCinOffset = (iterInfo.kL0Offset % convTileInfo.kPerGroup) / (convTileInfo.orgKh * convTileInfo.orgKw);
int64_t srcGmCin = std::min(
convTileInfo.orgCin / convAttrParam.groups - srcCinOffset,
convTileInfo.kBL1 / (convTileInfo.orgKh * convTileInfo.orgKw));
std::vector<int64_t> srcGmValidShape =
std::vector<int64_t>{iterInfo.nL1Size, srcGmCin, convTileInfo.orgKh, convTileInfo.orgKw};
if (convAttrParam.isConv3D) {
iterInfo.dkBL1Size = 1;
if (iterInfo.kBL1Size > convTileInfo.kPerGroup) {
srcCinOffset = 0;
iterInfo.dkBL1Size = iterInfo.kBL1Size / convTileInfo.kPerGroup;
}
dstBL1Shape = std::vector<int64_t>{
iterInfo.kBL1Size / convTileInfo.cin0, CeilDiv(iterInfo.nL1Size, MKN_N_VALUE), MKN_N_VALUE,
convTileInfo.cin0};
srcGmValidShape = std::vector<int64_t>{
iterInfo.nL1Size, srcGmCin, iterInfo.dkBL1Size, convTileInfo.orgKh, convTileInfo.orgKw};
}
dstBL1TensorPtr = std::make_shared<LogicalTensor>(
function, tensorGraphNodes.weightTensorPtr->Datatype(), dstBL1Shape, SymbolicScalar::FromConcrete(dstBL1Shape),
tensorGraphNodes.weightTensorPtr->Format(), "bL1Tensor");
dstBL1TensorPtr->UpdateDynValidShape(SymbolicScalar::FromConcrete(dstBL1Shape));
auto& copyInOpBl1 =
function.AddOperation(Opcode::OP_L1_COPY_IN_CONV, {tensorGraphNodes.weightTensorPtr}, {dstBL1TensorPtr});
copyInOpBl1.SetAttribute("isConv", true);
SetCopyInBL1Op(
copyInOpBl1, tensorGraphNodes, convTileInfo, iterInfo, convAttrParam, dstBL1Shape, srcGmValidShape,
srcCinOffset);
}
LogicalTensorPtr ConstructWeightTile(
Function& function, const ConvGraphNodes& tensorGraphNodes, const ConvTileInfo& convTileInfo,
ConvIterInfo& iterInfo, LogicalTensorPtr& dstBL1TensorPtr, const ConvAttrParam& convAttrParam)
{
if (iterInfo.kL0Offset % convTileInfo.kBL1 == 0) {
iterInfo.bL1UpadateFlag = true;
}
if (iterInfo.bL1UpadateFlag) {
ConstructWeightL1Tile(function, tensorGraphNodes, convTileInfo, iterInfo, dstBL1TensorPtr, convAttrParam);
}
std::vector<int64_t> dstBL0Shape =
std::vector<int64_t>{iterInfo.kL0Size, ConvAlignB(iterInfo.nL0Size, MKN_N_VALUE)};
LogicalTensorPtr dstBL0TensorPtr = std::make_shared<LogicalTensor>(
function, tensorGraphNodes.weightTensorPtr->Datatype(), dstBL0Shape,
SymbolicScalar::FromConcrete({iterInfo.kL0Size, iterInfo.nL0Size}), tensorGraphNodes.weightTensorPtr->Format(),
"bL0Tensor");
dstBL0TensorPtr->UpdateDynValidShape(SymbolicScalar::FromConcrete(dstBL0Shape));
auto& load2dOpBl0 = function.AddOperation(Opcode::OP_LOAD2D_CONV, {dstBL1TensorPtr}, {dstBL0TensorPtr});
load2dOpBl0.SetAttribute(OpAttributeKey::postK, iterInfo.kL0Offset % convTileInfo.kBL1);
load2dOpBl0.SetAttribute(OpAttributeKey::postN, iterInfo.nL0Offset);
load2dOpBl0.SetAttribute("l0_tile_shape", SymbolicScalar::FromConcrete(dstBL0Shape));
load2dOpBl0.SetAttribute("isConv", true);
return dstBL0TensorPtr;
}
void SetAMulBAttr(const ConvGraphNodes& tensorGraphNodes, const ConvTileInfo& convTileInfo, Operation& op)
{
ASSERT(
ConvExpandFuncError::EXPANDFUNC_TENSOR_OP_NULLPTR, tensorGraphNodes.fmapTensorPtr != nullptr &&
tensorGraphNodes.weightTensorPtr != nullptr &&
tensorGraphNodes.resTensorPtr != nullptr)
<< "Expected fmapTensorPtr, weightTensorPtr, and resTensorPtr to be non-nullptr.";
int64_t nzAttr = (static_cast<int64_t>(tensorGraphNodes.fmapTensorPtr->Format())) |
(static_cast<int64_t>(tensorGraphNodes.weightTensorPtr->Format()) << 1) |
(static_cast<int64_t>(tensorGraphNodes.resTensorPtr->Format()) << 2);
op.SetAttribute("isConv", true);
op.SetAttribute(MATMUL_NZ_ATTR, nzAttr);
op.SetAttribute(A_MUL_B_ACT_M, convTileInfo.hL0 * convTileInfo.wL0);
op.SetAttribute(A_MUL_B_ACT_K, convTileInfo.kL0);
op.SetAttribute(A_MUL_B_ACT_N, convTileInfo.nL0);
if (op.GetOpcode() == Opcode::OP_A_MUL_B) {
op.SetAttribute(A_MUL_B_BIAS_ATTR, tensorGraphNodes.biasTensorPtr != nullptr);
}
}
LogicalTensorPtr DoMmad(
Function& function, const ConvAttrParam& convAttrParam, const ConvGraphNodes& tensorGraphNodes,
ConvGraphNodes& tileGraphNodes, const ConvTileInfo& convTileInfo, const ConvIterInfo& iterInfo)
{
ASSERT(
ConvExpandFuncError::EXPANDFUNC_TILE_OP_NULLPTR, tileGraphNodes.fmapTensorPtr != nullptr &&
tileGraphNodes.weightTensorPtr != nullptr &&
tileGraphNodes.resTensorPtr != nullptr)
<< "Inputs and res must be non-nullptr.";
std::vector<LogicalTensorPtr> mmadInputs;
std::vector<LogicalTensorPtr> mmadOutputs;
const std::string MmadOpStr = iterInfo.isFirstK ? "TILE_A_MUL_B" : "TILE_A_MULACC_B";
if (iterInfo.isFirstK) {
mmadInputs = {tileGraphNodes.fmapTensorPtr, tileGraphNodes.weightTensorPtr};
if (convAttrParam.hasBias) {
ASSERT(ConvExpandFuncError::EXPANDFUNC_TILE_OP_NULLPTR, tileGraphNodes.biasTensorPtr != nullptr)
<< "bias must be non-nullptr when hasBias Flag.";
mmadInputs.push_back(tileGraphNodes.biasTensorPtr);
}
} else {
mmadInputs = {tileGraphNodes.fmapTensorPtr, tileGraphNodes.weightTensorPtr, tileGraphNodes.cL0PartialSumPtr};
}
if (iterInfo.isLastK) {
mmadOutputs = {tileGraphNodes.resTensorPtr};
} else {
std::vector<int64_t> cL0PartialSumShape = {
ConvAlignB(iterInfo.mL0Size, MKN_M_VALUE), ConvAlignB(iterInfo.nL0Size, MKN_N_VALUE)};
tileGraphNodes.cL0PartialSumPtr = std::make_shared<LogicalTensor>(
function, DataType::DT_FP32, cL0PartialSumShape,
SymbolicScalar::FromConcrete({iterInfo.mL0Size, iterInfo.nL0Size}), TileOpFormat::TILEOP_NZ,
"cL0PartialSumTensor");
tileGraphNodes.cL0PartialSumPtr->UpdateDynValidShape(SymbolicScalar::FromConcrete(cL0PartialSumShape));
mmadOutputs = {tileGraphNodes.cL0PartialSumPtr};
}
auto& aMulBOp = function.AddOperation(MmadOpStr, mmadInputs, mmadOutputs);
SetAMulBAttr(tensorGraphNodes, convTileInfo, aMulBOp);
return mmadOutputs[0];
}
std::vector<int64_t> GetCopyOutDstOffset(
const ConvAttrParam& convAttrParam, const ConvTileInfo& convTileInfo, const ConvIterInfo& iterInfo)
{
int64_t dst_n_offset = iterInfo.batchOffset;
int64_t dst_c_offset = iterInfo.groupOffset * convTileInfo.coutPerGroup + iterInfo.nL1Offset + iterInfo.nL0Offset;
int64_t dst_d_offset = iterInfo.doL1Offset;
int64_t dst_h_offset = iterInfo.hL1OutOffset + iterInfo.hL0Offset;
int64_t dst_w_offset = iterInfo.wL1OutOffset + iterInfo.wL0Offset;
std::vector<int64_t> dstResGmOffset;
if (IsArch32Platform()) {
int64_t cout1PerGroup = CeilDiv(convTileInfo.coutPerGroup, convTileInfo.cin0);
int64_t cout1Offset = iterInfo.groupOffset * cout1PerGroup +
(iterInfo.nL1Offset + iterInfo.nL0Offset) / convTileInfo.cin0;
if (convAttrParam.isConv3D) {
dstResGmOffset = {dst_n_offset, dst_d_offset, cout1Offset, dst_h_offset, dst_w_offset, 0};
} else {
dstResGmOffset = {dst_n_offset, cout1Offset, dst_h_offset, dst_w_offset, 0};
}
} else {
dstResGmOffset = {dst_n_offset, dst_c_offset, dst_h_offset, dst_w_offset};
if (convAttrParam.isConv3D) {
dstResGmOffset = {dst_n_offset, dst_c_offset, dst_d_offset, dst_h_offset, dst_w_offset};
}
}
return dstResGmOffset;
}
void ConstrucCopyOutTile(
Function& function, const ConvAttrParam& convAttrParam, const ConvGraphNodes& tensorGraphNodes,
const ConvTileInfo& convTileInfo, const ConvIterInfo& iterInfo, const LogicalTensorPtr& resCl0TensorPtr)
{
auto& fixpipeOpRes =
function.AddOperation(Opcode::OP_L0C_COPY_OUT_CONV, {resCl0TensorPtr}, {tensorGraphNodes.resTensorPtr});
fixpipeOpRes.SetAttribute("isConv", true);
fixpipeOpRes.SetAttribute(LoadStoreConvOpAttributeKey::isConv3D, convAttrParam.isConv3D);
resCl0TensorPtr->UpdateDynValidShape({ConvAlignB(iterInfo.mL0Size, MKN_M_VALUE),
ConvAlignB(iterInfo.nL0Size, MKN_N_VALUE)});
int64_t cutW = std::min(iterInfo.woutL1Size - iterInfo.wL0Offset, convTileInfo.wL0);
fixpipeOpRes.SetAttribute(LoadStoreConvOpAttributeKey::cutW, cutW);
fixpipeOpRes.SetAttribute("res_tile_shape", tensorGraphNodes.resTensorPtr->GetDynValidShape());
if (IsArch32Platform()) {
fixpipeOpRes.SetAttribute(
LoadStoreConvOpAttributeKey::copyOutMode, static_cast<int64_t>(CopyOutMode::COPY_MOD_NZ2NZ));
fixpipeOpRes.SetAttribute(OpAttributeKey::l0cValidMN,
SymbolicScalar::FromConcrete({iterInfo.mL0Size, ConvAlignB(iterInfo.nL0Size, MKN_N_VALUE)}));
} else {
fixpipeOpRes.SetAttribute(
LoadStoreConvOpAttributeKey::copyOutMode, static_cast<int64_t>(CopyOutMode::COPY_MOD_NZ2DN));
fixpipeOpRes.SetAttribute(OpAttributeKey::l0cValidMN,
SymbolicScalar::FromConcrete({iterInfo.mL0Size, iterInfo.nL0Size}));
}
std::vector<int64_t> dstResGmOffset = GetCopyOutDstOffset(convAttrParam, convTileInfo, iterInfo);
auto copyAttr = std::make_shared<CopyOpAttribute>(
MemoryType::MEM_L1, OpImmediate::Specified(dstResGmOffset),
OpImmediate::Specified(tensorGraphNodes.resTensorPtr->tensor->GetRawShape()),
OpImmediate::Specified(tensorGraphNodes.resTensorPtr->tensor->GetDynRawShape()),
OpImmediate::Specified({ConvAlignB(iterInfo.mL0Size, MKN_M_VALUE), ConvAlignB(iterInfo.nL0Size, MKN_N_VALUE)}));
fixpipeOpRes.SetOpAttribute(copyAttr);
}
void Cal3DDkL1Size(const ConvTileInfo& convTileInfo, ConvIterInfo& iterInfo, const ConvAttrParam& convAttrParam)
{
iterInfo.dkL1Size = 1;
if (convAttrParam.isConv3D) {
iterInfo.dkL1Size = convTileInfo.orgKd;
iterInfo.dinL1Offset = iterInfo.doL1Offset * convAttrParam.strides[NUM2] - convAttrParam.paddings[NUM4];
int64_t srcDkOffset = iterInfo.dinL1Offset;
if (iterInfo.dinL1Offset < 0) {
int64_t tmpKd = CeilDiv(-iterInfo.dinL1Offset, convAttrParam.dilations[NUM2]);
iterInfo.dkL1Size -= tmpKd;
iterInfo.dkBL1SrcOffset = iterInfo.dkL1Size;
srcDkOffset = iterInfo.dinL1Offset + tmpKd * convAttrParam.dilations[NUM2];
}
int64_t kdL1EndOffset = iterInfo.dinL1Offset + (convTileInfo.orgKd - 1) * convAttrParam.dilations[NUM2] + 1;
if (kdL1EndOffset > convTileInfo.orgDin) {
int64_t tmpKd = CeilDiv(kdL1EndOffset - convTileInfo.orgDin, convAttrParam.dilations[NUM2]);
iterInfo.dkL1Size -= tmpKd;
}
iterInfo.dinL1Offset = srcDkOffset;
}
}
void UpdateL1IterInfo(const ConvTileInfo& convTileInfo, ConvIterInfo& iterInfo, const ConvAttrParam& convAttrParam)
{
iterInfo.houtL1Size = std::min(convTileInfo.orgHout - iterInfo.hL1OutOffset, convTileInfo.hAL1Out);
iterInfo.hL1InOffset = iterInfo.hL1OutOffset * convAttrParam.strides[0] - convAttrParam.paddings[0];
int64_t needHL1Size = (iterInfo.houtL1Size - 1) * convAttrParam.strides[0] +
(convTileInfo.orgKh - 1) * convAttrParam.dilations[0] + 1;
if (iterInfo.hL1InOffset < 0) {
iterInfo.hinL1Size = needHL1Size + iterInfo.hL1InOffset;
if (iterInfo.hL1InOffset + needHL1Size <= 0) {
iterInfo.hinL1Size = 0;
}
if (iterInfo.hinL1Size > convTileInfo.orgHin) {
iterInfo.hinL1Size = convTileInfo.orgHin;
}
} else if (convTileInfo.orgHin - iterInfo.hL1InOffset <= 0) {
iterInfo.hinL1Size = 0;
} else {
iterInfo.hinL1Size = std::min(convTileInfo.orgHin - iterInfo.hL1InOffset, needHL1Size);
}
iterInfo.woutL1Size = std::min(convTileInfo.orgWout - iterInfo.wL1OutOffset, convTileInfo.wAL1Out);
iterInfo.wL1InOffset = iterInfo.wL1OutOffset * convAttrParam.strides[1] - convAttrParam.paddings[NUM2];
int64_t needWL1Size = (iterInfo.woutL1Size - 1) * convAttrParam.strides[1] +
(convTileInfo.orgKw - 1) * convAttrParam.dilations[1] + 1;
if (iterInfo.wL1InOffset < 0) {
iterInfo.winL1Size = needWL1Size + iterInfo.wL1InOffset;
if (iterInfo.wL1InOffset + needWL1Size <= 0) {
iterInfo.winL1Size = 0;
}
if (iterInfo.winL1Size > convTileInfo.orgWin) {
iterInfo.winL1Size = convTileInfo.orgWin;
}
} else if (convTileInfo.orgWin - iterInfo.wL1InOffset <= 0) {
iterInfo.winL1Size = 0;
} else {
iterInfo.winL1Size = std::min(convTileInfo.orgWin - iterInfo.wL1InOffset, needWL1Size);
}
iterInfo.nL1Size = std::min(convTileInfo.coutPerGroup - iterInfo.nL1Offset, convTileInfo.nBL1);
Cal3DDkL1Size(convTileInfo, iterInfo, convAttrParam);
}
void UpdateL0IterInfo(const ConvTileInfo& convTileInfo, ConvIterInfo& iterInfo)
{
iterInfo.kL0Size = std::min(convTileInfo.kPerGroup * iterInfo.dkL1Size - iterInfo.kL0Offset, convTileInfo.kL0);
iterInfo.isFirstK = iterInfo.kL0Offset == 0 ? true : false;
iterInfo.isLastK =
iterInfo.kL0Offset + convTileInfo.kL0 >= convTileInfo.kPerGroup * iterInfo.dkL1Size ? true : false;
}
void IterL0ExpandFunc(
Function& function, ConvIterInfo& iterInfo, ConvTileInfo& convTileInfo, const ConvAttrParam& convAttrParam,
const ConvGraphNodes& tensorGraphNodes, ConvGraphNodes& tileGraphNodes)
{
LogicalTensorPtr fmapL1TensorPtr = nullptr;
LogicalTensorPtr weightL1TensorPtr = nullptr;
LogicalTensorPtr resCl0TensorPtr = nullptr;
for (iterInfo.nL0Offset = 0; iterInfo.nL0Offset < iterInfo.nL1Size; iterInfo.nL0Offset += convTileInfo.nL0) {
iterInfo.nL0Size = std::min(iterInfo.nL1Size - iterInfo.nL0Offset, convTileInfo.nL0);
for (iterInfo.hL0Offset = 0; iterInfo.hL0Offset < iterInfo.houtL1Size; iterInfo.hL0Offset += convTileInfo.hL0) {
for (iterInfo.wL0Offset = 0; iterInfo.wL0Offset < iterInfo.woutL1Size;
iterInfo.wL0Offset += convTileInfo.wL0) {
int64_t curH = std::min(convTileInfo.hL0, iterInfo.houtL1Size - iterInfo.hL0Offset);
int64_t curW = std::min(convTileInfo.wL0, iterInfo.woutL1Size - iterInfo.wL0Offset);
iterInfo.mL0Size = curH * curW;
if (curH > 1 && convTileInfo.wL0 != convTileInfo.wAL1Out) {
iterInfo.repeatStride = iterInfo.woutL1Size;
iterInfo.repeatTime = curH;
iterInfo.wStride = curW;
} else {
iterInfo.repeatTime = 1;
iterInfo.wStride = ConvAlignB(iterInfo.mL0Size, MKN_M_VALUE);
}
if (convAttrParam.hasBias) {
tileGraphNodes.biasTensorPtr =
ConstructBiasTile(function, tensorGraphNodes, iterInfo, convTileInfo);
}
std::vector<int64_t> dstCL0Shape = std::vector<int64_t>{
ConvAlignB(iterInfo.mL0Size, MKN_M_VALUE), ConvAlignB(iterInfo.nL0Size, MKN_N_VALUE)};
tileGraphNodes.resTensorPtr = std::make_shared<LogicalTensor>(
function, tensorGraphNodes.fmapTensorPtr->Datatype(), dstCL0Shape,
SymbolicScalar::FromConcrete(dstCL0Shape), tensorGraphNodes.fmapTensorPtr->Format(), "cL0Tensor");
for (iterInfo.kL0Offset = 0; iterInfo.kL0Offset < convTileInfo.kPerGroup * iterInfo.dkL1Size;
iterInfo.kL0Offset += convTileInfo.kL0) {
UpdateL0IterInfo(convTileInfo, iterInfo);
tileGraphNodes.fmapTensorPtr = ConstructFmapTile(
function, tensorGraphNodes, convTileInfo, iterInfo, fmapL1TensorPtr, convAttrParam);
tileGraphNodes.weightTensorPtr = ConstructWeightTile(
function, tensorGraphNodes, convTileInfo, iterInfo, weightL1TensorPtr, convAttrParam);
resCl0TensorPtr =
DoMmad(function, convAttrParam, tensorGraphNodes, tileGraphNodes, convTileInfo, iterInfo);
}
ConstrucCopyOutTile(function, convAttrParam, tensorGraphNodes, convTileInfo, iterInfo, resCl0TensorPtr);
}
}
}
}
void IterOneBatchFunc(
Function& function, ConvIterInfo& iterInfo, ConvTileInfo& convTileInfo, const ConvAttrParam& convAttrParam,
const ConvGraphNodes& tensorGraphNodes, ConvGraphNodes& tileGraphNodes)
{
for (iterInfo.doL1Offset = 0; iterInfo.doL1Offset < convTileInfo.orgDout; iterInfo.doL1Offset += 1) {
for (iterInfo.nL1Offset = 0; iterInfo.nL1Offset < convTileInfo.coutPerGroup;
iterInfo.nL1Offset += convTileInfo.nBL1) {
iterInfo.bL1UpadateFlag = true;
for (iterInfo.hL1OutOffset = 0; iterInfo.hL1OutOffset < convTileInfo.orgHout;
iterInfo.hL1OutOffset += convTileInfo.hAL1Out) {
for (iterInfo.wL1OutOffset = 0; iterInfo.wL1OutOffset < convTileInfo.orgWout;
iterInfo.wL1OutOffset += convTileInfo.wAL1Out) {
iterInfo.aL1UpadateFlag = true;
UpdateL1IterInfo(convTileInfo, iterInfo, convAttrParam);
IterL0ExpandFunc(function, iterInfo, convTileInfo, convAttrParam, tensorGraphNodes, tileGraphNodes);
}
}
}
}
}
void ConstructTileGraph(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& operandVec,
const LogicalTensorPtr& cTensorPtr, const Operation& op)
{
ConvAttrParam convAttrParam;
SetConvAttrParam(op, convAttrParam);
ConvGraphNodes tensorGraphNodes;
SetTensorGraphNodes(operandVec, cTensorPtr, convAttrParam, tensorGraphNodes);
ConvTileInfo convTileInfo;
SetConvShapeInfo(tileShape, tensorGraphNodes, convAttrParam, convTileInfo);
ConvIterInfo iterInfo;
ConvGraphNodes tileGraphNodes;
for (iterInfo.groupOffset = 0; iterInfo.groupOffset < convAttrParam.groups; iterInfo.groupOffset += 1) {
for (iterInfo.batchOffset = 0; iterInfo.batchOffset < convTileInfo.orgBatch; iterInfo.batchOffset += 1) {
IterOneBatchFunc(function, iterInfo, convTileInfo, convAttrParam, tensorGraphNodes, tileGraphNodes);
}
}
}
std::vector<int64_t> GetResTensorShape(
DataType outType, const Tensor& inputTensor, const Tensor& weightTensor, const ConvAttrParam& convAttrParam)
{
int64_t batchOut = inputTensor.GetShape()[NCHW_N_IDX];
int64_t cOut = weightTensor.GetShape()[NCHW_N_IDX];
int64_t hOut = ConvComputeHo(inputTensor, weightTensor, convAttrParam);
int64_t wOut = ConvComputeWo(inputTensor, weightTensor, convAttrParam);
std::vector<int64_t> resTensorShape;
if (IsArch32Platform()) {
int64_t cOut0 = ALIGN_SIZE_32 / BytesOf(outType);
int64_t cOut1 = convAttrParam.groups * CeilDiv(cOut / convAttrParam.groups, cOut0);
resTensorShape = {batchOut, cOut1, hOut, wOut, cOut0};
if (convAttrParam.isConv3D) {
int64_t dOut = ConvComputeDo(inputTensor, weightTensor, convAttrParam);
resTensorShape = {batchOut, dOut, cOut1, hOut, wOut, cOut0};
}
} else {
resTensorShape = {batchOut, cOut, hOut, wOut};
if (convAttrParam.isConv1D) {
resTensorShape = {batchOut, cOut, wOut};
}
if (convAttrParam.isConv3D) {
int64_t dOut = ConvComputeDo(inputTensor, weightTensor, convAttrParam);
resTensorShape = {batchOut, cOut, dOut, hOut, wOut};
}
}
return resTensorShape;
}
Tensor Conv(
DataType outType, const Tensor& inputTensor, const Tensor& weightTensor, const std::vector<int64_t>& strides,
const std::vector<SymbolicScalar>& paddings, const std::vector<int64_t>& dilations,
const ConvExtendParam& extendParam, const int64_t groups)
{
std::vector<int64_t> finalPaddings = SymbolicScalar::Concrete(paddings, 0);
std::vector<int64_t> finalDilations = dilations;
std::vector<int64_t> finalStrides = strides;
if (dilations.size() == CONV3D_INPUT_DIM - NUM2 && strides.size() == CONV3D_INPUT_DIM - NUM2 &&
paddings.size() == NUM2 * (CONV3D_INPUT_DIM - NUM2)) {
finalDilations = rotateVector(dilations, 1);
finalStrides = rotateVector(strides, 1);
finalPaddings = rotateVector(SymbolicScalar::Concrete(paddings, 0), NUM2);
}
const Tensor& biasTensor = extendParam.biasTensor;
ConvAttrParam convAttrParam(finalPaddings, finalStrides, finalDilations, groups);
CheckConvOperands(outType, inputTensor, weightTensor, biasTensor, convAttrParam);
std::vector<int64_t> resTensorShape = GetResTensorShape(outType, inputTensor, weightTensor, convAttrParam);
if (convAttrParam.isConv1D) {
convAttrParam.paddings.insert(convAttrParam.paddings.begin(), NUM2, 0);
convAttrParam.strides.insert(convAttrParam.strides.begin(), 1);
convAttrParam.dilations.insert(convAttrParam.dilations.begin(), 1);
}
TileOpFormat outFormat = TileOpFormat::TILEOP_ND;
if (IsArch32Platform()) {
outFormat = convAttrParam.isConv3D ? TileOpFormat::TILEOP_NDC1HWC0 : TileOpFormat::TILEOP_NC1HWC0;
}
Tensor resTensor(outType, resTensorShape, "TensorC", outFormat);
resTensor.GetStorage()->UpdateDynValidShape(SymbolicScalar::FromConcrete(resTensorShape));
return ConstructTensorGraph(inputTensor, weightTensor, biasTensor, resTensor, convAttrParam);
}
}
}
}