* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_SUB_GRAPH_KERNEL_H_
#define MINDSPORE_LITE_SRC_SUB_GRAPH_KERNEL_H_
#include <atomic>
#include <utility>
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "src/lite_kernel.h"
#include "src/executor.h"
#include "src/common/log_adapter.h"
#include "src/common/version_manager.h"
#include "src/cpu_info.h"
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
#include "nnacl/constant_of_shape_parameter.h"
#endif
namespace mindspore::kernel {
struct DataStore {
void *data_ = nullptr;
Allocator *allocator_ = nullptr;
bool own_data_ = true;
static DataStore *CreateDataStore(void *data = nullptr, bool own_data = true, Allocator *data_allocator = nullptr,
Allocator *allocator = nullptr) {
DataStore *data_store = nullptr;
if (allocator == nullptr) {
data_store = static_cast<DataStore *>(malloc(sizeof(DataStore)));
} else {
data_store = static_cast<DataStore *>(allocator->Malloc(sizeof(DataStore)));
}
if (data_store == nullptr) {
MS_LOG(ERROR) << "Malloc data_store failed";
return nullptr;
}
data_store->data_ = data;
data_store->own_data_ = own_data;
data_store->allocator_ = data_allocator;
return data_store;
}
};
class SubGraphKernel : public LiteKernel {
public:
SubGraphKernel(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
std::vector<LiteKernel *> nodes, Kernel *kernel)
: LiteKernel(std::shared_ptr<Kernel>(kernel)),
nodes_(std::move(nodes)),
in_nodes_(std::move(in_kernels)),
out_nodes_(std::move(out_kernels)) {
subgraph_type_ = kCpuFP32SubGraph;
desc_.data_type = kNumberTypeFloat32;
}
~SubGraphKernel() override {
for (auto *node : nodes_) {
delete node;
}
nodes_.clear();
}
bool IsReady(const std::vector<lite::Tensor *> &scope_tensors) override {
return std::all_of(this->in_nodes_.begin(), this->in_nodes_.end(),
[&](LiteKernel *kernel) { return kernel->IsReady(scope_tensors); });
}
int Prepare() override;
int Execute() override { return Execute(nullptr, nullptr); }
int Execute(const KernelCallBack &before, const KernelCallBack &after) override;
int ReSize() override;
void InitOutTensorInitRefCount(const std::vector<LiteKernel *> *mask_kernels) override;
void InitInputTensorInitRefCount();
int Init() override { return mindspore::lite::RET_OK; }
std::string ToString() const override;
std::vector<LiteKernel *> &nodes() { return this->nodes_; }
void DropNode(LiteKernel *node);
std::vector<LiteKernel *> in_nodes() { return this->in_nodes_; }
std::vector<LiteKernel *> out_nodes() { return this->out_nodes_; }
void SetSchemaVersion(int schema_version) { schema_version_ = schema_version; }
protected:
std::vector<LiteKernel *> nodes_{};
std::vector<LiteKernel *> in_nodes_{};
std::vector<LiteKernel *> out_nodes_{};
mindspore::lite::Executor *executor_ = nullptr;
int schema_version_ = lite::SCHEMA_VERSION::SCHEMA_CUR;
};
class CpuSubGraph : public SubGraphKernel {
public:
CpuSubGraph(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
std::vector<LiteKernel *> nodes, Kernel *kernel)
: SubGraphKernel(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) {
subgraph_type_ = kCpuFP32SubGraph;
desc_.arch = kernel::KERNEL_ARCH::kCPU;
}
~CpuSubGraph() override { delete this->executor_; }
int Prepare() override;
int Init() override { return SubGraphKernel::Init(); }
int Execute() override { return Execute(nullptr, nullptr); }
int Execute(const KernelCallBack &before, const KernelCallBack &after) override;
};
class CpuFp32SubGraph : public CpuSubGraph {
public:
CpuFp32SubGraph(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
std::vector<LiteKernel *> nodes, Kernel *kernel)
: CpuSubGraph(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) {
subgraph_type_ = kCpuFP32SubGraph;
static std::atomic_int index = {0};
this->set_name("CpuFP32SubGraph" + std::to_string(index++));
desc_.data_type = kNumberTypeFloat32;
}
~CpuFp32SubGraph() override = default;
};
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
class CpuFp16SubGraph : public CpuSubGraph {
public:
CpuFp16SubGraph(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
std::vector<LiteKernel *> nodes, Kernel *kernel)
: CpuSubGraph(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) {
subgraph_type_ = kCpuFP16SubGraph;
static std::atomic_int index = 0;
this->set_name("CpuFP16SubGraph" + std::to_string(index++));
desc_.data_type = kNumberTypeFloat16;
}
~CpuFp16SubGraph() override = default;
int Init() override {
const auto *context = this->Context();
MS_ASSERT(context != nullptr);
support_fp16_ = context->device_and_pkg_support_fp16();
return CpuSubGraph::Init();
}
int Prepare() override {
auto ret = CpuSubGraph::Prepare();
if (ret != RET_OK) {
return ret;
}
for (auto &node : this->nodes_) {
if (node->type() == schema::PrimitiveType_Cast) {
auto inputs = node->in_tensors();
MS_ASSERT(inputs.size() >= 2);
auto dst_tensor = inputs[1];
MS_ASSERT(dst_tensor != nullptr);
MS_ASSERT(dst_tensor->data_type() == kNumberTypeInt32);
MS_ASSERT(dst_tensor->data() != nullptr);
MS_ASSERT(dst_tensor->ElementsNum() == 1);
auto *dst_data = reinterpret_cast<int32_t *>(dst_tensor->data());
if (dst_data[0] == kNumberTypeFloat32) {
dst_data[0] = kNumberTypeFloat16;
}
auto outputs = node->out_tensors();
MS_ASSERT(outputs.size() == 1);
auto output = outputs.front();
MS_ASSERT(output != nullptr);
if (output->data_type() == kNumberTypeFloat32) {
output->set_data_type(kNumberTypeFloat16);
}
} else if (node->type() == schema::PrimitiveType_ConstantOfShape) {
auto param = node->op_parameter();
MS_ASSERT(param != nullptr);
if (static_cast<TypeId>(reinterpret_cast<ConstantOfShapeParameter *>(param)->data_type_ ==
kNumberTypeFloat32)) {
reinterpret_cast<ConstantOfShapeParameter *>(param)->data_type_ = kNumberTypeFloat16;
}
auto outputs = node->out_tensors();
MS_ASSERT(outputs.size() == 1);
auto output = outputs.front();
MS_ASSERT(output != nullptr);
if (output->data_type() == kNumberTypeFloat32) {
output->set_data_type(kNumberTypeFloat16);
}
}
}
return RET_OK;
}
private:
bool support_fp16_ = false;
};
#endif
class CustomSubGraph : public SubGraphKernel {
public:
CustomSubGraph(std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels,
std::vector<LiteKernel *> nodes, Kernel *kernel)
: SubGraphKernel(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) {
subgraph_type_ = kCustomSubGraph;
desc_.arch = kernel::KERNEL_ARCH::kCustom;
}
~CustomSubGraph() override { delete this->executor_; }
int Prepare() override;
int Init() override { return SubGraphKernel::Init(); }
int Execute() override { return Execute(nullptr, nullptr); }
int Execute(const KernelCallBack &before, const KernelCallBack &after) override;
};
}
#endif