11ce62f6创建于 2021年9月26日历史提交
/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/inner_kernel.h"
#include <algorithm>
#include "src/tensor.h"
#include "src/common/utils.h"
#include "src/runtime/infer_manager.h"

namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;

void InnerKernel::AllocWorkspace() {
  workspace_ = malloc(workspace_size());
  if (workspace_ == nullptr) {
    MS_LOG(ERROR) << "fail to alloc " << workspace_size() << "in kernel" << name();
    return;
  }
  ws_allocated_ = true;
}

void InnerKernel::FreeWorkspace() {
  if (ws_allocated_) {
    free(workspace_);
  }
  workspace_ = nullptr;
  ws_allocated_ = false;
}

int InnerKernel::PreProcess() {
  if (!InferShapeDone()) {
    auto ret = lite::KernelInferShape(in_tensors_, out_tensors_, op_parameter_);
    if (ret != 0) {
      MS_LOG(ERROR) << "InferShape fail!";
      return ret;
    }
    ret = ReSize();
    if (ret != 0) {
      MS_LOG(ERROR) << "ReSize fail!ret: " << ret;
      return ret;
    }
  }

  for (auto *output : this->out_tensors()) {
    MS_ASSERT(output != nullptr);
    if (registry_data_type_ == kNumberTypeFloat16 && output->data_type() == kNumberTypeFloat32) {
      output->set_data_type(kNumberTypeFloat16);
    }
    auto ret = output->MallocData();
    if (ret != RET_OK) {
      MS_LOG(ERROR) << "MallocData failed";
      return ret;
    }
    output->ResetRefCount();
  }
  return RET_OK;
}

int InnerKernel::Execute() {
  auto ret = PreProcess();
  if (lite::RET_OK != ret) {
    MS_LOG(ERROR) << "run kernel PreProcess failed, name: " << this->name();
    return ret;
  }

  if (op_parameter_->is_zero_shape_ == false) {
    ret = Run();
    if (lite::RET_OK != ret) {
      MS_LOG(ERROR) << "run kernel failed, name: " << this->name();
      return ret;
    }
  }

  ret = PostProcess();
  if (lite::RET_OK != ret) {
    MS_LOG(ERROR) << "run kernel PostProcess failed, name: " << this->name();
    return ret;
  }
  return lite::RET_OK;
}
}  // namespace mindspore::kernel