* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include <algorithm>
#include "src/lite_mindrt.h"
#include "mindrt/include/mindrt.hpp"
#include "src/lite_kernel_util.h"
#include "src/common/tensor_util.h"
#include "src/runtime/inner_allocator.h"
#include "src/runtime/kernel/arm/base/partial_fusion.h"
#ifdef ENABLE_FP16
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
#endif
namespace mindspore::lite {
void LiteOpActor::RunOpData(OpData<lite::Tensor> *inputs, OpContext<lite::Tensor> *context) {
auto op_uuid = context->sequential_num_;
input_op_datas_[op_uuid].push_back(inputs);
inputs_data_[inputs->index_] = inputs->data_;
if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
return;
}
auto ret = InitInputData();
if (ret != RET_OK) {
input_op_datas_.erase(op_uuid);
context->SetFailed(ret);
return;
}
ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
if (ret != RET_OK) {
input_op_datas_.erase(op_uuid);
context->SetFailed(ret);
return;
}
input_op_datas_.erase(op_uuid);
AsyncOutput(context);
SetOutputData(context);
return;
}
bool OfflineIsolated(const std::vector<kernel::LiteKernel *> &kernels, const kernel::LiteKernel &this_kernel,
const lite::Tensor &this_input_tensor) {
if (this_input_tensor.IsGraphInput()) {
return false;
}
for (auto &kernel : kernels) {
if (kernel == &this_kernel) {
continue;
}
if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(),
[&this_input_tensor](lite::Tensor *tensor) { return tensor == &this_input_tensor; })) {
return false;
}
}
return true;
}
void LiteOpActor::ReplaceNodeInTensor(kernel::LiteKernel *kernel, Tensor *old_tensor, Tensor *new_tensor) {
int ref_count = 0;
#ifndef DELEGATE_CLIP
if (kernel->desc().arch == kernel::kDelegate) {
ref_count++;
} else {
#endif
for (auto in_node : reinterpret_cast<kernel::SubGraphKernel *>(kernel)->in_nodes()) {
for (size_t node_in_index = 0; node_in_index < in_node->in_tensors().size(); node_in_index++) {
if (old_tensor == in_node->in_tensors()[node_in_index]) {
in_node->set_in_tensor(new_tensor, node_in_index);
ref_count++;
}
}
}
#ifndef DELEGATE_CLIP
}
#endif
new_tensor->set_init_ref_count(ref_count);
}
int LiteOpActor::IsolateInputData(std::vector<std::shared_ptr<LiteOpActor>> *actors) {
std::vector<kernel::LiteKernel *> kernels{};
std::transform(actors->begin(), actors->end(), std::back_inserter(kernels),
[](std::shared_ptr<LiteOpActor> actor) { return actor->kernel_; });
size_t in_tensor_size = kernel_->in_tensors().size();
for (size_t i = 0; i < in_tensor_size; i++) {
Tensor *old_tensor = kernel_->in_tensors()[i];
if (OfflineIsolated(kernels, *kernel_, *old_tensor)) {
if (old_tensor->data_type() == kNumberTypeFloat16 || old_tensor->data_type() == kNumberTypeFloat32) {
old_tensor->set_data_type(kernel_->desc().data_type);
}
#ifndef CONTROLFLOW_TENSORLIST_CLIP
if (old_tensor->data_type() == kObjectTypeTensorType) {
auto old_tensorlist = reinterpret_cast<TensorList *>(old_tensor);
if (old_tensorlist->tensors_data_type() == kNumberTypeFloat16 ||
old_tensorlist->tensors_data_type() == kNumberTypeFloat32) {
old_tensorlist->set_tensors_data_type(kernel_->desc().data_type);
}
}
#endif
old_tensor->set_allocator(kernel_->Context()->allocator);
continue;
}
TypeId new_data_type = old_tensor->data_type();
if (old_tensor->data_type() == kNumberTypeFloat16 || old_tensor->data_type() == kNumberTypeFloat32) {
new_data_type = kernel_->desc().data_type;
}
Tensor *new_tensor = new Tensor(new_data_type, old_tensor->shape(), old_tensor->format(), old_tensor->category());
if (new_tensor == nullptr) {
MS_LOG(ERROR) << "new Tensor failed.";
return RET_NULL_PTR;
}
new_tensor->set_allocator(old_tensor->allocator());
if (new_tensor->allocator() == nullptr && kernel_->Context() != nullptr &&
kernel_->desc().arch != kernel::kDelegate) {
new_tensor->set_allocator(kernel_->Context()->allocator);
}
new_tensor->set_tensor_name(kernel_->name() + "_duplicate_" + old_tensor->tensor_name());
for (LiteQuantParam quant : old_tensor->quant_params()) {
new_tensor->AddQuantParam(quant);
}
isolate_input_map_.insert(std::make_pair(new_tensor, old_tensor));
ReplaceNodeInTensor(kernel_, old_tensor, new_tensor);
kernel_->set_in_tensor(new_tensor, i);
}
return RET_OK;
}
int LiteOpActor::LiteActorInit(std::vector<std::shared_ptr<LiteOpActor>> *actors) {
auto ret = CompileArrow();
if (ret != RET_OK) {
MS_LOG(ERROR) << "compile arrow failed.";
return ret;
}
ret = PrepareOutputData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "prepare output data failed.";
return ret;
}
ret = IsolateInputData(actors);
if (ret != RET_OK) {
MS_LOG(ERROR) << "isolate input data failed.";
return ret;
}
return RET_OK;
}
int LiteOpActor::ResizeGraphInput(const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<std::vector<int>> &dims) {
for (auto map : isolate_input_map_) {
auto isolate_tensor = map.first;
auto src_tensor = map.second;
for (size_t i = 0; i < inputs.size(); i++) {
if (src_tensor == inputs[i]) {
isolate_tensor->set_shape(dims[i]);
}
}
}
return RET_OK;
}
int LiteOpActor::CompileArrowThroughOutputKernels() {
output_data_arrows_.clear();
int out_tensor_size = static_cast<int>(kernel_->out_tensors().size());
for (int i = 0; i < out_tensor_size; i++) {
for (auto out : kernel_->out_kernels()) {
int in_tensor_size = static_cast<int>(out->in_tensors().size());
int to_input_index = -1;
for (int j = 0; j < in_tensor_size; j++) {
if (kernel_->out_tensors()[i] == out->in_tensors()[j]) {
to_input_index = j;
break;
}
}
if (to_input_index == -1) {
continue;
}
auto id = out->name() + this->GetAID().Url();
auto arrow = std::make_shared<DataArrow>(i, AID(id), to_input_index);
if (arrow == nullptr) {
MS_LOG(ERROR) << "create DataArrow failed, out kernel: " << out->name();
return RET_ERROR;
}
output_data_arrows_.emplace_back(std::move(arrow));
}
}
return RET_OK;
}
#ifndef CONTROLFLOW_TENSORLIST_CLIP
int LiteOpActor::CompileArrowThroughPartialCall() {
#ifndef DELEGATE_CLIP
if (kernel_->desc().arch == kernel::kDelegate) {
MS_LOG(INFO) << "kernel is delegate subgraph kernel.";
return RET_OK;
}
#endif
auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
if (subgraph_kernel == nullptr) {
MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call.";
return RET_OK;
}
for (auto &node : subgraph_kernel->nodes()) {
if (node->type() != schema::PrimitiveType_Call) {
continue;
}
call_node_ = node;
auto partial_node = kernel::LiteKernelUtil::GetInputsSpecificNode(node, schema::PrimitiveType_PartialFusion);
if (!partial_node) {
continue;
}
partial_node_ = partial_node;
auto subgraph = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel())->subgraph_kernel();
auto out_actor_id = subgraph_to_actor_.at(subgraph);
kernel_->set_out_tensors(partial_node->in_tensors());
for (size_t i = 0; i < partial_node->in_tensors().size(); ++i) {
auto arrow = std::make_shared<DataArrow>(i, out_actor_id, i);
if (arrow == nullptr) {
MS_LOG(ERROR) << "create DataArrow failed";
return RET_ERROR;
}
output_data_arrows_.emplace_back(std::move(arrow));
}
}
subgraph_kernel->DropNode(partial_node_);
subgraph_kernel->DropNode(call_node_);
return RET_OK;
}
#endif
int LiteOpActor::CompileArrow() {
int ret;
output_data_arrows_.clear();
#ifndef CONTROLFLOW_TENSORLIST_CLIP
ret = CompileArrowThroughPartialCall();
if (ret != RET_OK) {
output_data_arrows_.clear();
MS_LOG(ERROR) << "CompileArrowThroughPartialCall failed.";
return ret;
}
if (!output_data_arrows_.empty()) {
MS_LOG(INFO) << "CompileArrowThroughPartialCall done.";
return RET_OK;
}
#endif
ret = CompileArrowThroughOutputKernels();
if (ret != RET_OK) {
output_data_arrows_.clear();
MS_LOG(ERROR) << "CompileArrowThroughOutputKernels failed.";
return ret;
}
return ret;
}
void LiteOpActor::MoveTensorInputData(Tensor *dst_tensor, Tensor *src_tensor) {
MS_ASSERT(src_tensor != dst_tensor);
dst_tensor->FreeData();
dst_tensor->ResetRefCount();
dst_tensor->set_allocator(src_tensor->allocator());
src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
if (src_tensor->data() != nullptr) {
dst_tensor->set_data(src_tensor->MutableData());
}
dst_tensor->set_own_data(src_tensor->own_data());
src_tensor->DecRefCount();
}
void LiteOpActor::MoveInputData(Tensor *dst_tensor, Tensor *src_tensor) {
if (src_tensor == dst_tensor) {
MS_LOG(INFO) << "no need to move.";
return;
}
MS_ASSERT(src_tensor->allocator() != nullptr);
#ifndef CONTROLFLOW_TENSORLIST_CLIP
if (src_tensor->data_type() == kObjectTypeTensorType) {
MoveTensorListInputData(reinterpret_cast<TensorList *>(dst_tensor), reinterpret_cast<TensorList *>(src_tensor));
} else {
MoveTensorInputData(dst_tensor, src_tensor);
}
#else
MoveTensorInputData(dst_tensor, src_tensor);
#endif
return;
}
void LiteOpActor::SetInputData(Tensor *dst_tensor, Tensor *src_tensor) {
dst_tensor->set_data(src_tensor->data());
dst_tensor->set_own_data(false);
}
int LiteOpActor::CastInputData(Tensor *dst, Tensor *src) {
int ret = RET_OK;
#ifndef CONTROLFLOW_TENSORLIST_CLIP
if (src->data_type() != kObjectTypeTensorType) {
ret = CastTensorInputData(dst, src);
} else {
ret = CastTensorListInputData(reinterpret_cast<TensorList *>(dst), reinterpret_cast<TensorList *>(src));
}
#else
ret = CastTensorInputData(dst, src);
#endif
src->DecRefCount();
return ret;
}
bool LiteOpActor::NeedCastData(Tensor *dst_tensor, Tensor *src_tensor) {
if (dst_tensor->data_type() != kObjectTypeTensorType && src_tensor->data_type() != kObjectTypeTensorType &&
dst_tensor->data_type() != src_tensor->data_type()) {
return true;
}
#ifndef CONTROLFLOW_TENSORLIST_CLIP
if (dst_tensor->data_type() == kObjectTypeTensorType && src_tensor->data_type() == kObjectTypeTensorType &&
reinterpret_cast<TensorList *>(dst_tensor)->tensors_data_type() !=
reinterpret_cast<TensorList *>(src_tensor)->tensors_data_type()) {
return true;
}
#endif
return false;
}
int LiteOpActor::CastTensorInputData(Tensor *dst, Tensor *src) {
dst->MallocData();
dst->ResetRefCount();
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
if (dst->shape() != src->shape()) {
MS_LOG(ERROR) << "dst tensor: " << dst->tensor_name() << " shape: " << dst->shape() << " vs "
<< "src tensor: " << src->tensor_name() << " shape: " << src->shape();
return RET_PARAM_INVALID;
}
auto dst_data = dst->MutableData();
auto src_data = src->MutableData();
auto src_nums_size = src->ElementsNum();
auto dst_data_type = static_cast<int>(dst->data_type());
auto src_data_type = static_cast<int>(src->data_type());
if (dst_data_type == kNumberTypeFloat32 && src_data_type == kNumberTypeFloat16) {
Float16ToFloat32_fp16_handler(src_data, dst_data, src_nums_size, support_fp16_);
} else if (dst_data_type == kNumberTypeFloat16 && src_data_type == kNumberTypeFloat32) {
Float32ToFloat16_fp16_handler(src_data, dst_data, src_nums_size, support_fp16_);
} else {
MS_LOG(ERROR) << "not support dst_data_type: " << dst_data_type << " src_data_type: " << src_data_type;
return RET_NOT_SUPPORT;
}
return RET_OK;
#endif
return RET_ERROR;
}
#ifndef CONTROLFLOW_TENSORLIST_CLIP
void LiteOpActor::MoveTensorListInputData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
MS_ASSERT(src_tensorlist != nullptr);
MS_ASSERT(dst_tensorlist != nullptr);
dst_tensorlist->FreeData();
dst_tensorlist->ResetRefCount();
dst_tensorlist->set_allocator(src_tensorlist->allocator());
auto src_tensorlist_tensors_size = src_tensorlist->tensors().size();
auto dst_tensorlist_tensors_size = dst_tensorlist->tensors().size();
if (src_tensorlist_tensors_size != dst_tensorlist_tensors_size) {
MS_LOG(ERROR) << "src tensorlist: " << src_tensorlist->tensor_name()
<< " tesnors size: " << src_tensorlist_tensors_size
<< " vs dst tensorlist: " << src_tensorlist->tensor_name()
<< " tensors size: " << dst_tensorlist_tensors_size;
return;
}
dst_tensorlist->set_own_data(src_tensorlist->own_data());
for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) {
auto &src_tensor = src_tensorlist->tensors()[i];
auto &dst_tensor = dst_tensorlist->tensors()[i];
if (src_tensor->allocator() != nullptr) {
src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
}
dst_tensor->set_own_data(src_tensor->own_data());
if (src_tensor->data() != nullptr) {
dst_tensor->set_data(src_tensor->MutableData());
}
dst_tensor->set_shape(src_tensor->shape());
}
if (src_tensorlist->IsConst() || src_tensorlist->IsGraphInput()) {
dst_tensorlist->set_own_data(false);
} else {
src_tensorlist->DecRefCount();
}
}
int LiteOpActor::CastTensorListInputData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
MS_ASSERT(src_tensorlist != nullptr);
MS_ASSERT(dst_tensorlist != nullptr);
dst_tensorlist->set_shape(src_tensorlist->shape());
std::vector<std::vector<int>> tensors_shapes{};
tensors_shapes.resize(src_tensorlist->tensors().size());
for (size_t i = 0; i < tensors_shapes.size(); ++i) {
tensors_shapes[i] = src_tensorlist->tensors()[i]->shape();
}
if (src_tensorlist->tensors_data_type() == kNumberTypeFloat16) {
dst_tensorlist->MallocTensorListData(kNumberTypeFloat32, tensors_shapes);
}
if (src_tensorlist->tensors_data_type() == kNumberTypeFloat32) {
dst_tensorlist->MallocTensorListData(kNumberTypeFloat16, tensors_shapes);
}
dst_tensorlist->set_allocator(src_tensorlist->allocator());
dst_tensorlist->ResetRefCount();
for (size_t i = 0; i < src_tensorlist->tensors().size(); ++i) {
auto &src_tensor = src_tensorlist->tensors()[i];
auto &dst_tensor = dst_tensorlist->tensors()[i];
CastTensorInputData(dst_tensor, src_tensor);
}
return RET_OK;
}
int LiteSwitchOpActor::CompileTrueBranchArrow() {
if (true_partial_node_ == nullptr) {
MS_LOG(ERROR) << "true_partial_node_ is nullptr.";
return RET_NULL_PTR;
}
auto subgraph = static_cast<kernel::PartialFusionKernel *>(true_partial_node_->kernel())->subgraph_kernel();
auto true_branch_actor_id = subgraph_to_actor_.at(subgraph);
for (size_t i = 0; i < true_partial_node_->in_tensors().size(); ++i) {
int out_tensor_size = static_cast<int>(kernel_->out_tensors().size());
for (int j = 0; j < out_tensor_size; ++j) {
if (true_partial_node_->in_tensors()[i] != kernel_->out_tensors()[j]) {
continue;
}
auto arrow = std::make_shared<DataArrow>(j, true_branch_actor_id, i);
if (arrow == nullptr) {
MS_LOG(ERROR) << "create DataArrow failed";
return RET_ERROR;
}
true_branch_output_data_arrows_.emplace_back(std::move(arrow));
}
}
return RET_OK;
}
int LiteSwitchOpActor::CompileFalseBranchArrow() {
if (false_partial_node_ == nullptr) {
MS_LOG(ERROR) << "false_partial_node_ is nullptr.";
return RET_NULL_PTR;
}
auto subgraph = static_cast<kernel::PartialFusionKernel *>(false_partial_node_->kernel())->subgraph_kernel();
auto false_branch_actor_id = subgraph_to_actor_.at(subgraph);
for (size_t i = 0; i < false_partial_node_->in_tensors().size(); ++i) {
int out_tensor_size = static_cast<int>(kernel_->out_tensors().size());
for (int j = 0; j < out_tensor_size; ++j) {
if (false_partial_node_->in_tensors()[i] != kernel_->out_tensors()[j]) {
continue;
}
auto arrow = std::make_shared<DataArrow>(j, false_branch_actor_id, i);
if (arrow == nullptr) {
MS_LOG(ERROR) << "create DataArrow failed";
return RET_ERROR;
}
false_branch_output_data_arrows_.emplace_back(std::move(arrow));
}
}
return RET_OK;
}
int LiteSwitchOpActor::GetSwitchAndCallNode(kernel::SubGraphKernel *subgraph_kernel) {
for (auto &node : subgraph_kernel->nodes()) {
if (node->type() != schema::PrimitiveType_Call) {
continue;
}
call_node_ = node;
auto switch_node = kernel::LiteKernelUtil::GetInputsSpecificNode(node, schema::PrimitiveType_Switch);
if (!switch_node) {
continue;
}
if (switch_node->in_tensors().size() < kSwitchMinInputTensorSize) {
MS_LOG(ERROR) << "actor name: " << this->GetAID() << "'s switch node " << switch_node->name()
<< " input tensor size: " << switch_node->in_tensors().size() << " is less than 3.";
return RET_ERROR;
}
switch_node_ = switch_node;
if (switch_node->in_kernels().size() == kSwitchMaxInputKernelSize) {
true_partial_node_ = switch_node->in_kernels().at(kSwitchTruePartialInputIndex);
false_partial_node_ = switch_node->in_kernels().at(kSwitchFalsePartialInputIndex);
}
if (switch_node->in_kernels().size() == kSwitchMinInputKernelSize) {
true_partial_node_ = switch_node->in_kernels().at(kSwitchTruePartialInputIndex - 1);
false_partial_node_ = switch_node->in_kernels().at(kSwitchFalsePartialInputIndex - 1);
}
break;
}
return RET_OK;
}
void LiteSwitchOpActor::AppendOutputTensors() {
for (auto &tensor : true_partial_node_->in_tensors()) {
if (std::find(output_tensors_.begin(), output_tensors_.end(), tensor) == output_tensors_.end()) {
output_tensors_.push_back(tensor);
}
}
for (auto &tensor : false_partial_node_->in_tensors()) {
if (std::find(output_tensors_.begin(), output_tensors_.end(), tensor) == output_tensors_.end()) {
output_tensors_.push_back(tensor);
}
}
kernel_->set_out_tensors(output_tensors_);
}
int LiteSwitchOpActor::CompileArrowThroughSwitchCall() {
auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
if (subgraph_kernel == nullptr) {
MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call.";
return RET_OK;
}
int ret = GetSwitchAndCallNode(subgraph_kernel);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetSwitchAndCallCnode failed.";
return ret;
}
AppendOutputTensors();
ret = CompileTrueBranchArrow();
if (ret != RET_OK) {
MS_LOG(ERROR) << "CompileTrueBranchArrow failed.";
true_branch_output_data_arrows_.clear();
return ret;
}
ret = CompileFalseBranchArrow();
if (ret != RET_OK) {
MS_LOG(ERROR) << "CompileFalseBranchArrow failed.";
false_branch_output_data_arrows_.clear();
true_branch_output_data_arrows_.clear();
return ret;
}
subgraph_kernel->DropNode(call_node_);
subgraph_kernel->DropNode(switch_node_);
subgraph_kernel->DropNode(true_partial_node_);
subgraph_kernel->DropNode(false_partial_node_);
return ret;
}
int LiteSwitchOpActor::CompileArrow() {
int ret = CompileArrowThroughSwitchCall();
if (ret != RET_OK) {
true_branch_output_data_arrows_.clear();
false_branch_output_data_arrows_.clear();
MS_LOG(ERROR) << "CompileArrowThroughSwitchCall failed.";
return ret;
}
if (!true_branch_output_data_arrows_.empty() && !false_branch_output_data_arrows_.empty()) {
MS_LOG(INFO) << "CompileArrowThroughSwitchCall done.";
return RET_OK;
}
ret = CompileArrowThroughOutputKernels();
if (ret != RET_OK) {
output_data_arrows_.clear();
MS_LOG(ERROR) << "CompileArrowThroughOutputKernels failed.";
return ret;
}
return ret;
}
int LiteSwitchOpActor::PrepareOutputData() {
true_branch_outputs_data_.resize(true_branch_output_data_arrows_.size());
for (size_t i = 0; i < true_branch_output_data_arrows_.size(); i++) {
auto &arrow = true_branch_output_data_arrows_[i];
auto data =
std::make_shared<OpData<Tensor>>(arrow->to_op_id_, (kernel_->out_tensors()).at(arrow->from_output_index_),
static_cast<int>(arrow->to_input_index_));
if (data == nullptr) {
MS_LOG(ERROR) << "new true_branch_output_data failed.";
return RET_NULL_PTR;
}
true_branch_outputs_data_.at(i) = data;
}
false_branch_outputs_data_.resize(false_branch_output_data_arrows_.size());
for (size_t i = 0; i < false_branch_output_data_arrows_.size(); i++) {
auto &arrow = false_branch_output_data_arrows_[i];
auto data =
std::make_shared<OpData<Tensor>>(arrow->to_op_id_, (kernel_->out_tensors()).at(arrow->from_output_index_),
static_cast<int>(arrow->to_input_index_));
if (data == nullptr) {
MS_LOG(ERROR) << "new false_branch_output_data failed.";
return RET_NULL_PTR;
}
false_branch_outputs_data_.at(i) = data;
}
return RET_OK;
}
void LiteSwitchOpActor::DecreaseTrueBranchInputTensor() {
switch_node_->in_tensors()[kSwitchCondTensorIndex]->DecRefCount();
for (auto input : true_partial_node_->in_tensors()) {
input->DecRefCount();
}
}
void LiteSwitchOpActor::DecreaseFalseBranchInputTensor() {
switch_node_->in_tensors()[kSwitchCondTensorIndex]->DecRefCount();
for (auto input : false_partial_node_->in_tensors()) {
input->DecRefCount();
}
}
void LiteSwitchOpActor::AsyncTrueBranchOutput(OpContext<Tensor> *context) {
MS_ASSERT(true_branch_output_data_arrows_.size() == true_branch_outputs_data_.size());
for (size_t i = 0; i < true_branch_output_data_arrows_.size(); ++i) {
auto &data = true_branch_outputs_data_.at(i);
Async(true_branch_output_data_arrows_[i]->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
}
}
void LiteSwitchOpActor::AsyncFalseBranchOutput(OpContext<Tensor> *context) {
MS_ASSERT(false_branch_output_data_arrows_.size() == false_branch_outputs_data_.size());
for (size_t i = 0; i < false_branch_output_data_arrows_.size(); ++i) {
auto &data = false_branch_outputs_data_.at(i);
Async(false_branch_output_data_arrows_[i]->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
}
}
void LiteSwitchOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *context) {
auto op_uuid = context->sequential_num_;
input_op_datas_[op_uuid].push_back(inputs);
inputs_data_[inputs->index_] = inputs->data_;
if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
return;
}
int ret = InitInputData();
if (ret != RET_OK) {
input_op_datas_.erase(op_uuid);
context->SetFailed(ret);
return;
}
ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
if (ret != RET_OK) {
input_op_datas_.erase(op_uuid);
context->SetFailed(ret);
return;
}
input_op_datas_.erase(op_uuid);
auto cond_ptr = reinterpret_cast<bool *>(switch_node_->in_tensors()[kSwitchCondTensorIndex]->data());
if (cond_ptr == nullptr) {
MS_LOG(ERROR) << "switch cond input data is nullptr.";
context->SetFailed(RET_NULL_PTR);
return;
}
if (*cond_ptr) {
DecreaseFalseBranchInputTensor();
AsyncTrueBranchOutput(context);
} else {
DecreaseTrueBranchInputTensor();
AsyncFalseBranchOutput(context);
}
}
#endif
void LiteOpActor::SetInputShape() {
for (size_t i = 0; i < inputs_data_.size(); ++i) {
auto &input_tensor = kernel_->in_tensors()[i];
if (input_tensor->shape() == inputs_data_[i]->shape()) {
continue;
}
MS_LOG(DEBUG) << "inputs_data_[" << i << "].shape: " << inputs_data_[i]->shape() << " vs kernel_->in_tensors()["
<< i << "].shape: " << kernel_->in_tensors()[i]->shape() << " are not equal.";
MS_LOG(DEBUG) << "this->kernel_->name(): " << this->kernel_->name();
if (input_tensor->data_type() == kObjectTypeTensorType) {
#ifndef CONTROLFLOW_TENSORLIST_CLIP
auto input_tensorlist = reinterpret_cast<TensorList *>(input_tensor);
auto input_data_tensorlist = reinterpret_cast<TensorList *>(inputs_data_[i]);
input_tensorlist->FreeTensorListData();
input_tensorlist->set_element_shape(input_data_tensorlist->element_shape());
input_tensorlist->set_shape(input_data_tensorlist->shape());
std::vector<std::vector<int>> tensor_shape{};
std::transform(input_data_tensorlist->tensors().begin(), input_data_tensorlist->tensors().end(),
std::back_inserter(tensor_shape), [](Tensor *tensor_item) { return tensor_item->shape(); });
input_tensorlist->MallocTensorListData(input_data_tensorlist->tensors_data_type(), tensor_shape);
#endif
} else {
input_tensor->set_shape(inputs_data_[i]->shape());
input_tensor->set_format(inputs_data_[i]->format());
}
}
}
int LiteOpActor::InitInputData() {
SetInputShape();
for (size_t i = 0; i < inputs_data_.size(); ++i) {
auto dst_tensor = kernel_->in_tensors()[i];
auto src_tensor = inputs_data_[i];
if (dst_tensor->init_ref_count() == 0) {
src_tensor->DecRefCount();
continue;
}
if (NeedCastData(dst_tensor, src_tensor)) {
CastInputData(dst_tensor, src_tensor);
continue;
}
if (src_tensor->allocator() == nullptr || src_tensor->IsGraphInput()) {
SetInputData(dst_tensor, src_tensor);
} else {
MoveInputData(dst_tensor, src_tensor);
}
}
return RET_OK;
}
void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) {
for (size_t i = 0; i < output_data_arrows_.size(); i++) {
auto data = outputs_data_.at(i);
Async(output_data_arrows_[i]->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data.get(), context);
}
}
void LiteOpActor::AddResultIndex(size_t index) { results_index_.push_back(index); }
void LiteOpActor::SetOutputData(OpContext<Tensor> *context) {
for (auto index : results_index_) {
context->SetResult(index, RET_OK);
}
}
int LiteOpActor::PrepareOutputData() {
outputs_data_.resize(output_data_arrows_.size());
for (size_t i = 0; i < output_data_arrows_.size(); i++) {
auto &arrow = output_data_arrows_[i];
auto data =
std::make_shared<OpData<Tensor>>(arrow->to_op_id_, (kernel_->out_tensors()).at(arrow->from_output_index_),
static_cast<int>(arrow->to_input_index_));
if (data == nullptr) {
MS_LOG(ERROR) << "new output_data failed.";
return RET_NULL_PTR;
}
outputs_data_.at(i) = data;
}
return RET_OK;
}
std::vector<std::shared_ptr<LiteOpActor>> CreateOpActor(const std::vector<kernel::LiteKernel *> &kernels,
const lite::InnerContext *ctx) {
std::vector<std::shared_ptr<LiteOpActor>> actors;
std::unordered_map<kernel::LiteKernel *, AID> subgraph_name_AID_map{};
ActorThreadPool *thread_pool = reinterpret_cast<ActorThreadPool *>(ctx->thread_pool());
if (thread_pool == nullptr) {
MS_LOG(ERROR) << "thread pool is nullptr";
return actors;
}
for (auto &kernel : kernels) {
kernel->set_name(kernel->name() + "_" + to_string(actor_count++));
#ifndef CONTROLFLOW_TENSORLIST_CLIP
if ((kernel::LiteKernelUtil::IsSwitchCall(kernel))) {
auto switch_actor = std::make_shared<LiteSwitchOpActor>(kernel);
if (switch_actor == nullptr) {
MS_LOG(ERROR) << "create LiteSwitchOpActor failed: " << kernel->name();
actors.clear();
return actors;
}
switch_actor->set_thread_pool(thread_pool);
subgraph_name_AID_map[kernel] = switch_actor->GetAID();
actors.push_back(switch_actor);
} else {
#endif
auto actor = std::make_shared<LiteOpActor>(kernel);
if (actor == nullptr) {
MS_LOG(ERROR) << "create LiteOpActor failed: " << kernel->name();
actors.clear();
return actors;
}
actor->set_thread_pool(thread_pool);
subgraph_name_AID_map[kernel] = actor->GetAID();
actors.push_back(actor);
#ifndef CONTROLFLOW_TENSORLIST_CLIP
}
#endif
}
for (auto &actor : actors) {
actor->SetSubgraphAIDMap(subgraph_name_AID_map);
auto aid = mindspore::Spawn(actor);
}
return actors;
}
int MindrtInit() { return mindspore::Initialize("", "", "", ""); }
void MindrtTerminate(const std::vector<std::shared_ptr<LiteOpActor>> &actor_list) {
for (const auto &actor : actor_list) {
mindspore::Terminate(actor->GetAID());
}
}
}