* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file dequantize.h
* \brief INT8/INT16 反量化 Tile 算子
*
* - INT8 -> FP32: int8 -> half -> float, 然后 (src - offset) * scale
* - INT16 -> FP32: int16 -> float, 然后 (src - offset) * scale
*
* 只支持逐行反量化 (axis=-1),逐列反量化在 Operation 层通过 Transpose 实现
*/
#ifndef TILEOP_TILE_OPERATOR_DEQUANTIZE__H
#define TILEOP_TILE_OPERATOR_DEQUANTIZE__H
#include "pto_tile.h"
#include "utils/layout.h"
#include "utils/tile_tensor.h"
namespace pto {
enum class DequantType {
INT8 = 0,
INT16 = 1
};
}
#ifndef PTO_CEIL
#define PTO_CEIL(x, y) ((((x) + (y)-1) / (y)) * (y))
#endif
#define OP_TILE_OP_TDEQUANT_INT8 TDequantInt8
* @brief INT8 反量化 (逐行)
* @param dst 输出 FP32 张量, 形状 [..., H, W]
* @param src 输入 INT8 张量, 形状 [..., H, W]
* @param scale 缩放因子, 形状 [..., H]
* @param offset 零点偏移, 形状 [..., H] (对称量化时传全0)
*
* 公式: dst = (src - offset) * scale
* 转换路径: int8 -> half -> float
*/
template <typename T0, typename T1, typename T2, typename T3>
TILEOP void TDequantInt8(T0 dst, T1 src, T2 scale, T3 offset) {
constexpr size_t expectSize = 5;
const auto dstLayout = dst.GetLayout();
const auto srcLayout = src.GetLayout();
const auto scaleLayout = scale.GetLayout();
const auto offsetLayout = offset.GetLayout();
auto dstShape0 = dstLayout.template GetShapeDim<0, expectSize>();
auto dstShape1 = dstLayout.template GetShapeDim<1, expectSize>();
auto dstShape2 = dstLayout.template GetShapeDim<2, expectSize>();
auto dstShape3 = dstLayout.template GetShapeDim<3, expectSize>();
auto dstShape4 = dstLayout.template GetShapeDim<4, expectSize>();
if (dstShape3 == 0 || dstShape4 == 0) {
return;
}
auto srcShape3 = srcLayout.template GetShapeDim<3, expectSize>();
auto srcShape4 = srcLayout.template GetShapeDim<4, expectSize>();
auto dstStride0 = dstLayout.template GetStrideDim<0, expectSize>();
auto dstStride1 = dstLayout.template GetStrideDim<1, expectSize>();
auto dstStride2 = dstLayout.template GetStrideDim<2, expectSize>();
auto srcStride0 = srcLayout.template GetStrideDim<0, expectSize>();
auto srcStride1 = srcLayout.template GetStrideDim<1, expectSize>();
auto srcStride2 = srcLayout.template GetStrideDim<2, expectSize>();
auto scaleStride1 = scaleLayout.template GetStrideDim<1, expectSize>();
auto scaleStride2 = scaleLayout.template GetStrideDim<2, expectSize>();
auto scaleStride3 = scaleLayout.template GetStrideDim<3, expectSize>();
auto offsetStride1 = offsetLayout.template GetStrideDim<1, expectSize>();
auto offsetStride2 = offsetLayout.template GetStrideDim<2, expectSize>();
auto offsetStride3 = offsetLayout.template GetStrideDim<3, expectSize>();
constexpr auto dstTileH = TileOp::GetTensorTileShapeDim<T0, 3, expectSize>();
constexpr auto dstTileW = TileOp::GetTensorTileShapeDim<T0, 4, expectSize>();
constexpr int paddedCol_dst = PTO_CEIL(dstTileW, static_cast<int>(TILE_ALIGNMENT_BYTES / sizeof(float)));
constexpr auto srcTileH = TileOp::GetTensorTileShapeDim<T1, 3, expectSize>();
constexpr auto srcTileW = TileOp::GetTensorTileShapeDim<T1, 4, expectSize>();
constexpr int paddedCol_src = PTO_CEIL(srcTileW, static_cast<int>(TILE_ALIGNMENT_BYTES / sizeof(int8_t)));
constexpr auto scaleTileW = TileOp::GetTensorTileShapeDim<T2, 4, expectSize>();
constexpr int paddedRow_scale = PTO_CEIL(scaleTileW, static_cast<int>(TILE_ALIGNMENT_BYTES / sizeof(float)));
constexpr auto offsetTileW = TileOp::GetTensorTileShapeDim<T3, 4, expectSize>();
constexpr int paddedRow_offset = PTO_CEIL(offsetTileW, static_cast<int>(TILE_ALIGNMENT_BYTES / sizeof(float)));
using DstDtype = typename T0::Type;
using SrcDtype = typename T1::Type;
using ScaleDtype = typename T2::Type;
using OffsetDtype = typename T3::Type;
using DstTileDefine = pto::Tile<pto::TileType::Vec, DstDtype, dstTileH, paddedCol_dst,
pto::BLayout::RowMajor, -1, -1>;
using SrcTileDefine = pto::Tile<pto::TileType::Vec, SrcDtype, srcTileH, paddedCol_src,
pto::BLayout::RowMajor, -1, -1>;
using ScaleTileDefine = pto::Tile<pto::TileType::Vec, ScaleDtype, paddedRow_scale, 1,
pto::BLayout::ColMajor, -1, -1>;
using OffsetTileDefine = pto::Tile<pto::TileType::Vec, OffsetDtype, paddedRow_offset, 1,
pto::BLayout::ColMajor, -1, -1>;
for (LoopVar n0Index = 0; n0Index < dstShape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < dstShape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < dstShape2; ++n2Index) {
DstTileDefine dstTile(dstShape3, dstShape4);
SrcTileDefine srcTile(srcShape3, srcShape4);
ScaleTileDefine scaleTile(srcShape3, 1);
OffsetTileDefine offsetTile(srcShape3, 1);
auto dstOffset = n0Index * dstStride0 + n1Index * dstStride1 + n2Index * dstStride2;
auto srcOffset = n0Index * srcStride0 + n1Index * srcStride1 + n2Index * srcStride2;
auto scaleOffset = n0Index * scaleStride1 + n1Index * scaleStride2 + n2Index * scaleStride3;
auto offsetOffset = n0Index * offsetStride1 + n1Index * offsetStride2 + n2Index * offsetStride3;
pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + dstOffset * sizeof(DstDtype)));
pto::TASSIGN(srcTile, (uint64_t)(src.GetAddr() + srcOffset * sizeof(SrcDtype)));
pto::TASSIGN(scaleTile, (uint64_t)(scale.GetAddr() + scaleOffset * sizeof(ScaleDtype)));
pto::TASSIGN(offsetTile, (uint64_t)(offset.GetAddr() + offsetOffset * sizeof(OffsetDtype)));
pto::TDEQUANT(dstTile, srcTile, scaleTile, offsetTile);
}
}
}
}
#define OP_TILE_OP_TDEQUANT_INT16 TDequantInt16
* @brief INT16 反量化 (逐行)
* @param dst 输出 FP32 张量, 形状 [..., H, W]
* @param src 输入 INT16 张量, 形状 [..., H, W]
* @param scale 缩放因子, 形状 [..., H]
* @param offset 零点偏移, 形状 [..., H] (对称量化时传全0)
*
* 公式: dst = (src - offset) * scale
* 转换路径: int16 -> float
*/
template <typename T0, typename T1, typename T2, typename T3>
TILEOP void TDequantInt16(T0 dst, T1 src, T2 scale, T3 offset) {
constexpr size_t expectSize = 5;
const auto dstLayout = dst.GetLayout();
const auto srcLayout = src.GetLayout();
const auto scaleLayout = scale.GetLayout();
const auto offsetLayout = offset.GetLayout();
auto dstShape0 = dstLayout.template GetShapeDim<0, expectSize>();
auto dstShape1 = dstLayout.template GetShapeDim<1, expectSize>();
auto dstShape2 = dstLayout.template GetShapeDim<2, expectSize>();
auto dstShape3 = dstLayout.template GetShapeDim<3, expectSize>();
auto dstShape4 = dstLayout.template GetShapeDim<4, expectSize>();
if (dstShape3 == 0 || dstShape4 == 0) {
return;
}
auto srcShape3 = srcLayout.template GetShapeDim<3, expectSize>();
auto srcShape4 = srcLayout.template GetShapeDim<4, expectSize>();
auto dstStride0 = dstLayout.template GetStrideDim<0, expectSize>();
auto dstStride1 = dstLayout.template GetStrideDim<1, expectSize>();
auto dstStride2 = dstLayout.template GetStrideDim<2, expectSize>();
auto srcStride0 = srcLayout.template GetStrideDim<0, expectSize>();
auto srcStride1 = srcLayout.template GetStrideDim<1, expectSize>();
auto srcStride2 = srcLayout.template GetStrideDim<2, expectSize>();
auto scaleStride1 = scaleLayout.template GetStrideDim<1, expectSize>();
auto scaleStride2 = scaleLayout.template GetStrideDim<2, expectSize>();
auto scaleStride3 = scaleLayout.template GetStrideDim<3, expectSize>();
auto offsetStride1 = offsetLayout.template GetStrideDim<1, expectSize>();
auto offsetStride2 = offsetLayout.template GetStrideDim<2, expectSize>();
auto offsetStride3 = offsetLayout.template GetStrideDim<3, expectSize>();
constexpr auto dstTileH = TileOp::GetTensorTileShapeDim<T0, 3, expectSize>();
constexpr auto dstTileW = TileOp::GetTensorTileShapeDim<T0, 4, expectSize>();
constexpr int paddedCol_dst = PTO_CEIL(dstTileW, static_cast<int>(TILE_ALIGNMENT_BYTES / sizeof(float)));
constexpr auto srcTileH = TileOp::GetTensorTileShapeDim<T1, 3, expectSize>();
constexpr auto srcTileW = TileOp::GetTensorTileShapeDim<T1, 4, expectSize>();
constexpr int paddedCol_src = PTO_CEIL(srcTileW, static_cast<int>(TILE_ALIGNMENT_BYTES / sizeof(int16_t)));
constexpr auto scaleTileW = TileOp::GetTensorTileShapeDim<T2, 4, expectSize>();
constexpr int paddedRow_scale = PTO_CEIL(scaleTileW, static_cast<int>(TILE_ALIGNMENT_BYTES / sizeof(float)));
constexpr auto offsetTileW = TileOp::GetTensorTileShapeDim<T3, 4, expectSize>();
constexpr int paddedRow_offset = PTO_CEIL(offsetTileW, static_cast<int>(TILE_ALIGNMENT_BYTES / sizeof(float)));
using DstDtype = typename T0::Type;
using SrcDtype = typename T1::Type;
using ScaleDtype = typename T2::Type;
using OffsetDtype = typename T3::Type;
using DstTileDefine = pto::Tile<pto::TileType::Vec, DstDtype, dstTileH, paddedCol_dst,
pto::BLayout::RowMajor, -1, -1>;
using SrcTileDefine = pto::Tile<pto::TileType::Vec, SrcDtype, srcTileH, paddedCol_src,
pto::BLayout::RowMajor, -1, -1>;
using ScaleTileDefine = pto::Tile<pto::TileType::Vec, ScaleDtype, paddedRow_scale, 1,
pto::BLayout::ColMajor, -1, -1>;
using OffsetTileDefine = pto::Tile<pto::TileType::Vec, OffsetDtype, paddedRow_offset, 1,
pto::BLayout::ColMajor, -1, -1>;
for (LoopVar n0Index = 0; n0Index < dstShape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < dstShape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < dstShape2; ++n2Index) {
DstTileDefine dstTile(dstShape3, dstShape4);
SrcTileDefine srcTile(srcShape3, srcShape4);
ScaleTileDefine scaleTile(srcShape3, 1);
OffsetTileDefine offsetTile(srcShape3, 1);
auto dstOffset = n0Index * dstStride0 + n1Index * dstStride1 + n2Index * dstStride2;
auto srcOffset = n0Index * srcStride0 + n1Index * srcStride1 + n2Index * srcStride2;
auto scaleOffset = n0Index * scaleStride1 + n1Index * scaleStride2 + n2Index * scaleStride3;
auto offsetOffset = n0Index * offsetStride1 + n1Index * offsetStride2 + n2Index * offsetStride3;
pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + dstOffset * sizeof(DstDtype)));
pto::TASSIGN(srcTile, (uint64_t)(src.GetAddr() + srcOffset * sizeof(SrcDtype)));
pto::TASSIGN(scaleTile, (uint64_t)(scale.GetAddr() + scaleOffset * sizeof(ScaleDtype)));
pto::TASSIGN(offsetTile, (uint64_t)(offset.GetAddr() + offsetOffset * sizeof(OffsetDtype)));
pto::TDEQUANT(dstTile, srcTile, scaleTile, offsetTile);
}
}
}
}
#define OP_TILE_OP_TDEQUANT TDequant
* @brief 统一反量化接口
* @tparam dequantType INT8 或 INT16
*
* 注意: TDequant 总是需要 4 个参数 (dst, src, scale, offset)
* 对称量化时,offset 传全 0 的张量
*/
template <pto::DequantType dequantType, typename T0, typename T1, typename T2, typename T3>
TILEOP void TDequant(T0 dst, T1 src, T2 scale, T3 offset) {
if constexpr (dequantType == pto::DequantType::INT8) {
TDequantInt8(dst, src, scale, offset);
} else if constexpr (dequantType == pto::DequantType::INT16) {
TDequantInt16(dst, src, scale, offset);
} else {
static_assert(dequantType == pto::DequantType::INT8 || dequantType == pto::DequantType::INT16,
"TDequant only supports INT8 or INT16 type.");
}
}
#endif