/**
 * Copyright 2020-2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/lite_kernel.h"
#include <algorithm>
#include "src/tensor.h"
#include "src/common/utils.h"
#include "src/common/version_manager.h"

namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;

bool LiteKernel::IsReady(const std::vector<lite::Tensor *> &scope_tensors) {
  MS_ASSERT(kernel_ != nullptr);
  auto &in_tensors = this->in_tensors();
  return std::all_of(in_tensors.begin(), in_tensors.end(), [&](lite::Tensor *in_tensor) {
    if (IsContain(scope_tensors, in_tensor)) {
      return in_tensor->IsReady();
    } else {
      return true;
    }
  });
}

void LiteKernel::InitOutTensorInitRefCount(const std::vector<LiteKernel *> *mask_kernels) {
  for (auto *tensor : this->out_tensors()) {
    MS_ASSERT(tensor != nullptr);
    size_t init_ref_count = 0;
    for (auto *post_kernel : this->out_kernels_) {
      if ((mask_kernels == nullptr) ||
          std::find(mask_kernels->begin(), mask_kernels->end(), post_kernel) != mask_kernels->end()) {
        auto &post_in_tensors = post_kernel->in_tensors();
        init_ref_count += std::count_if(
          post_in_tensors.begin(), post_in_tensors.end(),
          [&tensor](const lite::Tensor *post_kernel_in_tensor) { return post_kernel_in_tensor == tensor; });
      }
    }
    tensor->set_init_ref_count(init_ref_count);
  }
}

std::string LiteKernel::ToString() const {
  std::ostringstream oss;
  oss << "LiteKernel: " << this->name();
  oss << ", Type: " << this->type_str();
  oss << ", " << this->in_tensors().size() << " InputTensors:";
  for (auto tensor : in_tensors()) {
    oss << " " << tensor;
  }
  oss << ", " << this->out_tensors().size() << " OutputTensors:";
  for (auto tensor : out_tensors()) {
    oss << " " << tensor;
  }
  oss << ", " << this->in_kernels_.size() << " InputKernels:";
  for (auto in_kernel : in_kernels_) {
    oss << " " << in_kernel->name();
  }
  oss << ", " << this->out_kernels_.size() << " OutputKernels:";
  for (auto out_kernel : out_kernels_) {
    oss << " " << out_kernel->name();
  }
  return oss.str();
}

int LiteKernel::DoExecute() {
  auto ret = kernel_->Execute();
  if ((ret == lite::RET_OK) && (desc_.provider != kBuiltin)) {
    for (auto *output : out_tensors()) {
      MS_ASSERT(output != nullptr);
      output->ResetRefCount();
    }
    for (auto &in_tensor : in_tensors()) {
      MS_ASSERT(in_tensor != nullptr);
      in_tensor->DecRefCount();
    }
  }
  return ret;
}
}  // namespace mindspore::kernel