* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "list_value_serializer.h"
#include <vector>
#include <string>
#include <functional>
#include "framework/common/debug/ge_log.h"
#include "graph/types.h"
#include "graph/utils/attr_utils.h"
#include "graph_metadef/graph/debug/ge_util.h"
#include "tensor_desc_serializer.h"
#include "tensor_serializer.h"
#include "named_attrs_serializer.h"
#include "graph_serializer.h"
#include "graph/ge_tensor.h"
#include "graph/def_types.h"
namespace ge {
using ComputeGraphPtr = std::shared_ptr<ComputeGraph>;
using GeTensorPtr = std::shared_ptr<GeTensor>;
using ListValue = proto::AttrDef::ListValue;
using std::placeholders::_1;
using std::placeholders::_2;
graphStatus ListValueSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) {
const static std::map<AnyValue::ValueType, std::function<graphStatus(const AnyValue &, proto::AttrDef &)>>
type_serializer_map = {
{AnyValue::VT_LIST_INT, std::bind(&ListValueSerializer::SerializeListInt, _1, _2)},
{AnyValue::VT_LIST_FLOAT, std::bind(&ListValueSerializer::SerializeListFloat, _1, _2)},
{AnyValue::VT_LIST_BOOL, std::bind(&ListValueSerializer::SerializeListBool, _1, _2)},
{AnyValue::VT_LIST_BYTES, std::bind(&ListValueSerializer::SerializeListBuffer, _1, _2)},
{AnyValue::VT_LIST_DATA_TYPE, std::bind(&ListValueSerializer::SerializeListDataType, _1, _2)},
{AnyValue::VT_LIST_STRING, std::bind(&ListValueSerializer::SerializeListString, _1, _2)},
{AnyValue::VT_LIST_NAMED_ATTRS, std::bind(&ListValueSerializer::SerializeListNamedAttrs, _1, _2)},
{AnyValue::VT_LIST_TENSOR_DESC, std::bind(&ListValueSerializer::SerializeListGeTensorDesc, _1, _2)},
{AnyValue::VT_LIST_TENSOR, std::bind(&ListValueSerializer::SerializeListGeTensor, _1, _2)},
{AnyValue::VT_LIST_GRAPH, std::bind(&ListValueSerializer::SerializeListGraphDef, _1, _2)},
};
const auto iter = type_serializer_map.find(av.GetValueType());
if (iter == type_serializer_map.end()) {
GELOGE(GRAPH_FAILED, "Value type [%d] not support.", static_cast<int32_t>(av.GetValueType()));
return GRAPH_FAILED;
}
return iter->second(av, def);
}
graphStatus ListValueSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) {
const static std::map<ListValue::ListValueType, std::function<graphStatus(const proto::AttrDef &def, AnyValue &av)>>
type_deserializer_map = {
{ListValue::VT_LIST_INT, std::bind(&ListValueSerializer::DeserializeListInt, _1, _2)},
{ListValue::VT_LIST_FLOAT, std::bind(&ListValueSerializer::DeserializeListFloat, _1, _2)},
{ListValue::VT_LIST_STRING, std::bind(&ListValueSerializer::DeserializeListString, _1, _2)},
{ListValue::VT_LIST_BYTES, std::bind(&ListValueSerializer::DeserializeListBuffer, _1, _2)},
{ListValue::VT_LIST_BOOL, std::bind(&ListValueSerializer::DeserializeListBool, _1, _2)},
{ListValue::VT_LIST_DATA_TYPE, std::bind(&ListValueSerializer::DeserializeListDataType, _1, _2)},
{ListValue::VT_LIST_NAMED_ATTRS, std::bind(&ListValueSerializer::DeserializeListNamedAttrs, _1, _2)},
{ListValue::VT_LIST_TENSOR_DESC, std::bind(&ListValueSerializer::DeserializeListGeTensorDesc, _1, _2)},
{ListValue::VT_LIST_TENSOR, std::bind(&ListValueSerializer::DeserializeListGeTensor, _1, _2)},
{ListValue::VT_LIST_GRAPH, std::bind(&ListValueSerializer::DeserializeListGraphDef, _1, _2)},
};
const auto iter = type_deserializer_map.find(def.list().val_type());
if (iter == type_deserializer_map.end()) {
GELOGE(GRAPH_FAILED, "Value type [%d] not support.", static_cast<int32_t>(def.list().val_type()));
return GRAPH_FAILED;
}
return iter->second(def, av);
}
graphStatus ListValueSerializer::SerializeListInt(const AnyValue &av, proto::AttrDef &def) {
std::vector<int64_t> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_int attr.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_i();
for (const auto val : list_val) {
mutable_list->add_i(val);
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_INT);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::SerializeListString(const AnyValue &av, proto::AttrDef &def) {
std::vector<std::string> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_string attr.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_s();
for (const auto &val : list_val) {
mutable_list->add_s(val);
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_STRING);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::SerializeListFloat(const AnyValue &av, proto::AttrDef &def) {
std::vector<float> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_float attr.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_f();
for (const auto val : list_val) {
mutable_list->add_f(val);
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_FLOAT);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::SerializeListBool(const AnyValue &av, proto::AttrDef &def) {
std::vector<bool> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_bool attr.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_b();
for (const auto val : list_val) {
mutable_list->add_b(val);
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_BOOL);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::SerializeListGeTensorDesc(const AnyValue &av, proto::AttrDef &def) {
std::vector<ge::GeTensorDesc> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_tensor_desc attr.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_td();
for (const auto &val : list_val) {
const auto attr_proto = mutable_list->add_td();
GE_CHECK_NOTNULL(attr_proto);
GeTensorSerializeUtils::GeTensorDescAsProto(val, attr_proto);
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_TENSOR_DESC);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::SerializeListGeTensor(const AnyValue &av, proto::AttrDef &def) {
std::vector<GeTensor> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_tensor attr_value.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_t();
for (const auto &val : list_val) {
const auto attr_proto = mutable_list->add_t();
GE_CHECK_NOTNULL(attr_proto);
GeTensorSerializeUtils::GeTensorAsProto(val, attr_proto);
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_TENSOR);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::SerializeListBuffer(const AnyValue &av, proto::AttrDef &def) {
std::vector<Buffer> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_buffer attr.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_bt();
for (auto val : list_val) {
if ((val.GetData() != nullptr) && (val.size() > 0U)) {
mutable_list->add_bt(val.GetData(), val.GetSize());
}
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_BYTES);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::SerializeListGraphDef(const AnyValue &av, proto::AttrDef &def) {
std::vector<proto::GraphDef> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_graph attr_value.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_g();
for (const auto &graph : list_val) {
const auto mutable_graph = mutable_list->add_g();
GE_CHECK_NOTNULL(mutable_graph);
*mutable_graph = graph;
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_GRAPH);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::SerializeListNamedAttrs(const AnyValue &av, proto::AttrDef &def) {
std::vector<ge::NamedAttrs> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_named_attr attr.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_na();
const auto attr_serializer = AttrSerializerRegistry::GetInstance().GetSerializer(GetTypeId<ge::NamedAttrs>());
const auto named_attr_serializer = dynamic_cast<NamedAttrsSerializer *>(attr_serializer);
GE_CHECK_NOTNULL(named_attr_serializer);
for (const auto &val : list_val) {
const auto attr_proto = mutable_list->add_na();
GE_CHECK_NOTNULL(attr_proto);
if (named_attr_serializer->Serialize(val, attr_proto) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "NamedAttr [%s] serialize failed.", val.GetName().c_str());
return GRAPH_FAILED;
}
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_NAMED_ATTRS);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::SerializeListDataType(const AnyValue &av, proto::AttrDef &def) {
std::vector<ge::DataType> list_val;
const graphStatus ret = av.GetValue(list_val);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "Failed to get list_datatype attr.");
return GRAPH_FAILED;
}
const auto mutable_list = def.mutable_list();
GE_CHECK_NOTNULL(mutable_list);
mutable_list->clear_dt();
for (const auto val : list_val) {
mutable_list->add_dt(static_cast<proto::DataType>(val));
}
mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_DATA_TYPE);
return GRAPH_SUCCESS;
}
graphStatus ListValueSerializer::DeserializeListInt(const proto::AttrDef &def, AnyValue &av) {
std::vector<int64_t> values(static_cast<size_t>(def.list().i_size()));
for (auto idx = 0; idx < def.list().i_size(); ++idx) {
values[static_cast<size_t>(idx)] = def.list().i(idx);
}
return av.SetValue(std::move(values));
}
graphStatus ListValueSerializer::DeserializeListString(const proto::AttrDef &def, AnyValue &av) {
std::vector<std::string> values(static_cast<size_t>(def.list().s_size()));
for (auto idx = 0; idx < def.list().s_size(); ++idx) {
values[static_cast<size_t>(idx)] = def.list().s(idx);
}
return av.SetValue(std::move(values));
}
graphStatus ListValueSerializer::DeserializeListFloat(const proto::AttrDef &def, AnyValue &av) {
std::vector<float> values(static_cast<size_t>(def.list().f_size()));
for (auto idx = 0; idx < def.list().f_size(); ++idx) {
values[static_cast<size_t>(idx)] = def.list().f(idx);
}
return av.SetValue(std::move(values));
}
graphStatus ListValueSerializer::DeserializeListBool(const proto::AttrDef &def, AnyValue &av) {
std::vector<bool> values(static_cast<size_t>(def.list().b_size()));
for (auto idx = 0; idx < def.list().b_size(); ++idx) {
values[static_cast<size_t>(idx)] = def.list().b(idx);
}
return av.SetValue(std::move(values));
}
graphStatus ListValueSerializer::DeserializeListGeTensorDesc(const proto::AttrDef &def, AnyValue &av) {
std::vector<ge::GeTensorDesc> values(static_cast<size_t>(def.list().td_size()));
for (auto idx = 0; idx < def.list().td_size(); ++idx) {
GeTensorSerializeUtils::AssembleGeTensorDescFromProto(&def.list().td(idx), values[static_cast<size_t>(idx)]);
}
return av.SetValue(std::move(values));
}
graphStatus ListValueSerializer::DeserializeListGeTensor(const proto::AttrDef &def, AnyValue &av) {
std::vector<GeTensor> values(static_cast<size_t>(def.list().t_size()));
for (auto idx = 0; idx < def.list().t_size(); ++idx) {
GeTensorSerializeUtils::AssembleGeTensorFromProto(&def.list().t(idx), values[static_cast<size_t>(idx)]);
}
return av.SetValue(std::move(values));
}
graphStatus ListValueSerializer::DeserializeListBuffer(const proto::AttrDef &def, AnyValue &av) {
std::vector<Buffer> values(static_cast<size_t>(def.list().bt_size()));
for (auto idx = 0; idx < def.list().bt_size(); ++idx) {
values[static_cast<size_t>(idx)] =
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(def.list().bt(idx).data()), def.list().bt(idx).size());
}
return av.SetValue(std::move(values));
}
graphStatus ListValueSerializer::DeserializeListGraphDef(const proto::AttrDef &def, AnyValue &av) {
std::vector<proto::GraphDef> values(static_cast<size_t>(def.list().g_size()));
for (auto idx = 0; idx < def.list().g_size(); ++idx) {
values[static_cast<size_t>(idx)] = def.list().g(idx);
}
return av.SetValue(std::move(values));
}
graphStatus ListValueSerializer::DeserializeListNamedAttrs(const proto::AttrDef &def, AnyValue &av) {
const auto attr_deserializer = AttrSerializerRegistry::GetInstance().
GetDeserializer(proto::AttrDef::ValueCase::kFunc);
const auto named_attr_deserializer = dynamic_cast<NamedAttrsSerializer *>(attr_deserializer);
GE_CHECK_NOTNULL(named_attr_deserializer);
std::vector<ge::NamedAttrs> values(static_cast<size_t>(def.list().na_size()));
for (auto idx = 0; idx < def.list().na_size(); ++idx) {
if (named_attr_deserializer->Deserialize(def.list().na(idx), values[static_cast<size_t>(idx)])
!= GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "NamedAttr [%s] deserialize failed.", def.list().na(idx).name().c_str());
return GRAPH_FAILED;
}
}
return av.SetValue(std::move(values));
}
graphStatus ListValueSerializer::DeserializeListDataType(const proto::AttrDef &def, AnyValue &av) {
std::vector<ge::DataType> values(static_cast<size_t>(def.list().dt_size()));
for (auto idx = 0; idx < def.list().dt_size(); ++idx) {
values[static_cast<size_t>(idx)] = static_cast<DataType>(def.list().dt(idx));
}
return av.SetValue(std::move(values));
}
REG_GEIR_SERIALIZER(list_int, ListValueSerializer, GetTypeId<std::vector<int64_t>>(), proto::AttrDef::kList);
REG_GEIR_SERIALIZER(list_str, ListValueSerializer, GetTypeId<std::vector<std::string>>(), proto::AttrDef::kList);
REG_GEIR_SERIALIZER(list_float, ListValueSerializer, GetTypeId<std::vector<float>>(), proto::AttrDef::kList);
REG_GEIR_SERIALIZER(list_bool, ListValueSerializer, GetTypeId<std::vector<bool>>(), proto::AttrDef::kList);
REG_GEIR_SERIALIZER(list_tensor_desc, ListValueSerializer,
GetTypeId<std::vector<GeTensorDesc>>(), proto::AttrDef::kList);
REG_GEIR_SERIALIZER(list_tensor, ListValueSerializer, GetTypeId<std::vector<GeTensor>>(), proto::AttrDef::kList);
REG_GEIR_SERIALIZER(list_buffer, ListValueSerializer, GetTypeId<std::vector<Buffer>>(), proto::AttrDef::kList);
REG_GEIR_SERIALIZER(list_graph_def, ListValueSerializer,
GetTypeId<std::vector<proto::GraphDef>>(), proto::AttrDef::kList);
REG_GEIR_SERIALIZER(list_named_attr, ListValueSerializer,
GetTypeId<std::vector<ge::NamedAttrs>>(), proto::AttrDef::kList);
REG_GEIR_SERIALIZER(list_data_type, ListValueSerializer, GetTypeId<std::vector<ge::DataType>>(), proto::AttrDef::kList);
}