* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/lite_kernel_util.h"
#include <queue>
#include <unordered_map>
#include <set>
#include "src/sub_graph_kernel.h"
namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
std::set<lite::Tensor *> LiteKernelUtil::AllOutTensor(const std::vector<kernel::LiteKernel *> &kernels) {
std::set<lite::Tensor *> all_out_tensors{};
for (const auto &kernel_in_subgraph : kernels) {
for (auto *tensor : kernel_in_subgraph->out_tensors()) {
all_out_tensors.insert(tensor);
}
}
return all_out_tensors;
}
std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputNodes(const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<kernel::LiteKernel *> input_nodes;
std::set<lite::Tensor *> all_out_tensors = AllOutTensor(kernels);
for (const auto &kernel : kernels) {
MS_ASSERT(kernel != nullptr);
bool kernel_is_input = false;
auto all_input_tensors = kernel->in_tensors();
for (auto input : kernel->in_tensors()) {
if (input->IsConst()) {
continue;
}
if (all_out_tensors.find(input) != all_out_tensors.end()) {
continue;
}
kernel_is_input = true;
break;
}
if (kernel_is_input && !lite::IsContain(input_nodes, kernel)) {
input_nodes.push_back(kernel);
}
}
return input_nodes;
}
std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputNodes(
const std::vector<kernel::LiteKernel *> &kernels) {
std::set<kernel::LiteKernel *> all_kernels{};
for (const auto &kernel : kernels) {
all_kernels.insert(kernel);
}
std::vector<kernel::LiteKernel *> output_nodes;
for (const auto &kernel : kernels) {
MS_ASSERT(kernel != nullptr);
if (kernel->is_model_output() || (kernel->out_kernels().empty() && !kernel->out_tensors().empty())) {
if (!lite::IsContain(output_nodes, kernel)) {
output_nodes.push_back(kernel);
}
continue;
}
if (std::any_of(kernel->out_kernels().begin(), kernel->out_kernels().end(),
[&all_kernels](kernel::LiteKernel *tmp) { return all_kernels.find(tmp) == all_kernels.end(); }) &&
!lite::IsContain(output_nodes, kernel)) {
output_nodes.push_back(kernel);
}
}
return output_nodes;
}
std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<lite::Tensor *> input_tensors;
std::vector<kernel::LiteKernel *> input_nodes = SubgraphInputNodes(kernels);
for (const auto &input_node : input_nodes) {
auto &in_node_in_kernels = input_node->in_kernels();
auto &in_node_in_tensors = input_node->in_tensors();
for (auto &in_node_in_tensor : in_node_in_tensors) {
if (in_node_in_tensor->IsGraphInput()) {
if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
input_tensors.push_back(in_node_in_tensor);
}
}
}
for (auto in_node_in_kernel : in_node_in_kernels) {
auto iter = std::find(kernels.begin(), kernels.end(), in_node_in_kernel);
if (iter != kernels.end()) {
continue;
}
auto &outer_in_kernel_out_tensors = in_node_in_kernel->out_tensors();
for (auto in_node_in_tensor : in_node_in_tensors) {
auto outer_in_kernel_out_tensors_iter =
std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_node_in_tensor);
if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
if (!lite::IsContain(input_tensors, in_node_in_tensor)) {
input_tensors.push_back(in_node_in_tensor);
}
}
}
}
}
return input_tensors;
}
std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<lite::Tensor *> output_tensors;
std::vector<kernel::LiteKernel *> output_nodes = SubgraphOutputNodes(kernels);
for (const auto &output_kernel : output_nodes) {
auto &outer_out_kernels = output_kernel->out_kernels();
auto &out_kernel_out_tensors = output_kernel->out_tensors();
for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
if (out_kernel_out_tensor->IsGraphOutput()) {
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
output_tensors.push_back(out_kernel_out_tensor);
}
}
}
if (!outer_out_kernels.empty()) {
for (auto outer_out_kernel : outer_out_kernels) {
auto iter = std::find(kernels.begin(), kernels.end(), outer_out_kernel);
if (iter != kernels.end()) {
continue;
}
auto &outer_out_kernel_in_tensors = outer_out_kernel->in_tensors();
for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
auto outer_out_kernel_in_tensors_iter =
std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) {
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
output_tensors.push_back(out_kernel_out_tensor);
}
}
}
}
}
}
return output_tensors;
}
int LiteKernelUtil::TopologicalSortKernels(std::vector<kernel::LiteKernel *> *kernels) {
auto old_kernels = *kernels;
kernels->clear();
std::queue<kernel::LiteKernel *> kernel_queue;
for (auto kernel : old_kernels) {
if (kernel->in_kernels().empty()) {
kernel_queue.push(kernel);
kernels->emplace_back(kernel);
}
}
while (!kernel_queue.empty()) {
auto cur_kernel = kernel_queue.front();
kernel_queue.pop();
MS_ASSERT(cur_kernel != nullptr);
auto next_kernels = cur_kernel->out_kernels();
for (auto next_kernel : next_kernels) {
auto in_kernels = next_kernel->in_kernels();
if (lite::IsContain(*kernels, const_cast<kernel::LiteKernel *>(next_kernel))) {
MS_LOG(ERROR) << "TopologicalSortKernels failed, loop exist";
return RET_ERROR;
}
if (std::all_of(in_kernels.begin(), in_kernels.end(), [&](const kernel::LiteKernel *in_kernel) {
return lite::IsContain(*kernels, const_cast<kernel::LiteKernel *>(in_kernel));
})) {
kernel_queue.push(next_kernel);
}
}
}
if (kernels->size() != old_kernels.size()) {
MS_LOG(ERROR) << "TopologicalSortKernels failed, kernels size before sort: " << old_kernels.size()
<< ", kernels size after sort: " << kernels->size();
return RET_ERROR;
}
return RET_OK;
}
void LiteKernelUtil::InitTensorInitRefCount(const std::vector<kernel::LiteKernel *> &kernels) {
for (auto *kernel : kernels) {
kernel->InitOutTensorInitRefCount(&kernels);
}
}
int LiteKernelUtil::SetInput(const LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs) { return -1; }
#ifndef CONTROLFLOW_TENSORLIST_CLIP
bool LiteKernelUtil::IsSwitchCall(kernel::LiteKernel *kernel) {
#ifndef DELEGATE_CLIP
if (kernel->desc().arch == kernel::kDelegate) {
return false;
}
#endif
auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
if (subgraph_kernel == nullptr) {
return false;
}
for (auto &node : subgraph_kernel->nodes()) {
if (node->type() == schema::PrimitiveType_Switch &&
InputsContainsSpecificNode(node, schema::PrimitiveType_PartialFusion) && node->out_kernels().size() == 1 &&
node->out_kernels().front()->type() == schema::PrimitiveType_Call) {
return true;
}
}
return false;
}
#endif
kernel::LiteKernel *LiteKernelUtil::GetInputsSpecificNode(const kernel::LiteKernel *kernel,
const schema::PrimitiveType &primitive_type) {
for (auto input : kernel->in_kernels()) {
if (input->type() == primitive_type) {
return input;
}
}
return nullptr;
}
bool LiteKernelUtil::InputsContainsSpecificNode(const kernel::LiteKernel *kernel,
const schema::PrimitiveType &primitive_type) {
if (GetInputsSpecificNode(kernel, primitive_type)) {
return true;
}
return false;
}
void LiteKernelUtil::FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels) {
std::unordered_map<lite::Tensor *, kernel::LiteKernel *> tensor_pre_kernel;
std::unordered_map<lite::Tensor *, std::vector<kernel::LiteKernel *>> tensor_post_kernels;
for (auto *kernel : kernels) {
for (auto *tensor : kernel->out_tensors()) {
tensor_pre_kernel[tensor] = kernel;
}
for (auto *tensor : kernel->in_tensors()) {
(tensor_post_kernels[tensor]).push_back(kernel);
}
}
for (auto *kernel : kernels) {
kernel->set_in_kernels({});
for (auto *tensor : kernel->in_tensors()) {
auto iter = tensor_pre_kernel.find(tensor);
if (iter != tensor_pre_kernel.end() && kernel != iter->second) {
kernel->AddInKernel(iter->second);
}
}
kernel->set_out_kernels({});
for (auto *tensor : kernel->out_tensors()) {
auto iter = tensor_post_kernels.find(tensor);
if (iter != tensor_post_kernels.end()) {
for (auto *find_kernel : iter->second) {
if (kernel == find_kernel) {
continue;
}
kernel->AddOutKernel(find_kernel);
}
}
}
}
}
}