* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file decode_update_tiling.cpp
* \brief
*/
#include "register/op_def_registry.h"
#include "platform/platform_info.h"
#include "util/shape_util.h"
#include "log/log.h"
#include "tiling/tiling_api.h"
#include "op_host/tiling_type.h"
#include "op_host/tiling_util.h"
#include "decode_update_tiling.h"
#include "op_host/tiling_templates_registry.h"
#include "attention_update_tiling.h"
namespace {
constexpr uint64_t TILING_KEY_GO_FP32_WITHOUT_MAX_OUT = 10010;
constexpr uint64_t TILING_KEY_GO_FP16_WITHOUT_MAX_OUT = 10020;
constexpr uint64_t TILING_KEY_GO_BF16_WITHOUT_MAX_OUT = 10030;
constexpr uint64_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024UL;
constexpr uint32_t LSE_TOTAL_LENGTH_DIM = 0;
constexpr uint32_t IN_HD_DIM = 1;
constexpr uint32_t LOCAL_OUT_IDNEX = 1;
constexpr uint32_t ATTR_UPDATETYPE_INDEX = 0;
constexpr uint32_t ATTR_SP_INDEX = 1;
constexpr uint32_t MIN_BLOCK_LENGTH = 3;
}
namespace optiling {
uint64_t GetTilingKey(uint32_t sp, ge::DataType goType_)
{
uint64_t tiling_key = 0UL;
if (goType_ == ge::DataType::DT_FLOAT) {
tiling_key = TILING_KEY_GO_FP32_WITHOUT_MAX_OUT;
} else if (goType_ == ge::DataType::DT_FLOAT16) {
tiling_key = TILING_KEY_GO_FP16_WITHOUT_MAX_OUT;
} else if (goType_ == ge::DataType::DT_BF16) {
tiling_key = TILING_KEY_GO_BF16_WITHOUT_MAX_OUT;
}
return tiling_key;
}
ge::graphStatus DecodeUpdateTiling(gert::TilingContext *context) {
OP_CHECK_IF(context == nullptr, OP_LOGE("AttentionUpdate", "context is null"),
return ge::GRAPH_FAILED);
auto compileInfo = reinterpret_cast<const DecodeUpdateCompileInfo*>(context->GetCompileInfo());
if (compileInfo->is_ascendc) {
return Ops::Transformer::OpTiling::TilingRegistry::GetInstance().DoTilingImpl(context);
}
auto nodeName = context->GetNodeName();
OP_LOGD(nodeName, "Tiling initing");
auto attrs = context->GetAttrs();
const int64_t* updateTypePtr = attrs->GetAttrPointer<int64_t>(ATTR_UPDATETYPE_INDEX);
const int64_t* spPtr = attrs->GetAttrPointer<int64_t>(ATTR_SP_INDEX);
auto updateType = *updateTypePtr;
auto sp = *spPtr;
ge::DataType goType_ = context->GetInputDesc(LOCAL_OUT_IDNEX * sp)->GetDataType();
auto tilingKey = GetTilingKey(sp, goType_);
context->SetTilingKey(tilingKey);
auto lseShape = context->GetInputShape(0)->GetStorageShape();
auto inShape = context->GetInputShape(sp)->GetStorageShape();
const uint32_t totalLength = lseShape.GetDim(LSE_TOTAL_LENGTH_DIM);
const uint32_t hd = inShape.GetDim(IN_HD_DIM);
OP_LOGD(nodeName, "TotalLength of b*s*hc is %u, sp is %ld, hd is %u",
totalLength, sp, hd);
uint32_t blockDims = 0;
uint32_t coreNum = compileInfo->coreNum;
blockDims = std::min<uint32_t>((totalLength + MIN_BLOCK_LENGTH - 1UL) / MIN_BLOCK_LENGTH, coreNum);
DecodeUpdateTilingData tilingData;
tilingData.set_formerNum(static_cast<uint32_t>(totalLength % blockDims));
tilingData.set_tailNum(static_cast<uint32_t>(blockDims - tilingData.get_formerNum()));
tilingData.set_tailLength(static_cast<uint32_t>(totalLength / blockDims));
tilingData.set_formerLength(static_cast<uint32_t>(tilingData.get_formerNum() == 0 ?
0 : tilingData.get_tailLength() + 1));
tilingData.set_hDim(static_cast<uint32_t>(hd));
tilingData.set_sp(static_cast<uint32_t>(sp));
tilingData.set_updateType(static_cast<uint32_t>(updateType));
tilingData.set_totalLength(static_cast<uint32_t>(totalLength));
OP_LOGD(nodeName, "BlockDims is %u", blockDims);
OP_LOGD(nodeName, "formerLength is %u", tilingData.get_formerLength());
OP_LOGD(nodeName, "tailLength is %u", tilingData.get_tailLength());
OP_LOGD(nodeName, "formerNum is %u", tilingData.get_formerNum());
OP_LOGD(nodeName, "tailNum is %u", tilingData.get_tailNum());
OP_LOGD(nodeName, "hDim is %u", tilingData.get_hDim());
OP_LOGD(nodeName, "sp is %u", tilingData.get_sp());
OP_LOGD(nodeName, "totalLength is %u", tilingData.get_totalLength());
size_t* workspaces = context->GetWorkspaceSizes(1);
workspaces[0] = SYS_WORKSPACE_SIZE;
context->SetBlockDim(blockDims);
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(),
context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingPrepare4DecodeUpdate(gert::TilingParseContext* context) {
auto compileInfo = context->GetCompiledInfo<DecodeUpdateCompileInfo>();
OP_CHECK_IF(compileInfo == nullptr,
OP_LOGE(context, "compileInfoPtr is null"),
return ge::GRAPH_FAILED);
auto platformInfo = context->GetPlatformInfo();
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
uint64_t ubSize;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
compileInfo->coreNum = ascendcPlatform.GetCoreNumAiv();
compileInfo->ubSize = ubSize;
compileInfo->is_ascendc = Ops::Transformer::OpTiling::IsRegbaseSocVersion(context);
OP_CHECK_IF(compileInfo->coreNum <= 0,
OP_LOGE(context->GetNodeName(),
"AttentionUpdate GetHardwareInfo Failed, vectorCoreNum: %u",
compileInfo->coreNum),
return ge::GRAPH_FAILED);
OP_CHECK_IF(compileInfo->ubSize <= 0,
OP_LOGE(context->GetNodeName(),
"AttentionUpdate GetHardwareInfo Failed, ubSize: %lu",
compileInfo->ubSize),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(AttentionUpdate)
.Tiling(DecodeUpdateTiling)
.TilingParse<DecodeUpdateCompileInfo>(TilingPrepare4DecodeUpdate);
}