* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file attention_update_tiling.cpp
* \brief
*/
#include "attention_update_tiling.h"
namespace optiling {
using namespace Ops::Base;
constexpr uint64_t DIM_0 = 0;
constexpr uint64_t DIM_1 = 1;
constexpr uint64_t LSE_DIM_NUM = 1;
constexpr uint64_t GO_DIM_NUM = 2;
constexpr uint64_t D_MIN = 8;
constexpr uint64_t D_MAX = 512;
constexpr uint64_t D_DIVIDE_8 = 8;
constexpr uint64_t ATTR_SP_MAX = 16;
constexpr uint64_t INPUT_LSE_INDEX = 0;
constexpr uint64_t INPUT_GO_INDEX = 1;
constexpr uint64_t ATTR_UPDATE_TYPE_INDEX = 0;
constexpr uint64_t ATTR_SP_INDEX = 1;
constexpr uint64_t OUTPUT_INDEX = 0;
constexpr uint64_t OUTPUT_LSE_M_INDEX = 1;
constexpr uint64_t ALL_TO_SP_MULTIPLIER = 2UL;
constexpr uint64_t NUM_2 = 2UL;
constexpr uint64_t TILING_KEY_EMPTY = 10000UL;
constexpr uint64_t TILING_KEY_INIT_VALUE = 20000UL;
constexpr uint64_t DOUBLE_BUFFER_NUM = 2UL;
constexpr uint64_t SYS_WORKSPACE_SIZE = static_cast<uint64_t>(16 * 1024 * 1024);
bool AttentionUpdateTiling::IsCapable()
{
return true;
}
ge::graphStatus AttentionUpdateTiling::DoLibApiTiling()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AttentionUpdateTiling::GetPlatformInfo()
{
ubBlockSize_ = Ops::Base::GetUbBlockSize(context_);
auto platformInfo = context_->GetPlatformInfo();
if (platformInfo == nullptr) {
auto compileInfoPtr = reinterpret_cast<const DecodeUpdateCompileInfo *>(context_->GetCompileInfo());
OP_CHECK_IF(compileInfoPtr == nullptr, OP_LOGE(context_, "compile info is null"), return ge::GRAPH_FAILED);
totalCoreNum_ = compileInfoPtr->coreNum;
ubSize_ = compileInfoPtr->ubSize;
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
totalCoreNum_ = static_cast<uint64_t>(ascendcPlatform.GetCoreNumAiv());
if (totalCoreNum_ == 0UL) {
OP_LOGE(context_->GetNodeName(), "coreNum is 0");
return ge::GRAPH_FAILED;
}
uint64_t ubSize = 0;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
if (ubSize == static_cast<uint64_t>(0)) {
OP_LOGE(context_->GetNodeName(), "ubSize is 0");
return ge::GRAPH_FAILED;
}
ubSize_ = static_cast<uint64_t>(ubSize);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AttentionUpdateTiling::CheckInputDtype()
{
if (goType_ != ge::DataType::DT_FLOAT && goType_ != ge::DataType::DT_FLOAT16 && goType_ != ge::DataType::DT_BF16) {
std::string dtypeStr = ToString(goType_);
OP_LOGE_FOR_INVALID_DTYPE(context_->GetNodeName(), "go", dtypeStr.c_str(),
"FLOAT, FLOAT16 or BF16");
return ge::GRAPH_FAILED;
}
if (lseType_ != ge::DataType::DT_FLOAT) {
std::string dtypeStr = ToString(lseType_);
OP_LOGE_FOR_INVALID_DTYPE(context_->GetNodeName(), "lse", dtypeStr.c_str(), "FLOAT");
return ge::GRAPH_FAILED;
}
for (uint64_t i = 0; i < sp_ * ALL_TO_SP_MULTIPLIER; i++) {
if (i == 0UL || i == sp_) {
continue;
}
auto currentDtype = context_->GetInputDesc(i)->GetDataType();
if (i >= sp_) {
if (goType_ != currentDtype) {
std::string paramMsg = "go[" + std::to_string(i-sp_) + "] tensor";
std::string currentDtypeStr = ToString(currentDtype);
std::string reasonMsg = "All tensors in input go must have the same dtype, "
"but the dtype of go[" + std::to_string(i-sp_) + "] tensor is different from other dtypes " +
ToString(goType_);
OP_LOGE_FOR_INVALID_DTYPE_WITH_REASON(context_->GetNodeName(), paramMsg.c_str(),
currentDtypeStr.c_str(), reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
} else {
if (lseType_ != currentDtype) {
std::string paramMsg = "lse[" + std::to_string(i) + "] tensor";
std::string currentDtypeStr = ToString(currentDtype);
std::string reasonMsg = "All tensors in the lse must have the same dtype, "
"but the dtype of lse[" + std::to_string(i) + "] tensor is different from other dtypes " +
ToString(lseType_);
OP_LOGE_FOR_INVALID_DTYPE_WITH_REASON(context_->GetNodeName(), paramMsg.c_str(),
currentDtypeStr.c_str(), reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AttentionUpdateTiling::CheckInputDim()
{
uint64_t dimNum = 0;
if (!(d_ >= D_MIN && d_ <= D_MAX && d_ % D_DIVIDE_8 == 0)) {
std::string shapeStr = ToString(goShape_);
OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(context_->GetNodeName(), "go",
shapeStr.c_str(),
"The H dim of input go should be in the range of [8, 512] and divisible by 8, "
"where H refers to the 1st dim");
return ge::GRAPH_FAILED;
}
for (uint64_t i = 0; i < ALL_TO_SP_MULTIPLIER * sp_; i++) {
dimNum = context_->GetInputShape(i)->GetOriginShape().GetDimNum();
auto currentShape = context_->GetInputShape(i)->GetOriginShape();
uint64_t currentBshSize = currentShape.GetDim(DIM_0);
std::string paramMsg;
if (i >= sp_) {
uint64_t currentD = context_->GetInputShape(i)->GetOriginShape().GetDim(DIM_1);
paramMsg = "go[" + std::to_string(i-sp_) + "] tensor";
if (!(dimNum == GO_DIM_NUM)) {
std::string dimNumStr = std::to_string(dimNum);
OP_LOGE_FOR_INVALID_SHAPEDIM(context_->GetNodeName(), paramMsg.c_str(),
dimNumStr.c_str(), "2D");
return ge::GRAPH_FAILED;
}
if (!(currentD == d_)) {
std::string currentShapeStr = ToString(currentShape);
std::string reasonMsg = "The H dims of all go tensors must be the same, "
"but the H dim of " + paramMsg + " is different from other H dims " +
ToString(goShape_) +", where H refers to the 1st dim";
OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(context_->GetNodeName(), paramMsg.c_str(),
currentShapeStr.c_str(), reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
} else {
paramMsg = "lse[" + std::to_string(i) + "] tensor";
if (!(dimNum == LSE_DIM_NUM)) {
std::string dimNumStr = std::to_string(dimNum);
OP_LOGE_FOR_INVALID_SHAPEDIM(context_->GetNodeName(), paramMsg.c_str(),
dimNumStr.c_str(), "1D");
return ge::GRAPH_FAILED;
}
}
if (!(bshSize_ == currentBshSize)) {
std::string currentShapeStr = ToString(currentShape);
std::string reasonMsg = "The batch axis of all go and lse tensors should be the same, "
"but the batch axis of " + paramMsg + " is different from other batch axes " +
ToString(goShape_) + ", where batch refers to the 0th dim";
OP_LOGE_FOR_INVALID_SHAPE_WITH_REASON(context_->GetNodeName(), paramMsg.c_str(),
currentShapeStr.c_str(), reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AttentionUpdateTiling::CheckInputParams()
{
if (CheckInputDtype() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (CheckInputDim() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (!(updateType_ == 0 || updateType_ == 1)) {
std::string updateTypeStr = std::to_string(updateType_);
OP_LOGE_FOR_INVALID_VALUE(context_->GetNodeName(), "update_type",
updateTypeStr.c_str(), "0 or 1");
return ge::GRAPH_FAILED;
}
if (!(sp_ >= 1 && sp_ <= ATTR_SP_MAX)) {
std::string spStr = std::to_string(sp_);
OP_LOGE_FOR_INVALID_VALUE(context_->GetNodeName(), "sp",
spStr.c_str(), "in the range of [1, 16]");
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AttentionUpdateTiling::CheckOutputParams()
{
auto outputDesc = context_->GetOutputDesc(OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, outputDesc);
auto outputType = outputDesc->GetDataType();
if (goType_ != outputType) {
std::string dtypeMsg = ToString(goType_) + " and " + ToString(outputType);
OP_LOGE_FOR_INVALID_DTYPES_WITH_REASON(context_->GetNodeName(), "go and output",
dtypeMsg.c_str(), "The dtypes of parameter go and parameter output should be the same");
return ge::GRAPH_FAILED;
}
if (updateType_ == 1) {
auto outputLseMDesc = context_->GetOutputDesc(OUTPUT_LSE_M_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, outputLseMDesc);
auto outputLseMType = outputLseMDesc->GetDataType();
if (outputLseMType != ge::DataType::DT_FLOAT) {
std::string dtypeStr = ToString(outputLseMType);
OP_LOGE_FOR_INVALID_DTYPE_WITH_REASON(context_->GetNodeName(), "lse_m",
dtypeStr.c_str(),
"The dtype of output lse_m should be FLOAT when the attr update_type is 1");
return ge::GRAPH_FAILED;
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AttentionUpdateTiling::GetShapeAttrsInfo()
{
OP_CHECK_IF(context_ == nullptr, OP_LOGE("AttentionUpdate", "context is null"), return ge::GRAPH_FAILED);
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
const int64_t *spPtr = attrs->GetAttrPointer<int64_t>(ATTR_SP_INDEX);
OP_CHECK_IF(spPtr == nullptr, OP_LOGE("AttentionUpdate", "spPtr is null"), return ge::GRAPH_FAILED);
sp_ = static_cast<uint64_t>(*spPtr);
uint32_t allTensorCount = context_->GetComputeNodeInputNum();
if (allTensorCount != sp_ * NUM_2){
std::string tensorCountStr = std::to_string(static_cast<int64_t>(allTensorCount));
std::string reasonMsg =
"The number of tensors in input lse and go should be twice the attr sp, where sp is "
+ std::to_string(sp_);
OP_LOGE_FOR_INVALID_TENSORNUMS_WITH_REASON(context_->GetNodeName(), "lse and go",
tensorCountStr.c_str(), reasonMsg.c_str());
return ge::GRAPH_FAILED;
}
for (uint64_t i = 0; i < NUM_2 * sp_; i++) {
OP_CHECK_NULL_WITH_CONTEXT(context_, context_->GetInputShape(i));
}
goShape_ = context_->GetInputShape(INPUT_GO_INDEX * sp_)->GetOriginShape();
lseShape_ = context_->GetInputShape(INPUT_LSE_INDEX)->GetOriginShape();
for (uint64_t i = 0; i < NUM_2 * sp_; i++) {
OP_CHECK_NULL_WITH_CONTEXT(context_, context_->GetInputDesc(i));
}
goType_ = context_->GetInputDesc(INPUT_GO_INDEX * sp_)->GetDataType();
lseType_ = context_->GetInputDesc(INPUT_LSE_INDEX)->GetDataType();
d_ = goShape_.GetDim(DIM_1);
bshSize_ = goShape_.GetDim(DIM_0);
const int64_t *updateTypePtr = attrs->GetAttrPointer<int64_t>(ATTR_UPDATE_TYPE_INDEX);
OP_CHECK_IF(updateTypePtr == nullptr, OP_LOGE("AttentionUpdate", "updateTypePtr is null"), return ge::GRAPH_FAILED);
updateType_ = static_cast<int64_t>(*updateTypePtr);
goDtypeSize_ = GetSizeByDataType(goType_);
if (goDtypeSize_ == 0) {
std::string dtypeStr = Ops::Base::ToString(goType_);
OP_LOGE(context_->GetNodeName(), "UpdateAttention get go dtype[%s] size is 0.",
dtypeStr.c_str());
return ge::GRAPH_FAILED;
}
OP_CHECK_IF(CheckInputParams() != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "AttentionUpdate CheckInputParams FAILED."), return ge::GRAPH_FAILED);
OP_CHECK_IF(CheckOutputParams() != ge::GRAPH_SUCCESS,
OP_LOGE(context_->GetNodeName(), "AttentionUpdate CheckOutputParams FAILED."), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AttentionUpdateTiling::DoOpTiling()
{
perCoreCount_ = Ops::Base::CeilDiv(bshSize_, totalCoreNum_);
usedCoreNum_ = Ops::Base::CeilDiv(bshSize_, perCoreCount_);
lastCoreCount_ = bshSize_ - (usedCoreNum_ - 1) * perCoreCount_;
ubSize_ = ubSize_ - sp_ * ubBlockSize_;
uint64_t goBlockNum = ubBlockSize_ / goDtypeSize_;
uint64_t dAlign = Ops::Base::CeilAlign(d_, goBlockNum);
int64_t inputFactor = sp_ * sizeof(float) * DOUBLE_BUFFER_NUM + sp_ * dAlign * goDtypeSize_ * DOUBLE_BUFFER_NUM;
int64_t calcFactor = sp_ * sizeof(float);
int64_t outputFactor = DOUBLE_BUFFER_NUM * sizeof(float) + dAlign * goDtypeSize_ * DOUBLE_BUFFER_NUM;
bshInLoop_ = ubSize_ / (inputFactor + calcFactor + outputFactor);
uint64_t bshPerLoop =
(ubSize_ - sp_ * bshInLoop_ * dAlign * goDtypeSize_ * DOUBLE_BUFFER_NUM -
bshInLoop_ * dAlign * goDtypeSize_ * DOUBLE_BUFFER_NUM) /
(sp_ * sizeof(float) * DOUBLE_BUFFER_NUM + sp_ * sizeof(float) + sizeof(float) * DOUBLE_BUFFER_NUM);
uint64_t lseUbBlockCount = ubBlockSize_ / sizeof(float);
if (bshPerLoop > lseUbBlockCount) {
bshPerLoop = Ops::Base::FloorAlign(bshPerLoop, lseUbBlockCount);
}
bshInLoop_ = std::min(bshInLoop_, bshPerLoop);
perCorePerLoopCount_ = bshPerLoop;
perCoreLoops_ = Ops::Base::CeilDiv(perCoreCount_, bshPerLoop);
lastCoreLoops_ = Ops::Base::CeilDiv(lastCoreCount_, bshPerLoop);
perCoreLastLoopCount_ = perCoreCount_ - (perCoreLoops_ - 1) * bshPerLoop;
lastCoreLastLoopCount_ = lastCoreCount_ - (lastCoreLoops_ - 1) * bshPerLoop;
usedCoreNum_ = std::max(usedCoreNum_, static_cast<uint64_t>(1));
return ge::GRAPH_SUCCESS;
}
uint64_t AttentionUpdateTiling::GetTilingKey() const
{
uint64_t tilingKey = TILING_KEY_INIT_VALUE;
if (bshSize_ == 0) {
tilingKey = TILING_KEY_EMPTY;
} else {
tilingKey = TILING_KEY_INIT_VALUE + updateType_;
}
return tilingKey;
}
ge::graphStatus AttentionUpdateTiling::GetWorkspaceSize()
{
workspaceSize_ = SYS_WORKSPACE_SIZE;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AttentionUpdateTiling::PostTiling()
{
auto workspaces = context_->GetWorkspaceSizes(1);
OP_CHECK_NULL_WITH_CONTEXT(context_, workspaces);
workspaces[0] = workspaceSize_;
context_->SetBlockDim(usedCoreNum_);
tilingData_.set_sp(sp_);
tilingData_.set_d(d_);
tilingData_.set_usedCoreNum(usedCoreNum_);
tilingData_.set_perCoreCount(perCoreCount_);
tilingData_.set_lastCoreCount(lastCoreCount_);
tilingData_.set_perCoreLoops(perCoreLoops_);
tilingData_.set_lastCoreLoops(lastCoreLoops_);
tilingData_.set_perCorePerLoopCount(perCorePerLoopCount_);
tilingData_.set_perCoreLastLoopCount(perCoreLastLoopCount_);
tilingData_.set_lastCoreLastLoopCount(lastCoreLastLoopCount_);
tilingData_.set_bshInLoop(bshInLoop_);
tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize());
return ge::GRAPH_SUCCESS;
}
void AttentionUpdateTiling::DumpTilingInfo()
{
std::ostringstream info;
info << "sp: " << sp_ << std::endl;
info << "d: " << d_ << std::endl;
info << "usedCoreNum: " << usedCoreNum_ << std::endl;
info << "perCoreCount: " << perCoreCount_ << std::endl;
info << "lastCoreCount: " << lastCoreCount_ << std::endl;
info << "perCoreLoops: " << perCoreLoops_ << std::endl;
info << "lastCoreLoops: " << lastCoreLoops_ << std::endl;
info << "perCorePerLoopCount: " << perCorePerLoopCount_ << std::endl;
info << "perCoreLastLoopCount: " << perCoreLastLoopCount_ << std::endl;
info << "lastCoreLastLoopCount: " << lastCoreLastLoopCount_ << std::endl;
info << "bshInLoop: " << bshInLoop_ << std::endl;
OP_LOGI(context_->GetNodeName(), "%s", info.str().c_str());
}
REGISTER_OPS_TILING_TEMPLATE(AttentionUpdate, AttentionUpdateTiling, 1);
}