#include <cmath>
#include <c10/core/MemoryFormat.h>
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
using namespace std;
namespace {
constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
struct TensorQuantizationParams {
double scale;
std::int32_t zero_point;
int precision;
};
}
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
void calculate_moving_average(const at::Tensor& x, at::Tensor & running_min, at::Tensor & running_max,
float averaging_const, bool per_row_fake_quant, int ch_axis)
{
at::Tensor x_min;
at::Tensor x_max;
if (per_row_fake_quant) {
TORCH_CHECK(ch_axis == 0,
"Per-channel FakeQuant in fused_moving_avg_obs_fake_quant is only supported on axis == 0",
OPS_ERROR(ErrCode::PARAM));
std::tie(x_min, x_max) = at::aminmax(x, 1);
for (const auto i : c10::irange(std::min(x_min.numel(), running_min.numel()))) {
float min_value = running_min[i].item().toFloat();
float max_value = running_max[i].item().toFloat();
if (std::isinf(min_value)) {
running_min[i] = x_min[i].item().toFloat();
} else {
running_min[i] = running_min[i].item().toFloat() + averaging_const *
(x_min[i].item().toFloat() - running_min[i].item().toFloat());
}
if (std::isinf(max_value)) {
running_max[i] = x_max[i].item().toFloat();
} else {
running_max[i] = running_max[i].item().toFloat() + averaging_const *
(x_max[i].item().toFloat() - running_max[i].item().toFloat());
}
}
} else {
std::tie(x_min, x_max) = at::aminmax(x);
float min_value = running_min[0].item().toFloat();
float max_value = running_max[0].item().toFloat();
if (std::isinf(min_value)) {
running_min[0] = x_min.item().toFloat();
} else {
running_min[0] = running_min[0].item().toFloat() + averaging_const *
(x_min.item().toFloat() - running_min[0].item().toFloat());
}
if (std::isinf(max_value)) {
running_max[0] = x_max.item().toFloat();
} else {
running_max[0] = running_max[0].item().toFloat() + averaging_const *
(x_max.item().toFloat() - running_max[0].item().toFloat());
}
}
return;
}
TensorQuantizationParams choose_quantization_params(float min, float max, int32_t qmin, int32_t qmax,
bool preserve_sparsity = false, bool force_scale_power_of_two = false, bool reduce_range = false)
{
TORCH_CHECK(min <= max, "In choose_quantization_params, min should be less than or equal to mask",
OPS_ERROR(ErrCode::PARAM));
if (reduce_range) {
qmin = qmin / 2;
qmax = qmax / 2;
}
if (min < 0 && max > 0 && preserve_sparsity) {
int symmetric_qmin = -((qmax - qmin) / 2 + 1);
int symmetric_qmax = (qmax - qmin) / 2;
double max_scale = std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax));
min = max_scale * symmetric_qmin;
max = max_scale * symmetric_qmax;
}
min = std::min(min, 0.f);
max = std::max(max, 0.f);
TORCH_CHECK(qmin < qmax, "In choose_quantization_params, qmin should be less than qmax",
OPS_ERROR(ErrCode::PARAM));
double scale = (static_cast<double>(max) - min) / (qmax - qmin);
if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
scale = 0.1;
}
TORCH_CHECK(scale > 0, "quantization scale should be > 0", OPS_ERROR(ErrCode::VALUE));
if (force_scale_power_of_two) {
if (scale < 1) {
scale = 1.0 / (1 << static_cast<int>(std::floor(std::log(1.0 / scale) / std::log(2))));
} else {
scale = 1 << static_cast<int>(std::ceil(std::log(scale) / std::log(2)));
}
}
if (scale < SMALL_SCALE_THRESHOLD) {
float org_scale = scale;
scale = SMALL_SCALE_THRESHOLD;
if (min == 0.0f) {
max = SMALL_SCALE_THRESHOLD * (qmax - qmin);
} else if (max == 0.0f) {
min = -SMALL_SCALE_THRESHOLD * (qmax - qmin);
} else {
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
min *= amplifier;
max *= amplifier;
}
}
double zero_point_from_min = qmin - min / static_cast<double>(scale);
double zero_point_from_max = qmax - max / static_cast<double>(scale);
double zero_point_form_min_error = std::abs(qmin) - std::abs(min / static_cast<double>(scale));
double zero_point_form_max_error = std::abs(qmax) - std::abs(max / static_cast<double>(scale));
double initial_zero_point = zero_point_form_min_error < zero_point_form_max_error ?
zero_point_from_min : zero_point_from_max;
if (min < 0 && max > 0 && preserve_sparsity) {
initial_zero_point = static_cast<double>(qmin + qmax) / 2;
}
int32_t nudged_zero_point = 0;
if (initial_zero_point < qmin) {
nudged_zero_point = qmin;
} else if (initial_zero_point > qmax) {
nudged_zero_point = qmax;
} else {
nudged_zero_point = nearbyint(initial_zero_point);
}
TensorQuantizationParams result;
result.scale = scale;
result.zero_point = nudged_zero_point;
return result;
}
std::tuple<at::Tensor, at::Tensor> choose_qparams_fake_quant(const at::Tensor& x,
const at::Tensor& inp_running_min, const at::Tensor& inp_running_max, at::Tensor& scale, at::Tensor& zero_point,
bool per_row_fake_quant, bool symmetric_quant, int qmin, int qmax, int ch_axis)
{
std::tuple<at::Tensor, at::Tensor> fake_quant_out;
at::Tensor x_min;
at::Tensor x_max;
if (per_row_fake_quant) {
float* x_min_data = inp_running_min.data_ptr<float>();
float* x_max_data = inp_running_max.data_ptr<float>();
for (const auto i : c10::irange(inp_running_min.numel())) {
TensorQuantizationParams result = choose_quantization_params(inp_running_min[i].item().toFloat(),
inp_running_max[i].item().toFloat(), qmin, qmax, symmetric_quant, false);
scale[i] = result.scale;
zero_point[i] = result.zero_point;
}
fake_quant_out = at::fake_quantize_per_channel_affine_cachemask(x, scale, zero_point, ch_axis, qmin, qmax);
} else {
TensorQuantizationParams result = choose_quantization_params(inp_running_min.item().toFloat(),
inp_running_max.item().toFloat(), qmin, qmax, symmetric_quant, false);
scale[0] = result.scale;
zero_point[0] = result.zero_point;
auto fake_quant_enabled = at::ones(1, x.options().dtype(at::kLong));
fake_quant_out = at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(x, scale, zero_point,
fake_quant_enabled, qmin, qmax);
}
return fake_quant_out;
}
std::tuple<at::Tensor, at::Tensor> _fused_moving_avg_obs_fq_helper_common(
const at::Tensor& self, const at::Tensor& observer_on, const at::Tensor& fake_quant_on,
at::Tensor& running_min, at::Tensor& running_max, at::Tensor& scale, at::Tensor& zero_point,
const double averaging_const, const int64_t quant_min, const int64_t quant_max, const int64_t ch_axis,
bool per_row_fake_quant, bool symmetric_quant)
{
TORCH_CHECK(ch_axis < self.dim(), "Error in fused_moving_avg_obs_fake_quant: ch_axis must be < self.dim()",
OPS_ERROR(ErrCode::PARAM));
auto observe = observer_on.item().toInt();
if (running_min.numel() == 1) {
running_min = running_min.reshape(1);
running_max = running_max.reshape(1);
}
if (per_row_fake_quant) {
at::Tensor y = self;
if (self.dim() != 2) {
c10::SmallVector<int64_t, SIZE> res = op_infer::array_to_small_vector(self.sizes());
std::iota(res.begin(), res.end(), 0);
res[ch_axis] = 0;
res[0] = ch_axis;
y = self.permute(res);
y = y.flatten(1);
}
int64_t size = self.size(ch_axis);
if (running_min.numel() == 0) {
float inf = std::numeric_limits<float>::infinity();
running_min.resize_(size).fill_(inf);
running_max.resize_(size).fill_(-inf);
scale.resize_(size);
zero_point.resize_(size);
}
if (observe) {
calculate_moving_average(y, running_min, running_max, averaging_const, per_row_fake_quant, ch_axis);
}
} else {
if (observe) {
calculate_moving_average(self, running_min, running_max, averaging_const, per_row_fake_quant, ch_axis);
}
}
auto fake_quant = fake_quant_on.item().toInt();
if (fake_quant) {
return choose_qparams_fake_quant(self, running_min, running_max, scale, zero_point, per_row_fake_quant,
symmetric_quant, quant_min, quant_max, ch_axis);
}
auto mask = at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
return std::make_tuple(self.clone(), mask);
}
}