* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/train/train_session.h"
#include <sys/stat.h>
#include <algorithm>
#include <utility>
#include <vector>
#include <iostream>
#include <fstream>
#include <memory>
#include <queue>
#include <map>
#include "include/errorcode.h"
#include "src/executor.h"
#include "src/lite_model.h"
#include "src/lite_kernel_util.h"
#include "src/sub_graph_kernel.h"
#include "src/tensor.h"
#include "src/kernel_registry.h"
#include "src/common/prim_util.h"
#include "src/common/tensor_util.h"
#include "src/common/utils.h"
#include "src/runtime/kernel/arm/fp32_grad/convolution.h"
#include "src/runtime/kernel/arm/fp32/batchnorm_fp32.h"
#include "src/train/loss_kernel.h"
#include "src/train/optimizer_kernel.h"
#include "src/train/train_utils.h"
#include "src/train/train_export.h"
#include "src/train/opt_allocator.h"
#include "src/train/static_allocator.h"
#include "src/train/train_populate_parameter.h"
#include "src/train/train_populate_parameter_v0.h"
namespace mindspore {
namespace lite {
const char *kGradName = "Gradients";
const char *kOptimizerName = "optimizer";
TrainSession::TrainSession() {
is_train_session_ = true;
InitCallBack();
}
int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) {
if (train_cfg != nullptr) {
if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) {
MS_LOG(ERROR) << "illegal loss scale configuration";
return RET_NULL_PTR;
}
cfg_ = *train_cfg;
}
if (context == nullptr) {
MS_LOG(ERROR) << "context cannot be nullptr";
return RET_NULL_PTR;
}
allocator_ = context->allocator;
return lite::LiteSession::Init(context);
}
std::vector<CreatorOp> TrainSession::ReplaceOps() {
const std::vector<CreatorOp> replace = {
};
mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
std::vector<CreatorOp> results;
for (auto v : replace) {
const CreatorOp cl = make_tuple(std::get<0>(v), reg->GetCreator(std::get<0>(v)));
results.push_back(cl);
reg->RegKernel(std::get<0>(v), std::get<1>(v));
}
return results;
}
void TrainSession::RestoreOps(const std::vector<CreatorOp> &restore) {
mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
for (auto v : restore) {
reg->RegKernel(std::get<0>(v), std::get<1>(v));
}
}
int TrainSession::AllocWorkSpace() {
size_t workspace_size = 0;
for (auto kernel : this->train_kernels_) {
if (workspace_size < static_cast<kernel::InnerKernel *>(kernel->kernel())->workspace_size()) {
workspace_size = static_cast<kernel::InnerKernel *>(kernel->kernel())->workspace_size();
}
}
workspace_ = malloc(workspace_size);
if (workspace_ == nullptr) {
MS_LOG(ERROR) << "cannot allocate " << workspace_size << " for workspace";
return RET_ERROR;
}
for (auto kernel : this->train_kernels_) {
static_cast<kernel::InnerKernel *>(kernel->kernel())->set_workspace(workspace_);
}
return RET_OK;
}
void TrainSession::FreeWorkSpace() {
if (workspace_ != nullptr) {
free(workspace_);
workspace_ = nullptr;
}
for (auto kernel : this->train_kernels_) {
static_cast<kernel::InnerKernel *>(kernel->kernel())->FreeWorkspace();
}
}
int TrainSession::InitCallBack() {
sched_mix_precision_callback_ = [&](const Model::Node *node) {
if (!context_->IsCpuFloat16Enabled()) {
return false;
}
if (cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
auto out_tensor_indexs = node->output_indices_;
if (out_tensor_indexs.empty()) {
MS_LOG(DEBUG) << "Debug: " << node->name_ << " fp32";
return false;
}
auto is_fp16 = model_->all_tensors_.at(out_tensor_indexs[0])->dataType() == kNumberTypeFloat16;
MS_LOG(DEBUG) << "Debug: " << node->name_ << ((is_fp16) ? " fp16" : " fp32");
return is_fp16;
}
auto node_type = GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR);
if (node_type == schema::PrimitiveType_Cast) {
return false;
}
auto in_size = node->input_indices_.size();
bool force_fp16 = false;
for (std::size_t k = 0; k < in_size; k++) {
schema::Tensor *tensor = model_->all_tensors_.at(node->input_indices_[k]);
if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
force_fp16 = true;
break;
}
}
const auto &node_name = node->name_;
bool is_fp16 = true;
if (!force_fp16) {
if (node_name.find(kOptimizerName) != std::string::npos) {
is_fp16 = false;
}
if ((node_name.find(get_loss_name()) != std::string::npos)) {
is_fp16 = false;
}
if ((cfg_.mix_precision_cfg_.keep_batchnorm_fp32_) &&
(node_type == schema::PrimitiveType_FusedBatchNorm || node_type == schema::PrimitiveType_BatchNorm ||
node_type == schema::PrimitiveType_BatchNormGrad)) {
is_fp16 = false;
}
}
MS_LOG(DEBUG) << "Debug: " << node_name << ((is_fp16) ? " fp16" : " fp32");
return is_fp16;
};
return RET_OK;
}
int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels) {
if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK;
OptAllocator allocator;
std::unordered_map<lite::Tensor *, int> ref_count;
std::unordered_map<lite::Tensor *, size_t> offset_map;
for (auto kernel : kernels) {
for (auto tensor : kernel->out_tensors()) {
size_t size = tensor->Size();
size_t offset = allocator.Malloc(size);
offset_map[tensor] = offset;
ref_count[tensor] = tensor->init_ref_count();
}
for (auto tensor : kernel->in_tensors()) {
if (tensor->category() == lite::Tensor::VAR) {
int count = ref_count[tensor] - 1;
ref_count[tensor] = count;
if (count == 0) {
allocator.Free(offset_map[tensor]);
}
}
}
}
if (tensors_data_ == nullptr) {
auto size = allocator.total_size();
auto buf = malloc(size);
if (buf == nullptr) {
MS_LOG(ERROR) << "cannot allocate buffer size" << size;
return RET_ERROR;
}
StaticAllocator *alloc = reinterpret_cast<StaticAllocator *>(allocator_.get());
alloc->SetContex(buf, size);
tensors_data_ = buf;
}
for (auto kernel : train_kernels_) {
for (auto tensor : kernel->out_tensors()) {
auto it = offset_map.find(tensor);
if (it != offset_map.end()) {
tensor->set_data(reinterpret_cast<void *>(reinterpret_cast<char *>(tensors_data_) + it->second));
}
}
}
return RET_OK;
}
int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; }
int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) {
model_ = model;
auto restore = ReplaceOps();
sched_cb_ = std::make_unique<SchedulerCb>(sched_mix_precision_callback_);
if (sched_cb_ == nullptr) {
MS_LOG(ERROR) << "Failed to create SchedulerCb node";
return RET_ERROR;
}
#ifdef ENABLE_V0
if (reinterpret_cast<LiteModel *>(model_.get())->GetSchemaVersion() == SCHEMA_VERSION::SCHEMA_V0) {
kernel::PopulateTrainV0Parameters();
}
#endif
if (reinterpret_cast<LiteModel *>(model_.get())->GetSchemaVersion() == SCHEMA_VERSION::SCHEMA_CUR) {
kernel::PopulateTrainParameters();
}
auto ret = lite::LiteSession::CompileGraph(model_.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to compile train model";
return RET_ERROR;
}
orig_output_node_map_ = output_node_map_;
orig_output_tensor_map_ = output_tensor_map_;
orig_output_tensor_names_ = output_tensor_names_;
for (auto inTensor : inputs_) inTensor->MutableData();
RestoreOps(restore);
CompileTrainKernels();
CompileOptimizedKernels();
CompileTrainOutputs();
CompileEvalOutputs();
if (CompileInferenceKernels() != RET_OK) {
MS_LOG(ERROR) << "CompileInferenceKernels failed.";
return RET_ERROR;
}
ret = AllocWorkSpace();
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to allocate space";
return RET_ERROR;
}
ret = AllocTensors(train_kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to allocate space";
return RET_ERROR;
}
return RET_OK;
}
TrainSession::~TrainSession() {
FreeWorkSpace();
if (tensors_data_ != nullptr) {
free(tensors_data_);
tensors_data_ = nullptr;
}
}
int TrainSession::ExecKernels(const KernelCallBack &before, const KernelCallBack &after,
const std::vector<kernel::LiteKernel *> &run_kernels) {
for (auto *kernel : run_kernels) {
MS_ASSERT(kernel != nullptr);
auto ret = kernel->Execute(before, after);
if (RET_OK != ret) {
MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
return ret;
}
}
return RET_OK;
}
void TrainSession::RestoreTensorData() {
for (auto &restored_origin_tensor : restored_origin_tensors_) {
auto *origin_tensor = restored_origin_tensor.first;
auto *restored_tensor = restored_origin_tensor.second;
MS_ASSERT(origin_tensor != nullptr);
MS_ASSERT(restored_tensor != nullptr);
bool own_data = restored_tensor->own_data();
if (origin_tensor->data() == nullptr) {
restored_tensor->FreeData();
} else {
origin_tensor->FreeData();
}
origin_tensor->set_data_type(restored_tensor->data_type());
origin_tensor->set_data(restored_tensor->data());
origin_tensor->set_own_data(own_data);
}
}
void TrainSession::FreeRestoreTensors() {
for (auto &restored_origin_tensor : restored_origin_tensors_) {
auto *restored_tensor = restored_origin_tensor.second;
restored_tensor->set_data(nullptr);
delete (restored_tensor);
}
restored_origin_tensors_.clear();
}
bool TrainSession::IsLossTensor(Tensor *tensor) {
MS_ASSERT(tensor != nullptr);
auto t_n = tensor->tensor_name();
return (t_n.find(get_loss_name()) != std::string::npos);
}
bool TrainSession::AllInputsNeedScale(kernel::LiteKernel *kernel) {
auto type = kernel->type();
bool is_scale = false;
switch (type) {
case schema::PrimitiveType_AbsGrad:
case schema::PrimitiveType_AddFusion:
case schema::PrimitiveType_SubFusion:
case schema::PrimitiveType_AddN:
for (auto &tensor : kernel->in_tensors()) {
is_scale = is_scale || tensor->IsScale();
}
return (is_scale);
default:
return false;
}
return false;
}
int TrainSession::MixPrecisionPreProcess(kernel::LiteKernel *kernel, float scale) {
auto kernel_type = kernel->desc().data_type;
auto all_scale = AllInputsNeedScale(kernel);
for (auto &tensor : kernel->in_tensors()) {
if ((tensor->IsScale() == false) && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || (all_scale == true))) {
ScaleTensor(tensor, scale);
}
if (tensor->data_type() != kernel_type) {
auto restore_tensor = CastTensor(tensor, kernel_type, this->context_->device_and_pkg_support_fp16());
if (restore_tensor != nullptr) {
restored_origin_tensors_[tensor] = restore_tensor;
}
}
}
return RET_OK;
}
int TrainSession::MixPrecisionPostProcess(kernel::LiteKernel *kernel) {
RestoreTensorData();
FreeRestoreTensors();
float scale = 1.0f;
auto all_scale = AllInputsNeedScale(kernel);
for (auto &tensor : kernel->in_tensors()) {
if (tensor->IsScale()) {
scale *= tensor->get_scale();
if (all_scale) {
break;
}
}
}
for (auto &tensor : kernel->out_tensors()) {
tensor->set_scale(scale);
}
for (auto &tensor : kernel->in_tensors()) {
if ((tensor->IsScale() == true) && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || (all_scale == true))) {
ScaleTensor(tensor, 1.0f / scale);
}
}
return RET_OK;
}
int TrainSession::MixPrecisionExecKernels(const KernelCallBack &before, const KernelCallBack &after,
const std::vector<kernel::LiteKernel *> &run_kernels) {
float scale = cfg_.mix_precision_cfg_.loss_scale_;
for (auto *kernel : run_kernels) {
MS_ASSERT(kernel != nullptr);
MixPrecisionPreProcess(kernel, scale);
auto ret = kernel->Execute(before, after);
if (RET_OK != ret) {
MixPrecisionPostProcess(kernel);
if (ret == RET_OUT_OF_TENSOR_RANGE) {
bool is_dynamic_scale = cfg_.mix_precision_cfg_.dynamic_loss_scale_;
cfg_.mix_precision_cfg_.loss_scale_ = std::max(((is_dynamic_scale) ? (scale / 2.f) : scale), 1.0f);
num_of_not_nan_iter_ = 0;
return RET_OK;
}
MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
return ret;
}
MixPrecisionPostProcess(kernel);
}
if (cfg_.mix_precision_cfg_.dynamic_loss_scale_) {
num_of_not_nan_iter_++;
if (num_of_not_nan_iter_ >= cfg_.mix_precision_cfg_.num_of_not_nan_iter_th_) {
cfg_.mix_precision_cfg_.loss_scale_ = std::min((cfg_.mix_precision_cfg_.loss_scale_ * 2.0f), 65536.0f);
num_of_not_nan_iter_ = 0;
}
}
if (train_mode_ == false) {
for (auto t : this->outputs_) {
if (t->data_type() == kNumberTypeFloat16) {
auto restore = CastTensor(t, kNumberTypeFloat32, this->context_->device_and_pkg_support_fp16());
delete restore;
}
}
}
return RET_OK;
}
int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
auto ret = CheckTensorsInvalid(inputs_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CheckInputs failed";
return ret;
}
this->outputs_.clear();
for (auto &ms_tensors : output_node_map_) {
for (auto &ms_tensor : ms_tensors.second) {
auto lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
this->outputs_.push_back(lite_tensor);
}
}
if (this->context_ == nullptr) {
MS_LOG(ERROR) << "context is null";
return lite::RET_NULL_PTR;
}
auto &run_kernels = (train_mode_) ? train_kernels_ : inference_kernels_;
if (context_->IsCpuFloat16Enabled() && !cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
ret = MixPrecisionExecKernels(before, after, run_kernels);
} else {
ret = ExecKernels(before, after, run_kernels);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to run model kernels";
return ret;
}
if (train_mode_ && virtual_batch_multiplier_) {
virtual_batch_idx_++;
if (virtual_batch_idx_ >= virtual_batch_multiplier_) {
virtual_batch_idx_ = 0;
ret = OptimizerStep();
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to optimize model weights";
return ret;
}
}
}
return RET_OK;
}
int TrainSession::Train() {
train_mode_ = true;
virtual_batch_idx_ = 0;
for (auto &kernel : this->train_kernels_) {
MS_ASSERT(kernel != nullptr);
auto ret = kernel->Train();
if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set train mode";
return RET_ERROR;
}
}
output_node_map_ = train_output_node_map_;
output_tensor_map_ = train_output_tensor_map_;
output_tensor_names_ = train_output_tensor_names_;
kernel::LiteKernelUtil::InitTensorInitRefCount(train_kernels_);
for (auto &ms_tensors : eval_output_node_map_) {
for (auto &ms_tensor : ms_tensors.second) {
lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
}
}
auto ret = AllocTensors(train_kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to allocate tensor space";
return RET_ERROR;
}
return RET_OK;
}
int TrainSession::Eval() {
train_mode_ = false;
virtual_batch_idx_ = 0;
for (auto &kernel : this->train_kernels_) {
MS_ASSERT(kernel != nullptr);
auto ret = kernel->Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set eval mode";
return RET_ERROR;
}
}
output_node_map_ = eval_output_node_map_;
output_tensor_map_ = eval_output_tensor_map_;
output_tensor_names_ = eval_output_tensor_names_;
kernel::LiteKernelUtil::InitTensorInitRefCount(inference_kernels_);
for (auto &ms_tensors : eval_output_node_map_) {
for (auto &ms_tensor : ms_tensors.second) {
lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
}
}
auto ret = AllocTensors(inference_kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to allocate space";
return RET_ERROR;
}
return RET_OK;
}
void TrainSession::CompileEvalOutputs() {
eval_output_node_map_.clear();
eval_output_tensor_map_.clear();
eval_output_tensor_names_.clear();
for (auto kernel : this->train_kernels_) {
if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) {
for (auto in_kernel : kernel->in_kernels()) {
if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue;
if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) {
auto *ms_tensor = in_kernel->out_tensors().at(0);
if (ms_tensor != nullptr) {
ms_tensor->set_init_ref_count(ms_tensor->init_ref_count() + 1);
eval_output_node_map_[in_kernel->name()].emplace_back(ms_tensor);
auto index = TSFindTensor(tensors_, ms_tensor);
if (index != tensors_.size()) {
eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
if (!ms_tensor->tensor_name().empty()) {
eval_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
} else {
eval_output_tensor_names_.emplace_back(std::to_string(index));
}
}
}
}
}
}
}
if (eval_output_node_map_.size() == 0) eval_output_node_map_ = orig_output_node_map_;
if (eval_output_tensor_map_.size() == 0) eval_output_tensor_map_ = orig_output_tensor_map_;
if (eval_output_tensor_names_.size() == 0) eval_output_tensor_names_ = orig_output_tensor_names_;
}
void TrainSession::CompileTrainOutputs() {
train_output_node_map_.clear();
train_output_tensor_map_.clear();
train_output_tensor_names_.clear();
for (auto kernel : this->train_kernels_) {
if (orig_output_node_map_.find(kernel->name()) == orig_output_node_map_.end()) continue;
if (IsMaskOutput(kernel)) continue;
if (train_output_node_map_.find(kernel->name()) == train_output_node_map_.end()) {
auto *ms_tensor = kernel->out_tensors().at(0);
if (ms_tensor != nullptr) {
train_output_node_map_[kernel->name()].emplace_back(ms_tensor);
auto index = TSFindTensor(tensors_, ms_tensor);
if (index != tensors_.size()) {
train_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
if (!ms_tensor->tensor_name().empty()) {
train_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
} else {
train_output_tensor_names_.emplace_back(std::to_string(index));
}
}
}
}
}
if (train_output_node_map_.size() == 0) train_output_node_map_ = orig_output_node_map_;
if (train_output_tensor_map_.size() == 0) train_output_tensor_map_ = orig_output_tensor_map_;
if (train_output_tensor_names_.size() == 0) train_output_tensor_names_ = orig_output_tensor_names_;
}
void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *v) {
MS_ASSERT(kernel != nullptr);
MS_ASSERT(v != nullptr);
if (std::find(v->begin(), v->end(), kernel) == v->end()) {
for (auto in_node : kernel->in_kernels()) {
BuildInferenceKernelsRecursive(in_node, v);
}
if (!IsLossKernel(kernel)) v->push_back(kernel);
}
}
void TrainSession::CompileTrainKernels() {
train_kernels_.clear();
for (auto ori_kernel : kernels_) {
if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) {
train_kernels_.push_back(ori_kernel);
} else {
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(ori_kernel);
for (auto kernel : sub_graph->nodes()) {
train_kernels_.push_back(kernel);
}
}
}
}
int TrainSession::CompileInferenceKernels() {
inference_kernels_.clear();
for (auto item : eval_output_node_map_) {
std::string kernel_name = item.first;
auto kernel = TSFindKernel(train_kernels_, kernel_name);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr";
return RET_ERROR;
}
BuildInferenceKernelsRecursive(kernel, &inference_kernels_);
}
if (inference_kernels_.size() == 0) {
inference_kernels_ = this->train_kernels_;
}
return RET_OK;
}
void TrainSession::CompileOptimizedKernels() {
std::vector<lite::Tensor *> out_tensor;
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(out_tensor));
if (cfg_.accumulate_gradients_) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
optimizer->SetOptimizerMode(kernel::WeightUpdateMode::ACCUMULATE_GRADS);
}
}
}
for (auto kernel : this->train_kernels_) {
if (!IsOptimizer(kernel)) {
for (auto it : kernel->in_tensors()) {
if (std::find(out_tensor.begin(), out_tensor.end(), it) != out_tensor.end()) {
kernel->SetTrainable(true);
break;
}
}
}
}
}
int TrainSession::SetLearningRate(float learning_rate) {
if (learning_rate < 0.0f) {
MS_LOG(ERROR) << "learning rate should more than 0";
return RET_ERROR;
}
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto ret = optimizer->SetLearningRate(learning_rate);
if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
return RET_ERROR;
}
}
}
return RET_OK;
}
float TrainSession::GetLearningRate() {
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
return optimizer->GetLearningRate();
}
}
return 0.0;
}
std::vector<tensor::MSTensor *> TrainSession::GetOptimizerParams() const {
std::vector<tensor::MSTensor *> params;
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto kernelParams = optimizer->GetOptimizerParams();
for (size_t ix = 0; ix < kernelParams.size(); ix++) {
auto kernelParam = kernelParams[ix];
auto name = kernelParam->tensor_name();
bool found = false;
for (size_t iy = 0; iy < params.size(); iy++) {
if (params[iy]->tensor_name() == name) {
found = true;
break;
}
}
if (!found) {
params.push_back(kernelParam);
}
}
}
}
return params;
}
int TrainSession::SetOptimizerParams(const std::vector<tensor::MSTensor *> ¶ms) {
for (size_t ix = 0; ix < params.size(); ix++) {
auto param = params[ix];
if (param == nullptr) {
MS_LOG(ERROR) << "Param tensor " << param->tensor_name() << " is null.";
return RET_ERROR;
}
bool found = false;
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
found = optimizer->SetOptimizerParams(param);
if (found) break;
}
}
if (!found) {
MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type "
<< param->data_type() << " is not a valid params tensor";
return RET_ERROR;
}
}
return RET_OK;
}
std::vector<tensor::MSTensor *> TrainSession::GetGradients() const {
std::vector<tensor::MSTensor *> gradients;
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto kernelGradint = optimizer->GetGradients();
if (kernelGradint != nullptr) {
gradients.push_back(kernelGradint);
}
}
}
return gradients;
}
int TrainSession::ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) {
auto current_gradients = GetGradients();
if (current_gradients.size() != gradients.size()) {
MS_LOG(ERROR) << "gradients vector has wrong size " << gradients.size() << " instead of "
<< current_gradients.size();
return RET_ERROR;
}
for (size_t ix = 0; ix < gradients.size(); ix++) {
auto gradient = gradients[ix];
if (gradient == nullptr) {
MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " is null.";
return RET_ERROR;
}
bool found = false;
for (size_t iy = 0; iy < current_gradients.size(); iy++) {
auto current_gradient = current_gradients[iy];
if (current_gradient->tensor_name() == gradient->tensor_name()) {
found = true;
if (current_gradient->Size() == gradient->Size()) {
std::copy(static_cast<char *>(gradient->data()), static_cast<char *>(gradient->data()) + gradient->Size(),
static_cast<char *>(current_gradient->MutableData()));
} else {
MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " has wrong size " << gradient->Size()
<< " instead of " << current_gradient->Size();
return RET_ERROR;
}
break;
}
}
if (!found) {
MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " not found";
return RET_ERROR;
}
}
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
optimizer->set_grad_sum_valid();
auto ret = optimizer->OptimizerStep();
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to optimize model weights";
return ret;
}
}
}
for (size_t ix = 0; ix < current_gradients.size(); ix++) {
delete current_gradients[ix];
}
return RET_OK;
}
int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
auto mod =
(virtual_batch_multiplier <= 1) ? kernel::WeightUpdateMode::NORMAL : kernel::WeightUpdateMode::VIRTUAL_BATCH;
virtual_batch_multiplier_ = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
virtual_batch_idx_ = 0;
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
if (optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::NORMAL &&
optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::VIRTUAL_BATCH) {
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode, conflict with accumulate grads";
return RET_ERROR;
}
auto ret = optimizer->SetOptimizerMode(mod);
if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
return RET_ERROR;
}
if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
lr = (lr < 0.0f) ? (optimizer->GetLearningRate() / static_cast<float>(virtual_batch_multiplier_)) : lr;
ret = optimizer->SetLearningRate(lr);
} else {
ret = optimizer->RestoreDefaultLearningRate();
}
if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
return RET_ERROR;
}
}
if (IsBN(kernel) && kernel->IsTrainable()) {
auto batchnorm = static_cast<kernel::BatchnormCPUKernel *>(kernel->kernel());
auto ret = RET_OK;
if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
ret = batchnorm->set_momentum(momentum);
} else {
ret = batchnorm->RestoreDefaultMomentum();
}
if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set momentum";
return RET_ERROR;
}
}
}
return RET_OK;
}
int TrainSession::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
int tmp = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
if (tmp != 0 && virtual_batch_multiplier_ != 0) {
AdminSetupVirtualBatch(0, lr, momentum);
}
return AdminSetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
}
int TrainSession::OptimizerStep() {
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto ret = optimizer->OptimizerStep();
if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";
return RET_ERROR;
}
}
}
return RET_OK;
}
bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const {
return (kernel->type() == schema::PrimitiveType_SoftmaxCrossEntropyWithLogits ||
kernel->type() == schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits ||
kernel->type() == schema::PrimitiveType_SmoothL1Loss ||
kernel->type() == schema::PrimitiveType_SmoothL1LossGrad ||
kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogits ||
kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad) ||
kernel->name().find(cfg_.loss_name_) != std::string::npos;
}
bool TrainSession::IsGradKernel(const kernel::LiteKernel *kernel) const {
return kernel->name().find(kGradName) != std::string::npos;
}
bool TrainSession::IsOptimizer(kernel::LiteKernel *kernel) const {
return ((kernel->type() == schema::PrimitiveType_Adam) || (kernel->type() == schema::PrimitiveType_SGD) ||
(kernel->type() == schema::PrimitiveType_ApplyMomentum));
}
bool TrainSession::IsMaskOutput(kernel::LiteKernel *kernel) const {
return (IsOptimizer(kernel) || (kernel->type() == schema::PrimitiveType_Assign));
}
bool TrainSession::IsBN(kernel::LiteKernel *kernel) const {
return ((kernel->type() == schema::PrimitiveType_BatchNorm) ||
(kernel->type() == schema::PrimitiveType_FusedBatchNorm));
}
int TrainSession::Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) {
FreeWorkSpace();
if (tensors_data_ != nullptr) {
free(tensors_data_);
tensors_data_ = nullptr;
}
auto ret = lite::LiteSession::Resize(inputs, dims);
if (ret != RET_OK) {
MS_LOG(ERROR) << "train resize input failed.";
return RET_ERROR;
}
ret = AllocWorkSpace();
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to allocate space";
return RET_ERROR;
}
ret = AllocTensors(train_kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "train alloc failed after resize.";
return RET_ERROR;
}
return RET_OK;
}
int TrainSession::FindUseInTensorKernel(std::vector<kernel::LiteKernel *> *use_in_tensor_kernels,
const std::vector<lite::Tensor *> &kernel_in_tensors,
const std::vector<kernel::LiteKernel *> &inference_kernels) {
for (size_t i = 0; i < inference_kernels.size(); i++) {
for (size_t j = 0; j < kernel_in_tensors.size(); j++) {
if (IsContain(inference_kernels[i]->out_tensors(), kernel_in_tensors[j])) {
use_in_tensor_kernels->push_back(inference_kernels[i]);
}
}
}
return RET_OK;
}
int TrainSession::FindExportKernels(std::vector<kernel::LiteKernel *> *export_kernels,
const std::vector<std::string> &export_output_tensor_names,
const std::vector<kernel::LiteKernel *> &inference_kernels) {
std::vector<std::string> all_kernel_name = {};
std::transform(inference_kernels.begin(), inference_kernels.end(), std::back_inserter(all_kernel_name),
[](kernel::LiteKernel *kernel) { return kernel->name(); });
std::queue<std::string> need_kernel_names;
for (auto &kernel : inference_kernels) {
if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(), [&](lite::Tensor *out_tensor) {
return IsContain(export_output_tensor_names, out_tensor->tensor_name());
})) {
need_kernel_names.push(kernel->name());
}
}
if (need_kernel_names.size() == 0) {
MS_LOG(ERROR) << "can not find tensor";
return RET_ERROR;
}
while (!need_kernel_names.empty()) {
auto kernel_name = need_kernel_names.front();
need_kernel_names.pop();
auto it = find(all_kernel_name.begin(), all_kernel_name.end(), kernel_name);
if (it == all_kernel_name.end()) {
MS_LOG(ERROR) << "not find kernel name in export trained model.";
return RET_ERROR;
}
auto kernel = inference_kernels[it - all_kernel_name.begin()];
if (!IsContain(*export_kernels, kernel)) {
export_kernels->push_back(kernel);
}
auto kernel_in_tensors = kernel->in_tensors();
std::vector<kernel::LiteKernel *> use_in_tensor_kernels;
auto status = FindUseInTensorKernel(&use_in_tensor_kernels, kernel_in_tensors, inference_kernels);
if (status != RET_OK) {
MS_LOG(ERROR) << "FindUseInTensorKernel failed.";
return RET_ERROR;
}
for (size_t i = 0; i < use_in_tensor_kernels.size(); i++) {
need_kernel_names.push(use_in_tensor_kernels[i]->name());
}
}
return RET_OK;
}
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
FormatType format, std::vector<std::string> out_put_tensor_name) {
if (file_name.empty()) {
MS_LOG(ERROR) << "File name cannot be empty";
return RET_ERROR;
}
if (model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN) {
MS_LOG(ERROR) << "Export model type parameter error";
return RET_ERROR;
}
if (quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT) {
MS_LOG(ERROR) << "Export quant type parameter error";
return RET_ERROR;
}
if (format != FT_FLATBUFFERS) {
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
return RET_ERROR;
}
bool orig_train_state = IsTrain();
Eval();
TrainExport texport(file_name);
int status = texport.ExportInit(model_.get()->name_, model_.get()->version_);
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot init export";
return status;
}
if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
std::vector<kernel::LiteKernel *> export_kernels = {};
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
if (status != RET_OK) {
MS_LOG(ERROR) << "FindExportKernels failed.";
return RET_ERROR;
}
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
} else {
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
model_.get(), quant_type);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot export Network";
return status;
}
status = texport.SaveToFile();
if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << file_name;
return status;
}
if (orig_train_state) Train();
return status;
}
std::vector<tensor::MSTensor *> TrainSession::GetFeatureMaps() const {
std::vector<tensor::MSTensor *> features;
for (auto cur_tensor : this->tensors_) {
if (cur_tensor->IsConst() && cur_tensor->data_type() == kNumberTypeFloat32) {
features.push_back(cur_tensor);
}
}
return features;
}
int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) {
for (auto feature : features_map) {
bool find = false;
for (auto tensor : tensors_) {
if (!tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
continue;
}
if (feature->tensor_name() != tensor->tensor_name()) {
continue;
}
if (feature->Size() != tensor->Size()) {
MS_LOG(ERROR) << "feature name:" << feature->tensor_name() << ",len diff:"
<< "old is:" << tensor->Size() << "new is:" << feature->Size();
return RET_ERROR;
}
find = true;
memcpy(tensor->data(), feature->data(), tensor->Size());
}
if (!find) {
MS_LOG(ERROR) << "cannot find feature:" << feature->tensor_name() << ",update failed";
return RET_ERROR;
}
}
return RET_OK;
}
}
session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
bool train_mode, const lite::TrainCfg *cfg) {
if (context == nullptr) {
MS_LOG(ERROR) << "context cannot be nullptr";
return nullptr;
}
auto session = std::make_unique<lite::TrainSession>();
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed";
return nullptr;
}
if (context->allocator == nullptr) {
const_cast<lite::Context *>(context)->allocator = std::make_shared<StaticAllocator>();
if (context->allocator == nullptr) {
MS_LOG(ERROR) << " cannot convert to static allocation";
}
}
auto *inner_context = new (std::nothrow) mindspore::lite::InnerContext(context);
auto ret = session->Init(inner_context, cfg);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init session failed";
return nullptr;
}
std::string filename = fn;
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
filename = filename + ".ms";
}
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str()));
if (model == nullptr) {
MS_LOG(ERROR) << "create model for train session failed " << filename;
return nullptr;
}
ret = session->CompileTrainGraph(model);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "Compiling Train Graph session failed";
return nullptr;
}
if (train_mode) {
ret = session->Train();
} else {
ret = session->Eval();
}
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "Could not switch to Train Modei " << train_mode;
return nullptr;
}
return session.release();
}
}