* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstddef>
#include <iomanip>
#include <cinttypes>
#include "graph/ge_error_codes.h"
#include "graph/def_types.h"
#include "register/kernel_registry.h"
#include "framework/common/debug/log.h"
#include "exe_graph/runtime/tensor.h"
#include "kernel/memory/mem_block.h"
#include "runtime/mem.h"
#include "common/checker.h"
#include "runtime/kernel.h"
#include "proto/task.pb.h"
#include "common/plugin/ge_make_unique_util.h"
#include "core/debug/kernel_tracing.h"
#include "exe_graph/runtime/gert_mem_allocator.h"
#include "kernel/kernel_log.h"
#include "kernel/memory/caching_mem_allocator.h"
#include "core/executor/multi_thread_topological/executor/schedule/producer/producers/kernel_tags/critical_section_config.h"
#include "core/utils/tensor_utils.h"
#include "exe_graph/runtime/gert_tensor_data.h"
#include "exe_graph/lowering/shape_utils.h"
using namespace ge;
namespace gert {
namespace kernel {
const int32_t kInputNum = 2;
const int32_t kDimNum = 2;
enum class BuildAddrFromTensorComputeInputs {
kTensor,
kAllocator,
};
enum class SplitToSequenceArgIndex {
KAxis,
KInputXShape,
KInputXAddr,
KSplitShapes,
KInnerSplitAddrs
};
enum class GetSplitVecArgIndex {
KAxis,
KInputNum,
KInputXShape,
KSplitShapes,
KSplitAddr,
KInnerSplitAddrs
};
struct CopyTensorArg {
void* src;
void* dst;
int32_t type_size;
};
const std::set<ge::DataType> kSupportedDataType = {
DT_UINT8,
DT_UINT16,
DT_UINT32,
DT_UINT64,
DT_INT8,
DT_INT16,
DT_INT32,
DT_INT64,
DT_FLOAT16,
DT_FLOAT,
DT_DOUBLE,
DT_BOOL,
DT_COMPLEX64,
DT_COMPLEX128
};
ge::graphStatus AllocSpecifiedMem(KernelContext *context) {
auto gert_allocator = context->GetInputValue<GertAllocator *>(0);
GE_ASSERT_NOTNULL(gert_allocator);
auto shape_sizes = context->GetInputPointer<TypedContinuousVector<uint64_t>>(1);
GE_ASSERT_NOTNULL(shape_sizes);
auto tensor_num = shape_sizes->GetSize();
auto tensor_data_addrs = ContinuousVector::Create<TensorData*>(tensor_num);
auto tensor_data_vec = reinterpret_cast<ContinuousVector*>(tensor_data_addrs.get());
GE_ASSERT_NOTNULL(tensor_data_vec);
tensor_data_vec->SetSize(tensor_num);
auto shape_sizes_data = shape_sizes->GetData();
auto tensor_data_info = static_cast<TensorData**>(tensor_data_vec->MutableData());
for (size_t i = 0; i < tensor_num; ++i) {
auto mem_block = reinterpret_cast<memory::MultiStreamMemBlock *>(gert_allocator->Malloc(shape_sizes_data[i]));
KERNEL_CHECK_NOTNULL(mem_block);
KERNEL_CHECK(mem_block->GetAddr() != nullptr, "malloc failed, tensor size=%zu", shape_sizes_data[i]);
tensor_data_info[i] = new (std::nothrow) TensorData();
GE_ASSERT_NOTNULL(tensor_data_info[i]);
*(tensor_data_info[i]) =
TensorUtils::ToTensorData(mem_block->GetMemBlock(), mem_block->GetSize(), gert_allocator->GetPlacement());
KERNEL_TRACE(TRACE_STR_ALLOC_MEM", index %zu, tensor data address %p", 0, mem_block, mem_block->GetAddr(),
shape_sizes_data[i], i, tensor_data_info[i]->GetAddr());
}
auto tensor_data_av = context->GetOutput(0);
GE_ASSERT_NOTNULL(tensor_data_av);
tensor_data_av->SetWithDefaultDeleter<uint8_t[]>(tensor_data_addrs.release());
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(AllocSpecifiedMem).RunFunc(AllocSpecifiedMem).ConcurrentCriticalSectionKey(kKernelUseMemory);
ge::graphStatus FreeSpecifiedMem(KernelContext *context) {
auto addrs = context->MutableInputPointer<TypedContinuousVector<TensorData*>>(0);
GE_ASSERT_NOTNULL(addrs);
for (size_t i = 0; i < addrs->GetSize(); ++i) {
if (addrs->MutableData()[i] != nullptr) {
KERNEL_TRACE("[MEM]Free memory, index %zu, address %p", i, addrs->MutableData()[i]->GetAddr());
delete addrs->MutableData()[i];
addrs->MutableData()[i] = nullptr;
}
}
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(FreeSpecifiedMem).RunFunc(FreeSpecifiedMem).ConcurrentCriticalSectionKey(kKernelUseMemory);
ge::graphStatus CalcTensorSizeFromSpecifiedShape(KernelContext *context) {
auto datatype = context->GetInputValue<ge::DataType>(0);
auto shapes = context->GetInputPointer<TypedContinuousVector<Shape>>(1);
auto addrs_av = context->GetOutput(0);
GE_ASSERT_NOTNULL(shapes);
GE_ASSERT_NOTNULL(addrs_av);
auto shape_num = shapes->GetSize();
auto shape_data = shapes->GetData();
auto shape_sizes_addrs = ContinuousVector::Create<uint64_t>(shape_num);
GE_ASSERT_NOTNULL(shape_sizes_addrs);
auto shapes_addrs_vec = reinterpret_cast<ContinuousVector*>(shape_sizes_addrs.get());
shapes_addrs_vec->SetSize(shape_num);
auto shape_addrs_data = static_cast<uint64_t*>(shapes_addrs_vec->MutableData());
for (size_t i = 0; i < shape_num; i++) {
GE_ASSERT_GRAPH_SUCCESS(CalcAlignedSizeByShape(shape_data[i], datatype, shape_addrs_data[i]));
}
addrs_av->SetWithDefaultDeleter<uint8_t[]>(shape_sizes_addrs.release());
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(CalcTensorSizeFromSpecifiedShape).RunFunc(CalcTensorSizeFromSpecifiedShape);
template <typename T>
ge::graphStatus BuildSplitVecWithInput(const gert::StorageShape &split_storeage_shape,
const gert::GertTensorData *split_addr, std::vector<int64_t> &split_vec, const int32_t input_x_axis_dim) {
gert::Shape split_shape = split_storeage_shape.GetStorageShape();
auto split_rank = split_shape.GetDimNum();
GE_ASSERT_TRUE((split_rank < kDimNum), "split must be scalar or 1D, but given %u", split_rank);
uint32_t split_num = 0;
if (split_rank == 0) {
T split_scalar;
GE_ASSERT_RT_OK(rtMemcpy(&split_scalar, sizeof(T), split_addr->GetAddr(), sizeof(T),
RT_MEMCPY_DEVICE_TO_HOST));
GE_ASSERT_TRUE((split_scalar > 0), "split scalar must be greater than 0,%d", split_scalar);
auto num_even_split = input_x_axis_dim / split_scalar;
auto num_remain = input_x_axis_dim % split_scalar;
split_num = num_even_split;
if (num_remain != 0) {
split_num += 1;
}
split_vec.resize(split_num);
std::fill(split_vec.begin(), split_vec.begin() + num_even_split, split_scalar);
std::fill(split_vec.begin() + num_even_split, split_vec.end(), num_remain);
} else if (split_rank == 1) {
split_num = split_shape.GetDim(0);
split_vec.resize(split_num);
auto temp_array = MakeUnique<T[]>(split_num);
GE_ASSERT_RT_OK(rtMemcpy(temp_array.get(), sizeof(T) * split_num,
split_addr->GetAddr(), sizeof(T) * split_num, RT_MEMCPY_DEVICE_TO_HOST));
auto split_sum = 0;
for (size_t i = 0; i < split_num; i++) {
GE_ASSERT_TRUE((temp_array[i] > 0), "split must be greater than 0,but get:%ld", temp_array[i]);
split_vec[i] = temp_array[i];
split_sum += split_vec[i];
}
GE_ASSERT_EQ(split_sum, input_x_axis_dim);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GetSplitVec(gert::KernelContext *context) {
auto axis = context->GetInputValue<int32_t>(static_cast<size_t>(GetSplitVecArgIndex::KAxis));
auto input_num = context->GetInputValue<uint32_t>(static_cast<size_t>(GetSplitVecArgIndex::KInputNum));
auto input_storage_shape = context->GetInputPointer<gert::StorageShape>(
static_cast<size_t>(GetSplitVecArgIndex::KInputXShape));
GE_ASSERT_NOTNULL(input_storage_shape);
gert::Shape input_x_shape = input_storage_shape->GetStorageShape();
auto input_tensor_rank = static_cast<int32_t>(input_x_shape.GetDimNum());
GE_ASSERT_TRUE((axis >= -input_tensor_rank) && (axis <= (input_tensor_rank - 1)),
"axis %d, is out of range:[-%d,%d].", axis, input_tensor_rank, input_tensor_rank - 1);
if (axis < 0) {
axis += input_tensor_rank;
}
const int64_t input_x_axis_dim = input_x_shape.GetDim(axis);
std::vector<int64_t> split_sizes;
auto extend_ctx = reinterpret_cast<gert::ExtendedKernelContext *>(context);
if (input_num == kInputNum) {
auto split_storage_shape = *context->GetInputPointer<gert::StorageShape>(
static_cast<size_t>(GetSplitVecArgIndex::KSplitShapes));
auto split_addr =
context->GetInputPointer<gert::GertTensorData>(static_cast<size_t>(GetSplitVecArgIndex::KSplitAddr));
auto split_desc = extend_ctx->GetInputDesc(1);
GE_ASSERT_NOTNULL(split_desc);
ge::DataType split_data_type = split_desc->GetDataType();
GE_ASSERT_TRUE((split_data_type == DT_INT32) || (split_data_type == DT_INT64),
"only support int32_t or int64_t for split, %u", split_data_type);
switch (split_data_type) {
case DT_INT32:
GE_ASSERT_GRAPH_SUCCESS(BuildSplitVecWithInput<int32_t>(split_storage_shape, split_addr,
split_sizes, input_x_axis_dim), "Build split vec err");
break;
case DT_INT64:
GE_ASSERT_GRAPH_SUCCESS(BuildSplitVecWithInput<int64_t>(split_storage_shape, split_addr,
split_sizes, input_x_axis_dim), "Build split vec err");
break;
default:
return GRAPH_FAILED;
}
} else {
split_sizes.resize(input_x_axis_dim);
std::fill(split_sizes.begin(), split_sizes.begin() + input_x_axis_dim, 1);
}
auto output_num = split_sizes.size();
std::vector<gert::Shape> shape_after_split(output_num, input_x_shape);
for (size_t i = 0; i < output_num; i++) {
shape_after_split[i].SetDim(axis, split_sizes[i]);
}
auto addrs_av = context->GetOutput(0);
GE_ASSERT_NOTNULL(addrs_av);
auto addrs = gert::ContinuousVector::Create<gert::Shape>(output_num);
auto addrs_vec = reinterpret_cast<gert::ContinuousVector*>(addrs.get());
addrs_vec->SetSize(output_num);
auto shapes_addrs = reinterpret_cast<gert::Shape*>(addrs_vec->MutableData());
for (size_t i = 0; i < output_num; i++) {
shapes_addrs[i] = shape_after_split[i];
}
addrs_av->SetWithDefaultDeleter<uint8_t[]>(addrs.release());
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(GetSplitVecWithInput).RunFunc(GetSplitVec);
int64_t CaculateSize(const gert::Shape &shape, size_t start, size_t end) {
int64_t size = 1;
if (end > shape.GetDimNum()) {
return size;
}
for (auto i = start; i < end; i++) {
size *= shape.GetDim(i);
}
return size;
}
int64_t SizeToSplitAxis(const gert::Shape &shape, int32_t axis) {
return CaculateSize(shape, 0, axis);
}
int64_t SizeFromSplitAxis(const gert::Shape &shape, int32_t axis) {
return CaculateSize(shape, axis, shape.GetDimNum());
}
ge::graphStatus CopyTensor(int64_t M, int64_t N, int lda, int64_t ldb, const CopyTensorArg cpy_arg) {
if (lda == N && ldb == N) {
GE_ASSERT_RT_OK(rtMemcpy(cpy_arg.dst, N * cpy_arg.type_size, cpy_arg.src,
N * cpy_arg.type_size, RT_MEMCPY_DEVICE_TO_DEVICE));
return ge::GRAPH_SUCCESS;
}
GELOGD("M:%ld, N:%ld, lda:%d, ldb:%ld, type_size:%d", M, N, lda, ldb, cpy_arg.type_size);
for (size_t i = 0; i < static_cast<size_t>(M); i++) {
GE_ASSERT_RT_OK(rtMemcpy(static_cast<uint8_t*>(cpy_arg.dst) + ldb * i * cpy_arg.type_size,
N * cpy_arg.type_size,
static_cast<uint8_t*>(cpy_arg.src) + lda * i * cpy_arg.type_size,
N * cpy_arg.type_size, RT_MEMCPY_DEVICE_TO_DEVICE));
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus SplitToSequenceComputeKernel(gert::KernelContext *context) {
auto input_storage_shape = context->GetInputPointer<gert::StorageShape>(
static_cast<size_t>(SplitToSequenceArgIndex::KInputXShape));
GE_ASSERT_NOTNULL(input_storage_shape);
auto input_addr = context->GetInputPointer<gert::GertTensorData>(
static_cast<size_t>(SplitToSequenceArgIndex::KInputXAddr));
GE_ASSERT_NOTNULL(input_addr);
auto extend_ctx = reinterpret_cast<gert::ExtendedKernelContext *>(context);
auto input_x_desc = extend_ctx->GetInputDesc(0);
GE_ASSERT_NOTNULL(input_x_desc);
ge::DataType input_data_type = input_x_desc->GetDataType();
GE_ASSERT_TRUE(kSupportedDataType.count(input_data_type) != 0,
"SplitToSequence kernel data type [%d] not support.", input_data_type);
int32_t type_size = GetSizeByDataType(input_data_type);
auto axis = context->GetInputValue<int32_t>(static_cast<size_t>(SplitToSequenceArgIndex::KAxis));
auto split_shapes = context->GetInputPointer<gert::TypedContinuousVector<gert::Shape>>(
static_cast<size_t>(SplitToSequenceArgIndex::KSplitShapes));
GE_ASSERT_NOTNULL(split_shapes);
gert::Shape input_x_shape = input_storage_shape->GetStorageShape();
auto input_x_rank = static_cast<int32_t>(input_x_shape.GetDimNum());
if (axis < 0) {
axis += input_x_rank;
}
auto before_dims = SizeToSplitAxis(input_x_shape, axis);
auto after_dims_include_axis = SizeFromSplitAxis(input_x_shape, axis);
auto after_dims_exclude_axis = ((axis + 1) == input_x_rank) ? 1 : SizeFromSplitAxis(input_x_shape, axis + 1);
auto inner_tensor_addrs = context->MutableInputPointer<gert::TypedContinuousVector<
gert::TensorData*>>(static_cast<size_t>(SplitToSequenceArgIndex::KInnerSplitAddrs));
auto split_num = split_shapes->GetSize();
auto split_shape_data = split_shapes->GetData();
int64_t input_offset = 0;
CopyTensorArg cpy_arg;
for (size_t i = 0; i < split_num; i++) {
gert::Shape each_shape_data = split_shape_data[i];
auto split_size = each_shape_data.GetDim(axis);
cpy_arg.src = reinterpret_cast<uint8_t*>(input_addr->GetAddr()) + input_offset;
GE_ASSERT_NOTNULL(inner_tensor_addrs->MutableData()[i]);
cpy_arg.dst = inner_tensor_addrs->MutableData()[i]->GetAddr();
cpy_arg.type_size = type_size;
GE_ASSERT_GRAPH_SUCCESS(CopyTensor(before_dims, split_size * after_dims_exclude_axis,
after_dims_include_axis, split_size * after_dims_exclude_axis,
cpy_arg), "copy tensor err.");
input_offset += split_size * after_dims_exclude_axis * type_size;
}
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(SplitToSequenceDoCompute).RunFunc(SplitToSequenceComputeKernel);
ge::graphStatus BuildInferShapeFromTensorDoCompute(gert::KernelContext *context) {
GELOGD("Enter BuildInferShapeFromTensorDoCompute");
auto input_tensor = context->GetInputPointer<gert::Tensor>(0);
GE_ASSERT_NOTNULL(input_tensor);
auto storage_shape = context->GetOutputPointer<StorageShape>(0);
GE_ASSERT_NOTNULL(storage_shape);
storage_shape->MutableStorageShape() = input_tensor->GetStorageShape();
storage_shape->MutableOriginShape() = input_tensor->GetStorageShape();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus BuildSingleInferShapeOutputs(const ge::FastNode *node, KernelContext *context) {
GELOGD("Enter BuildSingleInferShapeOutputs");
(void)node;
auto output0 = context->GetOutput(0);
GE_ASSERT_NOTNULL(output0);
auto extend_context = reinterpret_cast<ExtendedKernelContext *>(context);
GE_ASSERT_NOTNULL(extend_context);
auto output_desc = extend_context->GetOutputDesc(0UL);
GE_ASSERT_NOTNULL(output_desc);
auto shape_tensor = new (std::nothrow) Tensor(StorageShape(),
output_desc->GetFormat(), output_desc->GetDataType());
GE_ASSERT_NOTNULL(shape_tensor);
output0->SetWithDefaultDeleter(shape_tensor);
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(BuildInferShapeFromTensorCompute)
.RunFunc(BuildInferShapeFromTensorDoCompute)
.OutputsCreator(BuildSingleInferShapeOutputs);
ge::graphStatus BuildAddrFromTensorDoCompute(gert::KernelContext *context) {
GELOGD("Enter BuildAddrFromTensorDoCompute");
auto input_tensor =
context->GetInputPointer<gert::Tensor>(static_cast<size_t>(BuildAddrFromTensorComputeInputs::kTensor));
auto gert_allocator =
context->GetInputValue<GertAllocator *>(static_cast<size_t>(BuildAddrFromTensorComputeInputs::kAllocator));
GE_ASSERT_NOTNULL(input_tensor);
auto gtd = context->GetOutputPointer<GertTensorData>(0);
GE_ASSERT_NOTNULL(gtd);
return TensorUtils::ShareTdToGtd(input_tensor->GetTensorData(), *gert_allocator, *gtd);
}
ge::graphStatus AllocSingleOutputTensorData(const ge::FastNode *node, KernelContext *context) {
GELOGD("Enter AllocSingleOutputTensorData");
(void)node;
auto output0 = context->GetOutput(0);
GE_ASSERT_NOTNULL(output0);
auto tensor_data = new (std::nothrow) GertTensorData();
GE_ASSERT_NOTNULL(tensor_data);
output0->SetWithDefaultDeleter(tensor_data);
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(BuildAddrFromTensorCompute)
.RunFunc(BuildAddrFromTensorDoCompute)
.OutputsCreator(AllocSingleOutputTensorData)
.ConcurrentCriticalSectionKey(kKernelUseMemory);
ge::graphStatus GetSequenceHandleShapeKernel(KernelContext *context) {
auto output_storage_shape = context->GetOutputPointer<gert::StorageShape>(0);
GE_ASSERT_NOTNULL(output_storage_shape);
gert::Shape scalar_shape;
scalar_shape.SetScalar();
output_storage_shape->MutableStorageShape() = scalar_shape;
output_storage_shape->MutableOriginShape() = scalar_shape;
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(GetSequenceHandleShape)
.RunFunc(GetSequenceHandleShapeKernel)
.OutputsCreator(BuildSingleInferShapeOutputs);
}
}