* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#include "KernelMetaDataParser.h"
#include "utils/InjectLogger.h"
#include "hccl.h"
namespace {
constexpr uint8_t U64_BYTE_SIZE = 8U;
std::string DfxTensorTypeToString(DfxTensorType type)
{
switch (type) {
case DfxTensorType::INVALID_TENSOR: return "invalid_tensor";
case DfxTensorType::GENERAL_TENSOR: return "general_tensor";
case DfxTensorType::INPUT_TENSOR: return "input_tensor";
case DfxTensorType::OUTPUT_TENSOR: return "output_tensor";
case DfxTensorType::WORKSPACE_TENSOR: return "workspace_tensor";
case DfxTensorType::ASCENDC_LOG: return "ascendc_log";
case DfxTensorType::MC2_CTX: return "mc2_ctx";
case DfxTensorType::TILING_DATA: return "tiling_data";
case DfxTensorType::L1: return "l1";
case DfxTensorType::L2: return "l2";
case DfxTensorType::OVERFLOW_MES: return "overflow_mes";
case DfxTensorType::FFTS_ADDRESS: return "ffts_address";
case DfxTensorType::SHAPE_TENSOR: return "shape_tensor";
}
return "unknown";
}
bool IsSupportDfxPointPtr(DfxPointerType type)
{
switch (type) {
case DfxPointerType::LEVEL_1_POINTER:
case DfxPointerType::LEVEL_2_POINTER_WITH_SHAPE:
case DfxPointerType::SHAPE_TENSOR_PLACEHOLD:
return true;
case DfxPointerType::INVALID_POINTER:
case DfxPointerType::LEVEL_2_POINTER:
default:
return false;
}
}
}
bool KernelMetaDataParser::Parse()
{
auto shSize = metaData_.size();
uint32_t index {0U};
while (index < shSize - 1) {
auto type = static_cast<DfxType>(metaData_[index] + (metaData_[index + 1] << 8));
index += 2;
if (index + 2 > shSize) {
WARN_LOG("Parse meta data failed");
return false;
}
if (type == DfxType::DFX_TYPE) {
ParseKernelArgs(index);
if (memInfo_.inputParamsAddrInfos.empty()) {
WARN_LOG("Parse kernel args failed");
return false;
}
continue;
}
if (type == DfxType::TIK_TYPE) {
memInfo_.isTik = true;
}
auto length = static_cast<uint32_t>(metaData_[index] + (metaData_[index + 1] << 8));
index += (2 + length);
DEBUG_LOG("Dfx type is %u, length: %u", static_cast<uint32_t>(type), static_cast<uint32_t>(length));
}
return true;
}
KernelMetaDataParser::KernelMetaDataParser(const std::vector<uint8_t> &metaData) : metaData_(metaData)
{
memInfo_.Clear();
dfxParser_[DfxTensorType::GENERAL_TENSOR] = std::bind(&KernelMetaDataParser::ParseMetaInput, this,
std::placeholders::_1, std::placeholders::_2);
dfxParser_[DfxTensorType::INPUT_TENSOR] = std::bind(&KernelMetaDataParser::ParseMetaInput, this,
std::placeholders::_1, std::placeholders::_2);
dfxParser_[DfxTensorType::OUTPUT_TENSOR] = std::bind(&KernelMetaDataParser::ParseMetaInput, this,
std::placeholders::_1, std::placeholders::_2);
dfxParser_[DfxTensorType::WORKSPACE_TENSOR] = std::bind(&KernelMetaDataParser::ParseMetaInput, this,
std::placeholders::_1, std::placeholders::_2);
dfxParser_[DfxTensorType::FFTS_ADDRESS] = std::bind(&KernelMetaDataParser::ParseMetaFfts, this,
std::placeholders::_1, std::placeholders::_2);
dfxParser_[DfxTensorType::TILING_DATA] = std::bind(&KernelMetaDataParser::ParseMetaTiling, this,
std::placeholders::_1, std::placeholders::_2);
dfxParser_[DfxTensorType::MC2_CTX] = std::bind(&KernelMetaDataParser::ParseMetaMc2Ctx, this,
std::placeholders::_1, std::placeholders::_2);
dfxParser_[DfxTensorType::SHAPE_TENSOR] = std::bind(&KernelMetaDataParser::ParseMetaInput, this,
std::placeholders::_1, std::placeholders::_2);
}
void KernelMetaDataParser::ParseMetaInput(uint32_t &index, uint32_t paramsNo)
{
if (metaData_.size() < (U64_BYTE_SIZE + index)) {
WARN_LOG("Can not Parse Meta input data, index error");
return;
}
uint64_t length = 0;
for (size_t i = 0; i < U64_BYTE_SIZE; ++i) {
length |= static_cast<uint64_t>(metaData_[i + index]) << ((U64_BYTE_SIZE - 1 - i) * U64_BYTE_SIZE);
}
memInfo_.inputParamsAddrInfos.emplace_back(AddrInfo{0U, length, MemInfoSrc::EXTRA,
MemInfoDesc::INPUT, paramsNo});
DEBUG_LOG("Get %zu data, length is %lu", memInfo_.inputParamsAddrInfos.size(), static_cast<uint64_t>(length));
}
void KernelMetaDataParser::ParseMetaMc2Ctx(uint32_t &index, uint32_t paramsNo)
{
DEBUG_LOG("has mc2ctx input");
(void)metaData_;
(void)index;
(void)paramsNo;
memInfo_.inputParamsAddrInfos.emplace_back(AddrInfo{0U, sizeof(HcclCombinOpParam), MemInfoSrc::EXTRA,
MemInfoDesc::HCCL_MC2_CONTEXT});
memInfo_.hasMc2Ctx = true;
}
void KernelMetaDataParser::ParseMetaFfts(uint32_t &index, uint32_t paramsNo)
{
(void)index;
(void)paramsNo;
SetFftsInfo(memInfo_);
}
void KernelMetaDataParser::ParseMetaTiling(uint32_t &index, uint32_t paramsNo)
{
(void)paramsNo;
if (metaData_.size() < (U64_BYTE_SIZE + index)) {
WARN_LOG("Can not Parse Meta tiling data, index error");
return;
}
memInfo_.tilingDataSize = 0;
for (size_t i = 0; i < U64_BYTE_SIZE; ++i) {
memInfo_.tilingDataSize |=
static_cast<uint64_t>(metaData_[i + index]) << ((U64_BYTE_SIZE - 1 - i) * U64_BYTE_SIZE);
}
DEBUG_LOG("Get tiling data, data length is %lu", static_cast<uint64_t>(memInfo_.tilingDataSize));
}
void KernelMetaDataParser::ParseKernelArgs(uint32_t &index)
{
auto dfxLen = static_cast<uint32_t>(metaData_[index] + (metaData_[index + 1] << 8));
index += 2;
uint32_t dfxRange = dfxLen + index;
uint32_t paramsIdx{};
while (index < dfxRange && (index + 7 < metaData_.size())) {
auto numOfUint64 = static_cast<uint32_t>((metaData_[index + 2] << 8) + (metaData_[index + 3]));
if (numOfUint64 == 0) {
DEBUG_LOG("number Of uint64 is 0");
index += 4;
continue;
}
index += 4;
auto type = static_cast<DfxTensorType>((metaData_[index + 6] << 8) + metaData_[index + 7]);
auto ptrType = static_cast<DfxPointerType>((metaData_[index + 4] << 8) + metaData_[index + 5]);
DEBUG_LOG("Get Meta data, data type: %s, number Of uint64 is: %u",
DfxTensorTypeToString(type).c_str(), static_cast<uint32_t>(numOfUint64));
auto funcIter = dfxParser_.find(type);
if (funcIter == dfxParser_.end()) {
index += (numOfUint64 * U64_BYTE_SIZE);
DEBUG_LOG("Cannot find dfx type, type is %u", static_cast<uint32_t>(type));
break;
}
if (!IsSupportDfxPointPtr(ptrType)) {
memInfo_.Clear();
INFO_LOG("unsupported level pointer, type is %u", static_cast<uint32_t>(ptrType));
return;
}
if (ptrType == DfxPointerType::LEVEL_2_POINTER_WITH_SHAPE) {
DEBUG_LOG("the %u parameter of the kernel is a double pointer", paramsIdx);
uint64_t dtypeBytes = 0;
for (size_t i = 0; i < U64_BYTE_SIZE; ++i) {
dtypeBytes |= static_cast<uint64_t>(metaData_[i + index + 16]) << ((U64_BYTE_SIZE - 1 - i) * U64_BYTE_SIZE);
}
SecondPtrInfo secondPtrInfo{};
secondPtrInfo.dtypeBytes = dtypeBytes;
memInfo_.secondPtrAddrInfos.insert({paramsIdx, secondPtrInfo});
}
index += 8;
funcIter->second(index, paramsIdx + 1);
index += ((numOfUint64 - 1) * U64_BYTE_SIZE);
paramsIdx++;
}
}