// Copyright (c) 2023 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <ATen/native/TypeProperties.h>
#include "torch_npu/csrc/core/npu/NPUException.h"
#include "torch_npu/csrc/core/npu/NpuVariables.h"
#include "op_plugin/utils/KernelNpuOutputDtype.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/OpApiInterface.h"
namespace op_infer {

at::ScalarType angle_out_dtype(const at::Tensor& self)
{
    auto out_dtype = self.scalar_type();
    if (self.is_complex()) {
        out_dtype = self.scalar_type() == at::kComplexFloat ? at::kFloat : at::kDouble;
    } else if (at::isIntegralType(out_dtype, true)) {
        out_dtype = at::kFloat;
    }
    return out_dtype;
}

at::ScalarType polar_out_dtype(const at::Tensor& abs, const at::Tensor& angle)
{
    at::ScalarType high_type = at::native::result_type(abs, angle);
    if (high_type == at::ScalarType::Float) {
        high_type = at::ScalarType::ComplexFloat;
    } else if (high_type == at::ScalarType::Double) {
        high_type = at::ScalarType::ComplexDouble;
    } else if (high_type == at::ScalarType::Half) {
        high_type = at::ScalarType::ComplexHalf;
    }
    return high_type;
}

at::ScalarType npu_group_norm_silu_dst_type(const at::Tensor& input, const c10::optional<at::Tensor>& weight,
                                            const c10::optional<at::Tensor>& bias)
{
    at::native::ResultTypeState state = {};
    state = at::native::update_result_type_state(input, state);

    if (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950) {
        if (weight.has_value()) {
            state = at::native::update_result_type_state(weight.value(), state);
        } else if (bias.has_value()) {
            state = at::native::update_result_type_state(bias.value(), state);
        }
    }

    return at::native::result_type(state);
}

at::ScalarType npu_group_quant_dst_type(c10::optional<at::ScalarType> dst_dtype)
{
    at::ScalarType dst_type = c10::value_or_else(dst_dtype, [] {return at::ScalarType::Char;});
    if (dst_type == at::kQInt8) {
        dst_type = at::kChar;
    }
    TORCH_CHECK(dst_type == at::ScalarType::Char || dst_type == at::ScalarType::QUInt4x2,
                "dst_dtype must be Int8 or Int4" + OPS_ERROR(ErrCode::TYPE));
    if (dst_type == at::ScalarType::QUInt4x2) {
        dst_type = at::ScalarType::Int;
    }
    return dst_type;
}

at::ScalarType npu_add_rms_norm_dynamic_quant_y_dtype(c10::optional<at::ScalarType> y_dtype)
{
    at::ScalarType dtype = c10::value_or_else(y_dtype, [] { return at::ScalarType::Char; });
    TORCH_CHECK(dtype == at::ScalarType::Char || dtype == at::ScalarType::QUInt4x2,
                "y_dtype must be torch.int8 or torch.quint4x2, but got ", dtype, OPS_ERROR(ErrCode::PARAM));
    // aclnn int4 output uses DT_INT32 (8 int4 packed per int32)
    if (dtype == at::ScalarType::QUInt4x2) {
        return at::ScalarType::Int;
    }
    return dtype;
}

at::ScalarType clamp_out_dtype(const at::Tensor& self, const c10::optional<at::Tensor>& min, const c10::optional<at::Tensor>& max)
{
    TORCH_CHECK(min.has_value() || max.has_value(), "torch.clamp:At least one of 'min' or 'max' must be not None!");

    at::native::ResultTypeState state = {};
    state = at::native::update_result_type_state(self, state);

    if (!min.has_value()) {
        state = at::native::update_result_type_state(max.value(), state);
    } else if (!max.has_value()) {
        state = at::native::update_result_type_state(min.value(), state);
    } else {
        state = at::native::update_result_type_state(max.value(), state);
        state = at::native::update_result_type_state(min.value(), state);
    }

    return at::native::result_type(state);
}

at::ScalarType clamp_scalar_out_dtype(const at::Tensor& self, const c10::optional<at::Scalar>& min, const c10::optional<at::Scalar>& max)
{
    TORCH_CHECK(min.has_value() || max.has_value(), "torch.clamp:At least one of 'min' or 'max' must be not None!");

    at::native::ResultTypeState state = {};
    state = at::native::update_result_type_state(self, state);

    if (!min.has_value()) {
        state = at::native::update_result_type_state(max.value(), state);
    } else if (!max.has_value()) {
        state = at::native::update_result_type_state(min.value(), state);
    } else {
        state = at::native::update_result_type_state(max.value(), state);
        state = at::native::update_result_type_state(min.value(), state);
    }

    return at::native::result_type(state);
}

at::ScalarType abs_out_dtype(const at::Tensor& self)
{
    at::ScalarType output = self.scalar_type();
    if (output == at::ScalarType::ComplexFloat) {
        output = at::ScalarType::Float;
    } else if (output == at::ScalarType::ComplexDouble) {
        output = at::ScalarType::Double;
    } else if (output == at::ScalarType::ComplexHalf) {
        output = at::ScalarType::Half;
    }
    return output;
}

at::ScalarType npu_moe_distribute_dispatch_setup_out_dtype(c10::optional<int64_t> y_dtype)
{
    auto output_dtype = at::kChar;
    output_dtype = at_npu::native::OpPreparation::convert_to_scalar_type(c10_npu::GetAclDataType(y_dtype.value()));
    return output_dtype;
}

} // namespace op_infer