* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file conv3d_tiling_algorithm.h
* \brief
*/
#ifndef TILING_CONV_CONV3D_TILING_ALGORITHM_H
#define TILING_CONV_CONV3D_TILING_ALGORITHM_H
#include <cstdint>
#include "include/adv_api/conv/conv3d/conv3d_tiling_base.h"
namespace Conv3dTilingApi {
enum class L1TilingMode : uint32_t { FULL_LOAD_BL1 = 0, FULL_LOAD_AL1, ALL_FULL_LOAD, NONE_FULL_LOAD, INVALID };
struct L1TilingFlag {
L1TilingMode abL1Mode = L1TilingMode::INVALID;
IterateMNOrder iterateMNOrder = IterateMNOrder::INVALID;
bool kAL1CanFullLoadFlag = false;
bool kBL1CanFullLoadFlag = false;
bool kABL1CanFullLoadFlag = false;
bool isBiasFullLoad = false;
bool isWeightBypass = false;
};
struct DoubleBufferTilingValue {
uint8_t pbAL1 = 2;
uint8_t pbBL1 = 1;
uint8_t pbAL0 = 2;
uint8_t pbBL0 = 2;
uint8_t pbCL0 = 1;
uint64_t pBufferFlag = 0;
};
struct L1TilingIdx {
uint64_t kAL1Idx = 0;
uint64_t kBL1Idx = 0;
uint64_t mAL1Idx = 0;
uint64_t nBL1Idx = 0;
void SetIdx(uint64_t kAl1, uint64_t kBl1, uint64_t mAL1, uint64_t nBL1)
{
kAL1Idx = kAl1;
kBL1Idx = kBl1;
mAL1Idx = mAL1;
nBL1Idx = nBL1;
}
};
struct L0TilingRange {
std::vector<uint64_t> mL0Range;
std::vector<uint64_t> kL0Range;
std::vector<uint64_t> nL0Range;
};
struct L0TilingParams {
uint64_t mL0 = 0;
uint64_t kL0 = 0;
uint64_t nL0 = 0;
uint64_t orgCoAlignN0 = 0;
};
struct L0TilingIdx {
uint64_t mL0Idx = 0;
uint64_t nL0Idx = 0;
uint64_t kL0Idx = 0;
};
struct L1TilingRange {
std::vector<uint64_t> mAL1ValueRange;
std::vector<uint64_t> nBL1ValueRange;
std::vector<uint64_t> kAL1Range;
std::vector<uint64_t> kBL1Range;
};
struct L1TilingParams {
uint64_t kAL1 = 0;
uint64_t kAL1Tail = 0;
uint64_t kBL1 = 0;
uint64_t kBL1Tail = 0;
uint64_t mAL1Value = 0;
uint64_t nBL1Value = 0;
};
struct L1TilingCalc {
uint64_t ci0HkWk = 0;
uint64_t alignCinKhKwKd = 0;
uint64_t inputFullLoadL1Size = 0;
uint64_t weightFullLoadL1Size = 0;
uint64_t inputKL1FullLoadSize = 0;
uint64_t weightKL1FullLoadSize = 0;
uint64_t inputMinLoadL1Size = 0;
uint64_t weightMinLoadL1Size = 0;
uint64_t biasMinLoadL1Size = 0;
};
class Conv3dTilingAlgorithm {
public:
explicit Conv3dTilingAlgorithm(Conv3dTilingBase* tilingIns);
virtual ~Conv3dTilingAlgorithm() { tilingIns_ = nullptr; }
int64_t Process();
protected:
bool FixL0PingpongDecision();
bool CheckL0Buffer(uint64_t currmL0, uint64_t currkL0, uint64_t currnL0);
private:
virtual bool CoreL1TilingMinWeightBypass() const;
virtual bool NoneKABL1FullLoadWeightBypass() const;
virtual uint64_t L1NoFullLoadInputSize() const;
int64_t GetL1Tiling();
int64_t PreProcessingL1Tiling();
virtual bool CheckL1Buffer() const;
virtual int64_t InitCalcL1Params();
uint64_t InferHiL1(uint64_t inputHoL1, uint64_t hi) const;
virtual void GetL1TilingRange();
void InitL1Tiling();
void InitABL1TilingMode();
void CoreL1TilingDecision();
void InputL1FullLoadIter();
void WeightL1FullLoadIter();
void InitKL1LoadFlag();
void L1NoFullLoadIter();
void KAL1FullLoadIter();
void KBL1FullLoadIter();
uint64_t KABL1FullLoadIterN();
uint64_t KABL1FullLoadIterM();
void NoneKABL1FullLoadIter();
void BiasL1TilingDecision();
virtual void GetKL0TilingDecision();
virtual void WeightBypassDecision();
void UpdateL1DoubleBuffer();
virtual void SetKAL1KBL1TailRes();
void SetL1TilingRes();
int64_t GetL0Tiling();
void InitPingPong();
void GetL0TilingRange();
void L0TilingDecision();
uint64_t CalcAL0Size(uint64_t currmL0, uint64_t currkL0) const;
uint64_t CalcBL0Size(uint64_t currkL0, uint64_t currnL0) const;
uint64_t CalcCL0Size(uint64_t currmL0, uint64_t currnL0) const;
virtual uint64_t CalcL1SizeForL0Tiling(uint64_t currmL0, uint64_t currnL0) const;
virtual uint64_t CalcBTSize(uint64_t currnL0) const;
void CheckL0CDoubleBuffer();
void SetPBufferFlag();
protected:
Conv3dTilingBase* tilingIns_ = nullptr;
L0TilingRange l0TilingRange;
L0TilingParams l0TilingParams;
L0TilingIdx l0TilingIdx;
L1TilingRange l1TilingRange;
L1TilingParams l1TilingParams;
L1TilingCalc l1TilingCalc;
L1TilingIdx l1TilingIdx;
L1TilingFlag l1TilingFlag;
DoubleBufferTilingValue doubleBufferValue;
std::map<L1TilingMode, L1TilingIdx> l1TilingInitMap;
uint64_t fMapDTypeSize = 0;
uint64_t weightDTypeSize = 0;
uint64_t biasDTypeSize = 0;
};
}
#endif