* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <sstream>
#include "graph/ge_error_codes.h"
#include "runtime/mem.h"
#include "register/kernel_registry_impl.h"
#include "common/checker.h"
#include "exe_graph/runtime/gert_tensor_data.h"
#include "common/runtime_api_wrapper.h"
#include "graph/utils/tensor_utils.h"
#include "exe_graph/runtime/storage_shape.h"
#include "graph/node.h"
#include "acl/acl_rt.h"
namespace gert {
namespace kernel {
ge::graphStatus InitCmoArgs(KernelContext *context) {
(void)context;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CreateCmoLaunchArgs(const ge::FastNode *node, KernelContext *context) {
(void)node;
const uint32_t op_code = context->GetInputValue<uint32_t>(0UL);
auto launch_result = context->GetOutput(0UL);
GE_ASSERT_NOTNULL(launch_result);
auto *cmo_task_info = new (std::nothrow) rtCmoTaskInfo_t();
GE_ASSERT_NOTNULL(cmo_task_info);
cmo_task_info->opCode = static_cast<uint16_t>(op_code);
launch_result->SetWithDefaultDeleter(cmo_task_info);
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(InitCmoArgs).RunFunc(InitCmoArgs).OutputsCreator(CreateCmoLaunchArgs);
ge::graphStatus UpdatePrefetchTaskInfo(KernelContext *const context) {
auto cmo_args = context->MutableInputPointer<rtCmoTaskInfo_t>(0UL);
GE_ASSERT_NOTNULL(cmo_args);
auto tensor_data = context->GetInputPointer<GertTensorData>(1UL);
GE_ASSERT_NOTNULL(tensor_data);
auto storage_shape = context->GetInputPointer<StorageShape>(2UL);
GE_ASSERT_NOTNULL(storage_shape);
const auto shape_size = storage_shape->GetStorageShape().GetShapeSize();
const auto dtype = context->GetInputValue<ge::DataType>(3UL);
const auto max_len = context->GetInputValue<uint32_t>(4UL);
const int64_t *offset = context->GetInputPointer<int64_t>(5UL);
GE_ASSERT_NOTNULL(offset);
int64_t tensor_len = ge::GetSizeInBytes(shape_size, dtype);
GE_ASSERT_TRUE(tensor_len > 0);
const bool is_offset_valid = (*offset >= 0) && (*offset < tensor_len);
GE_ASSERT_TRUE(is_offset_valid, "The offset [%ld] should be within the range of [0, %ld).", *offset, tensor_len);
tensor_len -= (*offset);
cmo_args->lengthInner = std::min(static_cast<uint32_t>(tensor_len), max_len);
cmo_args->sourceAddr = ge::PtrToValue(tensor_data->GetAddr()) + static_cast<uint64_t>(*offset);
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(UpdatePrefetchTaskInfo).RunFunc(UpdatePrefetchTaskInfo);
ge::graphStatus LaunchCmoTask(KernelContext *const context) {
auto cmo_args = context->GetInputPointer<rtCmoTaskInfo_t>(0UL);
GE_ASSERT_NOTNULL(cmo_args);
auto stream = context->GetInputValue<aclrtStream>(1UL);
GE_ASSERT_RT_OK(ge::rtCmoTaskLaunch(cmo_args, stream, 0U));
return ge::GRAPH_SUCCESS;
}
static std::vector<std::string> LaunchCmoTracer(const KernelContext *context) {
std::stringstream ss;
auto cmo_args = context->GetInputPointer<rtCmoTaskInfo_t>(0UL);
GE_ASSERT_NOTNULL(cmo_args);
auto stream = context->GetInputValue<aclrtStream>(1UL);
ss << "Launch cmo task with inner_len:" << cmo_args->lengthInner << ", src:" << cmo_args->sourceAddr
<< " on stream:" << stream;
return {ss.str()};
}
REGISTER_KERNEL(LaunchCmoTask).RunFunc(LaunchCmoTask).TracePrinter(LaunchCmoTracer);
}
}