* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "ffts_update_kernel.h"
#include "acl/acl_rt.h"
#include "register/ffts_plus_task_update.h"
#include "register/ffts_node_calculater_registry.h"
#include "kernel/kernel_log.h"
#include "exe_graph/runtime/gert_tensor_data.h"
#include "register/op_tiling.h"
#include "graph/ge_error_codes.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/tensor_utils.h"
#include "register/kernel_registry_impl.h"
#include "common/checker.h"
#include "common/sgt_slice_type.h"
#include "engine/aicore/fe_rt2_common.h"
#include "engine/ffts_plus/converter/ffts_plus_proto_transfer.h"
#include "engine/ffts_plus/converter/ffts_plus_common.h"
namespace gert {
namespace kernel {
namespace {
ge::graphStatus FftsTaskInfoPreProc(KernelContext *context) {
auto task_data = context->GetInputValue<uint8_t*>(static_cast<size_t>(0));
FE_ASSERT_NOTNULL(task_data);
auto task_info_para = context->GetInputValue<NodeMemPara*>(static_cast<size_t>(1));
FE_ASSERT_NOTNULL(task_info_para);
const auto task_info_data = reinterpret_cast<const TransTaskInfo*>(task_data);
auto task_info = reinterpret_cast<TransTaskInfo*>(task_info_para->host_addr);
size_t head_len = sizeof(TransTaskInfo) + sizeof(rtFftsPlusSqe_t);
if (memcpy_s(task_info, task_info_para->size, task_info_data, head_len) != EOK) {
KLOGE("Copy task info data[%zu] to dst[%zu] failed.", head_len, task_info_para->size);
return ge::GRAPH_FAILED;
}
auto *const ffts_plus_sqe = ge::PtrToPtr<uint8_t, rtFftsPlusSqe_t>(task_info->args);
task_info->rt_task_info.fftsPlusSqe = ffts_plus_sqe;
size_t buf_offset = task_info_data->offsets[static_cast<size_t>(InfoStType::kDescBuf)];
auto dev_task_info = reinterpret_cast<TransTaskInfo*>(task_info_para->dev_addr);
GELOGD("Base add host[%lx], dev[%lx].", task_info_para->host_addr, task_info_para->dev_addr);
GE_ASSERT_TRUE(buf_offset < (task_info_para->size - sizeof(TransTaskInfo)));
uintptr_t dev_buf_base = reinterpret_cast<uintptr_t>(&dev_task_info->args[buf_offset]);
uintptr_t align_base = AddrAlignBy128(dev_buf_base);
size_t new_offset = reinterpret_cast<uint8_t*>(align_base) - dev_task_info->args;
GE_ASSERT_TRUE(new_offset < (task_info_para->size - sizeof(TransTaskInfo)));
task_info->rt_task_info.descBuf = &task_info->args[new_offset];
const void* pre_data_base = &task_info_data->args[buf_offset];
size_t buf_len = task_info->rt_task_info.descBufLen;
size_t left_len = task_info_para->size - sizeof(TransTaskInfo) - new_offset;
GELOGD("Mem base:%lx, align_base:%lx, offset:%zu, new offset:%zu, buf_len[%zu], left len[%zu].",
dev_buf_base, align_base, buf_offset, new_offset, buf_len, left_len);
if (memcpy_s(const_cast<void*>(task_info->rt_task_info.descBuf), left_len, pre_data_base, buf_len) != EOK) {
KLOGE("Failed to copy op_desc buffer.");
return ge::GRAPH_FAILED;
}
task_info->offsets[static_cast<size_t>(InfoStType::kDescBuf)] = new_offset;
auto out_ret = context->GetOutput(0U);
FE_ASSERT_NOTNULL(out_ret);
out_ret->Set(&task_info->rt_task_info, nullptr);
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(FftsTaskInfoPreProc).RunFunc(FftsTaskInfoPreProc);
ge::graphStatus NodeMemParaAssign(KernelContext *context) {
auto node_para = context->GetOutputPointer<NodeMemPara>(static_cast<size_t>(MemParaOutKey::NODE_PARA));
FE_ASSERT_NOTNULL(node_para);
auto mem_guard = context->GetOutputPointer<gert::ContinuousVector>(static_cast<size_t>(MemParaOutKey::MEM_GUARD));
FE_ASSERT_NOTNULL(mem_guard);
auto pre_para = context->GetInputPointer<MemPrePara>(static_cast<size_t>(MemParaInKey::PRE_PARA));
FE_ASSERT_NOTNULL(pre_para);
auto pre_data = context->GetInputValue<uint8_t*>(static_cast<size_t>(MemParaInKey::PRE_DATA));
auto dev_addr = context->GetInputValue<gert::GertTensorData *>(static_cast<size_t>(MemParaInKey::DEV_ADDR));
FE_ASSERT_NOTNULL(dev_addr);
auto host_addr = context->GetInputValue<gert::GertTensorData *>(static_cast<size_t>(MemParaInKey::HOST_ADDR));
FE_ASSERT_NOTNULL(host_addr);
void* dev_addr_base = dev_addr->GetAddr();
void* host_addr_base = host_addr->GetAddr();
node_para->dev_addr = static_cast<uint8_t*>(dev_addr_base) + pre_para->offset;
node_para->host_addr = static_cast<uint8_t*>(host_addr_base) + pre_para->offset;
node_para->size = pre_para->size;
auto guard_vec = reinterpret_cast<MemGuard*>(mem_guard->MutableData());
size_t cur_idx = mem_guard->GetSize();
cur_idx = (cur_idx == mem_guard->GetCapacity()) ? 0 : cur_idx;
if (mem_guard->SetSize(cur_idx + 1) != ge::GRAPH_SUCCESS) {
KLOGE("Resize mem guard size failed.");
return ge::GRAPH_FAILED;
}
guard_vec[cur_idx].guard_ptr = static_cast<uint8_t*>(host_addr_base) + pre_para->offset + pre_para->size;
guard_vec[cur_idx].guard_val = rand();
*reinterpret_cast<int64_t*>(guard_vec[cur_idx].guard_ptr) = guard_vec[cur_idx].guard_val;
GELOGD("Base addr:%lx, node addr base:%lx, size:%zu, offset:%zu, guard[%ld].", host_addr_base,
node_para->host_addr, node_para->size, pre_para->offset, guard_vec[cur_idx].guard_val);
if (pre_data && (pre_para->pre_size > 0)) {
if (memcpy_s(node_para->host_addr, node_para->size, pre_data, pre_para->pre_size) != EOK) {
KLOGE("Failed to copy pre data.");
return ge::GRAPH_FAILED;
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CreateNodeMemPara(const ge::FastNode *node, KernelContext *context) {
(void)node;
auto mem_para = context->GetOutput(0);
FE_ASSERT_NOTNULL(mem_para);
mem_para->SetWithDefaultDeleter(new (std::nothrow) NodeMemPara());
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(NodeMemParaAssign).RunFunc(NodeMemParaAssign).OutputsCreator(CreateNodeMemPara);
ge::graphStatus FFTSTaskAndArgsCopy(KernelContext *context) {
auto need_launch = context->GetInputValue<uint32_t>(static_cast<size_t>(H2DInKey::LAUNCH_FLAG));
if (need_launch == 0U) {
GELOGD("No need to transfer arguments to the device.");
return ge::GRAPH_SUCCESS;
}
auto stream = context->GetInputValue<rtStream_t>(static_cast<size_t>(H2DInKey::STREAM_ID));
auto dev_addr = context->GetInputValue<gert::GertTensorData *>(static_cast<size_t>(H2DInKey::DEV_ADDR));
FE_ASSERT_NOTNULL(dev_addr);
auto host_addr = context->GetInputValue<gert::GertTensorData *>(static_cast<size_t>(H2DInKey::HOST_ADDR));
FE_ASSERT_NOTNULL(host_addr);
void* dev_addr_base = dev_addr->GetAddr();
void* host_addr_base = host_addr->GetAddr();
GELOGD("H2D{%lx}{%lx}{%zu}{%zu}.", dev_addr_base, host_addr_base, dev_addr->GetSize(), host_addr->GetSize());
GE_CHK_RT_RET(aclrtMemcpyAsync(dev_addr_base, dev_addr->GetSize(), host_addr_base,
host_addr->GetSize(), ACL_MEMCPY_HOST_TO_BUF_TO_DEVICE, stream));
return ge::GRAPH_SUCCESS;
}
std::vector<std::string> CheckMemGuard(const KernelContext *context) {
std::stringstream ss;
std::vector<std::string> msgs;
auto mem_guard = context->GetInputPointer<gert::ContinuousVector>(static_cast<size_t>(H2DInKey::MEM_GUARD));
if (mem_guard == nullptr) {
GELOGW("No memory guard set up.");
return msgs;
}
auto guard_vec = reinterpret_cast<const MemGuard*>(mem_guard->GetData());
for (size_t i = 0; i < mem_guard->GetSize(); ++i) {
auto guard = guard_vec[i];
auto cur_val = *static_cast<int64_t*>(guard.guard_ptr);
if (cur_val != guard.guard_val) {
GELOGW("Mem guard[%zu] value[%ld] does not match the actual value[%ld].", i, guard.guard_val, cur_val);
ss << "FFTS memory may has over write with block index: " << i;
msgs.emplace_back(ss.str());
return msgs;
}
}
ss << "FFTS memory check ok.";
msgs.emplace_back(ss.str());
return msgs;
}
REGISTER_KERNEL(FFTSTaskAndArgsCopy).RunFunc(FFTSTaskAndArgsCopy).TracePrinter(CheckMemGuard);
}
}
}