* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "single_op/task/rts_kernel_task_builder.h"
#include "graph/def_types.h"
#include "single_op/task/build_task_utils.h"
#include "common/plugin/ge_make_unique_util.h"
#include "common/aclrt_malloc_helper.h"
namespace ge {
namespace {
constexpr size_t kNumAddresses = 2UL;
}
void RtsKernelTaskBuilder::UpdateCopyKind(const OpDescPtr &op_desc, rtMemcpyKind_t &kind) {
std::vector<int64_t> v_output_memory_type;
std::vector<int64_t> v_input_memory_type;
const bool has_output_mem_type_attr =
AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_output_memory_type);
const bool has_input_mem_type_attr =
AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_MEM_TYPE_LIST, v_input_memory_type);
const bool is_output_svm =
(v_output_memory_type.size() == 1U) && (v_output_memory_type[0U] == static_cast<int64_t>(RT_MEMORY_HOST_SVM));
const bool is_input_svm =
(v_input_memory_type.size() == 1U) && (v_input_memory_type[0U] == static_cast<int64_t>(RT_MEMORY_HOST_SVM));
if (is_output_svm && (!has_input_mem_type_attr)) {
kind = RT_MEMCPY_DEVICE_TO_HOST;
} else if (is_input_svm && (!has_output_mem_type_attr)) {
kind = RT_MEMCPY_HOST_TO_DEVICE;
} else {
}
}
Status RtsKernelTaskBuilder::BuildMemcpyAsyncTask(const GetOpDescFunc &get_op_desc_func,
const SingleOpModelParam ¶m, const domi::TaskDef &task_def,
OpTask *&op_task, StreamResource &resource) {
(void)resource;
const domi::MemcpyAsyncDef &kernel_def = task_def.memcpy_async();
auto task = MakeUnique<MemcpyAsyncTask>();
GE_CHECK_NOTNULL(task);
OpDescPtr op_desc = nullptr;
GE_CHK_STATUS_RET(get_op_desc_func(kernel_def.op_index(), op_desc), "get op desc failed");
task->SetOpDesc(op_desc);
task->dst_max_ = kernel_def.dst_max();
task->count_ = kernel_def.count();
task->kind_ = static_cast<rtMemcpyKind_t>(kernel_def.kind());
UpdateCopyKind(op_desc, task->kind_);
const auto addresses = BuildTaskUtils::JoinAddresses(BuildTaskUtils::GetAddresses(op_desc, param, false));
if (addresses.size() != kNumAddresses) {
GELOGE(INTERNAL_ERROR, "[Build][MemcpyAsyncTask] Invalid address count: %zu", addresses.size());
return INTERNAL_ERROR;
}
task->addresses_[0U] = static_cast<uintptr_t>(PtrToValue(addresses[0UL]));
task->addresses_[1U] = static_cast<uintptr_t>(PtrToValue(addresses[1UL]));
op_task = task.release();
return SUCCESS;
}
Status RtsKernelTaskBuilder::BuildNpuGetFloatStatusTask(const GetOpDescFunc &get_op_desc_func,
const SingleOpModelParam ¶m, const domi::TaskDef &task_def,
OpTask *&op_task, StreamResource &resource) {
(void)resource;
const domi::NpuGetFloatStatusDef &kernel_def = task_def.npu_get_float_status();
auto task = MakeUnique<NpuGetFloatStatusTask>();
GE_CHECK_NOTNULL(task);
task->mode_ = kernel_def.mode();
task->output_size_ = kernel_def.output_size();
task->args_size_ = sizeof(uint8_t *);
GE_CHK_RT_RET(ge::AclrtMalloc(&task->args_, task->args_size_, RT_MEMORY_HBM, GE_MODULE_NAME_U16));
GE_CHK_STATUS_RET(ModelUtils::GetRtAddress(param.runtime_param, static_cast<uintptr_t>(kernel_def.output_addr()),
task->output_addr_));
OpDescPtr op_desc = nullptr;
GE_CHK_STATUS_RET(get_op_desc_func(kernel_def.op_index(), op_desc), "get op desc failed");
task->SetOpDesc(op_desc);
op_task = task.release();
return SUCCESS;
}
Status RtsKernelTaskBuilder::BuildNpuClearFloatStatusTask(const GetOpDescFunc &get_op_desc_func,
const domi::TaskDef &task_def, OpTask *&op_task,
StreamResource &resource) {
(void)resource;
const domi::NpuClearFloatStatusDef &kernel_def = task_def.npu_clear_float_status();
std::unique_ptr<NpuClearFloatStatusTask> task = MakeUnique<NpuClearFloatStatusTask>();
GE_CHECK_NOTNULL(task);
task->mode_ = kernel_def.mode();
OpDescPtr op_desc = nullptr;
GE_CHK_STATUS_RET(get_op_desc_func(kernel_def.op_index(), op_desc), "get op desc failed");
task->SetOpDesc(op_desc);
op_task = task.release();
return SUCCESS;
}
Status RtsKernelTaskBuilder::BuildNpuGetFloatDebugStatusTask(const GetOpDescFunc &get_op_desc_func,
const SingleOpModelParam ¶m,
const domi::TaskDef &task_def, OpTask *&op_task,
StreamResource &resource) {
(void)resource;
const domi::NpuGetFloatDebugStatusDef &kernel_def = task_def.npu_get_float_debug_status();
auto task = MakeUnique<NpuGetFloatDebugStatusTask>();
GE_CHECK_NOTNULL(task);
task->mode_ = kernel_def.mode();
task->output_size_ = kernel_def.output_size();
task->args_size_ = sizeof(uint8_t *);
GE_CHK_RT_RET(ge::AclrtMalloc(&task->args_, task->args_size_, RT_MEMORY_HBM, GE_MODULE_NAME_U16));
GE_CHK_STATUS_RET(ModelUtils::GetRtAddress(param.runtime_param, static_cast<uintptr_t>(kernel_def.output_addr()),
task->output_addr_));
OpDescPtr op_desc = nullptr;
GE_CHK_STATUS_RET(get_op_desc_func(kernel_def.op_index(), op_desc), "get op desc failed");
task->SetOpDesc(op_desc);
op_task = task.release();
return SUCCESS;
}
Status RtsKernelTaskBuilder::BuildNpuClearFloatDebugStatusTask(const GetOpDescFunc &get_op_desc_func,
const domi::TaskDef &task_def, OpTask *&op_task,
StreamResource &resource) {
(void)resource;
const domi::NpuClearFloatDebugStatusDef &kernel_def = task_def.npu_clear_float_debug_status();
std::unique_ptr<NpuClearFloatDebugStatusTask> task = MakeUnique<NpuClearFloatDebugStatusTask>();
GE_CHECK_NOTNULL(task);
task->mode_ = kernel_def.mode();
OpDescPtr op_desc = nullptr;
GE_CHK_STATUS_RET(get_op_desc_func(kernel_def.op_index(), op_desc), "get op desc failed");
task->SetOpDesc(op_desc);
op_task = task.release();
return SUCCESS;
}
}