* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/parallel/spliter.h"
#include <queue>
#include "tools/optimizer/fisson/fisson_util.h"
#include "tools/optimizer/parallel/split_strategy.h"
namespace mindspore {
namespace opt {
Spliter *Spliter::GetInstance() {
static Spliter spliter;
return &spliter;
}
void Spliter::VisitNodesInputs(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
for (const auto &node : func_graph->GetOrderedCnodes()) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
for (const auto &input : node->inputs()) {
if (!utils::isa<CNodePtr>(input)) {
continue;
}
nodes_inputs_[node].insert(input);
}
}
}
void Spliter::VisitNodesOutputs(const FuncGraphPtr &func_graph) {
for (const auto &node : func_graph->GetOrderedCnodes()) {
for (const auto &output_item : nodes_inputs_) {
if (output_item.first != node) {
for (const auto &output : output_item.second) {
if (node == output) {
nodes_outputs_[node].insert(output_item.first);
}
}
}
}
}
}
void Spliter::RecordGraphInfo(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
VisitNodesInputs(func_graph);
VisitNodesOutputs(func_graph);
for (const auto &node : func_graph->GetOrderedCnodes()) {
if (!utils::isa<CNodePtr>(node)) {
return;
}
if (nodes_outputs_[node].size() > kDefaultBatch) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(cnode->input(kAnfPrimitiveIndex));
MS_ASSERT(prim != nullptr);
auto device_type =
prim->GetAttr(ops::kDeviceType) != nullptr ? GetValue<int>(prim->GetAttr(ops::kDeviceType)) : kDeviceTypeNone;
if (device_type != kDeviceTypeNone) {
return;
}
if (match_visited_[node] || !IsConv2D(node)) {
continue;
}
int match_num = 0;
std::queue<AnfNodePtr> conv_nodes;
conv_nodes.push(node);
while (true) {
if (conv_nodes.empty()) {
break;
}
auto curr_node = conv_nodes.front();
conv_nodes.pop();
if (match_visited_[curr_node]) {
continue;
}
auto curr_cnode = curr_node->cast<CNodePtr>();
match_visited_[curr_node] = true;
for (const auto &pre_input_node : nodes_inputs_[curr_node]) {
if (match_visited_[pre_input_node] || !IsConv2D(pre_input_node)) {
break;
}
conv_nodes.push(pre_input_node);
}
if (nodes_outputs_[curr_cnode].size() > kDefaultBatch) {
break;
}
for (const auto &post_output_node : nodes_outputs_[curr_node]) {
if (match_visited_[post_output_node] || !IsConv2D(post_output_node)) {
break;
}
conv_nodes.push(post_output_node);
}
match_num++;
}
if (match_num != 0) {
match_numbers_.insert(match_num);
}
}
}
void Spliter::UpdateNodeOutputs(const std::string &input_node_name, const AnfNodePtr &candidate_output) {
if (candidate_output == nullptr) {
return;
}
if (graph_node_outputs_.find(input_node_name) != graph_node_outputs_.end()) {
std::vector<AnfNodePtr>::iterator it;
it =
find(graph_node_outputs_[input_node_name].begin(), graph_node_outputs_[input_node_name].end(), candidate_output);
if (it != graph_node_outputs_[input_node_name].end()) {
return;
}
}
graph_node_outputs_[input_node_name].push_back(candidate_output);
}
void Spliter::UpdateNodeInputShapes(const std::string &node_name, const std::vector<ShapeVector> &input_shapes) {
graph_node_input_shapes_[node_name] = (input_shapes);
}
void Spliter::UpdateNodeOutputShapes(const std::string &node_name, const std::vector<ShapeVector> &output_shapes) {
graph_node_output_shapes_[node_name] = (output_shapes);
}
}
}