* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/manager/trans_var_data_utils.h"
#include "common/math/math_util.h"
#include "formats/formats.h"
#include "base/err_mgr.h"
#include "formats/utils/formats_trans_utils.h"
#include "common/datatype_transfer/datatype_transfer.h"
#include "framework/common/op/ge_op_utils.h"
#include "framework/common/framework_types_internal.h"
#include "common/thread_pool/thread_pool.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "acl/acl_rt.h"
namespace ge {
namespace {
constexpr uint32_t kDefaultVarTransThreadNum = 16U;
class RtContextSwitchGuard {
public:
RtContextSwitchGuard(const uint32_t device_id) {
auto ret = aclrtGetCurrentContext(&last_);
if (ret != ACL_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Call aclrtGetCurrentContext failed, device_id:%u, ret:%d,",
device_id, ret);
GELOGE(RT_FAILED, "[Call][aclrtGetCurrentContext] Failed to get current context, device_id:%u, ret:%d",
device_id, ret);
return;
}
if (aclrtCreateContext(¤t_, static_cast<int32_t>(device_id)) != ACL_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Call aclrtCreateContext failed, device_id:%u", device_id);
GELOGE(RT_FAILED, "[Call][aclrtCreateContext] failed, device_id:%u", device_id);
return;
}
ret = aclrtSetCurrentContext(current_);
if (ret != ACL_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Call aclrtSetCurrentContext failed, device_id:%u, ret:%d", device_id, ret);
GELOGE(RT_FAILED, "[Call][aclrtSetCurrentContext] failed, device_id:%u, ret:%d", device_id, ret);
return;
}
GELOGD("Create and switch rt context %p for device %u, backup last %p.", current_, device_id, last_);
}
~RtContextSwitchGuard() {
try {
if (current_ != nullptr) {
const auto ret = aclrtDestroyContext(current_);
GELOGD("Destory current context %p result %d", current_, ret);
}
if (last_ != nullptr) {
const auto ret = aclrtSetCurrentContext(last_);
GELOGD("Recovery last context %p result %d.", last_, ret);
}
} catch (...) {
}
}
private:
aclrtContext last_ = nullptr;
aclrtContext current_ = nullptr;
};
int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) {
int64_t var_size = GetSizeByDataType(desc.GetDataType());
if (var_size <= 0) {
REPORT_INNER_ERR_MSG("E19999", "Data type:%s in desc, it's size:%" PRId64 " < 0, check invalid",
TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str(), var_size);
GELOGE(PARAM_INVALID, "[Calc][VarDataSize] by data type %s failed.",
TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str());
return -1;
}
const auto shape = desc.GetShape();
const auto dim_num = shape.GetDimNum();
for (size_t dim_index = 0U; dim_index < dim_num; ++dim_index) {
if (CheckInt64MulOverflow(var_size, shape.GetDim(dim_index)) != SUCCESS) {
return -1;
}
var_size *= shape.GetDim(dim_index);
}
return var_size;
}
Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_result, void *const var_addr) {
GE_CHECK_NOTNULL(var);
GELOGD("Copy var %s from host to device, size %zu", var->GetName().c_str(), trans_result.length);
const auto ret = aclrtMemcpy(var_addr, trans_result.length, PtrToPtr<uint8_t, void>(trans_result.data.get()),
trans_result.length, ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Call aclrtMemcpy failed, op:%s(%s), size:%" PRIu64 ", ret:%d,",
var->GetName().c_str(), var->GetType().c_str(), static_cast<uint64_t>(trans_result.length), ret);
GELOGE(RT_FAILED, "[Call][aclrtMemcpy] failed, op:%s(%s), size:%" PRIu64 ", ret:%d,", var->GetName().c_str(),
var->GetType().c_str(), trans_result.length, ret);
return RT_FAILED;
}
return SUCCESS;
}
Status CopyVarFromDevice(const uint64_t session_id, const NodePtr &var, std::unique_ptr<uint8_t[]> &var_data,
const GeTensorDesc &input_desc, const uint32_t device_id) {
uint8_t *var_logic = nullptr;
GE_CHECK_NOTNULL(var);
GE_CHECK_NOTNULL(VarManager::Instance(session_id));
const Status ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, var_logic);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Get][VarAddr] failed, node:%s, session_id:%" PRIu64 ", ret:%d.", var->GetName().c_str(),
session_id, ret);
return INTERNAL_ERROR;
}
uint8_t *const var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM, device_id);
if (var_addr == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Get variable memory addr failed, mem_type:%u, op:%s(%s), session_id:%" PRIu64 "",
RT_MEMORY_HBM, var->GetName().c_str(), var->GetType().c_str(), session_id);
GELOGE(INTERNAL_ERROR, "[Get][VarMemoryAddr] failed, mem_type:%u, op:%s(%s), session_id:%" PRIu64,
RT_MEMORY_HBM, var->GetName().c_str(), var->GetType().c_str(), session_id);
return INTERNAL_ERROR;
}
const int64_t var_size_bytes = CalcVarSizeInBytes(input_desc);
if (var_size_bytes <= 0) {
return INTERNAL_ERROR;
}
std::unique_ptr<uint8_t[]> var_host = MakeUnique<uint8_t[]>(static_cast<size_t>(var_size_bytes));
if (var_host == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "New host memory failed, size:%" PRId64 ", op:%s(%s), session_id:%" PRIu64 "",
var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id);
GELOGE(OUT_OF_MEMORY, "[New][Memory] for rt-host failed, size:%" PRId64 ", op:%s(%s), session_id:%" PRIu64,
var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id);
return OUT_OF_MEMORY;
}
const auto rt_rslt = aclrtMemcpy(PtrToPtr<uint8_t, void>(var_host.get()), static_cast<uint64_t>(var_size_bytes),
PtrToPtr<uint8_t, void>(var_addr), static_cast<uint64_t>(var_size_bytes), ACL_MEMCPY_DEVICE_TO_HOST);
if (rt_rslt != ACL_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Call aclrtMemcpy failed, size:%" PRId64 ", op:%s(%s), session_id:%" PRIu64 ", ret:%d",
var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, rt_rslt);
GELOGE(RT_FAILED, "[Call][aclrtMemcpy] failed, size:%" PRId64 ", op:%s(%s), session_id:%" PRIu64 ", ret:%d",
var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, rt_rslt);
return RT_FAILED;
}
GELOGD("Copy var %s from device to host, size %" PRId64, var->GetName().c_str(), var_size_bytes);
var_data.swap(var_host);
GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
return SUCCESS;
}
bool IsNoNeedTrans(const std::string &node_type) {
return (node_type == RESHAPE) || (node_type == REFORMAT) ||
(node_type == SQUEEZEV2) || (node_type == UNSQUEEZEV2);
}
Status TransVarOnHost(uint8_t *const var_data, const VarTransRoad &trans_road, formats::TransResult &result) {
formats::TransResult result_last_time{};
bool use_init_data = true;
for (const auto &trans_info : trans_road) {
if (IsNoNeedTrans(trans_info.node_type)) {
GELOGD("Skip to trans variable data on the reshape/reformat node");
continue;
}
uint8_t *src_data = nullptr;
if (use_init_data) {
src_data = var_data;
use_init_data = false;
} else {
src_data = result_last_time.data.get();
}
formats::TransResult tmp_result{};
if ((trans_info.node_type == TRANSDATA) || (trans_info.node_type == TRANSPOSED)) {
const auto src_format = trans_info.input.GetFormat();
const auto src_shape = trans_info.input.GetShape().GetDims();
const auto dst_format = trans_info.output.GetFormat();
const auto dst_shape = trans_info.output.GetShape().GetDims();
const auto data_type = trans_info.input.GetDataType();
const Format src_primary_format = static_cast<Format>(GetPrimaryFormat(src_format));
const Format dst_primary_format = static_cast<Format>(GetPrimaryFormat(dst_format));
const Format src_sub_format = static_cast<Format>(GetSubFormat(src_format));
const Format dst_sub_format = static_cast<Format>(GetSubFormat(dst_format));
const int64_t src_c0_format = GetC0Value(static_cast<int32_t>(src_format));
const int64_t dst_c0_format = GetC0Value(static_cast<int32_t>(dst_format));
GELOGD("Trans format from %s to %s, primary formats from %s to %s, src c0 is %" PRId64 ", dts c0 is %" PRId64 ", "
"shape %s to %s, data-type %s",
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
TypeUtils::FormatToSerialString(src_primary_format).c_str(),
TypeUtils::FormatToSerialString(dst_primary_format).c_str(), src_c0_format, dst_c0_format,
ToString(src_shape).c_str(), ToString(dst_shape).c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str());
const Status ret = formats::TransDataFormat(
{src_data, src_format, dst_format, src_primary_format, dst_primary_format, src_sub_format, dst_sub_format,
src_c0_format, dst_c0_format, src_shape, dst_shape, data_type},
tmp_result);
if (ret != SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Trans format from %s to %s, primary formats from %s to %s, src c0 is "
"[%" PRId64 "], dts c0 is [%" PRId64 "], shape %s to %s failed, "
"data type:%s, ret:%u,",
TypeUtils::FormatToSerialString(src_format).c_str(),
TypeUtils::FormatToSerialString(dst_format).c_str(),
TypeUtils::FormatToSerialString(src_primary_format).c_str(),
TypeUtils::FormatToSerialString(dst_primary_format).c_str(), src_c0_format, dst_c0_format,
ToString(src_shape).c_str(), ToString(dst_shape).c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str(), ret);
GELOGE(INTERNAL_ERROR, "[Trans][Format] from %s to %s, primary formats from %s to %s, shape %s to %s failed, "
"src c0 is %" PRId64 ", dts c0 is %" PRId64 ", data type %s error code %u",
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
TypeUtils::FormatToSerialString(src_primary_format).c_str(),
TypeUtils::FormatToSerialString(dst_primary_format).c_str(),
ToString(src_shape).c_str(), ToString(dst_shape).c_str(), src_c0_format, dst_c0_format,
TypeUtils::DataTypeToSerialString(data_type).c_str(), ret);
return ret;
}
} else if (trans_info.node_type == CAST) {
const auto input_shape = trans_info.input.GetShape();
const int64_t src_data_size = (input_shape.GetShapeSize() == 0) ? 1 : input_shape.GetShapeSize();
const auto src_data_type = trans_info.input.GetDataType();
const auto dst_data_type = trans_info.output.GetDataType();
GELOGD("Trans data type from %s to %s, input shape %s, data size %" PRId64,
TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), ToString(input_shape.GetDims()).c_str(),
src_data_size);
const Status ret = formats::TransTensorDataType(
{src_data, static_cast<size_t>(src_data_size), src_data_type, dst_data_type}, tmp_result);
if (ret != SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Trans data type from %s to %s failed, input shape %s, "
"data size %" PRId64 ", ret:%u",
TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
TypeUtils::DataTypeToSerialString(dst_data_type).c_str(),
ToString(input_shape.GetDims()).c_str(), src_data_size, ret);
GELOGE(INTERNAL_ERROR, "[Trans][DataType] from %s to %s failed, input shape %s, "
"data size %" PRId64 ", error code %u",
TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), ToString(input_shape.GetDims()).c_str(),
src_data_size, ret);
return ret;
}
} else {
REPORT_INNER_ERR_MSG("E19999", "Trans var data failed, the trans type %s does not supported, check invalid",
trans_info.node_type.c_str());
GELOGE(UNSUPPORTED, "[Trans][VarData] failed, the trans type %s does not supported",
trans_info.node_type.c_str());
return UNSUPPORTED;
}
result_last_time = tmp_result;
}
result = result_last_time;
return SUCCESS;
}
Status ReAssignVarAddr(const uint64_t session_id,
const std::string &var_name,
const GeTensorDesc &tensor_desc,
void *&var_device,
const uint32_t device_id) {
GE_CHECK_NOTNULL(VarManager::Instance(session_id));
uint8_t *var_logic = nullptr;
const Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, var_logic);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Get][VarAddr] failed, var name:%s, session_id:%" PRIu64 ", ret:%u",
var_name.c_str(), session_id, ret);
return INTERNAL_ERROR;
}
uint8_t *const var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM, device_id);
if (var_addr == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Get variable memory addr failed, mem_type:%u, var_name:%s, session_id:%" PRIu64 ",",
RT_MEMORY_HBM, var_name.c_str(), session_id);
GELOGE(INTERNAL_ERROR, "[Get][VarMemoryAddr] failed, mem_type:%u, var_name:%s, session_id:%" PRIu64,
RT_MEMORY_HBM, var_name.c_str(), session_id);
return INTERNAL_ERROR;
}
var_device = var_addr;
GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
return SUCCESS;
}
Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, const uint64_t session_id,
const uint32_t device_id) {
GE_CHECK_NOTNULL(var);
bool need_trans = false;
for (auto &road : trans_road) {
if (!IsNoNeedTrans(road.node_type)) {
need_trans = true;
break;
}
}
if (!need_trans) {
return SUCCESS;
}
std::unique_ptr<uint8_t[]> var_data;
if (trans_road.empty()) {
REPORT_INNER_ERR_MSG("E19999", "Param trans_road is empty, session_id:%" PRIu64 ", check invalid", session_id);
GELOGE(INTERNAL_ERROR, "[Check][Param] trans_road is empty, session_id:%" PRIu64, session_id);
return INTERNAL_ERROR;
}
const GeTensorDesc &input_desc = trans_road.begin()->input;
auto ret = CopyVarFromDevice(session_id, var, var_data, input_desc, device_id);
if (ret != SUCCESS) {
return ret;
}
formats::TransResult trans_result{};
ret = TransVarOnHost(var_data.get(), trans_road, trans_result);
if (ret != SUCCESS) {
GELOGE(ret, "[Call][TransVarOnHost] failed, session_id:%" PRIu64 ", ret:%u", session_id, ret);
return ret;
}
void *var_device = nullptr;
ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, var_device, device_id);
if (ret != SUCCESS) {
GELOGE(ret, "[Call][ReAssignVarAddr] failed, session id:%" PRIu64 ", op:%s, ret:%u",
session_id, var->GetName().c_str(), ret);
return ret;
}
ret = CopyVarToDevice(var, trans_result, var_device);
if (ret != SUCCESS) {
GELOGE(ret, "[Call][CopyVarToDevice] failed, var:%s, ret:%u", var->GetName().c_str(), ret);
return ret;
}
return SUCCESS;
}
Status TransTensor(const uint8_t *const var_data, const NodePtr &var_src, const NodePtr &var_dst,
formats::TransResult &result) {
GE_CHECK_NOTNULL(var_src);
GE_CHECK_NOTNULL(var_src->GetOpDesc());
GE_CHECK_NOTNULL(var_dst);
GE_CHECK_NOTNULL(var_dst->GetOpDesc());
const int64_t src_data_shape_size = var_src->GetOpDesc()->GetOutputDesc(0U).GetShape().GetShapeSize();
const auto src_data_datatype = var_src->GetOpDesc()->GetOutputDesc(0U).GetDataType();
const auto dst_data_datatype = var_dst->GetOpDesc()->GetOutputDesc(0U).GetDataType();
GE_IF_BOOL_EXEC(
src_data_datatype != dst_data_datatype,
const Status ret = formats::TransTensorDataType(
{var_data, static_cast<size_t>(src_data_shape_size), src_data_datatype, dst_data_datatype}, result);
if (ret != SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Trans data type from %s to %s failed, data size %" PRId64 ", ret:%u",
TypeUtils::DataTypeToSerialString(src_data_datatype).c_str(),
TypeUtils::DataTypeToSerialString(dst_data_datatype).c_str(), src_data_shape_size, ret);
GELOGE(INTERNAL_ERROR, "[Trans][DataType] from %s to %s failed, data size %" PRId64 ", ret:%u",
TypeUtils::DataTypeToSerialString(src_data_datatype).c_str(),
TypeUtils::DataTypeToSerialString(dst_data_datatype).c_str(), src_data_shape_size, ret);
return ret;
});
return SUCCESS;
}
Status CopyTensorFromSrcVarNode(const NodePtr &var_src,
const NodePtr &var_dst,
const uint64_t session_id,
const uint32_t device_id) {
GE_IF_BOOL_EXEC((var_src == nullptr) || (var_dst == nullptr) ||
(var_src->GetOpDesc() == nullptr) || (var_dst->GetOpDesc() == nullptr),
REPORT_INNER_ERR_MSG("E19999", "Param var_src or var_dst is nullptr, session_id:"
"%" PRIu64 ", device_id:%u, check invalid", session_id, device_id);
GELOGE(FAILED, "[Check][Param] Param var_src or var_dst is nullptr, "
"session_id:%" PRIu64 ", device_id:%u", session_id, device_id);
return FAILED);
const GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0U);
const DataType src_data_type = output_desc.GetDataType();
const GeShape src_shape = output_desc.GetShape();
const Format src_format = output_desc.GetFormat();
GELOGI("src_node %s, src_format %s, src_shape %s, src_type %s.", var_src->GetName().c_str(),
TypeUtils::FormatToSerialString(src_format).c_str(), ToString(src_shape.GetDims()).c_str(),
TypeUtils::DataTypeToSerialString(src_data_type).c_str());
const GeTensorDesc dst_tensor_desc = var_dst->GetOpDesc()->GetOutputDesc(0U);
const DataType data_type = dst_tensor_desc.GetDataType();
const GeShape data_shape = dst_tensor_desc.GetShape();
const Format data_format = dst_tensor_desc.GetFormat();
GELOGI("dst_node %s, src_format %s, src_shape %s, src_type %s.", var_dst->GetName().c_str(),
TypeUtils::FormatToSerialString(data_format).c_str(), ToString(data_shape.GetDims()).c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str());
std::unique_ptr<uint8_t[]> var_src_data;
const RtContextSwitchGuard switch_context(device_id);
auto ret = CopyVarFromDevice(session_id, var_src, var_src_data, output_desc, device_id);
GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(FAILED, "[Call][CopyVarFromDevice] failed, session id:%" PRIu64 ", var_src:%s",
session_id, var_src->GetName().c_str());
return ret);
formats::TransResult trans_result{};
ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result);
GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(INTERNAL_ERROR, "[Trans][Tensor] failed, var_src:%s, var_dst:%s",
var_src->GetName().c_str(), var_dst->GetName().c_str());
return ret);
void *var_device = nullptr;
ret = ReAssignVarAddr(session_id, var_dst->GetName(), dst_tensor_desc, var_device, device_id);
GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(INTERNAL_ERROR, "[Call][ReAssignVarAddr] failed, session id:%" PRIu64 ", var_dst:%s",
session_id, var_dst->GetName().c_str());
return ret);
ret = CopyVarToDevice(var_dst, trans_result, var_device);
GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(ret, "[Call][CopyVarToDevice] failed, var_dst:%s, ret:%u",
var_dst->GetName().c_str(), ret);
return ret);
return SUCCESS;
}
}
Status TransVarDataUtils::TransAllVarData(const std::vector<NodePtr> &variable_nodes,
const uint64_t session_id,
const uint32_t graph_id,
const uint32_t device_id) {
if (variable_nodes.empty()) {
GELOGI("No vars need trans, just return.");
return SUCCESS;
}
aclrtContext context = nullptr;
GE_CHK_RT_RET(aclrtGetCurrentContext(&context));
ThreadPool executor("ge_vartrans", kDefaultVarTransThreadNum, true);
std::vector<std::future<Status>> vector_future;
for (auto &node : variable_nodes) {
if (node == nullptr) {
continue;
}
if (node->GetType() != VARIABLE) {
continue;
}
auto const trans_func = [](const NodePtr &inner_node, const uint64_t inner_session_id, aclrtContext const ctx,
const uint32_t inner_graph_id, const uint32_t inner_device_id,
const error_message::ErrorManagerContext &error_context) -> Status {
error_message::SetErrMgrContext(error_context);
const auto rt_ret = aclrtSetCurrentContext(ctx);
if (rt_ret != ACL_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Call aclrtSetCurrentContext failed, session_id:%" PRIu64 ", graph_id:%u, ret:%d.",
inner_session_id, inner_graph_id, rt_ret);
GELOGE(RT_FAILED, "[Call][aclrtSetCurrentContext] failed, session_id:%" PRIu64 ", graph_id:%u, ret:%d.",
inner_session_id, inner_graph_id, rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}
uint32_t allocated_graph_id = 0U;
GE_CHECK_NOTNULL(VarManager::Instance(inner_session_id));
Status ret = VarManager::Instance(inner_session_id)->GetAllocatedGraphId(inner_node->GetName(),
allocated_graph_id);
if (ret != SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Get allocated GraphId failed, session_id:%" PRIu64 ", graph_id:%u, ret:0x%X.",
inner_session_id, inner_graph_id, ret);
GELOGE(INTERNAL_ERROR, "[Get][AllocatedGraphId] failed, node:%s, graph_id:%u.", inner_node->GetName().c_str(),
inner_graph_id);
return INTERNAL_ERROR;
}
uint32_t changed_graph_id = 0U;
ret = VarManager::Instance(inner_session_id)->GetChangedGraphId(inner_node->GetName(), changed_graph_id);
const bool call_trans_var = ((ret == SUCCESS) && (changed_graph_id == inner_graph_id) &&
(changed_graph_id != allocated_graph_id));
if (call_trans_var) {
GELOGI("VarManager::GetChangedGraphId() success, node:%s, graph_id:%u.", inner_node->GetName().c_str(),
inner_graph_id);
VarTransRoad *const trans_road = VarManager::Instance(inner_session_id)->GetTransRoad(inner_node->GetName());
if (trans_road == nullptr) {
GELOGI("The variable %s does not have any trans road", inner_node->GetName().c_str());
return SUCCESS;
}
ret = TransVarData(inner_node, *trans_road, inner_session_id, inner_device_id);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Trans][VarData] failed, node:%s, graph_id:%u, session_id:%" PRIu64 ".",
inner_node->GetName().c_str(), inner_graph_id, inner_session_id);
return INTERNAL_ERROR;
}
VarManager::Instance(inner_session_id)->RemoveChangedGraphId(inner_node->GetName());
}
return SUCCESS;
};
std::future<Status> f = executor.commit(trans_func, node, session_id, context, graph_id, device_id,
error_message::GetErrMgrContext());
if (!f.valid()) {
GELOGE(FAILED, "[Check][Param] Future is invalid, session id:%" PRIu64 ", graph id:%u", session_id, graph_id);
return FAILED;
}
vector_future.push_back(std::move(f));
}
Status ret_status;
for (size_t i = 0U; i < vector_future.size(); ++i) {
ret_status = vector_future[i].get();
if (ret_status != SUCCESS) {
GELOGE(ret_status, "[Check][Param] trans %zu vardata failed.", i);
return ret_status;
}
}
return SUCCESS;
}
Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, const std::vector<NodePtr> &variable_nodes,
const uint64_t session_id, const uint32_t device_id) {
GELOGD("CopyVarData start: session_id:%" PRIu64 ".", session_id);
if (compute_graph == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Param compute_graph is nullptr, session_id:%" PRIu64 ", device_id:%u, check invalid.",
session_id, device_id);
GELOGE(FAILED, "[Check][Param] compute_graph is nullptr, session_id:%" PRIu64 ", device_id:%u.",
session_id, device_id);
return FAILED;
}
bool copy_value = false;
for (const auto &node : variable_nodes) {
GE_IF_BOOL_EXEC(((node->GetOpDesc() == nullptr) || (node->GetOpDesc()->GetType() != VARIABLE)), continue);
auto cp_from_node = ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node");
if ((cp_from_node != nullptr) && (cp_from_node->length() != 0U)) {
GELOGI("Get original type of cp_from_node");
(void)ge::AttrUtils::GetBool(node->GetOpDesc(), "_copy_value", copy_value);
if (!copy_value) {
const auto src_node = compute_graph->FindNode(*cp_from_node);
GE_CHECK_NOTNULL(src_node);
GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(),
src_node->GetName().c_str());
const auto ret = CopyTensorFromSrcVarNode(src_node, node, session_id, device_id);
GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(FAILED, "[Copy][Tensor] failed, src_node:%s, node:%s, session_id:%" PRIu64 ", "
"device_id:%u", src_node->GetName().c_str(), node->GetName().c_str(), session_id, device_id);
return FAILED);
(void)ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true);
}
}
}
return SUCCESS;
}
}