#include <ATen/ops/_empty_affine_quantized.h>
#include <torch/library.h>
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
namespace {
#if VERSION_BETWEEN(V2R2, VERSION_NEWEST)
template <bool kReluFused = false>
at::Tensor add(at::Tensor qa, at::Tensor qb, double output_scale, int64_t output_zero_point)
{
if (qa.numel() == 0) {
return at::Tensor{};
}
TORCH_CHECK(
output_scale != 0, "output_scale can not be zero");
TORCH_CHECK(
qa.sizes() == qb.sizes(),
"Quantized npu add currently expects both input tensors to be the same shape");
TORCH_CHECK(
qa.qscheme() == c10::kPerTensorAffine,
"Only per tensor quantization is supported in Add.");
TORCH_CHECK(
qa.qscheme() == qb.qscheme(),
"Both inputs to Add must have the same quantization scheme.");
TORCH_CHECK(
qa.scalar_type() == qb.scalar_type(),
"Add operands should have same data type.");
TORCH_CHECK(
qa.scalar_type() == at::ScalarType::QInt8,
"Add operands expect scalar type QInt8");
at::Tensor qa_float = qa.int_repr().to(at::kFloat);
at::Tensor qb_float = qb.int_repr().to(at::kFloat);
at::Tensor calc_tensor = at::empty(
qa.sizes(),
qa.options().dtype(at::ScalarType::Float).memory_format(qa.suggest_memory_format()));
calc_tensor = qa_float + qb_float * (qb.q_scale() / qa.q_scale());
if (kReluFused) {
calc_tensor = at::relu(calc_tensor);
}
calc_tensor = calc_tensor * (qa.q_scale() / output_scale);
calc_tensor = at::clamp(calc_tensor, -128.0, 127.0);
at::Tensor tmp_result = calc_tensor.round().to(at::kChar);
at::Tensor result = at::_empty_affine_quantized(qa.sizes(), qa.options().dtype(at::ScalarType::QInt8),
output_scale, output_zero_point, qa.suggest_memory_format());
at_npu::native::NPUNativeFunctions::set_(result, tmp_result);
return result;
}
TORCH_LIBRARY_IMPL(quantized, QuantizedPrivateUse1, m) {
m.impl("quantized::add", TORCH_FN(add<false>));
m.impl("quantized::add_relu", TORCH_FN(add<true>));
}
#endif
}
}