* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <map>
#include "atb/graph_op_builder.h"
#include "atb/utils/log.h"
#include "atb/operation.h"
#include "atb/types.h"
namespace atb {
static const size_t MAX_TENSOR_NUM = 256;
GraphOpBuilder::GraphOpBuilder() {};
GraphOpBuilder::~GraphOpBuilder() {};
class GraphOpBuilderImpl : public GraphOpBuilder {
public:
GraphOpBuilderImpl() {};
~GraphOpBuilderImpl() override {};
Status Init(const std::string &opName, const InferShapeFunc &inferShapeFunc,
const SVector<std::string> &inTensorNames, const SVector<std::string> &outTensorNames) override;
Status Reshape(const std::string &srcTensorName, const ReshapeFunc &reshapeFunc,
const std::string &viewTensorName) override;
Status AddOperation(Operation *operation, const SVector<std::string> &inTensorNames,
const SVector<std::string> &outTensorNames) override;
Operation *Build() override;
private:
uint32_t GetTensorId(const std::string &tensorName);
private:
GraphParam graphParam_;
uint32_t internalTensorNum_ = 0;
std::map<std::string, uint32_t> inTensorIds_;
std::map<std::string, uint32_t> outTensorIds_;
std::map<std::string, uint32_t> internalTensorIds_;
std::map<std::string, std::pair<uint32_t, ReshapeFunc>> viewTensorIds_;
};
Status GraphOpBuilderImpl::Init(const std::string &opName, const InferShapeFunc &inferShapeFunc,
const SVector<std::string> &inTensorNames, const SVector<std::string> &outTensorNames)
{
graphParam_.name = opName;
graphParam_.inferShapeFunc = inferShapeFunc;
uint32_t id = 0;
size_t inTensorNum = inTensorNames.size();
if (inTensorNum > MAX_TENSOR_NUM) {
ATB_LOG(ERROR) << "inTensor num should be <= 256, but get " << inTensorNum;
return ERROR_INVALID_IN_TENSOR_NUM;
}
for (const std::string &inTensorName : inTensorNames) {
inTensorIds_[inTensorName] = id++;
}
size_t outTensorNum = outTensorNames.size();
if (outTensorNum > MAX_TENSOR_NUM) {
ATB_LOG(ERROR) << "outTensor num should be <= 256, but get " << outTensorNum;
return ERROR_INVALID_IN_TENSOR_NUM;
}
for (const std::string &outTensorName : outTensorNames) {
outTensorIds_[outTensorName] = id++;
}
graphParam_.inTensorNum = static_cast<uint32_t>(inTensorNum);
graphParam_.outTensorNum = static_cast<uint32_t>(outTensorNum);
return NO_ERROR;
}
Status GraphOpBuilderImpl::Reshape(const std::string &srcTensorName, const ReshapeFunc &reshapeFunc,
const std::string &viewTensorName)
{
viewTensorIds_[viewTensorName] = {GetTensorId(srcTensorName), reshapeFunc};
return NO_ERROR;
}
Status GraphOpBuilderImpl::AddOperation(Operation *operation, const SVector<std::string> &inTensorNames,
const SVector<std::string> &outTensorNames)
{
if (operation == nullptr) {
ATB_LOG(ERROR) << "operation can not be nullptr!";
return ERROR_INVALID_PARAM;
}
Node node;
node.operation = operation;
node.inTensorIds.resize(0);
node.outTensorIds.resize(0);
node.inTensorReshapeFuncs.resize(0);
for (const std::string &inTensorName : inTensorNames) {
node.inTensorIds.push_back(GetTensorId(inTensorName));
if (internalTensorNum_ > MAX_TENSOR_NUM) {
ATB_LOG(ERROR) << "internalTensor num should be <= 256, but get " << internalTensorNum_;
return ERROR_INVALID_IN_TENSOR_NUM;
}
if (viewTensorIds_.find(inTensorName) != viewTensorIds_.end()) {
node.inTensorReshapeFuncs.push_back(viewTensorIds_[inTensorName].second);
} else {
node.inTensorReshapeFuncs.push_back(nullptr);
}
}
for (const std::string &outTensorName : outTensorNames) {
node.outTensorIds.push_back(GetTensorId(outTensorName));
if (internalTensorNum_ > MAX_TENSOR_NUM) {
ATB_LOG(ERROR) << "internalTensor num should be <= 256, but get " << internalTensorNum_;
return ERROR_INVALID_IN_TENSOR_NUM;
}
}
graphParam_.nodes.push_back(node);
return NO_ERROR;
}
Operation *GraphOpBuilderImpl::Build()
{
graphParam_.internalTensorNum = internalTensorNum_;
Operation *graphOp = nullptr;
Status st = CreateOperation(graphParam_, &graphOp);
if (st != NO_ERROR) {
for (size_t i = 0; i < graphParam_.nodes.size(); i++) {
if (graphParam_.nodes.at(i).operation != nullptr) {
DestroyOperation(graphParam_.nodes.at(i).operation);
graphParam_.nodes.at(i).operation = nullptr;
}
}
ATB_LOG(ERROR) << "create graph op fail";
}
return graphOp;
}
uint32_t GraphOpBuilderImpl::GetTensorId(const std::string &tensorName)
{
if (inTensorIds_.find(tensorName) != inTensorIds_.end()) {
return inTensorIds_[tensorName];
} else if (outTensorIds_.find(tensorName) != outTensorIds_.end()) {
return outTensorIds_[tensorName];
} else if (viewTensorIds_.find(tensorName) != viewTensorIds_.end()) {
return viewTensorIds_[tensorName].first;
} else if (internalTensorIds_.find(tensorName) != internalTensorIds_.end()) {
return internalTensorIds_[tensorName];
} else {
uint32_t internalTensorId = inTensorIds_.size() + outTensorIds_.size() + internalTensorNum_++;
internalTensorIds_[tensorName] = internalTensorId;
return internalTensorId;
}
}
Status CreateGraphOpBuilder(GraphOpBuilder **builder)
{
if (builder == nullptr) {
ATB_LOG(ERROR) << "invalid param, builder is null";
return ERROR_INVALID_PARAM;
}
*builder = new GraphOpBuilderImpl();
return NO_ERROR;
}
Status DestroyGraphOpBuilder(GraphOpBuilder *builder)
{
if (builder != nullptr) {
delete builder;
}
return NO_ERROR;
}
}