* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_vec_proposal_check.cpp
* \brief
*/
#include "kernel_check_params.h"
#include "kernel_vec_proposal_check.h"
namespace AscendC {
namespace check {
const uint32_t ONE_REPEAT_CAL_NUM = 16;
const uint32_t PROPOSAL_SIZE = 8;
bool TikcppVecProposalCheck::CheckAddrAlign(const std::string& src0Name)
{
uint8_t alignByte = ONE_BLK_SIZE;
bool dstRes = true;
bool src0Res = true;
if (apiName == "MrgSort") {
alignByte = 8;
dstRes = CheckTensorAddrAlign(param_.dstAddr, param_.dstPos, ONE_BLK_SIZE, "dst");
src0Res = CheckTensorAddrAlign(param_.src0Addr, param_.src0Pos, alignByte, src0Name);
return dstRes && src0Res;
}
if (apiName == "MrgSort4" && param_.dstDtypeBytes == sizeof(half)) {
alignByte = 16;
}
dstRes = CheckTensorAddrAlign(param_.dstAddr, param_.dstPos, alignByte, "dst");
src0Res = CheckTensorAddrAlign(param_.src0Addr, param_.src0Pos, alignByte, src0Name);
return dstRes && src0Res;
}
uint8_t TikcppVecProposalCheck::CountBit(uint16_t validBit) const
{
uint8_t count = 0;
while (validBit != 0) {
count += (validBit & 0x1);
validBit >>= 1;
}
return count;
}
bool TikcppVecProposalCheck::CheckValidBit(uint16_t validBit) const
{
bool validBitRes = validBit == 3 || validBit == 7 || validBit == 15;
ASCENDC_CHECK_AND_LOG((validBitRes), {CHECK_LOG_ERROR("Failed to check validBit value in %s, its valid value is "
"[3, 7, 15], current value is %u.", apiName.c_str(), validBit);});
return true;
}
uint64_t TikcppVecProposalCheck::CalSortElemPerRep(uint16_t elementLengths[4], uint8_t count) const
{
uint64_t elePerRep = 0;
for (uint8_t i = 0; i < count; ++i) {
elePerRep += elementLengths[i];
}
return elePerRep;
}
bool TikcppVecProposalCheck::NeedRepeatTimes() const
{
bool cond1 = (param_.elementLengths[0] == param_.elementLengths[1]) &&
(param_.elementLengths[1] == param_.elementLengths[2]) &&
(param_.elementLengths[2] == param_.elementLengths[3]);
return cond1 && param_.isContinuous && (!param_.isExhausted) && (param_.validBit == 15);
}
bool TikcppVecProposalCheck::Vbs16Check() const
{
uint64_t calCount = param_.repeatTimes * ONE_REPEAT_CAL_NUM * PROPOSAL_SIZE;
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.dstDtypeBytes, param_.dstSize, calCount, "dstLocal"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src0DtypeBytes, param_.src0Size, calCount, "srcLocal"));
return true;
}
bool TikcppVecProposalCheck::Vbs32Check() const
{
ASCENDC_CHECK(CheckBufferSizeOverFlow(param_.src1Size, GlobalParams::Instance().bufferSizeMap.at(param_.src1Pos),
"check src1 tensor buffersize failed"));
if (param_.dstDtypeBytes == 0) {
CHECK_LOG_ERROR("dst dtype bytes is zero");
return false;
}
const uint32_t oneCalNumVbs32 = 32;
uint32_t elemPerRepeat = ONE_REPEAT_BYTE_SIZE / param_.dstDtypeBytes;
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.dstDtypeBytes, param_.dstSize, elemPerRepeat * param_.repeatTimes,
"dstLocal"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src0DtypeBytes, param_.src0Size, oneCalNumVbs32 * param_.repeatTimes,
"src0Local"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src1DtypeBytes, param_.src1Size, oneCalNumVbs32 * param_.repeatTimes,
"src1Local"));
return true;
}
bool TikcppVecProposalCheck::Vms4Check() const
{
ASCENDC_CHECK(CheckValidBit(param_.validBit));
uint8_t count = CountBit(param_.validBit);
if (param_.srcIndex >= count) {
return true;
}
uint64_t sortElePerRep = CalSortElemPerRep(param_.elementLengths, count);
uint64_t elemPerRep = sortElePerRep * PROPOSAL_SIZE;
uint64_t validRepeatTimes = 1;
if (NeedRepeatTimes()) {
validRepeatTimes = param_.repeatTimes;
ASCENDC_CHECK_VALUE_RANGE(validRepeatTimes, 1, MAX_REPEAT_TIMES, "repeatTimes", "MrgSort4");
}
if (!param_.isExhausted) {
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.dstDtypeBytes, param_.dstSize, elemPerRep * param_.repeatTimes,
"dstLocal"));
}
std::string tensorName = "src" + std::to_string(param_.srcIndex) + " in srcLocal";
uint64_t srcEle = (validRepeatTimes - 1) * elemPerRep + param_.elementLengths[param_.srcIndex] * PROPOSAL_SIZE;
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src0DtypeBytes, param_.src0Size, srcEle, tensorName));
return true;
}
bool TikcppVecProposalCheck::Vms4v2Check() const
{
ASCENDC_CHECK(CheckValidBit(param_.validBit));
uint8_t count = CountBit(param_.validBit);
if (param_.srcIndex >= count) {
return true;
}
uint64_t sortElePerRep = CalSortElemPerRep(param_.elementLengths, count);
uint64_t bytePerRep = sortElePerRep * PROPOSAL_SIZE;
uint64_t validRepeatTimes = 1;
if (NeedRepeatTimes()) {
validRepeatTimes = param_.repeatTimes;
ASCENDC_CHECK_VALUE_RANGE(validRepeatTimes, 1, MAX_REPEAT_TIMES, "repeatTimes", "MrgSort");
}
if (!param_.isExhausted) {
ASCENDC_CHECK(CheckTensorOverflowHigh(1, param_.dstSize, bytePerRep * param_.repeatTimes, "dstLocal"));
}
std::string tensorName = "src" + std::to_string(param_.srcIndex) + " in srcLocal";
uint64_t srcBytes = (validRepeatTimes - 1) * bytePerRep + param_.elementLengths[param_.srcIndex] * PROPOSAL_SIZE;
ASCENDC_CHECK(CheckTensorOverflowHigh(1, param_.src0Size, srcBytes, tensorName));
return true;
}
bool TikcppVecProposalCheck::VconcatCheck() const
{
uint32_t base = param_.repeatTimes * ONE_REPEAT_CAL_NUM;
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src0DtypeBytes, param_.src0Size, base, "srcLocal"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.dstDtypeBytes, param_.dstSize, base * PROPOSAL_SIZE, "dstLocal"));
return true;
}
bool TikcppVecProposalCheck::VextractCheck() const
{
uint32_t base = param_.repeatTimes * ONE_REPEAT_CAL_NUM;
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src0DtypeBytes, param_.src0Size, base * PROPOSAL_SIZE, "srcLocal"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.dstDtypeBytes, param_.dstSize, base, "dstLocal"));
return true;
}
bool TikcppVecProposalCheck::ConcatCheck() const
{
uint32_t base = param_.repeatTimes * ONE_REPEAT_CAL_NUM;
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src0DtypeBytes, param_.src0Size, base, "srcLocal"));
#if defined (__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 3002) || \
(__NPU_ARCH__ == 3102) || (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102))
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.dstDtypeBytes, param_.dstSize, base, "concatLocal"));
#elif defined (__NPU_ARCH__) && ((__NPU_ARCH__ == 1001) || (__NPU_ARCH__ == 2002))
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src1DtypeBytes, param_.src1Size, base * PROPOSAL_SIZE, "tmpLocal"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.dstDtypeBytes, param_.dstSize, base * PROPOSAL_SIZE, "concatLocal"));
#endif
return true;
}
bool TikcppVecProposalCheck::ExtractCheck() const
{
#if defined (__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 3002) || \
(__NPU_ARCH__ == 3102) || (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102))
uint64_t groupNumPerRep = 32;
uint64_t totalEleNum = groupNumPerRep * param_.repeatTimes;
ASCENDC_CHECK(CheckTensorSizeOverflow(param_.repeatTimes * ONE_REPEAT_BYTE_SIZE, param_.src0Size, "sortedLocal",
"Extract"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src1DtypeBytes, param_.src1Size, totalEleNum, "dstIndexLocal"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.dstDtypeBytes, param_.dstSize, totalEleNum, "dstValueLocal"));
#elif defined (__NPU_ARCH__) && ((__NPU_ARCH__ == 1001) || (__NPU_ARCH__ == 2002))
uint32_t base = param_.repeatTimes * ONE_REPEAT_CAL_NUM;
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src0DtypeBytes, param_.src0Size, base * PROPOSAL_SIZE, "sortedLocal"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.src1DtypeBytes, param_.src1Size, base, "dstIndexLocal"));
ASCENDC_CHECK(CheckTensorOverflowHigh(param_.dstDtypeBytes, param_.dstSize, base, "dstValueLocal"));
#endif
return true;
}
bool TikcppVecProposalCheck::CheckAllHighLevel()
{
const std::string supportPos = "VECIN / VECOUT / VECCALC";
ASCENDC_CHECK(CheckTensorScope(param_.dstLogicPos, static_cast<uint8_t>(HardWareIndex::UB), "dst", supportPos));
std::string src0Name = "src0";
if (apiName == "MrgSort" || apiName == "MrgSort4") {
src0Name = "src" + std::to_string(param_.srcIndex) + " in srcLocal";
}
ASCENDC_CHECK(CheckTensorScope(param_.src0LogicPos, static_cast<uint8_t>(HardWareIndex::UB), src0Name, supportPos));
if (apiName == "Sort32" || apiName == "Concat" || apiName == "Extract") {
ASCENDC_CHECK(CheckTensorScope(param_.src1LogicPos, static_cast<uint8_t>(HardWareIndex::UB), "src1", supportPos));
ASCENDC_CHECK(CheckTensorAddrAlign(param_.src1Addr, param_.src1Pos, ONE_BLK_SIZE, "src1"));
}
ASCENDC_CHECK(CheckAddrAlign(src0Name));
std::string bufferSrc0 = "check " + src0Name + " tensor buffersize failed";
ASCENDC_CHECK(CheckBufferSizeOverFlow(param_.dstSize, GlobalParams::Instance().bufferSizeMap.at(param_.dstPos),
"check dst tensor buffersize failed"));
ASCENDC_CHECK(CheckBufferSizeOverFlow(param_.src0Size, GlobalParams::Instance().bufferSizeMap.at(param_.src0Pos),
bufferSrc0));
if (apiName == "Sort32") {
return Vbs32Check();
} else if (apiName == "ProposalConcat") {
return VconcatCheck();
} else if (apiName == "Concat") {
return ConcatCheck();
} else if (apiName == "ProposalExtract") {
return VextractCheck();
} else if (apiName == "Extract") {
return ExtractCheck();
} else if (apiName == "RpSort16") {
return Vbs16Check();
} else if (apiName == "MrgSort4") {
return Vms4Check();
} else if (apiName == "MrgSort") {
return Vms4v2Check();
}
return true;
}
}
}