* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file nsa_compress_tiling.cpp
* \brief
*/
#include "nsa_compress_tiling.h"
#include "nsa_compress_tiling_common.h"
#include <climits>
#include <graph/utils/type_utils.h>
#include "register/op_impl_registry.h"
#include "log/log.h"
#include "err/ops_err.h"
#include "op_host/tiling_base.h"
using namespace ge;
using namespace AscendC;
using namespace Ops::Transformer::OpTiling;
namespace optiling {
static ge::graphStatus CheckParams(const gert::TilingContext *context)
{
if (context->GetInputShape(INPUT_INPUT_INDEX) != nullptr && context->GetInputShape(WEIGHT_INPUT_INDEX) != nullptr &&
context->GetAttrs() != nullptr) {
auto &inputShape = context->GetInputShape(INPUT_INPUT_INDEX)->GetStorageShape();
auto &weightShape = context->GetInputShape(WEIGHT_INPUT_INDEX)->GetStorageShape();
auto actSeqLenTensor = context->GetOptionalInputTensor(ACT_SEQ_LEN_INPUT_INDEX);
auto &actSeqLenShape = actSeqLenTensor->GetShape().GetStorageShape();
const char *inputLayout = context->GetAttrs()->GetAttrPointer<char>(INPUTLAYOUT_ATTRS_INDEX);
const int64_t inputCompressBlockSize =
*context->GetAttrs()->GetAttrPointer<int64_t>(COMPRESS_BLOCK_SIZE_ATTRS_INDEX);
const int64_t inputCompressStride = *context->GetAttrs()->GetAttrPointer<int64_t>(COMPRESS_STRIDE_ATTRS_INDEX);
const int64_t actseqlenType = *context->GetAttrs()->GetAttrPointer<int64_t>(ACT_SEQ_LEN_TYPE_ATTRS_INDEX);
OP_CHECK_IF((inputLayout[0] != 'T' || inputLayout[1] != 'N' || inputLayout[2] != 'D'),
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "The inputLayout currently only supports the 'TND' format"),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
(inputShape.GetDim(1) != weightShape.GetDim(1)),
OPS_REPORT_VECTOR_INNER_ERR(
context->GetNodeName(),
"input.shape[1] must equal weight.shape[1], but got input.shape[1]=%ld, weight.shape[1]=%ld",
inputShape.GetDim(1), weightShape.GetDim(1)),
return ge::GRAPH_FAILED);
OP_CHECK_IF((inputShape.GetDim(2) % 16 != 0),
OPS_REPORT_VECTOR_INNER_ERR(
context->GetNodeName(), "input.shape[2] must be a multiple of 16, but got input.shape[2]=%ld",
inputShape.GetDim(2)),
return ge::GRAPH_FAILED);
OP_CHECK_IF((weightShape.GetDim(0) != inputCompressBlockSize),
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(),
"weight.shape[0] must equal compressBlockSize, but got "
"weight.shape[0]=%ld, compressBlockSize=%ld",
weightShape.GetDim(0), inputCompressBlockSize),
return ge::GRAPH_FAILED);
OP_CHECK_IF((inputCompressBlockSize % 16 != 0),
OPS_REPORT_VECTOR_INNER_ERR(
context->GetNodeName(), "compressBlockSize must be a multiple of 16, but got compressBlockSize=%ld",
inputCompressBlockSize),
return ge::GRAPH_FAILED);
OP_CHECK_IF((inputCompressStride % 16 != 0),
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(),
"compressStride must be a multiple of 16, but got compressStride=%ld",
inputCompressStride),
return ge::GRAPH_FAILED);
OP_CHECK_IF((inputCompressBlockSize < inputCompressStride),
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(),
"compressStride can not greater than compressBlockSize, but got "
"compressBlockSize=%ld, compressStride=%ld",
inputCompressBlockSize, inputCompressStride),
return ge::GRAPH_FAILED);
OP_CHECK_IF((actseqlenType != 0),
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "actseqlenType only support 0, but got actseqlenType=%ld",
actseqlenType),
return ge::GRAPH_FAILED);
const int64_t *actSeqLenValue = actSeqLenTensor->GetData<int64_t>();
uint32_t batchSize = actSeqLenShape.GetDim(Zero);
int64_t preSeqLen = 0;
for (uint32_t i = 0; i < batchSize; ++i) {
int64_t valueI = actSeqLenValue[i];
OP_CHECK_IF((valueI < preSeqLen),
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "actSeqLenOptional currently only supports the prefix sum format and requires values to be greater than 0, but actSeqLenOptional[%u]=%ld contains an invalid value",
i, valueI),
return ge::GRAPH_FAILED);
preSeqLen = valueI;
}
OP_CHECK_IF((preSeqLen != inputShape.GetDim(0)),
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "input.shape[0] must equal actSeqLenOptional[-1], but got "
"input.shape[0]=%ld, actSeqLenOptional[-1]=%ld",
inputShape.GetDim(0), preSeqLen),
return ge::GRAPH_FAILED);
return ge::SUCCESS;
}
OP_LOGW(context, "fail to get shape or attr from context");
return ge::GRAPH_FAILED;
}
static bool IsEmptyInput(gert::TilingContext *context)
{
auto inputShape = context->GetInputShape(INPUT_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, inputShape);
auto weightShape = context->GetInputShape(WEIGHT_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, weightShape);
auto actSeqLenTensor = context->GetOptionalInputTensor(ACT_SEQ_LEN_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, actSeqLenTensor);
int64_t inputShapeSize = inputShape->GetStorageShape().GetShapeSize();
int64_t weightShapeSize = weightShape->GetStorageShape().GetShapeSize();
if ((inputShapeSize == 0 || weightShapeSize == 0)) {
return true;
}
auto &actSeqLenShape = actSeqLenTensor->GetShape().GetStorageShape();
OP_CHECK_IF(actSeqLenShape.GetDimNum() != 1,
OP_LOGE(context->GetNodeName(),
"NsaCompress actSeqLenShape is invalid %lu %ld", actSeqLenShape.GetDimNum(), actSeqLenShape.GetDim(0)),
return ge::GRAPH_FAILED);
const int64_t *actSeqLenValue = actSeqLenTensor->GetData<int64_t>();
OP_CHECK_IF(actSeqLenValue == nullptr,
OP_LOGE(context->GetNodeName(),
"NsaCompress actSeqLenTensor data is null pointer"),
return ge::GRAPH_FAILED);
return false;
}
ASCENDC_EXTERN_C ge::graphStatus TilingNsaCompress(gert::TilingContext *context)
{
if (IsEmptyInput(context)) {
return ge::GRAPH_FAILED;
}
if (CheckParams(context) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
auto resultCode = TilingRegistry::GetInstance().DoTilingImpl(context);
return resultCode;
}
ASCENDC_EXTERN_C ge::graphStatus TilingPrepareForNsaCompress(gert::TilingParseContext *context)
{
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
OP_CHECK_IF(platformInfoPtr == nullptr,
OP_LOGE(context->GetNodeName(),
"platformInfoPtr is null"),
return ge::GRAPH_FAILED);
auto compileInfoPtr = context->GetCompiledInfo<NsaCompressCompileInfo>();
OP_CHECK_IF(compileInfoPtr == nullptr,
OP_LOGE(context->GetNodeName(),
"compileInfoPtr is null"),
return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
compileInfoPtr->aivNum = ascendcPlatform.GetCoreNumAiv();
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
OP_CHECK_IF((compileInfoPtr->aivNum == 0 || compileInfoPtr->ubSize == 0),
OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "platform info is invalid, aivNum=%u, ubSize=%lu",
compileInfoPtr->aivNum, compileInfoPtr->ubSize),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(NsaCompress)
.Tiling(TilingNsaCompress)
.TilingInputsDataDependency({ACT_SEQ_LEN_INPUT_INDEX})
.TilingParse<NsaCompressCompileInfo>(TilingPrepareForNsaCompress);
}