* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "kernel/memory/sink_weight_assigner.h"
#include "kernel/memory/caching_mem_allocator.h"
#include "kernel/kernel_log.h"
#include "exe_graph/runtime/kernel_context.h"
#include "exe_graph/runtime/tensor.h"
#include "graph/debug/ge_attr_define.h"
#include "common/checker.h"
#include "register/kernel_registry.h"
#include "runtime/model_v2_executor.h"
#include "exe_graph/runtime/gert_tensor_data.h"
#include "core/executor/multi_thread_topological/executor/schedule/producer/producers/kernel_tags/critical_section_config.h"
namespace gert {
namespace kernel {
ge::graphStatus CreateMemAssigner(KernelContext *context) {
auto assinger_chain = context->GetOutput(0);
GE_ASSERT_NOTNULL(assinger_chain);
auto mem_base =
context->GetInputPointer<gert::GertTensorData>(static_cast<size_t>(CreateMemAssignerInputs::kBaseMem));
auto assinger = new (std::nothrow) memory::SinkWeightAssigner(mem_base->GetAddr(), mem_base->GetSize());
GE_ASSERT_NOTNULL(assinger);
assinger_chain->SetWithDefaultDeleter(assinger);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AssignWeightMemory(KernelContext *context) {
auto tensor = context->GetInputPointer<TensorData>(static_cast<size_t>(AssignWeightMemoryInputs::kOffsetTensor));
auto assigner = context->MutableInputPointer<memory::SinkWeightAssigner>(
static_cast<size_t>(AssignWeightMemoryInputs::kAssigner));
auto device_gtd =
context->GetOutputPointer<GertTensorData>(static_cast<size_t>(AssignWeightMemoryOutputs::kTensorData));
if (tensor == nullptr || assigner == nullptr || device_gtd == nullptr) {
GELOGE(ge::GRAPH_FAILED, "input or output is invalid.");
return ge::GRAPH_FAILED;
}
void *device_addr = assigner->Assign(tensor->GetSize(), ge::PtrToValue(tensor->GetAddr()));
if (device_addr == nullptr) {
GELOGE(ge::GRAPH_FAILED, "assign device addr failed.");
return ge::GRAPH_FAILED;
}
auto stream_id = context->GetInputPointer<int64_t>(static_cast<size_t>(AssignWeightMemoryInputs::kStreamId));
GE_ASSERT_NOTNULL(stream_id);
auto placement = device_gtd->GetPlacement();
auto size = tensor->GetSize();
*device_gtd = GertTensorData{device_addr, size, placement, *stream_id};
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AssignMemoryOutputCreator(const ge::FastNode *node, KernelContext *context) {
(void)node;
auto input_gtd = context->GetInputPointer<GertTensorData>(
static_cast<size_t>(AssignWeightMemoryInputs::kOffsetTensor));
GE_ASSERT_NOTNULL(input_gtd);
auto chain = context->GetOutput(static_cast<size_t>(AssignWeightMemoryOutputs::kTensorData));
GE_ASSERT_NOTNULL(chain);
auto out_gtd = new (std::nothrow) GertTensorData();
GE_ASSERT_NOTNULL(out_gtd);
*out_gtd = GertTensorData{input_gtd->GetAddr(), input_gtd->GetSize(), input_gtd->GetPlacement(), -1};
chain->SetWithDefaultDeleter(out_gtd);
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(CreateMemAssigner).RunFunc(CreateMemAssigner);
REGISTER_KERNEL(AssignWeightMemory).RunFunc(AssignWeightMemory).OutputsCreator(AssignMemoryOutputCreator);
ge::graphStatus GetOrCreateWeightMem(KernelContext *context) {
auto required_size = context->GetInputValue<size_t>(static_cast<size_t>(GetOrCreateWeightInputs::kRequiredSize));
auto appointed_weight =
context->GetInputPointer<OuterWeightMem>(static_cast<size_t>(GetOrCreateWeightInputs::kAppointedWeight));
auto device_mem =
context->GetOutputPointer<GertTensorData>(static_cast<size_t>(GetOrCreateWeightOutputs::kTensorData));
auto gert_allocator =
context->GetInputValue<GertAllocator *>(static_cast<size_t>(GetOrCreateWeightInputs::kAllocator));
GE_ASSERT_NOTNULL(gert_allocator);
if ((appointed_weight->weight_ptr == nullptr) || (appointed_weight->weight_size < required_size)) {
GELOGI("appointed weight info is invalid, given_size[%zu], required_size[%zu], need to alloc inner.",
appointed_weight->weight_size, required_size);
auto block = gert_allocator->Malloc(required_size);
KERNEL_CHECK_NOTNULL(block);
KERNEL_CHECK(block->GetAddr() != nullptr, "malloc failed, tensor size[%zu]", required_size);
device_mem->ShareFrom({required_size, device_mem->GetPlacement(), gert_allocator->GetStreamId(), block});
} else {
*device_mem = GertTensorData{const_cast<void *>(appointed_weight->weight_ptr), appointed_weight->weight_size,
device_mem->GetPlacement(), gert_allocator->GetStreamId()};
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GetOrCreateWeightOutputCreator(const ge::FastNode *node, KernelContext *context) {
(void)node;
auto gtd = new (std::nothrow) GertTensorData(0U, kOnDeviceHbm, -1, nullptr);
GE_ASSERT_NOTNULL(gtd);
auto chain = context->GetOutput(static_cast<size_t>(GetOrCreateWeightOutputs::kTensorData));
GE_ASSERT_NOTNULL(chain);
chain->SetWithDefaultDeleter(gtd);
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(GetOrCreateWeightMem)
.RunFunc(GetOrCreateWeightMem)
.OutputsCreator(GetOrCreateWeightOutputCreator)
.ConcurrentCriticalSectionKey(kKernelUseMemory);
}
}