* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_vec_reduce_other_whl_check.cpp
* \brief
*/
#include "kernel_check_params.h"
#include "model/model_factory_mask.h"
#include "kernel_vec_reduce_other_whl_check.h"
namespace AscendC {
namespace check {
bool TikcppVecReduceOtherWhlCheck::CheckWholeReduceDtypeBytes(const std::string &errMsg)
{
uint32_t dstDtypeBytes = params_.dstDtypeBytes;
uint32_t srcDtypeBytes = params_.src0DtypeBytes;
if (dstDtypeBytes != srcDtypeBytes) {
CHECK_LOG_ERROR("%s, ""Reduce need dst data type (%u),dst src type (%u), should be same",
errMsg.c_str(), dstDtypeBytes, srcDtypeBytes);
return false;
}
return true;
}
bool TikcppVecReduceOtherWhlCheck::CheckAddrAlign()
{
uint8_t alignByte = 4;
if (params_.dstDtypeBytes == sizeof(half)) {
alignByte = 2;
}
return CheckTensorAddrAlign(params_.dstAddr, params_.dstPos, alignByte, "dst");
}
static bool CheckTensorWhlOverflowLowCounter(std::vector<uint64_t>& maskArray, const VecReduceWhlApiParams& param,
const uint64_t unit, const std::string& tensorName, const std::string& apiName)
{
uint32_t oneRepeatNum = ONE_REPEAT_BYTE_SIZE / param.dstDtypeBytes;
uint64_t elementNum = (maskArray.size() == 1) ? maskArray[0] : maskArray[1];
int32_t repeatTimes = (elementNum + oneRepeatNum - 1) / oneRepeatNum;
uint32_t needSize = (repeatTimes - 1) * param.dstRepeatStride * unit + unit;
ASCENDC_CHECK(CheckTensorSizeOverflow(needSize, param.dstSize, tensorName, apiName, ModeType::COUNTER_MODE));
return true;
}
static bool CheckTensorWhlOverflowLowNorm(const VecReduceWhlApiParams& param, const uint64_t unit,
const std::string& tensorName, const std::string& apiName)
{
uint32_t needSize = (param.repeatTimes - 1) * param.dstRepeatStride * unit;
if (param.order == ReduceOrder::ORDER_VALUE_INDEX || param.order == ReduceOrder::ORDER_INDEX_VALUE) {
needSize = needSize + param.dstDtypeBytes * 2;
} else if (param.order == ReduceOrder::ORDER_ONLY_VALUE) {
needSize = needSize + param.dstDtypeBytes;
} else if (param.order == ReduceOrder::ORDER_ONLY_INDEX) {
needSize = needSize + sizeof(uint32_t);
}
ASCENDC_CHECK(CheckTensorSizeOverflow(needSize, param.dstSize, tensorName, apiName, ModeType::NORM_MODE));
return true;
}
bool TikcppVecReduceOtherWhlCheck::CheckTensorWhlOverflowLow(std::vector<uint64_t>& maskArray,
const uint64_t unit, const std::string& tensorName)
{
if (ModelFactoryGetMaskMode() == 1) {
return CheckTensorWhlOverflowLowCounter(maskArray, params_, unit, tensorName, apiName);
}
return CheckTensorWhlOverflowLowNorm(params_, unit, tensorName, apiName);
}
bool TikcppVecReduceOtherWhlCheck::CheckAllLowLevel(std::vector<uint64_t> maskArray)
{
uint32_t maxByteLen = std::max(params_.dstDtypeBytes, params_.src0DtypeBytes);
ASCENDC_CHECK(UpdateMaskArrayAndCheck(maskArray, maxByteLen));
if ((apiName == "WholeReduceMax") || (apiName == "WholeReduceMin")) {
ASCENDC_CHECK(CheckWholeReduceDtypeBytes("Check Whole Reduce data type"));
if (params_.order == ReduceOrder::ORDER_VALUE_INDEX || params_.order == ReduceOrder::ORDER_INDEX_VALUE) {
constexpr uint32_t MULTIPLIE = 2;
ASCENDC_CHECK(CheckTensorWhlOverflowLow(maskArray, MULTIPLIE * params_.dstDtypeBytes, "dstLocal"));
} else if (params_.order == ReduceOrder::ORDER_ONLY_VALUE) {
ASCENDC_CHECK(CheckTensorWhlOverflowLow(maskArray, params_.dstDtypeBytes, "dstLocal"));
} else if (params_.order == ReduceOrder::ORDER_ONLY_INDEX) {
ASCENDC_CHECK(CheckTensorWhlOverflowLow(maskArray, sizeof(uint32_t), "dstLocal"));
}
}
const std::string supportPos = "VECIN / VECOUT / VECCALC";
ASCENDC_CHECK(CheckTensorScope(params_.dstLogicPos, static_cast<uint8_t>(HardWareIndex::UB), "dst", supportPos));
ASCENDC_CHECK(CheckTensorScope(params_.src0LogicPos, static_cast<uint8_t>(HardWareIndex::UB), "src", supportPos));
ASCENDC_CHECK(CheckAddrAlign());
ASCENDC_CHECK(CheckBufferSizeOverFlow(params_.dstSize, GlobalParams::Instance().bufferSizeMap.at(params_.dstPos),
"check dst tensor buffersize failed"));
ASCENDC_CHECK(CheckBufferSizeOverFlow(params_.src0Size, GlobalParams::Instance().bufferSizeMap.at(params_.src0Pos),
"check src tensor buffersize failed"));
TensorOverflowParams params = {params_.src0Size, params_.src0DtypeBytes, static_cast<uint64_t>(params_.repeatTimes),
static_cast<uint64_t>(params_.src0BlockStride), static_cast<uint64_t>(params_.src0RepeatStride), false};
ASCENDC_CHECK(CheckTensorOverflowLow(maskArray, params, "srcLocal"));
return true;
}
}
}