#include <ATen/native/TypeProperties.h>
#include "torch_npu/csrc/core/npu/NPUException.h"
#include "torch_npu/csrc/core/npu/NpuVariables.h"
#include "op_plugin/utils/KernelNpuOutputDtype.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/OpApiInterface.h"
namespace op_infer {
at::ScalarType angle_out_dtype(const at::Tensor& self)
{
auto out_dtype = self.scalar_type();
if (self.is_complex()) {
out_dtype = self.scalar_type() == at::kComplexFloat ? at::kFloat : at::kDouble;
} else if (at::isIntegralType(out_dtype, true)) {
out_dtype = at::kFloat;
}
return out_dtype;
}
at::ScalarType polar_out_dtype(const at::Tensor& abs, const at::Tensor& angle)
{
at::ScalarType high_type = at::native::result_type(abs, angle);
if (high_type == at::ScalarType::Float) {
high_type = at::ScalarType::ComplexFloat;
} else if (high_type == at::ScalarType::Double) {
high_type = at::ScalarType::ComplexDouble;
} else if (high_type == at::ScalarType::Half) {
high_type = at::ScalarType::ComplexHalf;
}
return high_type;
}
at::ScalarType npu_group_norm_silu_dst_type(const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias)
{
at::native::ResultTypeState state = {};
state = at::native::update_result_type_state(input, state);
if (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950) {
if (weight.has_value()) {
state = at::native::update_result_type_state(weight.value(), state);
} else if (bias.has_value()) {
state = at::native::update_result_type_state(bias.value(), state);
}
}
return at::native::result_type(state);
}
at::ScalarType npu_group_quant_dst_type(c10::optional<at::ScalarType> dst_dtype)
{
at::ScalarType dst_type = c10::value_or_else(dst_dtype, [] {return at::ScalarType::Char;});
if (dst_type == at::kQInt8) {
dst_type = at::kChar;
}
TORCH_CHECK(dst_type == at::ScalarType::Char || dst_type == at::ScalarType::QUInt4x2,
"dst_dtype must be Int8 or Int4" + OPS_ERROR(ErrCode::TYPE));
if (dst_type == at::ScalarType::QUInt4x2) {
dst_type = at::ScalarType::Int;
}
return dst_type;
}
at::ScalarType npu_add_rms_norm_dynamic_quant_y_dtype(c10::optional<at::ScalarType> y_dtype)
{
at::ScalarType dtype = c10::value_or_else(y_dtype, [] { return at::ScalarType::Char; });
TORCH_CHECK(dtype == at::ScalarType::Char || dtype == at::ScalarType::QUInt4x2,
"y_dtype must be torch.int8 or torch.quint4x2, but got ", dtype, OPS_ERROR(ErrCode::PARAM));
if (dtype == at::ScalarType::QUInt4x2) {
return at::ScalarType::Int;
}
return dtype;
}
at::ScalarType clamp_out_dtype(const at::Tensor& self, const c10::optional<at::Tensor>& min, const c10::optional<at::Tensor>& max)
{
TORCH_CHECK(min.has_value() || max.has_value(), "torch.clamp:At least one of 'min' or 'max' must be not None!");
at::native::ResultTypeState state = {};
state = at::native::update_result_type_state(self, state);
if (!min.has_value()) {
state = at::native::update_result_type_state(max.value(), state);
} else if (!max.has_value()) {
state = at::native::update_result_type_state(min.value(), state);
} else {
state = at::native::update_result_type_state(max.value(), state);
state = at::native::update_result_type_state(min.value(), state);
}
return at::native::result_type(state);
}
at::ScalarType clamp_scalar_out_dtype(const at::Tensor& self, const c10::optional<at::Scalar>& min, const c10::optional<at::Scalar>& max)
{
TORCH_CHECK(min.has_value() || max.has_value(), "torch.clamp:At least one of 'min' or 'max' must be not None!");
at::native::ResultTypeState state = {};
state = at::native::update_result_type_state(self, state);
if (!min.has_value()) {
state = at::native::update_result_type_state(max.value(), state);
} else if (!max.has_value()) {
state = at::native::update_result_type_state(min.value(), state);
} else {
state = at::native::update_result_type_state(max.value(), state);
state = at::native::update_result_type_state(min.value(), state);
}
return at::native::result_type(state);
}
at::ScalarType abs_out_dtype(const at::Tensor& self)
{
at::ScalarType output = self.scalar_type();
if (output == at::ScalarType::ComplexFloat) {
output = at::ScalarType::Float;
} else if (output == at::ScalarType::ComplexDouble) {
output = at::ScalarType::Double;
} else if (output == at::ScalarType::ComplexHalf) {
output = at::ScalarType::Half;
}
return output;
}
at::ScalarType npu_moe_distribute_dispatch_setup_out_dtype(c10::optional<int64_t> y_dtype)
{
auto output_dtype = at::kChar;
output_dtype = at_npu::native::OpPreparation::convert_to_scalar_type(c10_npu::GetAclDataType(y_dtype.value()));
return output_dtype;
}
}