* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License")
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file broadcast_to_tiling_base.h
* \brief calc corenum and threadnum for AscendC kernel
*/
#ifndef BROADCASTTO_TILING_NDDMA_H_
#define BROADCASTTO_TILING_NDDMA_H_
#include <array>
#include <cstdint>
#include "broadcast_to_tiling_arch35.h"
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
namespace optiling
{
constexpr size_t brctoMaxDMADimNum = 0x5;
constexpr size_t brctoMaxADimNum = static_cast<size_t>(0x8) * 3;
constexpr size_t brctoMaxBDimNum = static_cast<size_t>(0x8) * 2;
BEGIN_TILING_DATA_DEF(BroadcastToTilingData)
TILING_DATA_FIELD_DEF(int64_t, tilingKey);
TILING_DATA_FIELD_DEF(int64_t, dFactor);
TILING_DATA_FIELD_DEF(uint8_t, doubleMode);
TILING_DATA_FIELD_DEF(uint8_t, uAxisCnt);
TILING_DATA_FIELD_DEF(uint8_t, bufferCnt);
TILING_DATA_FIELD_DEF(uint8_t, blockAxis);
TILING_DATA_FIELD_DEF(uint32_t, tensorSize);
TILING_DATA_FIELD_DEF(int64_t, usedCoreCnt);
TILING_DATA_FIELD_DEF(int64_t, ntcALen);
TILING_DATA_FIELD_DEF(int64_t, tcALen);
TILING_DATA_FIELD_DEF(int64_t, ntcBLen);
TILING_DATA_FIELD_DEF(int64_t, tcBLen);
TILING_DATA_FIELD_DEF(int64_t, ntcULen);
TILING_DATA_FIELD_DEF(int64_t, tcULen);
TILING_DATA_FIELD_DEF(int64_t, aLpUnit);
TILING_DATA_FIELD_DEF(int64_t, uLpUnit);
TILING_DATA_FIELD_DEF(int64_t, uInOffset);
TILING_DATA_FIELD_DEF(int64_t, uOutOffset);
TILING_DATA_FIELD_DEF(int32_t, isUNotB);
TILING_DATA_FIELD_DEF(int32_t, isLastDimB);
TILING_DATA_FIELD_DEF(int32_t, aAxesNum);
TILING_DATA_FIELD_DEF(int32_t, bAxesNum);
TILING_DATA_FIELD_DEF_ARR(uint64_t, brctoMaxDMADimNum, xSrcStride);
TILING_DATA_FIELD_DEF_ARR(uint32_t, brctoMaxDMADimNum, xDstStride);
TILING_DATA_FIELD_DEF_ARR(uint32_t, brctoMaxDMADimNum, xSize);
TILING_DATA_FIELD_DEF_ARR(int64_t, brctoMaxADimNum, aAxesParams);
TILING_DATA_FIELD_DEF_ARR(int64_t, brctoMaxBDimNum, bAxesParams);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(BroadcastTo, BroadcastToTilingData);
ge::graphStatus Tiling4BroadcastToAscendC(gert::TilingContext* context, const gert::Shape* inShapePtr,
const gert::Shape* outShapePtr);
namespace brcto
{
constexpr int64_t MAX_TENSOR_SIZE = 0xff00;
constexpr int64_t TILING_MODE_NDDMA = 11000;
constexpr int64_t TILING_MODE_UB_BRC = 11001;
constexpr int64_t TILING_MODE_LAST_DIM_LARGE_A = 11002;
constexpr int64_t TILING_MODE_LAST_DIM_LARGE_B = 11003;
constexpr int64_t TILING_MODE_FULL_NDDMA = 11004;
constexpr int64_t TILING_MODE_LAST_DIM_SMALL_A = 11005;
constexpr size_t MAX_DIM_NUM = 0x10;
constexpr size_t BRCTO_MAX_DIM_NUM = 0x8;
constexpr size_t aParamUnit = 3;
constexpr size_t bParamUnit = 2;
constexpr int64_t nTwo = 2;
constexpr size_t kSyncWorkSpaceSize = static_cast<size_t>(16) * 1024 * 1024;
constexpr int64_t maxDataSize = static_cast<int64_t>(128) * 1024;
constexpr float coreFactor = 0.75;
constexpr int64_t LAST_DIM_GATE = 8;
ge::graphStatus GetShapeInfo(const gert::TilingContext* context, gert::Shape& inShape, gert::Shape& outShape);
ge::graphStatus GetABFlag(const gert::TilingContext* context, const gert::Shape& inShape, const gert::Shape& outShape,
std::array<bool, MAX_DIM_NUM>& abInfo);
void AdjustShapesToSameDimNum(gert::Shape& inShape, size_t outDimNum);
ge::graphStatus MergeAxis(const gert::TilingContext* context, gert::Shape& inShape, gert::Shape& outShape);
ge::graphStatus DeleteOneSizeAxis(const gert::TilingContext* context, gert::Shape& inShape, gert::Shape& outShape);
class BroadcastToTilingAscendC
{
public:
explicit BroadcastToTilingAscendC(gert::TilingContext* context, const gert::Shape* inShapePtr,
const gert::Shape* outShapePtr)
: context_(context), inShapePtr_(inShapePtr), outShapePtr_(outShapePtr){};
ge::graphStatus DoTiling();
template <typename T>
ge::graphStatus GetHardwareInfo()
{
auto compileInfo = reinterpret_cast<const T*>(context_->GetCompileInfo());
OP_CHECK_NULL_WITH_CONTEXT(context_, compileInfo);
ubSize_ = static_cast<int64_t>(compileInfo->ubSize);
coreNum_ = static_cast<int64_t>(compileInfo->coreNum);
blockSize_ = static_cast<int64_t>(compileInfo->blockSize);
cacheLine_ = static_cast<int64_t>(compileInfo->clSize);
vlSize_ = static_cast<int64_t>(compileInfo->vRegSize);
if (coreNum_ <= 0 || ubSize_ <= 0 || blockSize_ <= 0 || cacheLine_ <= 0 || vlSize_ <= 0) {
std::string valueMsg = "coreNum=" + std::to_string(coreNum_) + ", ubSize=" + std::to_string(ubSize_) +
", blockSize=" + std::to_string(blockSize_) +
", cacheLine=" + std::to_string(cacheLine_) +
", vlSize=" + std::to_string(vlSize_);
std::string reasonMsg = "BroadcastTo GetHardwareInfo failed, all values must be positive.";
OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(
context_->GetNodeName(), "hardwareInfo", valueMsg.c_str(), reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
auto dtype = context_->GetInputDesc(0)->GetDataType();
dtypeSize_ = GetSizeByDataType(dtype);
OP_LOGI(context_->GetNodeName(), "The ub size is: %ld", ubSize_);
return ge::GRAPH_SUCCESS;
}
private:
void CalcTilingData();
int64_t CalcDimSize(const gert::Shape*& shapePtr, size_t begDim, size_t endDim);
int64_t UpdateTensorSize(int64_t tensorSize);
void GetUAxisInfo();
void GetDMAAxesParams();
void GetABAxesParams();
void GetAxesInfo();
void GetMCTilingInfo();
void OptimizeMCTilingForDoubleMode();
void UpdateDimSize(int64_t& aDims, int64_t& bDims, int64_t& brwAxis, int64_t& outLastDim);
void CheckBrwd(int64_t& aDims, int64_t& bDims, int64_t& brwAxis, bool& isBrwd);
int64_t CalcTensorSize4Brwd(int64_t aDims, int64_t bDims, int64_t brwAxis);
int64_t CalcTensorSize4NBrwd(int64_t aDims, int64_t bDims, int64_t outLastDim);
void AdjustBrwdSize(int64_t& brwSize, int64_t uAxis);
void CalcTensorSize();
void CalcTilingKey();
void CalcDBMode();
void UpdateTilingKey();
ge::graphStatus WriteTilingData();
ge::graphStatus SetBlockCnt();
std::string PrintTilingData();
uint32_t CalcAxisWeight(int64_t lpCnt);
private:
gert::TilingContext* context_ = nullptr;
const gert::Shape* inShapePtr_ = nullptr;
const gert::Shape* outShapePtr_ = nullptr;
BroadcastToTilingData tilingData_;
int64_t coreNum_{0};
int64_t ubSize_{0};
int64_t vlSize_{0};
int64_t cacheLine_{0};
int64_t blockSize_{0};
int64_t aAxisLen_{1};
int64_t bAxisLen_{1};
int64_t uAxisLen_{1};
int64_t dtypeSize_{1};
size_t uAxis_{0};
bool isDMABrcA_{false};
std::array<bool, MAX_DIM_NUM> abInfo_{0};
int64_t maxTensorSize_{0};
int64_t minTensorSize_{0};
int64_t tilingKey_{0};
int64_t dFactor_{1};
int8_t doubleMode_{0};
int8_t blockAxis_{0};
int8_t uAxisCnt_{1};
int8_t bufferCnt_{nTwo};
int64_t tensorSize_{0};
int64_t usedCoreCnt_{0};
int64_t ntcALen_{1};
int64_t tcALen_{0};
int64_t ntcBLen_{1};
int64_t tcBLen_{0};
int64_t ntcULen_{1};
int64_t tcULen_{0};
int64_t aLpUnit_{1};
int64_t uLpUnit_{1};
int64_t uInOffset_{0};
int64_t uOutOffset_{0};
int32_t isUNotB_{1};
int32_t isLastDimB_{0};
int32_t aAxesNum_{0};
int32_t bAxesNum_{0};
uint64_t xSrcStride_[brctoMaxDMADimNum]{0};
uint32_t xDstStride_[brctoMaxDMADimNum]{0};
uint32_t xSize_[brctoMaxDMADimNum]{1, 1, 1, 1, 1};
int64_t aAxesParams_[brctoMaxADimNum]{0};
int64_t bAxesParams_[brctoMaxBDimNum]{0};
};
}
}
#endif