/*
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "slice_operation.h"
#include <mki/utils/platform/platform_info.h>
#include "atb/utils/log.h"
#include "atb/utils/tensor_check.h"
#include "atb/utils/param_to_json.h"
#include "atb/utils/singleton.h"
#include "atb/operation/atb_operation_ir_cfg.h"
#include "atb/operation/op_param_funcs.h"
#include "slice_aclnn_runner.h"
#include "slice_ops_runner.h"

namespace atb {
static const uint32_t IN_TENSOR_NUM = 1;
static const uint32_t OUT_TENSOR_NUM = 1;

template <> Status CreateOperation(const infer::SliceParam &opParam, Operation **operation)
{
    if (operation == nullptr) {
        return ERROR_INVALID_PARAM;
    }
    OP_PARAM_RSV_CHECK(opParam);
    if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
        if (SliceAclnnRunner::LoadAclnnFuncs() != NO_ERROR) {
            ATB_LOG(ERROR) << "Load aclnn function failed, please check your CANN version.";
            return ERROR_CANN_ERROR;
        }
    }
    *operation = new (std::nothrow) SliceOperation(opParam);
    if (*operation == nullptr) {
        ATB_LOG(ERROR) << "failed to new operation";
        return ERROR_OUT_OF_HOST_MEMORY;
    }
    return NO_ERROR;
}

SliceOperation::SliceOperation(const infer::SliceParam &param) : OperationBase("SliceOperation"), param_(param)
{
    operationIr_ = GetSingleton<AtbOperationIrCfg>().GetOperationIr("SliceOperation");
}

SliceOperation::~SliceOperation() {}

uint32_t SliceOperation::GetInputNum() const
{
    return IN_TENSOR_NUM;
}

uint32_t SliceOperation::GetOutputNum() const
{
    return OUT_TENSOR_NUM;
}

Status SliceOperation::InferShapeImpl(const SVector<TensorDesc> &inTensorDescs,
                                      SVector<TensorDesc> &outTensorDescs) const
{
    outTensorDescs.at(0) = inTensorDescs.at(0);

    const SVector<int64_t> &offsets = param_.offsets;
    const SVector<int64_t> &size = param_.size;
    for (size_t i = 0; i < param_.size.size(); i++) {
        int64_t offsetValue = offsets.at(i);
        int64_t sizeValue = size.at(i);
        int64_t xDim = inTensorDescs.at(0).shape.dims[i];
        if (offsetValue < 0) {
            offsetValue = offsetValue + xDim;
        }
        if (sizeValue == -1) {
            sizeValue = xDim - offsetValue;
        }

        outTensorDescs.at(0).shape.dims[i] = sizeValue;
    }

    return NO_ERROR;
}

Status SliceOperation::InferShapeCheckImpl(const SVector<TensorDesc> &inTensorDescs) const
{
    return ParamCheck(inTensorDescs.at(0));
}

Status SliceOperation::SetupCheckImpl(const SVector<Tensor> &inTensors, const SVector<Tensor> &outTensors) const
{
    Status status = ParamCheck(inTensors.at(0).desc);
    if (status != NO_ERROR) {
        return status;
    }

    const SVector<int64_t> &offsets = param_.offsets;
    const SVector<int64_t> &size = param_.size;
    for (size_t i = 0; i < param_.size.size(); i++) {
        int64_t offsetValue = offsets.at(i);
        int64_t sizeValue = size.at(i);
        int64_t xDim = inTensors.at(0).desc.shape.dims[i];
        if (offsetValue < 0) {
            offsetValue = offsetValue + xDim;
        }
        if (sizeValue == -1) {
            sizeValue = xDim - offsetValue;
        }

        if (outTensors.at(0).desc.shape.dims[i] != sizeValue) {
            ATB_LOG(ERROR) << "outTensor dim and param size are different at: " << i
                           << " dim : " << outTensors.at(0).desc.shape.dims[i] << " param size: " << sizeValue;
            return ERROR_INVALID_TENSOR_DIM;
        }
    }

    return NO_ERROR;
}

Status SliceOperation::ParamCheck(TensorDesc inTensorDesc) const
{
    size_t sliceSize = param_.size.size();
    auto xTensorDims = inTensorDesc.shape.dimNum;
    if (sliceSize != param_.offsets.size()) {
        ATB_LOG(ERROR) << "SliceOperation InferShapeImpl failed: offsets length must be equal to size length, "
                       << "the offsets length is " << param_.offsets.size() << ". the size length is "
                       << param_.size.size();
        return ERROR_INVALID_PARAM;
    }
    if (sliceSize != xTensorDims) {
        ATB_LOG(ERROR) << "SliceOperation InferShapeImpl failed: size length must be equal to the input tensor dimNum, "
                       << "the size length is " << param_.size.size() << ". the input tensor dimNum is "
                       << inTensorDesc.shape.dimNum;
        return ERROR_INVALID_PARAM;
    }

    const SVector<int64_t> &offsets = param_.offsets;
    const SVector<int64_t> &size = param_.size;
    for (size_t i = 0; i < param_.size.size(); i++) {
        int64_t offsetValue = offsets.at(i);
        int64_t sizeValue = size.at(i);
        int64_t xDim = inTensorDesc.shape.dims[i];
        if (offsetValue < -xDim) {
            ATB_LOG(ERROR) << "SliceOperation InferShapeImpl: wrong offset: " << offsetValue;
            return ERROR_INVALID_PARAM;
        }
        if (offsetValue < 0) {
            offsetValue = offsetValue + xDim;
        }
        if (sizeValue == -1) {
            sizeValue = xDim - offsetValue;
        } else if (sizeValue < -1) {
            ATB_LOG(ERROR) << "SliceOperation InferShapeImpl: Wrong size: " << sizeValue;
            return ERROR_INVALID_PARAM;
        }
        if (std::numeric_limits<int64_t>::max() - offsetValue < sizeValue) {
            ATB_LOG(ERROR) << "SliceOperation InferShapeImpl: Calculate the total size overflow: " << offsetValue
                           << " size: " << sizeValue;
            return ERROR_INVALID_PARAM;
        }
        if (offsetValue + sizeValue > xDim) {
            ATB_LOG(ERROR) << "SliceOperation InferShapeImpl: Wrong offsets or size, offsets:" << offsetValue
                           << " size: " << sizeValue;
            return ERROR_INVALID_PARAM;
        }
    }

    return NO_ERROR;
}

std::shared_ptr<Runner> SliceOperation::CreateRunner(Context &context) const
{
    (void)context;
    if (Mki::PlatformInfo::Instance().GetPlatformType() == Mki::PlatformType::ASCEND_950) {
        return std::make_shared<SliceAclnnRunner>(param_);
    }
    return std::make_shared<SliceOpsRunner>(param_);
}

nlohmann::json SliceOperation::GetParamJson() const
{
    return OpParamToJson(param_);
}
} // namespace atb