* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "cpu_kernel_cache.h"
#include <climits>
#include "aicpu_engine_struct.h"
#include "cpu_kernel.h"
#include "cpu_kernel_register.h"
#include "cpu_kernel_utils.h"
#include "log.h"
#include "status.h"
using namespace aicpu;
namespace {
constexpr uint32_t kMaxLRUCacheNum = 1024U;
constexpr uint32_t kTopicTypeDeviceTypePostion = 7;
constexpr uint32_t kTopicTypeDeviceTypeMask = 0x0080U;
constexpr int64_t kMaxDimSize = 32;
#pragma pack(push, 1)
struct RuntimeTensorDesc {
uint64_t data_addr;
int64_t data_offset_size;
int64_t dtype;
int64_t shape[kMaxDimSize + 1];
int64_t original_shape[kMaxDimSize + 1];
int64_t format;
int64_t sub_format;
uint8_t reserved[456];
};
#pragma pack(pop)
}
namespace aicpu {
* Init kernel cache.
*/
int32_t CpuKernelCache::InitParameter()
{
KERNEL_LOG_INFO("cpu cache set capacity.");
SetCapacity(kMaxLRUCacheNum);
return 0;
}
* update framework output tensor shape.
*/
uint32_t CpuKernelCache::UpdateFWKOutputShape(ExtInfoMsg& ext_info_msg, const CpuKernelContext& ctx) const
{
if (ext_info_msg.unknown_shape) {
for (size_t i = 0; i < ctx.GetOutputsSize(); ++i) {
Tensor* output = ctx.Output(static_cast<uint32_t>(i));
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] failed.", i)
auto shape = output->GetTensorShape();
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] shape failed.", i)
for (int32_t index = 0; index < shape->GetDims(); ++index) {
ext_info_msg.output_shape_and_type[i]->dims[index] = shape->GetDimSize(index);
}
}
}
for (auto it = ext_info_msg.unknown_shape_output_index_addr.cbegin();
it != ext_info_msg.unknown_shape_output_index_addr.cend(); ++it) {
Tensor* output = ctx.Output(it->first);
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%u] failed.", it->first)
auto shape = output->GetTensorShape();
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get output[%u] shape failed.", it->first)
RuntimeTensorDesc* tensor_desc = reinterpret_cast<RuntimeTensorDesc*>(static_cast<uintptr_t>(it->second));
KERNEL_CHECK_FALSE(
(shape->GetDims() <= kMaxDimSize), KERNEL_STATUS_PARAM_INVALID,
"Max shape size[32], but got output[%u] shape size[%d]", it->first, shape->GetDims())
tensor_desc->shape[0] = shape->GetDims();
tensor_desc->original_shape[0] = shape->GetDims();
for (int32_t index = 0; index < shape->GetDims(); ++index) {
tensor_desc->shape[index + 1] = shape->GetDimSize(index);
tensor_desc->original_shape[index + 1] = shape->GetDimSize(index);
}
}
return KERNEL_STATUS_OK;
}
* get shape information from framework.
*/
void CpuKernelCache::GetDimsFromShapeAndType(
const FWKAdapter::ShapeAndType* shape_and_type, std::vector<int64_t>& dims) const
{
for (uint32_t index = 0; index < FWKAdapter::kMaxShapeDims; ++index) {
if (shape_and_type->dims[index] == LLONG_MIN) {
break;
}
int64_t dim_value = shape_and_type->dims[index];
KERNEL_LOG_INFO("Get extend shape[%u] is [%ld]", index, dim_value);
dims.emplace_back(dim_value);
}
}
void CpuKernelCache::GetDimsFromArrays(const int64_t* shape, size_t len, std::vector<int64_t>& dims) const
{
for (size_t index = 0; index < len; ++index) {
KERNEL_LOG_INFO("Get arrays shape[%zu] is [%ld]", index, shape[index]);
dims.emplace_back(shape[index]);
}
}
uint32_t CpuKernelCache::CheckTensorParam(
const std::vector<uint64_t>& io_addrs, ExtInfoMsg& ext_info_msg, CpuKernelContext& ctx) const
{
if (io_addrs.size() != ctx.GetInputsSize() + ctx.GetOutputsSize()) {
KERNEL_LOG_ERROR(
"Addr number[%zu] is not equal to the sum of inputs[%u] and output[%u].", io_addrs.size(),
ctx.GetInputsSize(), ctx.GetOutputsSize());
return KERNEL_STATUS_PARAM_INVALID;
}
if ((ext_info_msg.unknown_shape) && ((ext_info_msg.input_shape_and_type.size() != ctx.GetInputsSize()) ||
(ext_info_msg.output_shape_and_type.size() != ctx.GetOutputsSize()))) {
KERNEL_LOG_ERROR(
"Input shape_and_type size error, input size[%u], input "
"shape_and_type size[%zu], output size[%u], output shape_and_type size[%zu].",
ctx.GetInputsSize(), ext_info_msg.input_shape_and_type.size(), ctx.GetOutputsSize(),
ext_info_msg.output_shape_and_type.size());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t CpuKernelCache::UpdateInputTensor(
const std::vector<uint64_t>& io_addrs, ExtInfoMsg& ext_info_msg, CpuKernelContext& ctx, size_t& addr_index) const
{
for (size_t i = 0; i < ctx.GetInputsSize(); ++i, ++addr_index) {
Tensor* input = ctx.Input(static_cast<uint32_t>(i));
KERNEL_CHECK_NULLPTR(input, KERNEL_STATUS_PARAM_INVALID, "Get input[%zu] failed.", i)
auto iter = ext_info_msg.unknown_shape_input_index_addr.find(static_cast<uint32_t>(i));
if (iter != ext_info_msg.unknown_shape_input_index_addr.end()) {
iter->second = io_addrs[addr_index];
RuntimeTensorDesc* tensor_desc =
reinterpret_cast<RuntimeTensorDesc*>(static_cast<uintptr_t>(io_addrs[addr_index]));
std::vector<int64_t> dims;
KERNEL_CHECK_FALSE(
(tensor_desc->shape[0] <= kMaxDimSize), KERNEL_STATUS_PARAM_INVALID,
"Max shape size[%ld], but got input[%zu] shape size[%ld]", kMaxDimSize, i, tensor_desc->shape[0])
GetDimsFromArrays(&(tensor_desc->shape[1]), static_cast<size_t>(tensor_desc->shape[0]), dims);
auto shape = input->GetTensorShape();
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get input[%zu] shape failed.", i)
shape->SetDimSizes(dims);
input->SetData(reinterpret_cast<void*>(static_cast<uintptr_t>(tensor_desc->data_addr)));
} else {
input->SetData(reinterpret_cast<void*>(static_cast<uintptr_t>(io_addrs[addr_index])));
}
if (ext_info_msg.unknown_shape) {
std::vector<int64_t> dims;
GetDimsFromShapeAndType(ext_info_msg.input_shape_and_type[i], dims);
auto shape = input->GetTensorShape();
shape->SetDimSizes(dims);
}
KERNEL_LOG_INFO("Set input[%zu] addr[%lu] success.", i, io_addrs[addr_index]);
if (io_addrs[addr_index] == 0) {
continue;
}
KERNEL_CHECK_FALSE(
(input->NumElements() >= 0), KERNEL_STATUS_PARAM_INVALID,
"Input[%zu] data elements number must be >= 0, got size[%ld].", i, input->NumElements());
input->SetDataSize(std::max(uint64_t(0), static_cast<uint64_t>(input->CalcDataSizeByShape())));
}
return KERNEL_STATUS_OK;
}
* update tensor information.
*/
uint32_t CpuKernelCache::UpdateTensor(
const std::vector<uint64_t>& io_addrs, ExtInfoMsg& ext_info_msg, CpuKernelContext& ctx) const
{
KERNEL_CHECK_RET(CheckTensorParam(io_addrs, ext_info_msg, ctx) != KERNEL_STATUS_OK, KERNEL_STATUS_PARAM_INVALID);
size_t addr_index = 0;
auto ret = UpdateInputTensor(io_addrs, ext_info_msg, ctx, addr_index);
KERNEL_CHECK_RET(ret != KERNEL_STATUS_OK, ret);
bool no_tiling = ext_info_msg.unknown_shape_output_index_addr.empty();
for (size_t i = 0; i < ctx.GetOutputsSize(); i++, addr_index++) {
Tensor* output = ctx.Output(static_cast<uint32_t>(i));
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] failed.", i)
auto iter = ext_info_msg.unknown_shape_output_index_addr.find(static_cast<uint32_t>(i));
if (iter != ext_info_msg.unknown_shape_output_index_addr.end()) {
iter->second = io_addrs[addr_index];
RuntimeTensorDesc* tensor_desc =
reinterpret_cast<RuntimeTensorDesc*>(static_cast<uintptr_t>(io_addrs[addr_index]));
output->SetData(reinterpret_cast<void*>(static_cast<uintptr_t>(tensor_desc->data_addr)));
} else {
output->SetData(reinterpret_cast<void*>(static_cast<uintptr_t>(io_addrs[addr_index])));
}
if (ext_info_msg.unknown_shape) {
std::vector<int64_t> dims;
GetDimsFromShapeAndType(ext_info_msg.output_shape_and_type[i], dims);
auto shape = output->GetTensorShape();
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] shape failed.", i)
shape->SetDimSizes(dims);
}
KERNEL_CHECK_FALSE(
(ext_info_msg.unknown_shape || (!no_tiling) || (output->NumElements() >= 0)), KERNEL_STATUS_PARAM_INVALID,
"Output[%zu] data elements number must be >= 0 "
"when known shape, got size[%ld].",
i, output->NumElements());
output->SetDataSize(std::max(uint64_t(0), static_cast<uint64_t>(output->CalcDataSizeByShape())));
}
return KERNEL_STATUS_OK;
}
* parse extend tensor shape types information.
*/
uint32_t CpuKernelCache::ParseExtShapeType(const FWKAdapter::ExtInfo* ext_info, bool& unknown_shape) const
{
if (ext_info->infoLen != sizeof(int32_t)) {
KERNEL_LOG_ERROR(
"Parse extend shape type failed, as info length must be [%zu], but got "
"[%u].",
sizeof(int32_t), ext_info->infoLen);
return KERNEL_STATUS_PARAM_INVALID;
}
unknown_shape = true;
KERNEL_LOG_INFO("Kernel has unknown shape.");
return KERNEL_STATUS_OK;
}
* parse extend tensor shape and types information.
*/
uint32_t CpuKernelCache::ParseExtShapeAndType(
bool unknown_shape, FWKAdapter::ExtInfo* ext_info, std::vector<FWKAdapter::ShapeAndType*>& shape_and_type) const
{
if (!unknown_shape) {
return KERNEL_STATUS_OK;
}
uint32_t size = (ext_info->infoLen) / sizeof(FWKAdapter::ShapeAndType);
KERNEL_LOG_INFO("Parse extend shape and type, size[%u].", size);
uint32_t check = (ext_info->infoLen) % sizeof(FWKAdapter::ShapeAndType);
if (check != 0) {
KERNEL_LOG_ERROR(
"Parse extend info length[%u] failed, must be integer multiple of the "
"[%zu].",
ext_info->infoLen, sizeof(FWKAdapter::ShapeAndType));
return KERNEL_STATUS_PARAM_INVALID;
}
auto shapes = reinterpret_cast<FWKAdapter::ShapeAndType*>(ext_info->infoMsg);
for (uint32_t index = 0; index < size; ++index) {
shape_and_type.emplace_back(&shapes[index]);
}
return KERNEL_STATUS_OK;
}
* parse extend session information.
*/
uint32_t CpuKernelCache::ParseExtSessionInfo(FWKAdapter::ExtInfo* ext_info, uint64_t& kernel_id) const
{
KERNEL_LOG_INFO("Parse extend session info");
auto need_len = sizeof(SessionInfo);
if (ext_info->infoLen != need_len) {
KERNEL_LOG_ERROR(
"Parse extend session info failed, as info length must be "
"[%zu], but got [%u].",
sizeof(SessionInfo), ext_info->infoLen);
return KERNEL_STATUS_PARAM_INVALID;
}
auto session_info = reinterpret_cast<SessionInfo*>(ext_info->infoMsg);
kernel_id = session_info->kernelId;
return KERNEL_STATUS_OK;
}
* get bit status.
*/
bool CpuKernelCache::GetBitStatus(uint64_t num, uint64_t pos) const { return ((num & (1UL << pos)) != 0); }
* parse bitmap information.
*/
uint32_t CpuKernelCache::ParseExtBitMap(const FWKAdapter::ExtInfo* ext_info, bool& unknown_shape) const
{
if (ext_info->infoLen != sizeof(int64_t)) {
KERNEL_LOG_ERROR(
"Parse extend bitmap failed, as info length must be [%zu], but got "
"[%u].",
sizeof(int64_t), ext_info->infoLen);
return KERNEL_STATUS_PARAM_INVALID;
}
uint64_t bit_map_info = static_cast<uint64_t>(*(reinterpret_cast<const int64_t*>(ext_info->infoMsg)));
unknown_shape = (!GetBitStatus(bit_map_info, 0));
KERNEL_LOG_INFO("Unknown_shape is [%d].", unknown_shape);
return KERNEL_STATUS_OK;
}
* parse topictype 16bit devicetype information.
*/
uint32_t CpuKernelCache::ParseExtTopicTypeDeviceType(
const FWKAdapter::ExtInfo* ext_info, bool& devicetype_host_flag) const
{
if (ext_info->infoLen != sizeof(uint32_t)) {
KERNEL_LOG_ERROR(
"Parse extend topictype failed, as info length must be [%zu], but got "
"[%u].",
sizeof(uint32_t), ext_info->infoLen);
return KERNEL_STATUS_PARAM_INVALID;
}
uint32_t topic_type_info = static_cast<uint32_t>(*(reinterpret_cast<const uint32_t*>(ext_info->infoMsg)));
devicetype_host_flag = (topic_type_info & kTopicTypeDeviceTypeMask) >> kTopicTypeDeviceTypePostion;
KERNEL_LOG_INFO("devicetype_host_flag is [%d].", devicetype_host_flag);
return KERNEL_STATUS_OK;
}
uint32_t CpuKernelCache::ParseAsyncWait(FWKAdapter::ExtInfo* ext_info, uint8_t& wait_type, uint32_t& wait_id) const
{
if (ext_info->infoLen != sizeof(FWKAdapter::AsyncWait)) {
KERNEL_LOG_ERROR(
"Parse extend async wait failed, as info length must be [%zu], but got [%u].",
sizeof(FWKAdapter::AsyncWait), ext_info->infoLen);
return KERNEL_STATUS_PARAM_INVALID;
}
FWKAdapter::AsyncWait* async_info = reinterpret_cast<FWKAdapter::AsyncWait*>(ext_info->infoMsg);
wait_type = async_info->waitType;
wait_id = async_info->waitId;
KERNEL_LOG_INFO("async wait type [%u], notify_id[%u].", wait_type, wait_id);
return KERNEL_STATUS_OK;
}
uint32_t CpuKernelCache::ParseExtUnknownShapeIndex(
FWKAdapter::ExtInfo* ext_info, std::map<uint32_t, uint64_t>& unknown_shape_index_addr) const
{
if (ext_info->infoLen % sizeof(uint32_t) != 0) {
KERNEL_LOG_ERROR(
"Parse unknown shape index extend info length[%u] failed, must be "
"integer multiple of the [%zu].",
ext_info->infoLen, sizeof(uint32_t));
return KERNEL_STATUS_PARAM_INVALID;
}
uint32_t size = ext_info->infoLen / sizeof(uint32_t);
KERNEL_LOG_INFO("Parse extend unknown shape index, size[%u].", size);
auto indexes = reinterpret_cast<uint32_t*>(ext_info->infoMsg);
for (uint32_t i = 0U; i < size; ++i) {
unknown_shape_index_addr[indexes[i]] = 0U;
}
return KERNEL_STATUS_OK;
}
uint32_t CpuKernelCache::ParseExtWorkSpaceInfo(
FWKAdapter::ExtInfo* ext_info, uint64_t& workspace_size, uint64_t& workspace_addr) const
{
if (ext_info->infoLen != sizeof(FWKAdapter::WorkSpaceInfo)) {
KERNEL_LOG_ERROR(
"Parse extend workspace_size info failed, as info length must be "
"[%zu], but got [%u].",
sizeof(FWKAdapter::WorkSpaceInfo), ext_info->infoLen);
return KERNEL_STATUS_PARAM_INVALID;
}
FWKAdapter::WorkSpaceInfo* workspace_info = reinterpret_cast<FWKAdapter::WorkSpaceInfo*>(ext_info->infoMsg);
workspace_size = workspace_info->size;
workspace_addr = workspace_info->addr;
KERNEL_LOG_DEBUG("workspace size info, workspace_size [%lu].", workspace_size);
return KERNEL_STATUS_OK;
}
* parse extend information.
*/
uint32_t CpuKernelCache::ParseExtMsg(AicpuParamHead* param_head, ExtInfoMsg& ext_info_msg) const
{
KERNEL_LOG_INFO("Parse extend info and update shape begin");
ext_info_msg.async_flag = false;
char* ext_info_addr = reinterpret_cast<char*>(static_cast<uintptr_t>(param_head->extInfoAddr));
uint32_t offset = 0;
FWKAdapter::ExtInfo* ext_info = nullptr;
while (offset + sizeof(FWKAdapter::ExtInfo) <= param_head->extInfoLength) {
ext_info = reinterpret_cast<FWKAdapter::ExtInfo*>(ext_info_addr + offset);
if (ext_info == nullptr) {
KERNEL_LOG_ERROR(
"Extend info is nullptr, extInfo length[%u], extend info addr[%ld], "
"offset[%u].",
param_head->extInfoLength, param_head->extInfoAddr, offset);
return KERNEL_STATUS_PARAM_INVALID;
}
uint32_t ret = KERNEL_STATUS_OK;
switch (ext_info->infoType) {
case FWKAdapter::FWK_ADPT_EXT_SHAPE_TYPE:
ret = ParseExtShapeType(ext_info, ext_info_msg.unknown_shape);
break;
case FWKAdapter::FWK_ADPT_EXT_INPUT_SHAPE:
ret = ParseExtShapeAndType(ext_info_msg.unknown_shape, ext_info, ext_info_msg.input_shape_and_type);
break;
case FWKAdapter::FWK_ADPT_EXT_OUTPUT_SHAPE:
ret = ParseExtShapeAndType(ext_info_msg.unknown_shape, ext_info, ext_info_msg.output_shape_and_type);
break;
case FWKAdapter::FWK_ADPT_EXT_SESSION_INFO:
ext_info_msg.has_sess_info = true;
ret = ParseExtSessionInfo(ext_info, ext_info_msg.kernel_id);
break;
case FWKAdapter::FWK_ADPT_EXT_BITMAP:
ret = ParseExtBitMap(ext_info, ext_info_msg.unknown_shape);
break;
case FWKAdapter::FWK_ADPT_EXT_TOPIC_TYPE:
ret = ParseExtTopicTypeDeviceType(ext_info, ext_info_msg.devicetype_host_flag);
break;
case FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT: {
ret = ParseAsyncWait(ext_info, ext_info_msg.wait_type, ext_info_msg.wait_id);
bool flag =
((ret == KERNEL_STATUS_OK) &&
(ext_info_msg.wait_type !=
static_cast<uint8_t>(FWKAdapter::FWKExtWaitType::FWK_ADPT_WAIT_TYPE_NULL)) &&
(ext_info_msg.wait_type !=
static_cast<uint8_t>(FWKAdapter::FWKExtWaitType::FWK_ADPT_WAIT_TYPE_INVALID)));
if (flag) {
ext_info_msg.async_flag = true;
}
break;
}
case FWKAdapter::FWK_ADPT_EXT_UNKNOWN_SHAPE_INPUT_INDEX:
ret = ParseExtUnknownShapeIndex(ext_info, ext_info_msg.unknown_shape_input_index_addr);
break;
case FWKAdapter::FWK_ADPT_EXT_UNKNOWN_SHAPE_OUTPUT_INDEX:
ret = ParseExtUnknownShapeIndex(ext_info, ext_info_msg.unknown_shape_output_index_addr);
break;
case FWKAdapter::FWK_ADPT_EXT_WORKSPACE_INFO:
ret = ParseExtWorkSpaceInfo(ext_info, ext_info_msg.workspace_size, ext_info_msg.workspace_addr);
break;
default:
KERNEL_LOG_INFO("Ignore infoType[%d], infoLen[%u].", ext_info->infoType, ext_info->infoLen);
break;
}
if (ret != KERNEL_STATUS_OK) {
return ret;
}
offset += FWKAdapter::kExtInfoHeadSize;
offset += ext_info->infoLen;
}
return KERNEL_STATUS_OK;
}
* parse io address.
*/
uint32_t CpuKernelCache::ParseIoAddr(
AicpuParamHead* param_head, std::vector<uint64_t>& io_addrs, char*& nodedef, uint32_t& nodedef_len) const
{
auto param_base = reinterpret_cast<char*>(param_head);
char* extend_param_base = param_base + sizeof(AicpuParamHead);
uint32_t extend_param_len = param_head->length - sizeof(AicpuParamHead);
if (param_head->ioAddrNum > 0) {
uint32_t addr_len = static_cast<uint32_t>(param_head->ioAddrNum * sizeof(uint64_t));
if (extend_param_len < addr_len) {
KERNEL_LOG_ERROR(
"Extend param is not enough for io addr, ioAddrNum[%u], "
"extend_param_len[%u].",
param_head->ioAddrNum, extend_param_len);
return KERNEL_STATUS_PARAM_INVALID;
}
auto io_addr_base = reinterpret_cast<uint64_t*>(extend_param_base);
for (uint32_t i = 0; i < param_head->ioAddrNum; ++i) {
io_addrs.push_back(io_addr_base[i]);
}
extend_param_base = extend_param_base + addr_len;
extend_param_len -= addr_len;
}
if (extend_param_len < sizeof(uint32_t)) {
KERNEL_LOG_ERROR(
"Extend param is not enough for addr, needLen[%zu], "
"extend_param_len[%u].",
sizeof(uint32_t), extend_param_len);
return KERNEL_STATUS_PARAM_INVALID;
}
nodedef_len = *reinterpret_cast<uint32_t*>(extend_param_base);
extend_param_base += sizeof(uint32_t);
nodedef = extend_param_base;
KERNEL_LOG_INFO("Parse io addr success, io number[%zu], nodedef length[%u].", io_addrs.size(), nodedef_len);
return KERNEL_STATUS_OK;
}
* get cpu kernel context from cache
*/
std::shared_ptr<CpuKernelContext> CpuKernelCache::GetCpuKernelContext(
std::shared_ptr<ExtInfoMsg> ext_info_msg, const char* nodedef, uint32_t nodedef_len,
std::shared_ptr<NodeDef>& nodedef_proto)
{
std::shared_ptr<CpuKernelContext> ctx = nullptr;
bool has_sess_info = ext_info_msg->has_sess_info;
const uint64_t ext_kernel_id = ext_info_msg->kernel_id;
KERNEL_LOG_INFO("Get cpu kernel context begin, kernel id[%lu]", ext_kernel_id);
if (has_sess_info) {
auto cache = GetCache(ext_kernel_id);
if (cache != nullptr) {
if (ext_info_msg->workspace_size > 0UL) {
CpuKernelUtils::UpdateCustWorkSpaceInfo(
cache->context.get(), ext_info_msg->workspace_size, ext_info_msg->workspace_addr);
}
KERNEL_LOG_DEBUG(
"get cache success, workspace addr=%ld.", ext_info_msg->workspace_addr, ext_info_msg->workspace_size);
return cache->context;
}
}
std::string str_data(nodedef, nodedef_len);
nodedef_proto = CpuKernelUtils::CreateNodeDef();
KERNEL_CHECK_NULLPTR(nodedef_proto, std::shared_ptr<CpuKernelContext>(nullptr), "Create node def failed.")
if (!nodedef_proto->ParseFromString(str_data)) {
return std::shared_ptr<CpuKernelContext>(nullptr);
}
auto waitType = CpuKernelUtils::CreateAttrValue();
waitType->SetInt(ext_info_msg->wait_type);
(void)nodedef_proto->AddAttrs("wait_type", waitType.get());
auto waitId = CpuKernelUtils::CreateAttrValue();
waitId->SetInt(ext_info_msg->wait_id);
(void)nodedef_proto->AddAttrs("wait_id", waitId.get());
KERNEL_LOG_INFO("AddAttrs wait info , waitType[%u] waitId[%u].", ext_info_msg->wait_type, ext_info_msg->wait_id);
DeviceType device_type = DEVICE;
if (ext_info_msg->devicetype_host_flag) {
device_type = HOST;
}
CpuKernelContext* tmp = new (std::nothrow) CpuKernelContext(device_type);
KERNEL_CHECK_NULLPTR(tmp, std::shared_ptr<CpuKernelContext>(nullptr), "Create context failed.")
ctx = std::shared_ptr<CpuKernelContext>(tmp);
uint32_t ret = ctx->Init(nodedef_proto.get());
if (ret != KERNEL_STATUS_OK) {
return std::shared_ptr<CpuKernelContext>(nullptr);
}
if (ext_info_msg->workspace_size > 0UL) {
CpuKernelUtils::UpdateCustWorkSpaceInfo(ctx.get(), ext_info_msg->workspace_size, ext_info_msg->workspace_addr);
KERNEL_LOG_DEBUG(
"workspace addr=%ld, workspace size is [%lu].", ext_info_msg->workspace_addr, ext_info_msg->workspace_size);
}
if (has_sess_info) {
CpuCacheData* cache_ptr = new (std::nothrow) CpuCacheData(nodedef_proto, ctx);
KERNEL_CHECK_NULLPTR(cache_ptr, std::shared_ptr<CpuKernelContext>(nullptr), "Create cpu cache data failed.")
std::shared_ptr<CpuCacheData> cache_shared = std::shared_ptr<CpuCacheData>(cache_ptr);
SetCache(ext_kernel_id, cache_shared);
KERNEL_LOG_INFO("Cache cpu kernel data success, kernel id[%lu].", ext_kernel_id);
}
KERNEL_LOG_INFO("Get cpu kernel context success, kernel id[%lu].", ext_kernel_id);
return ctx;
}
* parse io addrs and ext info from kernel param.
*/
int32_t CpuKernelCache::ParseRunKernelParam(
void* param, std::vector<uint64_t>& io_addrs, char*& node_def, uint32_t& node_def_len,
std::shared_ptr<ExtInfoMsg>& ext_info_msg) const
{
AicpuParamHead* param_head = static_cast<AicpuParamHead*>(param);
uint32_t ret = ParseIoAddr(param_head, io_addrs, node_def, node_def_len);
if (ret != KERNEL_STATUS_OK) {
return -1;
}
try {
ext_info_msg = std::make_shared<ExtInfoMsg>();
} catch (std::bad_alloc&) {
KERNEL_LOG_ERROR("Create ExtInfoMsg failed");
return -1;
}
ret = ParseExtMsg(param_head, *ext_info_msg);
if (ret != KERNEL_STATUS_OK) {
return -1;
}
return 0;
}
* dispatch cpu kernel: V2 优先, V1 兜底.
*/
uint32_t CpuKernelCache::DispatchCpuKernel(CpuKernelContext& ctx, ExtInfoMsg& ext_info_msg) const
{
const std::string& op_type = ctx.GetOpType();
const bool hit_v2 = CpuKernelRegister::Instance().IsRegisteredV2(op_type);
if (ext_info_msg.async_flag) {
auto cb = [this, &ctx, &ext_info_msg]() { return UpdateFWKOutputShape(ext_info_msg, ctx); };
return hit_v2 ? CpuKernelRegister::Instance().RunCpuKernelAsyncV2(
ctx, ext_info_msg.wait_type, ext_info_msg.wait_id, cb) :
CpuKernelRegister::Instance().RunCpuKernelAsync(
ctx, ext_info_msg.wait_type, ext_info_msg.wait_id, cb);
}
return hit_v2 ? CpuKernelRegister::Instance().RunCpuKernelV2(ctx) : CpuKernelRegister::Instance().RunCpuKernel(ctx);
}
int32_t CpuKernelCache::RunKernel(void* param)
{
std::vector<uint64_t> io_addrs;
char* node_def = nullptr;
uint32_t node_def_len = 0;
std::shared_ptr<ExtInfoMsg> ext_info_msg = nullptr;
int32_t pre_ret = ParseRunKernelParam(param, io_addrs, node_def, node_def_len, ext_info_msg);
if (pre_ret != 0) {
return pre_ret;
}
std::shared_ptr<NodeDef> node_def_proto = nullptr;
auto ctx = GetCpuKernelContext(ext_info_msg, node_def, node_def_len, node_def_proto);
KERNEL_CHECK_NULLPTR(
ctx, static_cast<int32_t>(KERNEL_STATUS_INNER_ERROR), "Get cpu kernel context from buff failed.")
uint32_t ret = UpdateTensor(io_addrs, *ext_info_msg, *ctx);
if (ret != KERNEL_STATUS_OK) {
return -1;
}
ret = DispatchCpuKernel(*ctx, *ext_info_msg);
if (!ext_info_msg->async_flag) {
if (ret != KERNEL_STATUS_OK) {
if ((ret == KERNEL_STATUS_SILENT_FAULT) || (ret == KERNEL_STATUS_DETECT_FAULT) ||
(ret == KERNEL_STATUS_DETECT_FAULT_NORAS) || (ret == KERNEL_STATUS_DETECT_LOW_BIT_FAULT) ||
(ret == KERNEL_STATUS_DETECT_LOW_BIT_FAULT_NORAS)) {
return static_cast<int32_t>(ret);
}
return -1;
}
ret = UpdateFWKOutputShape(*ext_info_msg, *ctx);
}
if (ret == KERNEL_STATUS_END_OF_SEQUENCE) {
return static_cast<int32_t>(ret);
}
if (ret != KERNEL_STATUS_OK) {
return -1;
}
return 0;
}
* run kernel with blockdim info.
*/
int32_t CpuKernelCache::RunCpuKernelWithBlock(void* param, struct BlkDimInfo* blkdim_info)
{
std::vector<uint64_t> io_addrs;
char* node_def = nullptr;
uint32_t node_def_len = 0;
std::shared_ptr<ExtInfoMsg> ext_info_msg = nullptr;
int32_t pre_ret = ParseRunKernelParam(param, io_addrs, node_def, node_def_len, ext_info_msg);
if (pre_ret != 0) {
return pre_ret;
}
std::shared_ptr<NodeDef> node_def_proto = nullptr;
auto ctx = GetCpuKernelContextWithBlock(ext_info_msg, node_def, node_def_len, node_def_proto, blkdim_info);
KERNEL_CHECK_NULLPTR(
ctx, static_cast<int32_t>(KERNEL_STATUS_INNER_ERROR), "Get cpu kernel context from buff failed.")
uint32_t ret = UpdateTensor(io_addrs, *ext_info_msg, *ctx);
if (ret != KERNEL_STATUS_OK) {
return -1;
}
ret = DispatchCpuKernel(*ctx, *ext_info_msg);
if (!ext_info_msg->async_flag) {
if (ret != KERNEL_STATUS_OK) {
return -1;
}
ret = UpdateFWKOutputShape(*ext_info_msg, *ctx);
}
if (ret != KERNEL_STATUS_OK) {
return -1;
}
return 0;
}
* get cpu kernel context from cache
*/
std::shared_ptr<CpuKernelContext> CpuKernelCache::GetCpuKernelContextWithBlock(
std::shared_ptr<ExtInfoMsg> ext_info_msg, const char* nodedef, uint32_t nodedef_len,
std::shared_ptr<NodeDef>& nodedef_proto, struct BlkDimInfo* blkdim_info)
{
std::shared_ptr<CpuKernelContext> ctx = nullptr;
const uint64_t kernel_id = ext_info_msg->kernel_id;
KERNEL_LOG_INFO("Get cpu kernel context with block info begin. kernel id[%lu].", kernel_id);
if (ext_info_msg->has_sess_info && blkdim_info->block_num == 1) {
auto cache = GetCache(kernel_id);
if (cache != nullptr) {
KERNEL_LOG_INFO("Get kernel from cache success.");
return cache->context;
}
}
std::string str_data(nodedef, nodedef_len);
nodedef_proto = CpuKernelUtils::CreateNodeDef();
KERNEL_CHECK_NULLPTR(
nodedef_proto, std::shared_ptr<CpuKernelContext>(nullptr), "Create node def with block info failed.")
if (!nodedef_proto->ParseFromString(str_data)) {
return std::shared_ptr<CpuKernelContext>(nullptr);
}
if (blkdim_info->block_num != 1U) {
auto block_num = CpuKernelUtils::CreateAttrValue();
block_num->SetInt(blkdim_info->block_num);
(void)nodedef_proto->AddAttrs("block_num", block_num.get());
auto blockid = CpuKernelUtils::CreateAttrValue();
blockid->SetInt(blkdim_info->block_id);
(void)nodedef_proto->AddAttrs("block_id", blockid.get());
KERNEL_LOG_INFO(
"AddAttrs block info , block_num[%u] block_id[%u].", blkdim_info->block_num, blkdim_info->block_id);
}
CpuKernelContext* tmp = new (std::nothrow) CpuKernelContext(DEVICE);
KERNEL_CHECK_NULLPTR(tmp, std::shared_ptr<CpuKernelContext>(nullptr), "Create context with block info failed.")
ctx = std::shared_ptr<CpuKernelContext>(tmp);
uint32_t ret = ctx->Init(nodedef_proto.get());
if (ret != KERNEL_STATUS_OK) {
return std::shared_ptr<CpuKernelContext>(nullptr);
}
if (ext_info_msg->has_sess_info) {
CpuCacheData* cache_ptr = new (std::nothrow) CpuCacheData(nodedef_proto, ctx);
KERNEL_CHECK_NULLPTR(cache_ptr, std::shared_ptr<CpuKernelContext>(nullptr), "Create cpu cache data failed.")
std::shared_ptr<CpuCacheData> cache_shared = std::shared_ptr<CpuCacheData>(cache_ptr);
SetCache(kernel_id, cache_shared);
KERNEL_LOG_INFO("Cache cpu kernel data success. kernel id[%lu]", kernel_id);
}
if (ext_info_msg->workspace_size > 0UL) {
uint64_t per_unit = ext_info_msg->workspace_size / blkdim_info->block_num;
uint64_t start_pos = per_unit * blkdim_info->block_id;
uint64_t block_workspace_size = blkdim_info->block_id < (blkdim_info->block_num - 1) ?
per_unit :
(ext_info_msg->workspace_size - (per_unit * (blkdim_info->block_num - 1)));
CpuKernelUtils::UpdateCustWorkSpaceInfo(
ctx.get(), block_workspace_size, ext_info_msg->workspace_addr + start_pos);
KERNEL_LOG_DEBUG(
"UpdateCustWorkSpaceInfo success, workspace size is [%lu], start_pos is [%lu].", block_workspace_size,
start_pos);
}
KERNEL_LOG_INFO("Get cpu kernel context success. kernel id[%lu].", kernel_id);
return ctx;
}
}