* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file quantize_impl.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
"impl/adv_api/detail/quantization/quantize/quantize_impl.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/quantization/quantize.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_QUANTIZATION_QUANTIZE_QUANTIZE_IMPL_H__
#endif
#ifndef IMPL_QUANTIZATION_QUANTIZE_QUANTIZE_IMPL_H
#define IMPL_QUANTIZATION_QUANTIZE_QUANTIZE_IMPL_H
#include "kernel_basic_intf.h"
#include "kernel_tensor.h"
#include "kernel_tiling/kernel_tiling.h"
#include "include/adv_api/quantization/quantize_utils.h"
#include "../../common/check.h"
namespace AscendC {
namespace QuantizeUtils {
template <typename DstT, typename SrcT, typename ScaleT>
__aicore__ inline void CheckApiDtypeValid()
{
static_assert(
SupportType<
Tuple<SrcT, ScaleT, DstT>, Tuple<half, half, fp8_e4m3fn_t>, Tuple<half, half, fp8_e5m2_t>,
Tuple<bfloat16_t, bfloat16_t, fp8_e4m3fn_t>, Tuple<bfloat16_t, bfloat16_t, fp8_e5m2_t>,
Tuple<float, float, fp8_e4m3fn_t>, Tuple<float, float, fp8_e5m2_t>, Tuple<half, float, fp8_e4m3fn_t>,
Tuple<half, float, fp8_e5m2_t>, Tuple<bfloat16_t, float, fp8_e4m3fn_t>,
Tuple<bfloat16_t, float, fp8_e5m2_t>, Tuple<half, half, hifloat8_t>, Tuple<half, half, int8_t>,
Tuple<bfloat16_t, bfloat16_t, hifloat8_t>, Tuple<bfloat16_t, bfloat16_t, int8_t>,
Tuple<float, float, hifloat8_t>, Tuple<float, float, int8_t>, Tuple<half, float, hifloat8_t>,
Tuple<half, float, int8_t>, Tuple<bfloat16_t, float, hifloat8_t>, Tuple<bfloat16_t, float, int8_t>,
Tuple<half, half, fp4x2_e1m2_t>, Tuple<half, half, fp4x2_e2m1_t>,
Tuple<bfloat16_t, bfloat16_t, fp4x2_e1m2_t>, Tuple<bfloat16_t, bfloat16_t, fp4x2_e2m1_t>,
Tuple<float, float, fp4x2_e1m2_t>, Tuple<float, float, fp4x2_e2m1_t>, Tuple<half, float, fp4x2_e1m2_t>,
Tuple<half, float, fp4x2_e2m1_t>, Tuple<bfloat16_t, float, fp4x2_e1m2_t>,
Tuple<bfloat16_t, float, fp4x2_e2m1_t>>(),
"Failed to check data type for Quantize");
}
template <typename DstT, typename ScaleT>
__simd_callee__ inline void StoreRes(__ubuf__ DstT* dstAddr, Reg::RegTensor<DstT>& vreg, Reg::MaskReg& mask)
{
if (SupportType<ScaleT, float>()) {
Reg::StoreAlign<DstT, Reg::StoreDist::DIST_PACK4_B32>(dstAddr, vreg, mask);
} else {
Reg::StoreAlign<DstT, Reg::StoreDist::DIST_PACK_B16>(dstAddr, vreg, mask);
}
}
template <typename DstT, typename ScaleT, const Reg::CastTrait& castTrait>
__simd_callee__ inline void TransRegForFp8(
Reg::RegTensor<ScaleT>& srcVreg, Reg::RegTensor<DstT>& dstVreg, Reg::MaskReg& mask)
{
if constexpr (castTrait.roundMode == RoundMode::CAST_RINT) {
Reg::Cast<DstT, ScaleT, castTrait>(dstVreg, srcVreg, mask);
} else {
Reg::Cast<DstT, ScaleT, LayoutZMrgZRndRSatS>(dstVreg, srcVreg, mask);
}
}
template <typename DstT, typename ScaleT, const Reg::CastTrait& castTrait>
__simd_callee__ inline void TransRegForHif8(
Reg::RegTensor<ScaleT>& srcVreg, Reg::RegTensor<DstT>& dstVreg, Reg::MaskReg& mask)
{
if constexpr (SupportType<ScaleT, bfloat16_t>()) {
Reg::MaskReg mask1;
Reg::MaskReg mask2 = Reg::CreateMask<ScaleT, Reg::MaskPattern::ALLF>();
Reg::RegTensor<float> f32Vreg;
Reg::RegTensor<float> f32Vreg2;
Reg::RegTensor<ScaleT> srcVreg2;
Reg::RegTensor<DstT> dstVreg2;
Reg::MaskInterleave<ScaleT>(mask1, mask2, mask, mask2);
Reg::Interleave(srcVreg, srcVreg2, srcVreg, srcVreg2);
Reg::Cast<float, ScaleT, layoutZMrgZ>(f32Vreg, srcVreg, mask1);
Reg::Cast<float, ScaleT, layoutZMrgZ>(f32Vreg2, srcVreg2, mask2);
if constexpr (castTrait.roundMode == RoundMode::CAST_ROUND || castTrait.roundMode == RoundMode::CAST_HYBRID) {
Reg::Cast<DstT, float, castTrait>(dstVreg, f32Vreg, mask1);
Reg::Cast<DstT, float, castTrait>(dstVreg2, f32Vreg2, mask2);
} else {
Reg::Cast<DstT, float, LayoutZMrgZRndASatS>(dstVreg, f32Vreg, mask1);
Reg::Cast<DstT, float, LayoutZMrgZRndASatS>(dstVreg2, f32Vreg2, mask2);
}
Reg::DeInterleave(
(Reg::RegTensor<ScaleT>&)dstVreg, (Reg::RegTensor<ScaleT>&)dstVreg2, (Reg::RegTensor<ScaleT>&)dstVreg,
(Reg::RegTensor<ScaleT>&)dstVreg2);
} else if constexpr (SupportType<ScaleT, half, float>()) {
if constexpr (castTrait.roundMode == RoundMode::CAST_ROUND || castTrait.roundMode == RoundMode::CAST_HYBRID) {
Reg::Cast<DstT, ScaleT, castTrait>(dstVreg, srcVreg, mask);
} else {
Reg::Cast<DstT, ScaleT, LayoutZMrgZRndASatS>(dstVreg, srcVreg, mask);
}
}
}
template <typename DstT, typename ScaleT, const Reg::CastTrait& castTrait>
__simd_callee__ inline void TransRegForS8(
Reg::RegTensor<ScaleT>& srcVreg, Reg::RegTensor<DstT>& dstVreg, Reg::MaskReg& mask)
{
if constexpr (SupportType<ScaleT, bfloat16_t>()) {
Reg::MaskReg mask1;
Reg::MaskReg mask2 = Reg::CreateMask<ScaleT, Reg::MaskPattern::ALLF>();
Reg::RegTensor<float> f32Vreg;
Reg::RegTensor<float> f32Vreg2;
Reg::RegTensor<ScaleT> srcVreg2;
Reg::RegTensor<DstT> dstVreg2;
Reg::MaskInterleave<ScaleT>(mask1, mask2, mask, mask2);
Reg::Interleave(srcVreg, srcVreg2, srcVreg, srcVreg2);
Reg::Cast<float, ScaleT, layoutZMrgZ>(f32Vreg, srcVreg, mask1);
Reg::Cast<float, ScaleT, layoutZMrgZ>(f32Vreg2, srcVreg2, mask2);
if constexpr (
castTrait.roundMode == RoundMode::CAST_RINT || castTrait.roundMode == RoundMode::CAST_ROUND ||
castTrait.roundMode == RoundMode::CAST_CEIL || castTrait.roundMode == RoundMode::CAST_FLOOR ||
castTrait.roundMode == RoundMode::CAST_TRUNC) {
Reg::Cast<int16_t, float, castTrait>((Reg::RegTensor<int16_t>&)f32Vreg, f32Vreg, mask1);
Reg::Cast<int16_t, float, castTrait>((Reg::RegTensor<int16_t>&)f32Vreg2, f32Vreg2, mask2);
} else {
Reg::Cast<int16_t, float, LayoutZMrgZRndRSatS>((Reg::RegTensor<int16_t>&)f32Vreg, f32Vreg, mask1);
Reg::Cast<int16_t, float, LayoutZMrgZRndRSatS>((Reg::RegTensor<int16_t>&)f32Vreg2, f32Vreg2, mask2);
}
Reg::Cast<half, int16_t, LayoutZMrgZRndRSatS>(
(Reg::RegTensor<half>&)f32Vreg, (Reg::RegTensor<int16_t>&)f32Vreg, mask1);
Reg::Cast<half, int16_t, LayoutZMrgZRndRSatS>(
(Reg::RegTensor<half>&)f32Vreg2, (Reg::RegTensor<int16_t>&)f32Vreg2, mask2);
Reg::Cast<DstT, half, LayoutZMrgZRndRSatS>(dstVreg, (Reg::RegTensor<half>&)f32Vreg, mask1);
Reg::Cast<DstT, half, LayoutZMrgZRndRSatS>(dstVreg2, (Reg::RegTensor<half>&)f32Vreg2, mask2);
Reg::DeInterleave(
(Reg::RegTensor<ScaleT>&)dstVreg, (Reg::RegTensor<ScaleT>&)dstVreg2, (Reg::RegTensor<ScaleT>&)dstVreg,
(Reg::RegTensor<ScaleT>&)dstVreg2);
} else if constexpr (SupportType<ScaleT, float>()) {
Reg::RegTensor<half> f16Vreg;
if constexpr (
castTrait.roundMode == RoundMode::CAST_RINT || castTrait.roundMode == RoundMode::CAST_ROUND ||
castTrait.roundMode == RoundMode::CAST_CEIL || castTrait.roundMode == RoundMode::CAST_FLOOR ||
castTrait.roundMode == RoundMode::CAST_TRUNC) {
Reg::Cast<int16_t, ScaleT, castTrait>((Reg::RegTensor<int16_t>&)f16Vreg, srcVreg, mask);
} else {
Reg::Cast<int16_t, ScaleT, LayoutZMrgZRndRSatS>((Reg::RegTensor<int16_t>&)f16Vreg, srcVreg, mask);
}
Reg::Cast<half, int16_t, LayoutZMrgZRndRSatS>(f16Vreg, (Reg::RegTensor<int16_t>&)f16Vreg, mask);
Reg::Cast<DstT, half, LayoutZMrgZRndRSatS>(dstVreg, f16Vreg, mask);
} else if constexpr (SupportType<ScaleT, half>()) {
if constexpr (
castTrait.roundMode == RoundMode::CAST_RINT || castTrait.roundMode == RoundMode::CAST_ROUND ||
castTrait.roundMode == RoundMode::CAST_CEIL || castTrait.roundMode == RoundMode::CAST_FLOOR ||
castTrait.roundMode == RoundMode::CAST_TRUNC) {
Reg::Cast<DstT, ScaleT, castTrait>(dstVreg, srcVreg, mask);
} else {
Reg::Cast<DstT, ScaleT, LayoutZMrgZRndRSatS>(dstVreg, srcVreg, mask);
}
}
}
template <typename DstT, typename ScaleT, const Reg::CastTrait& castTrait>
__simd_callee__ inline void TransRegForFp4(
Reg::RegTensor<ScaleT>& vreg, Reg::RegTensor<DstT>& dstVreg, Reg::MaskReg& mask)
{
Reg::RegTensor<bfloat16_t> bf16Vreg;
if constexpr (SupportType<ScaleT, float>()) {
Reg::Cast<bfloat16_t, ScaleT, LayoutZMrgZRndRSatS>(bf16Vreg, vreg, mask);
Reg::Pack<uint16_t, uint32_t, Reg::HighLowPart::LOWEST>(
(Reg::RegTensor<uint16_t>&)bf16Vreg, (Reg::RegTensor<uint32_t>&)bf16Vreg);
Reg::MaskPack(mask, mask);
if constexpr (
castTrait.roundMode == RoundMode::CAST_RINT || castTrait.roundMode == RoundMode::CAST_ROUND ||
castTrait.roundMode == RoundMode::CAST_CEIL || castTrait.roundMode == RoundMode::CAST_FLOOR ||
castTrait.roundMode == RoundMode::CAST_TRUNC) {
Reg::Cast<DstT, bfloat16_t, castTrait>(dstVreg, bf16Vreg, mask);
} else {
Reg::Cast<DstT, bfloat16_t, LayoutZMrgZRndRSatS>(dstVreg, bf16Vreg, mask);
}
} else if constexpr (SupportType<ScaleT, half>()) {
Reg::Cast<bfloat16_t, ScaleT, LayoutZMrgZRndRSatS>(bf16Vreg, vreg, mask);
if constexpr (
castTrait.roundMode == RoundMode::CAST_RINT || castTrait.roundMode == RoundMode::CAST_ROUND ||
castTrait.roundMode == RoundMode::CAST_CEIL || castTrait.roundMode == RoundMode::CAST_FLOOR ||
castTrait.roundMode == RoundMode::CAST_TRUNC) {
Reg::Cast<DstT, bfloat16_t, castTrait>(dstVreg, bf16Vreg, mask);
} else {
Reg::Cast<DstT, bfloat16_t, LayoutZMrgZRndRSatS>(dstVreg, bf16Vreg, mask);
}
} else if constexpr (SupportType<ScaleT, bfloat16_t>()) {
if constexpr (
castTrait.roundMode == RoundMode::CAST_RINT || castTrait.roundMode == RoundMode::CAST_ROUND ||
castTrait.roundMode == RoundMode::CAST_CEIL || castTrait.roundMode == RoundMode::CAST_FLOOR ||
castTrait.roundMode == RoundMode::CAST_TRUNC) {
Reg::Cast<DstT, bfloat16_t, castTrait>(dstVreg, vreg, mask);
} else {
Reg::Cast<DstT, bfloat16_t, LayoutZMrgZRndRSatS>(dstVreg, vreg, mask);
}
}
}
template <typename T>
__simd_callee__ inline void GetPerGroupScale(
__ubuf__ T* scaleUb, const int32_t start, const uint32_t groupSize, Reg::RegTensor<T>& scaleVreg)
{
if constexpr (SupportType<T, half, bfloat16_t>()) {
Reg::MaskReg fullMask = Reg::CreateMask<uint16_t, Reg::MaskPattern::ALL>();
Reg::RegTensor<int16_t> vciVreg;
Reg::RegTensor<uint16_t> indexVreg;
Reg::RegTensor<uint16_t> gsizeVreg;
Reg::Duplicate(gsizeVreg, static_cast<uint16_t>(groupSize));
Reg::Arange(vciVreg, static_cast<int16_t>(start));
Reg::Div(indexVreg, (Reg::RegTensor<uint16_t>&)vciVreg, gsizeVreg, fullMask);
Reg::Gather(scaleVreg, scaleUb, indexVreg, fullMask);
} else {
Reg::MaskReg fullMask = Reg::CreateMask<uint32_t, Reg::MaskPattern::ALL>();
Reg::RegTensor<int32_t> vciVreg;
Reg::RegTensor<uint32_t> indexVreg;
Reg::RegTensor<uint32_t> gsizeVreg;
Reg::Duplicate(gsizeVreg, static_cast<uint32_t>(groupSize));
Reg::Arange(vciVreg, static_cast<int32_t>(start));
Reg::Div(indexVreg, (Reg::RegTensor<uint32_t>&)vciVreg, gsizeVreg, fullMask);
Reg::Gather(scaleVreg, scaleUb, indexVreg, fullMask);
}
}
template <typename T>
__simd_callee__ inline void GetPerGroupScaleToFloat(
__ubuf__ T* scaleAddr, const int32_t start, const uint32_t groupSize, Reg::RegTensor<float>& floatVreg,
Reg::RegTensor<T>& tempVreg, Reg::MaskReg& mask)
{
if constexpr (SupportType<T, half, bfloat16_t>()) {
GetPerGroupScale(scaleAddr, start, groupSize, tempVreg);
Reg::UnPack<uint32_t, uint16_t>((Reg::RegTensor<uint32_t>&)tempVreg, (Reg::RegTensor<uint16_t>&)tempVreg);
Reg::Cast<float, T, layoutZMrgZ>(floatVreg, tempVreg, mask);
} else {
GetPerGroupScale(scaleAddr, start, groupSize, floatVreg);
}
}
template <typename T>
__simd_callee__ inline void DuplicateScalarToFloatVector(
Reg::RegTensor<float>& floatVreg, const T& scalar, Reg::RegTensor<T>& tempVreg, Reg::MaskReg& mask)
{
if constexpr (SupportType<T, half, bfloat16_t>()) {
Reg::Duplicate(tempVreg, scalar);
Reg::Cast<float, T, layoutZMrgZ>(floatVreg, tempVreg, mask);
} else {
Reg::Duplicate(floatVreg, scalar);
}
}
template <typename DstT, const Reg::CastTrait& castTrait>
__simd_callee__ inline void CastFp32DstToExpect(
Reg::RegTensor<float>& srcVreg, Reg::RegTensor<DstT>& dstVreg, Reg::MaskReg& mask)
{
if constexpr (SupportType<DstT, fp8_e4m3fn_t, fp8_e5m2_t>()) {
QuantizeUtils::TransRegForFp8<DstT, float, castTrait>(srcVreg, dstVreg, mask);
} else if constexpr (IsSameType<DstT, hifloat8_t>::value) {
QuantizeUtils::TransRegForHif8<DstT, float, castTrait>(srcVreg, dstVreg, mask);
} else {
QuantizeUtils::TransRegForS8<DstT, float, castTrait>(srcVreg, dstVreg, mask);
}
}
}
template <
const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT,
typename ActualScaleT>
__simd_vf__ inline void QuantizePerGroupForKColCommonVF(
__ubuf__ DstT* dstUb, __ubuf__ SrcT* srcUb, __ubuf__ ActualScaleT* scaleUb, __ubuf__ ActualScaleT* offsetUb,
const OffsetT offset, const QuantizeParams params)
{
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
uint16_t rowNum = params.m;
uint32_t vecLen = GetVecLen() / sizeof(float);
uint16_t repeat = CeilDivision(params.n, vecLen);
uint16_t scaleK = CeilDivision(params.n, params.groupSize);
static constexpr Reg::CastTrait castTrait = {
Reg::RegLayout::ZERO, Reg::SatMode::SAT, Reg::MaskMergeMode::ZEROING, config.roundMode};
Reg::MaskReg mask;
Reg::RegTensor<float> srcVreg;
Reg::RegTensor<DstT> dstVreg;
Reg::RegTensor<float> scaleVreg;
Reg::RegTensor<float> offsetVreg;
Reg::RegTensor<ActualScaleT> tempScaleVreg;
Reg::RegTensor<ActualScaleT> tempOffsetVreg;
Reg::RegTensor<SrcT> tempSrcVreg;
Reg::MaskReg fullMask = Reg::CreateMask<uint32_t, Reg::MaskPattern::ALL>();
if constexpr (config.hasOffset && isScalarOffset) {
QuantizeUtils::DuplicateScalarToFloatVector<ActualScaleT>(offsetVreg, offset, tempOffsetVreg, fullMask);
}
for (uint16_t i = 0; i < rowNum; ++i) {
uint32_t sreg = params.n;
for (uint16_t j = 0; j < repeat; ++j) {
mask = Reg::UpdateMask<float>(sreg);
QuantizeUtils::GetPerGroupScaleToFloat(
scaleUb + i * scaleK, j * vecLen, params.groupSize, scaleVreg, tempScaleVreg, fullMask);
if constexpr (config.hasOffset && !isScalarOffset) {
QuantizeUtils::GetPerGroupScaleToFloat(
offsetUb + i * scaleK, j * vecLen, params.groupSize, offsetVreg, tempOffsetVreg, fullMask);
}
if constexpr (SupportType<SrcT, half, bfloat16_t>()) {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_UNPACK_B16>(tempSrcVreg, srcUb + i * params.n + j * vecLen);
Reg::Cast<float, SrcT, layoutZMrgZ>(srcVreg, tempSrcVreg, mask);
} else {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_NORM>(srcVreg, srcUb + i * params.n + j * vecLen);
}
Reg::Mul(srcVreg, srcVreg, scaleVreg, mask);
if constexpr (config.hasOffset) {
Reg::Add(srcVreg, srcVreg, offsetVreg, mask);
}
QuantizeUtils::CastFp32DstToExpect<DstT, castTrait>(srcVreg, dstVreg, mask);
QuantizeUtils::StoreRes<DstT, float>(dstUb + i * params.n + j * vecLen, dstVreg, mask);
}
}
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerGroupForKColCommon(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
using ActualScaleT = typename ScaleT::PrimType;
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
__ubuf__ DstT* dstUb = (__ubuf__ DstT*)dstTensor.GetPhyAddr();
__ubuf__ SrcT* srcUb = (__ubuf__ SrcT*)srcTensor.GetPhyAddr();
__ubuf__ ActualScaleT* scaleUb = (__ubuf__ ActualScaleT*)scale.GetPhyAddr();
__ubuf__ ActualScaleT* offsetUb = nullptr;
if constexpr (!isScalarOffset) {
offsetUb = (__ubuf__ ActualScaleT*)offset.GetPhyAddr();
}
QuantizePerGroupForKColCommonVF<config, DstT, SrcT, ScaleT, OffsetT, ActualScaleT>(
dstUb, srcUb, scaleUb, offsetUb, offset, params);
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename ActualScaleT>
__simd_vf__ inline void QuantizePerGroupForKColFp4VF(
__ubuf__ DstT* dstUb, __ubuf__ SrcT* srcUb, __ubuf__ ActualScaleT* scaleUb, const QuantizeParams params)
{
uint16_t rowNum = params.m;
uint32_t vecLen = GetVecLen() / sizeof(float);
uint16_t repeat = CeilDivision(params.n, vecLen);
uint16_t scaleK = CeilDivision(params.n, params.groupSize);
static constexpr Reg::CastTrait castTrait = {
Reg::RegLayout::ZERO, Reg::SatMode::SAT, Reg::MaskMergeMode::ZEROING, config.roundMode};
Reg::MaskReg mask;
Reg::RegTensor<float> srcVreg;
Reg::RegTensor<DstT> dstVreg;
Reg::RegTensor<float> scaleVreg;
Reg::RegTensor<SrcT> tempVreg;
Reg::RegTensor<ActualScaleT> tempScaleVreg;
Reg::MaskReg fullMask = Reg::CreateMask<uint32_t, Reg::MaskPattern::ALL>();
for (uint16_t i = 0; i < rowNum; ++i) {
uint32_t sreg = params.n;
for (uint16_t j = 0; j < repeat; ++j) {
mask = Reg::UpdateMask<float>(sreg);
QuantizeUtils::GetPerGroupScaleToFloat(
scaleUb + i * scaleK, j * vecLen, params.groupSize, scaleVreg, tempScaleVreg, fullMask);
if constexpr (SupportType<SrcT, half, bfloat16_t>()) {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_UNPACK_B16>(tempVreg, srcUb + i * params.n + j * vecLen);
Reg::Cast<float, SrcT, layoutZMrgZ>(srcVreg, tempVreg, mask);
} else {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_NORM>(srcVreg, srcUb + i * params.n + j * vecLen);
}
Reg::Mul(srcVreg, srcVreg, scaleVreg, mask);
QuantizeUtils::TransRegForFp4<DstT, float, castTrait>(srcVreg, dstVreg, mask);
Reg::StoreAlign<uint8_t, Reg::StoreDist::DIST_PACK4_B32>(
(__ubuf__ uint8_t*)dstUb + (i * params.n + j * vecLen) / 2, (Reg::RegTensor<uint8_t>&)dstVreg, mask);
}
}
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT>
__aicore__ inline void QuantizePerGroupForKColFp4(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale,
const QuantizeParams& params)
{
using ActualScaleT = typename ScaleT::PrimType;
__ubuf__ DstT* dstUb = (__ubuf__ DstT*)dstTensor.GetPhyAddr();
__ubuf__ SrcT* srcUb = (__ubuf__ SrcT*)srcTensor.GetPhyAddr();
__ubuf__ ActualScaleT* scaleUb = (__ubuf__ ActualScaleT*)scale.GetPhyAddr();
QuantizePerGroupForKColFp4VF<config, DstT, SrcT, ScaleT, ActualScaleT>(dstUb, srcUb, scaleUb, params);
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerGroupForKCol(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
static_assert(TypeUtils::IsLocalTensorType<ScaleT>(), "Quantize PerGroup ScaleT should be Tensor");
using ActualScaleT = typename ScaleT::PrimType;
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
if constexpr (isScalarOffset) {
static_assert(IsSameType<ActualScaleT, OffsetT>::value, "scale and offset should be the same PrimType");
} else {
using ActualOffsetT = typename OffsetT::PrimType;
static_assert(IsSameType<ActualScaleT, ActualOffsetT>::value, "scale and offset should be the same PrimType");
}
QuantizeUtils::CheckApiDtypeValid<DstT, SrcT, ActualScaleT>();
if constexpr (SupportType<DstT, fp8_e4m3fn_t, fp8_e5m2_t, hifloat8_t, int8_t>()) {
QuantizePerGroupForKColCommon<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
} else if constexpr (SupportType<DstT, fp4x2_e2m1_t, fp4x2_e1m2_t>()) {
QuantizePerGroupForKColFp4<config, DstT, SrcT, ScaleT>(dstTensor, srcTensor, scale, params);
} else {
ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR, "unsupport dstT for Quantize!"); });
}
}
template <
typename DstT, typename SrcT, typename ActualScaleT, bool hasOffset, bool isScalarOffset,
const Reg::CastTrait& castTrait>
__simd_callee__ inline void QuantizePerGroupForKRowCommonOneRow(
__ubuf__ DstT* dstAddr, __ubuf__ SrcT* srcAddr, __ubuf__ ActualScaleT* scaleAddr, __ubuf__ ActualScaleT* offsetAddr,
Reg::RegTensor<DstT>& dstVreg, Reg::RegTensor<float>& srcVreg, Reg::RegTensor<float>& scaleVreg,
Reg::RegTensor<float>& offsetVreg, uint16_t repeat, uint32_t n, uint32_t vecLen)
{
Reg::RegTensor<ActualScaleT> tempScaleVreg;
Reg::RegTensor<ActualScaleT> tempOffsetVreg;
Reg::RegTensor<SrcT> tempSrcVreg;
Reg::MaskReg fullMask = Reg::CreateMask<uint32_t, Reg::MaskPattern::ALL>();
Reg::MaskReg mask;
for (uint16_t j = 0; j < repeat; ++j) {
if constexpr (SupportType<ActualScaleT, half, bfloat16_t>()) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_UNPACK_B16>(tempScaleVreg, scaleAddr + j * vecLen);
Reg::Cast<float, ActualScaleT, layoutZMrgZ>(scaleVreg, tempScaleVreg, fullMask);
if constexpr (hasOffset && !isScalarOffset) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_UNPACK_B16>(tempOffsetVreg, offsetAddr + j * vecLen);
Reg::Cast<float, ActualScaleT, layoutZMrgZ>(offsetVreg, tempOffsetVreg, fullMask);
}
} else {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_NORM>(scaleVreg, scaleAddr + j * vecLen);
if constexpr (hasOffset && !isScalarOffset) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_NORM>(offsetVreg, offsetAddr + j * vecLen);
}
}
mask = Reg::UpdateMask<float>(n);
if constexpr (SupportType<SrcT, half, bfloat16_t>()) {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_UNPACK_B16>(tempSrcVreg, srcAddr + j * vecLen);
Reg::Cast<float, SrcT, layoutZMrgZ>(srcVreg, tempSrcVreg, mask);
} else {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_NORM>(srcVreg, srcAddr + j * vecLen);
}
Reg::Mul(srcVreg, srcVreg, scaleVreg, mask);
if constexpr (hasOffset) {
Reg::Add(srcVreg, srcVreg, offsetVreg, mask);
}
QuantizeUtils::CastFp32DstToExpect<DstT, castTrait>(srcVreg, dstVreg, mask);
QuantizeUtils::StoreRes<DstT, float>(dstAddr + j * vecLen, dstVreg, mask);
}
}
template <
const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT,
typename ActualScaleT>
__simd_vf__ inline void QuantizePerGroupForKRowCommonVF(
__ubuf__ DstT* dstUb, __ubuf__ SrcT* srcUb, __ubuf__ ActualScaleT* scaleUb, __ubuf__ ActualScaleT* offsetUb,
const OffsetT offset, const QuantizeParams params, uint16_t rowNum, uint16_t tailRow)
{
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
uint16_t mainRowGroup = rowNum / params.groupSize;
uint32_t vecLen = GetVecLen() / sizeof(float);
uint16_t repeat = CeilDivision(params.n, vecLen);
static constexpr Reg::CastTrait castTrait = {
Reg::RegLayout::ZERO, Reg::SatMode::SAT, Reg::MaskMergeMode::ZEROING, config.roundMode};
Reg::RegTensor<float> srcVreg;
Reg::RegTensor<DstT> dstVreg;
Reg::RegTensor<float> scaleVreg;
Reg::RegTensor<float> offsetVreg;
Reg::RegTensor<ActualScaleT> tempOffsetVreg;
Reg::MaskReg fullMask = Reg::CreateMask<uint32_t, Reg::MaskPattern::ALL>();
if constexpr (config.hasOffset && isScalarOffset) {
QuantizeUtils::DuplicateScalarToFloatVector<ActualScaleT>(offsetVreg, offset, tempOffsetVreg, fullMask);
}
for (uint16_t i0 = 0; i0 < mainRowGroup; ++i0) {
for (uint16_t i1 = 0; i1 < static_cast<uint16_t>(params.groupSize); ++i1) {
if constexpr (isScalarOffset) {
QuantizePerGroupForKRowCommonOneRow<
DstT, SrcT, ActualScaleT, config.hasOffset, isScalarOffset, castTrait>(
dstUb + (i0 * params.groupSize + i1) * params.n, srcUb + (i0 * params.groupSize + i1) * params.n,
scaleUb + i0 * params.n, nullptr, dstVreg, srcVreg, scaleVreg, offsetVreg, repeat, params.n,
vecLen);
} else {
QuantizePerGroupForKRowCommonOneRow<
DstT, SrcT, ActualScaleT, config.hasOffset, isScalarOffset, castTrait>(
dstUb + (i0 * params.groupSize + i1) * params.n, srcUb + (i0 * params.groupSize + i1) * params.n,
scaleUb + i0 * params.n, offsetUb + i0 * params.n, dstVreg, srcVreg, scaleVreg, offsetVreg, repeat,
params.n, vecLen);
}
}
}
for (uint16_t i = 0; i < tailRow; ++i) {
if constexpr (isScalarOffset) {
QuantizePerGroupForKRowCommonOneRow<DstT, SrcT, ActualScaleT, config.hasOffset, isScalarOffset, castTrait>(
dstUb + (mainRowGroup * params.groupSize + i) * params.n,
srcUb + (mainRowGroup * params.groupSize + i) * params.n, scaleUb + mainRowGroup * params.n, nullptr,
dstVreg, srcVreg, scaleVreg, offsetVreg, repeat, params.n, vecLen);
} else {
QuantizePerGroupForKRowCommonOneRow<DstT, SrcT, ActualScaleT, config.hasOffset, isScalarOffset, castTrait>(
dstUb + (mainRowGroup * params.groupSize + i) * params.n,
srcUb + (mainRowGroup * params.groupSize + i) * params.n, scaleUb + mainRowGroup * params.n,
offsetUb + mainRowGroup * params.n, dstVreg, srcVreg, scaleVreg, offsetVreg, repeat, params.n, vecLen);
}
}
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerGroupForKRowCommon(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
using ActualScaleT = typename ScaleT::PrimType;
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
__ubuf__ DstT* dstUb = (__ubuf__ DstT*)dstTensor.GetPhyAddr();
__ubuf__ SrcT* srcUb = (__ubuf__ SrcT*)srcTensor.GetPhyAddr();
__ubuf__ ActualScaleT* scaleUb = (__ubuf__ ActualScaleT*)scale.GetPhyAddr();
__ubuf__ ActualScaleT* offsetUb = nullptr;
if constexpr (!isScalarOffset) {
offsetUb = (__ubuf__ ActualScaleT*)offset.GetPhyAddr();
}
uint16_t rowNum = params.m;
uint16_t tailRow = rowNum % params.groupSize;
QuantizePerGroupForKRowCommonVF<config, DstT, SrcT, ScaleT, OffsetT, ActualScaleT>(
dstUb, srcUb, scaleUb, offsetUb, offset, params, rowNum, tailRow);
}
template <typename DstT, typename SrcT, typename ActualScaleT, const Reg::CastTrait& castTrait>
__simd_callee__ inline void QuantizePerGroupForKRowFp4OneRow(
__ubuf__ DstT* dstAddr, __ubuf__ SrcT* srcAddr, __ubuf__ ActualScaleT* scaleAddr, Reg::RegTensor<DstT>& dstVreg,
Reg::RegTensor<float>& srcVreg, Reg::RegTensor<float>& scaleVreg, uint16_t repeat, uint32_t n, uint32_t vecLen)
{
Reg::MaskReg mask;
Reg::RegTensor<SrcT> tempVreg;
Reg::RegTensor<ActualScaleT> tempScaleVreg;
for (uint16_t j = 0; j < repeat; ++j) {
mask = Reg::UpdateMask<float>(n);
if constexpr (SupportType<ActualScaleT, half, bfloat16_t>()) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_UNPACK_B16>(tempScaleVreg, scaleAddr + j * vecLen);
Reg::Cast<float, ActualScaleT, layoutZMrgZ>(scaleVreg, tempScaleVreg, mask);
} else {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_NORM>(scaleVreg, scaleAddr + j * vecLen);
}
if constexpr (SupportType<SrcT, half, bfloat16_t>()) {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_UNPACK_B16>(tempVreg, srcAddr + j * vecLen);
Reg::Cast<float, SrcT, layoutZMrgZ>(srcVreg, tempVreg, mask);
} else {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_NORM>(srcVreg, srcAddr + j * vecLen);
}
Reg::Mul<float, Reg::MaskMergeMode::ZEROING>(srcVreg, srcVreg, scaleVreg, mask);
QuantizeUtils::TransRegForFp4<DstT, float, castTrait>(srcVreg, dstVreg, mask);
Reg::StoreAlign<uint8_t, Reg::StoreDist::DIST_PACK4_B32>(
(__ubuf__ uint8_t*)dstAddr + (j * vecLen) / 2, (Reg::RegTensor<uint8_t>&)dstVreg, mask);
}
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename ActualScaleT>
__simd_vf__ inline void QuantizePerGroupForKRowFp4VF(
__ubuf__ DstT* dstUb, __ubuf__ SrcT* srcUb, __ubuf__ ActualScaleT* scaleUb, const QuantizeParams params,
uint16_t rowNum, uint16_t tailRow)
{
uint16_t mainRowGroup = rowNum / params.groupSize;
uint32_t vecLen = GetVecLen() / sizeof(float);
uint16_t repeat = CeilDivision(params.n, vecLen);
static constexpr Reg::CastTrait castTrait = {
Reg::RegLayout::ZERO, Reg::SatMode::SAT, Reg::MaskMergeMode::ZEROING, config.roundMode};
Reg::RegTensor<DstT> dstVreg;
Reg::RegTensor<float> srcVreg;
Reg::RegTensor<float> scaleVreg;
for (uint16_t i0 = 0; i0 < mainRowGroup; ++i0) {
for (uint16_t i1 = 0; i1 < static_cast<uint16_t>(params.groupSize); ++i1) {
QuantizePerGroupForKRowFp4OneRow<DstT, SrcT, ActualScaleT, castTrait>(
dstUb + ((i0 * params.groupSize + i1) * params.n) / 2, srcUb + (i0 * params.groupSize + i1) * params.n,
scaleUb + i0 * params.n, dstVreg, srcVreg, scaleVreg, repeat, params.n, vecLen);
}
}
for (uint16_t i = 0; i < tailRow; ++i) {
QuantizePerGroupForKRowFp4OneRow<DstT, SrcT, ActualScaleT, castTrait>(
dstUb + ((mainRowGroup * params.groupSize + i) * params.n) / 2,
srcUb + (mainRowGroup * params.groupSize + i) * params.n, scaleUb + mainRowGroup * params.n, dstVreg,
srcVreg, scaleVreg, repeat, params.n, vecLen);
}
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT>
__aicore__ inline void QuantizePerGroupForKRowFp4(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale,
const QuantizeParams& params)
{
using ActualScaleT = typename ScaleT::PrimType;
__ubuf__ DstT* dstUb = (__ubuf__ DstT*)dstTensor.GetPhyAddr();
__ubuf__ SrcT* srcUb = (__ubuf__ SrcT*)srcTensor.GetPhyAddr();
__ubuf__ ActualScaleT* scaleUb = (__ubuf__ ActualScaleT*)scale.GetPhyAddr();
uint16_t rowNum = params.m;
uint16_t tailRow = rowNum % params.groupSize;
QuantizePerGroupForKRowFp4VF<config, DstT, SrcT, ScaleT, ActualScaleT>(
dstUb, srcUb, scaleUb, params, rowNum, tailRow);
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerGroupForKRow(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
static_assert(TypeUtils::IsLocalTensorType<ScaleT>(), "Quantize PerGroup ScaleT should be Tensor");
using ActualScaleT = typename ScaleT::PrimType;
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
if constexpr (isScalarOffset) {
static_assert(IsSameType<ActualScaleT, OffsetT>::value, "scale and offset should be the same PrimType");
} else {
using ActualOffsetT = typename OffsetT::PrimType;
static_assert(IsSameType<ActualScaleT, ActualOffsetT>::value, "scale and offset should be the same PrimType");
}
QuantizeUtils::CheckApiDtypeValid<DstT, SrcT, ActualScaleT>();
if constexpr (SupportType<DstT, fp8_e4m3fn_t, fp8_e5m2_t, hifloat8_t, int8_t>()) {
QuantizePerGroupForKRowCommon<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
} else if constexpr (SupportType<DstT, fp4x2_e2m1_t, fp4x2_e1m2_t>()) {
QuantizePerGroupForKRowFp4<config, DstT, SrcT, ScaleT>(dstTensor, srcTensor, scale, params);
} else {
ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR, "unsupport dstT for Quantize!"); });
}
}
template <
const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT,
typename ActualScaleT>
__simd_vf__ inline void QuantizePerTokenCommonVF(
__ubuf__ DstT* dstUb, __ubuf__ SrcT* srcUb, __ubuf__ ActualScaleT* scaleUb, __ubuf__ ActualScaleT* offsetUb,
const OffsetT offset, const QuantizeParams params)
{
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
uint16_t rowNum = params.m;
uint32_t vecLen = GetVecLen() / sizeof(float);
uint16_t repeat = CeilDivision(params.n, vecLen);
static constexpr Reg::CastTrait castTrait = {
Reg::RegLayout::ZERO, Reg::SatMode::SAT, Reg::MaskMergeMode::ZEROING, config.roundMode};
Reg::MaskReg mask;
Reg::RegTensor<float> srcVreg;
Reg::RegTensor<DstT> dstVreg;
Reg::RegTensor<float> scaleVreg;
Reg::RegTensor<float> offsetVreg;
Reg::RegTensor<ActualScaleT> tempScaleVreg;
Reg::RegTensor<ActualScaleT> tempOffsetVreg;
Reg::RegTensor<SrcT> tempSrcVreg;
Reg::MaskReg fullMask = Reg::CreateMask<uint32_t, Reg::MaskPattern::ALL>();
if constexpr (config.hasOffset && isScalarOffset) {
QuantizeUtils::DuplicateScalarToFloatVector<ActualScaleT>(offsetVreg, offset, tempOffsetVreg, fullMask);
}
for (uint16_t i = 0; i < rowNum; ++i) {
if constexpr (SupportType<ActualScaleT, half, bfloat16_t>()) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_BRC_B16>(tempScaleVreg, scaleUb + i);
Reg::Cast<float, ActualScaleT, layoutZMrgZ>(scaleVreg, tempScaleVreg, fullMask);
if constexpr (config.hasOffset && !isScalarOffset) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_BRC_B16>(tempOffsetVreg, offsetUb + i);
Reg::Cast<float, ActualScaleT, layoutZMrgZ>(offsetVreg, tempOffsetVreg, fullMask);
}
} else {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_BRC_B32>(scaleVreg, scaleUb + i);
if constexpr (config.hasOffset && !isScalarOffset) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_BRC_B32>(offsetVreg, offsetUb + i);
}
}
uint32_t sreg = params.n;
for (uint16_t j = 0; j < repeat; ++j) {
mask = Reg::UpdateMask<float>(sreg);
if constexpr (SupportType<SrcT, half, bfloat16_t>()) {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_UNPACK_B16>(tempSrcVreg, srcUb + i * params.n + j * vecLen);
Reg::Cast<float, SrcT, layoutZMrgZ>(srcVreg, tempSrcVreg, mask);
} else {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_NORM>(srcVreg, srcUb + i * params.n + j * vecLen);
}
Reg::Mul(srcVreg, srcVreg, scaleVreg, mask);
if constexpr (config.hasOffset) {
Reg::Add(srcVreg, srcVreg, offsetVreg, mask);
}
QuantizeUtils::CastFp32DstToExpect<DstT, castTrait>(srcVreg, dstVreg, mask);
QuantizeUtils::StoreRes<DstT, float>(dstUb + i * params.n + j * vecLen, dstVreg, mask);
}
}
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerTokenCommon(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
using ActualScaleT = typename ScaleT::PrimType;
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
__ubuf__ DstT* dstUb = (__ubuf__ DstT*)dstTensor.GetPhyAddr();
__ubuf__ SrcT* srcUb = (__ubuf__ SrcT*)srcTensor.GetPhyAddr();
__ubuf__ ActualScaleT* scaleUb = (__ubuf__ ActualScaleT*)scale.GetPhyAddr();
__ubuf__ ActualScaleT* offsetUb = nullptr;
if constexpr (!isScalarOffset) {
offsetUb = (__ubuf__ ActualScaleT*)offset.GetPhyAddr();
}
QuantizePerTokenCommonVF<config, DstT, SrcT, ScaleT, OffsetT, ActualScaleT>(
dstUb, srcUb, scaleUb, offsetUb, offset, params);
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerToken(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
static_assert(TypeUtils::IsLocalTensorType<ScaleT>(), "Quantize PerToken ScaleT should be Tensor");
using ActualScaleT = typename ScaleT::PrimType;
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
if constexpr (isScalarOffset) {
static_assert(IsSameType<ActualScaleT, OffsetT>::value, "scale and offset should be the same PrimType");
} else {
using ActualOffsetT = typename OffsetT::PrimType;
static_assert(IsSameType<ActualScaleT, ActualOffsetT>::value, "scale and offset should be the same PrimType");
}
QuantizeUtils::CheckApiDtypeValid<DstT, SrcT, ActualScaleT>();
QuantizePerTokenCommon<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
}
template <
const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT,
typename ActualScaleT>
__simd_vf__ inline void QuantizePerChannelCommonVF(
__ubuf__ DstT* dstUb, __ubuf__ SrcT* srcUb, __ubuf__ ActualScaleT* scaleUb, __ubuf__ ActualScaleT* offsetUb,
const OffsetT offset, const QuantizeParams params)
{
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
uint16_t colNum = params.n;
uint32_t vecLen = GetVecLen() / sizeof(float);
uint16_t repeat = CeilDivision(colNum, vecLen);
static constexpr Reg::CastTrait castTrait = {
Reg::RegLayout::ZERO, Reg::SatMode::SAT, Reg::MaskMergeMode::ZEROING, config.roundMode};
Reg::MaskReg mask;
Reg::RegTensor<float> srcVreg;
Reg::RegTensor<DstT> dstVreg;
Reg::RegTensor<float> scaleVreg;
Reg::RegTensor<float> offsetVreg;
Reg::RegTensor<ActualScaleT> tempScaleVreg;
Reg::RegTensor<ActualScaleT> tempOffsetVreg;
Reg::RegTensor<SrcT> tempSrcVreg;
Reg::MaskReg fullMask = Reg::CreateMask<uint32_t, Reg::MaskPattern::ALL>();
if constexpr (config.hasOffset && isScalarOffset) {
QuantizeUtils::DuplicateScalarToFloatVector<ActualScaleT>(offsetVreg, offset, tempOffsetVreg, fullMask);
}
for (uint16_t i = 0; i < static_cast<uint16_t>(params.m); ++i) {
uint32_t sreg = colNum;
for (uint16_t j = 0; j < repeat; ++j) {
mask = Reg::UpdateMask<float>(sreg);
if constexpr (SupportType<ActualScaleT, half, bfloat16_t>()) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_UNPACK_B16>(tempScaleVreg, scaleUb + j * vecLen);
Reg::Cast<float, ActualScaleT, layoutZMrgZ>(scaleVreg, tempScaleVreg, fullMask);
if constexpr (config.hasOffset && !isScalarOffset) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_UNPACK_B16>(tempOffsetVreg, offsetUb + j * vecLen);
Reg::Cast<float, ActualScaleT, layoutZMrgZ>(offsetVreg, tempOffsetVreg, fullMask);
}
} else {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_NORM>(scaleVreg, scaleUb + j * vecLen);
if constexpr (config.hasOffset && !isScalarOffset) {
Reg::LoadAlign<ActualScaleT, Reg::LoadDist::DIST_NORM>(offsetVreg, offsetUb + j * vecLen);
}
}
if constexpr (SupportType<SrcT, half, bfloat16_t>()) {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_UNPACK_B16>(tempSrcVreg, srcUb + i * params.n + j * vecLen);
Reg::Cast<float, SrcT, layoutZMrgZ>(srcVreg, tempSrcVreg, mask);
} else {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_NORM>(srcVreg, srcUb + i * params.n + j * vecLen);
}
Reg::Mul(srcVreg, srcVreg, scaleVreg, mask);
if constexpr (config.hasOffset) {
Reg::Add(srcVreg, srcVreg, offsetVreg, mask);
}
QuantizeUtils::CastFp32DstToExpect<DstT, castTrait>(srcVreg, dstVreg, mask);
QuantizeUtils::StoreRes<DstT, float>(dstUb + i * colNum + j * vecLen, dstVreg, mask);
}
}
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerChannelCommon(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
using ActualScaleT = typename ScaleT::PrimType;
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
__ubuf__ DstT* dstUb = (__ubuf__ DstT*)dstTensor.GetPhyAddr();
__ubuf__ SrcT* srcUb = (__ubuf__ SrcT*)srcTensor.GetPhyAddr();
__ubuf__ ActualScaleT* scaleUb = (__ubuf__ ActualScaleT*)scale.GetPhyAddr();
__ubuf__ ActualScaleT* offsetUb = nullptr;
if constexpr (!isScalarOffset) {
offsetUb = (__ubuf__ ActualScaleT*)offset.GetPhyAddr();
}
QuantizePerChannelCommonVF<config, DstT, SrcT, ScaleT, OffsetT, ActualScaleT>(
dstUb, srcUb, scaleUb, offsetUb, offset, params);
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerChannel(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
static_assert(TypeUtils::IsLocalTensorType<ScaleT>(), "Quantize PerChannel ScaleT should be Tensor");
using ActualScaleT = typename ScaleT::PrimType;
constexpr bool isScalarOffset = TypeUtils::IsInnerDefaultType<OffsetT>();
if constexpr (isScalarOffset) {
static_assert(IsSameType<ActualScaleT, OffsetT>::value, "scale and offset should be the same PrimType");
} else {
using ActualOffsetT = typename OffsetT::PrimType;
static_assert(IsSameType<ActualScaleT, ActualOffsetT>::value, "scale and offset should be the same PrimType");
}
QuantizeUtils::CheckApiDtypeValid<DstT, SrcT, ActualScaleT>();
QuantizePerChannelCommon<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__simd_vf__ inline void QuantizePerTensorCommonVF(
__ubuf__ DstT* dstUb, __ubuf__ SrcT* srcUb, const ScaleT scale, const OffsetT offset, const QuantizeParams params)
{
uint16_t rowNum = params.m;
constexpr uint32_t vecLen = GetVecLen() / sizeof(float);
uint16_t repeat = CeilDivision(params.n, vecLen);
static constexpr Reg::CastTrait castTrait = {
Reg::RegLayout::ZERO, Reg::SatMode::SAT, Reg::MaskMergeMode::ZEROING, config.roundMode};
Reg::MaskReg mask;
Reg::RegTensor<float> srcVreg;
Reg::RegTensor<DstT> dstVreg;
Reg::RegTensor<float> scaleVreg;
Reg::RegTensor<float> offsetVreg;
Reg::RegTensor<ScaleT> tempScaleVreg;
Reg::RegTensor<OffsetT> tempOffsetVreg;
Reg::RegTensor<SrcT> tempSrcVreg;
Reg::MaskReg fullMask = Reg::CreateMask<uint32_t, Reg::MaskPattern::ALL>();
QuantizeUtils::DuplicateScalarToFloatVector<ScaleT>(scaleVreg, scale, tempScaleVreg, fullMask);
if constexpr (config.hasOffset) {
QuantizeUtils::DuplicateScalarToFloatVector<ScaleT>(offsetVreg, offset, tempOffsetVreg, fullMask);
}
for (uint16_t i = 0; i < rowNum; ++i) {
uint32_t sreg = params.n;
for (uint16_t j = 0; j < repeat; ++j) {
mask = Reg::UpdateMask<float>(sreg);
if constexpr (SupportType<SrcT, half, bfloat16_t>()) {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_UNPACK_B16>(tempSrcVreg, srcUb + i * params.n + j * vecLen);
Reg::Cast<float, SrcT, layoutZMrgZ>(srcVreg, tempSrcVreg, mask);
} else {
Reg::LoadAlign<SrcT, Reg::LoadDist::DIST_NORM>(srcVreg, srcUb + i * params.n + j * vecLen);
}
Reg::Mul(srcVreg, srcVreg, scaleVreg, mask);
if constexpr (config.hasOffset) {
Reg::Add(srcVreg, srcVreg, offsetVreg, mask);
}
QuantizeUtils::CastFp32DstToExpect<DstT, castTrait>(srcVreg, dstVreg, mask);
QuantizeUtils::StoreRes<DstT, float>(dstUb + i * params.n + j * vecLen, dstVreg, mask);
}
}
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerTensorCommon(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
__ubuf__ DstT* dstUb = (__ubuf__ DstT*)dstTensor.GetPhyAddr();
__ubuf__ SrcT* srcUb = (__ubuf__ SrcT*)srcTensor.GetPhyAddr();
QuantizePerTensorCommonVF<config, DstT, SrcT, ScaleT, OffsetT>(dstUb, srcUb, scale, offset, params);
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizePerTensor(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
static_assert(TypeUtils::IsInnerDefaultType<ScaleT>(), "Quantize PerTensor ScaleT should be Scalar");
static_assert(TypeUtils::IsInnerDefaultType<OffsetT>(), "Quantize PerTensor OffsetT should be Scalar");
static_assert(IsSameType<ScaleT, OffsetT>::value, "ScaleT and OffsetT should be the same type");
QuantizeUtils::CheckApiDtypeValid<DstT, SrcT, ScaleT>();
QuantizePerTensorCommon<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
}
template <typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void CheckQuantizeParams(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset)
{
CheckTensorPosition(dstTensor, "dstTensor", "VECIN, VECOUT, VECCALC");
CheckTensorPosition(srcTensor, "srcTensor", "VECIN, VECOUT, VECCALC");
if constexpr (TypeUtils::IsLocalTensorType<ScaleT>()) {
CheckTensorPosition(scale, "scale", "VECIN, VECOUT, VECCALC");
}
if constexpr (TypeUtils::IsLocalTensorType<OffsetT>()) {
CheckTensorPosition(offset, "offset", "VECIN, VECOUT, VECCALC");
}
static_assert(
SupportType<SrcT, half, float, bfloat16_t>(), "Quantize only support half/float/bfloat16_t input dtype");
if constexpr (TypeUtils::IsInnerDefaultType<ScaleT>()) {
static_assert(
SupportType<ScaleT, half, float, bfloat16_t>(), "Quantize only support half/float/bfloat16_t scale dtype");
} else {
using ActualScaleT = typename ScaleT::PrimType;
static_assert(
SupportType<ActualScaleT, half, float, bfloat16_t>(),
"Quantize only support half/float/bfloat16_t scale dtype");
}
}
template <const QuantizeConfig& config, typename DstT, typename SrcT, typename ScaleT, typename OffsetT>
__aicore__ inline void QuantizeImpl(
const LocalTensor<DstT>& dstTensor, const LocalTensor<SrcT>& srcTensor, const ScaleT& scale, const OffsetT& offset,
const QuantizeParams& params)
{
if ASCEND_IS_AIC {
return;
}
CheckQuantizeParams(dstTensor, srcTensor, scale, offset);
static_assert(
(config.policy == QuantizePolicy::PER_TENSOR || config.policy == QuantizePolicy::PER_CHANNEL ||
config.policy == QuantizePolicy::PER_TOKEN || config.policy == QuantizePolicy::PER_GROUP),
"unsupported policy for Quantize in current device!");
ASCENDC_ASSERT((params.n % GetDataBlockSizeInBytes() == 0), { KERNEL_LOG(KERNEL_ERROR, "n must be 32B aligned"); });
if constexpr (config.policy == QuantizePolicy::PER_TENSOR) {
static_assert(
SupportType<DstT, int8_t, fp8_e4m3fn_t, fp8_e5m2_t, hifloat8_t>(),
"Quantize PerTensor only support int8_t/fp8_e4m3fn_t/fp8_e5m2_t/hifloat8_t output dtype");
QuantizePerTensor<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
} else if constexpr (config.policy == QuantizePolicy::PER_CHANNEL) {
static_assert(
SupportType<DstT, int8_t, fp8_e4m3fn_t, fp8_e5m2_t, hifloat8_t>(),
"Quantize PerChannel only support int8_t/fp8_e4m3fn_t/fp8_e5m2_t/hifloat8_t output dtype");
QuantizePerChannel<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
} else if constexpr (config.policy == QuantizePolicy::PER_TOKEN) {
static_assert(
SupportType<DstT, int8_t, fp8_e4m3fn_t, fp8_e5m2_t, hifloat8_t>(),
"Quantize PerToken only support int8_t/fp8_e4m3fn_t/fp8_e5m2_t/hifloat8_t output dtype");
QuantizePerToken<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
} else if constexpr (config.policy == QuantizePolicy::PER_GROUP) {
static_assert(
SupportType<DstT, int8_t, fp8_e4m3fn_t, fp8_e5m2_t, hifloat8_t, fp4x2_e2m1_t, fp4x2_e1m2_t>(),
"Quantize PerGroup only support "
"int8_t/fp8_e4m3fn_t/fp8_e5m2_t/hifloat8_t/fp4x2_e2m1_t/fp4x2_e1m2_t output dtype");
static_assert(((config.kDim == 1) || (config.kDim == 0)), "Quantize PerGroup only support K is axis 0/1!");
ASCENDC_ASSERT((params.groupSize > 0 && params.groupSize % 32 == 0), {
KERNEL_LOG(KERNEL_ERROR, "groupSize must be an integer multiple of 32 and greater then 0 !");
});
if constexpr (config.kDim == 1) {
QuantizePerGroupForKCol<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
} else {
QuantizePerGroupForKRow<config, DstT, SrcT, ScaleT, OffsetT>(dstTensor, srcTensor, scale, offset, params);
}
}
}
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_QUANTIZATION_QUANTIZE_QUANTIZE_IMPL_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_QUANTIZATION_QUANTIZE_QUANTIZE_IMPL_H__
#endif