* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph_builder/assign_device_mem.h"
#include "graph_builder/converter_checker.h"
#include "graph/utils/tensor_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "exe_graph/lowering/frame_selector.h"
#include "common/plugin/ge_make_unique_util.h"
#include "lowering/placement/placed_lowering_register.h"
#include "runtime/model_v2_executor.h"
#include "kernel/memory/sink_weight_assigner.h"
#include "kernel/common_kernel_impl/memory_copy.h"
namespace gert {
namespace {
constexpr uint32_t kFlattenOffsetKey = 0U;
constexpr uint32_t kFlattenSizeKey = 1U;
void PrepareForWeightSink(LoweringGlobalData &global_data) {
auto require_weight_size = global_data.GetModelWeightSize();
GELOGD("RequireSize[%zu]", require_weight_size);
auto required_mem = [&require_weight_size]() -> std::vector<bg::ValueHolderPtr> {
return bg::FrameSelector::OnInitRoot([&require_weight_size]() -> std::vector<bg::ValueHolderPtr> {
auto required_size_holder = bg::ValueHolder::CreateConst(
&require_weight_size, sizeof(require_weight_size));
return {required_size_holder};
});
};
(void)global_data.GetOrCreateUniqueValueHolder("RequireSize", required_mem);
}
}
bg::ValueHolderPtr AssignDeviceMem::GetBaseWeightAddr(LoweringGlobalData &global_data) {
PrepareForWeightSink(global_data);
auto weight_mem = [&global_data]() -> std::vector<bg::ValueHolderPtr> {
return bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector<bg::ValueHolderPtr> {
auto required_size_holder = bg::HolderOnInit(global_data.GetUniqueValueHolder("RequireSize"));
auto outer_weight_holder = bg::HolderOnInit(global_data.GetUniqueValueHolder("OuterWeightMem"));
auto allocator_holder = global_data.GetOrCreateAllocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeWorkspace});
return {bg::ValueHolder::CreateSingleDataOutput("GetOrCreateWeightMem",
{required_size_holder, outer_weight_holder, allocator_holder})};
});
};
return global_data.GetOrCreateUniqueValueHolder("BaseWeightAddr", weight_mem)[0];
}
bg::ValueHolderPtr AssignDeviceMem::GetOrCreateMemAssigner(LoweringGlobalData &global_data) {
auto mem_assinger_holder_creator = [&global_data]() -> std::vector<bg::ValueHolderPtr> {
return bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector<bg::ValueHolderPtr> {
auto device_mem_holder = bg::HolderOnInit(GetBaseWeightAddr(global_data));
return {bg::ValueHolder::CreateSingleDataOutput("CreateMemAssigner", {device_mem_holder})};
});
};
return global_data.GetOrCreateUniqueValueHolder("CreateMemAssigner", mem_assinger_holder_creator)[0];
}
┌───────────┐ ┌──────────────┐
│RequireSize│ │OuterWeightMem│
└───────────┴───┬────┴──────────────┘
│
┌────────▼───────────┐
│GetOrCreateWeightMem│
└────────┬───────────┘
│
┌─────┐ ┌──────▼──────────┐ ┌─────┐
│Const│ │CreateMemAssigner│ │Const│
└─────┴───┬──┴─────────────────┘ └──┬──┘
│ │
┌──────▼───────────┐ ┌─────▼──────┐
│AssignWeightMemory│ │SpliteTensor│
└──────────────────┴─────┬───────┴────────────┘
│
│
┌───────▼──────┐
│SinkWeightData│
└──────────────┘
*/
ge::graphStatus SinkWeightData(const ge::Node* const_node, LoweringGlobalData &global_data, ge::DataType data_type,
const OutputLowerResult &src, int64_t dst_logic_stream_id, OutputLowerResult &dst) {
(void)data_type;
(void)dst_logic_stream_id;
const auto op_desc = const_node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
ge::ConstGeTensorPtr tensor;
if (!ge::AttrUtils::GetTensor(op_desc, "value", tensor)) {
GELOGE(ge::PARAM_INVALID, "Failed to get tensor value from const %s", const_node->GetName().c_str());
return ge::PARAM_INVALID;
}
std::vector<int64_t> flatten_weight = GetFlattenOffsetInfo(&(tensor->GetTensorDesc()));
GE_ASSERT_TRUE(!flatten_weight.empty());
TensorData weight_offset;
weight_offset.SetAddr(ge::ValueToPtr(flatten_weight[kFlattenOffsetKey]), nullptr);
weight_offset.SetSize(static_cast<size_t>(flatten_weight[kFlattenSizeKey]));
GELOGI("get const[%s], offset[%ld],size[%ld]", const_node->GetNamePtr(),
flatten_weight[kFlattenOffsetKey], flatten_weight[kFlattenSizeKey]);
auto weight_offset_holder = bg::ValueHolder::CreateConst(&weight_offset, sizeof(TensorData));
auto get_or_create = AssignDeviceMem::GetOrCreateMemAssigner(global_data);
const int64_t logic_stream_id = op_desc->GetStreamId();
auto stream_id_holder = bg::ValueHolder::CreateConst(&logic_stream_id, sizeof(logic_stream_id));
GE_ASSERT_TRUE(IsValidHolder(stream_id_holder));
auto assign_mem_holder = bg::ValueHolder::CreateSingleDataOutput(
kernel::kAssignWeightMemory, {weight_offset_holder, bg::HolderOnInit(get_or_create), stream_id_holder});
auto sink_weight_data = bg::DevMemValueHolder::CreateSingleDataOutput(
"SinkWeightData", {src.address, assign_mem_holder, stream_id_holder,
global_data.GetStreamById(logic_stream_id)}, logic_stream_id);
GE_ASSERT_NOTNULL(sink_weight_data);
for (const auto &src_ordered_holder : src.order_holders) {
bg::ValueHolder::AddDependency(src_ordered_holder, sink_weight_data);
}
dst.shape = src.shape;
dst.address = sink_weight_data;
dst.address->SetPlacement(kOnDeviceHbm);
dst.has_init = true;
return ge::GRAPH_SUCCESS;
}
REGISTER_OUTPUT_LOWER(Const, SinkWeightData);
}