* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "node_def_builder.h"
#include "cpu_kernel_utils.h"
using namespace std;
namespace aicpu {
std::shared_ptr<NodeDef> NodeDefBuilder::CreateNodeDef() { return CpuKernelUtils::CpuKernelUtils::CreateNodeDef(); }
NodeDefBuilder::NodeDefBuilder(NodeDef* nodeDef, std::string name, std::string opName)
{
node_def_ = nodeDef;
name_ = name;
node_def_->SetOpType(opName);
}
void NodeDefBuilder::BuildNodeFromInputOutputNode(const InputOutputNode& node, bool isInput)
{
std::shared_ptr<Tensor> tensor;
if (isInput) {
tensor = node_def_->AddInputs();
} else {
tensor = node_def_->AddOutputs();
}
aicpu::CpuKernelUtils::SetTensorName(node.node, tensor);
tensor->SetDataType(node.d_type);
auto shape = tensor->GetTensorShape();
shape->SetDimSizes(node.dims);
shape->SetFormat(node.format);
int64_t dataSize = 1;
for (size_t i = 0; i < node.dims.size(); i++) {
dataSize = dataSize * node.dims[i];
}
dataSize = dataSize * GetSizeByDataType(node.d_type);
if (node.dims.empty()) {
dataSize = GetSizeByDataType(node.d_type);
}
if (node.data == nullptr) {
dataSize = 0;
}
tensor->SetDataSize(static_cast<uint64_t>(dataSize));
tensor->SetData(node.data);
}
NodeDefBuilder& NodeDefBuilder::Input(const InputOutputNode& input)
{
BuildNodeFromInputOutputNode(input, true);
return *this;
}
NodeDefBuilder& NodeDefBuilder::Output(const InputOutputNode& output)
{
BuildNodeFromInputOutputNode(output, false);
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, int32_t value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetInt(value);
(void)node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, int64_t value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetInt(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, float value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetFloat(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, double value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetFloat(static_cast<float>(value));
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, bool value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetBool(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, aicpu::DataType value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetDataType(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<bool>& value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetListBool(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::string& value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetString(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<std::string>& value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetListString(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<int64_t>& value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetListInt(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<std::vector<int64_t>>& value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetListListInt(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<float>& value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetListFloat(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<aicpu::DataType>& value)
{
auto attr = CpuKernelUtils::CreateAttrValue();
attr->SetListDataType(value);
node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<int64_t>& dims, std::string type)
{
if (type == "shape") {
auto shape = CpuKernelUtils::CreateAttrValue();
auto value = CpuKernelUtils::CreateTensorShape();
value->SetDimSizes(dims);
(void)node_def_->AddAttrs(name, shape.get());
(void)shape->SetTensorShape(value.get());
}
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<std::vector<int64_t>>& shape_lists,
std::string type)
{
if (type == "shape_list") {
auto shapeItems = CpuKernelUtils::CreateAttrValue();
for (size_t i = 0; i < shape_lists.size(); i++) {
auto value = shapeItems->AddListTensorShape();
value->SetDimSizes(shape_lists[i]);
}
(void)node_def_->AddAttrs(name, shapeItems.get());
}
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, aicpu::Tensor* tensor)
{
auto attr = CpuKernelUtils::CreateAttrValue();
(void)attr->SetTensor(tensor);
(void)node_def_->AddAttrs(name, attr.get());
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, std::vector<aicpu::Tensor*>& tensors)
{
auto attr = CpuKernelUtils::CreateAttrValue();
(void)attr->SetListTensor(tensors);
(void)node_def_->AddAttrs(name, attr.get());
return *this;
}
}