Copyright (c) 2026 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#include <pto/pto-inst.hpp>
#include <pto/common/constants.hpp>
#include <pto/npu/a5/TQuant.hpp>
#include "acl/acl.h"
#include <type_traits>
using namespace pto;
#define PTO_CEIL(x, y) ((((x) + (y) - 1) / (y)) * (y))
#define PTO_DIV_ROUNDUP(x, y) (((x) + (y) - 1) / (y))
namespace TQuantTest {
constexpr int TQUANT_A5_UB_ALIGN_BYTES = 32;
constexpr int TQUANT_A5_UB_SIZE_BYTES = 256 * 1024;
template <int validRows, int validCols, int mode, pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
__global__ AICORE void runTQuant(__gm__ uint8_t __out__ *out_e8m0, __gm__ uint8_t __out__ *out_fp8,
__gm__ float __in__ *src)
{
constexpr int paddedCols = PTO_CEIL(validCols, 32);
constexpr int groupedCols_flattened = validRows * (paddedCols / 32);
constexpr int groupedCols_valid = paddedCols / 32;
constexpr int groupedCols_flat_f32 = PTO_CEIL(groupedCols_flattened, 8);
constexpr int groupedCols_flat_u8 = PTO_CEIL(groupedCols_flattened, 32);
using SrcGlobal = GlobalTensor<float, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstE8Global =
GlobalTensor<uint8_t, Shape<1, 1, 1, 1, groupedCols_flattened>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstFP8Global = GlobalTensor<int8_t, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using SrcTile = Tile<TileType::Vec, float, validRows, paddedCols, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512,
PadValue::Zero>;
using DstE8Tile = Tile<TileType::Vec, uint8_t, 1, groupedCols_flat_u8, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
using DstFP8Tile = Tile<TileType::Vec, int8_t, validRows, paddedCols, BLayout::RowMajor, validRows, paddedCols,
SLayout::NoneBox, 512, PadValue::Zero>;
using MaxTile = Tile<TileType::Vec, float, 1, groupedCols_flat_f32, BLayout::RowMajor, -1, -1>;
using ScalingTile = Tile<TileType::Vec, float, 1, groupedCols_flat_f32, BLayout::RowMajor, -1, -1>;
SrcTile srcTile(validRows, validCols);
ScalingTile scalingTile(1, groupedCols_flattened);
DstFP8Tile fp8Tile;
DstE8Tile e8Tile(1, groupedCols_flattened);
MaxTile maxPerGpTile(1, groupedCols_flattened);
SrcGlobal srcGlobal(src);
DstE8Global e8Global(out_e8m0);
DstFP8Global fp8Global((__gm__ int8_t *)out_fp8);
TASSIGN(srcTile, 0x0);
TASSIGN(maxPerGpTile, 0x20100);
TASSIGN(scalingTile, 0x21820);
TASSIGN(e8Tile, 0x24100);
TASSIGN(fp8Tile, 0x25100);
TLOAD(srcTile, srcGlobal);
if constexpr (mode == 0) {
#ifndef __PTO_AUTO__
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
#endif
TQUANT<pto::QuantType::MXFP8, DstFP8Tile, SrcTile, DstE8Tile, MaxTile, ScalingTile, scaleAlg>(
fp8Tile, srcTile, &e8Tile, &maxPerGpTile, &scalingTile);
#ifndef __PTO_AUTO__
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
#endif
TSTORE(e8Global, e8Tile);
TSTORE(fp8Global, fp8Tile);
} else {
constexpr int groupedCols_valid = paddedCols / 32;
using E8ZzTile = Tile<TileType::Vec, uint8_t, validRows, groupedCols_valid, BLayout::RowMajor, -1, -1,
SLayout::RowMajor, 32, PadValue::Zero>;
using E8StoreTile = Tile<TileType::Vec, uint8_t, 1, groupedCols_flattened, BLayout::RowMajor, -1, -1,
SLayout::NoneBox, 512, PadValue::Zero>;
constexpr int tmpBufSize = (16 + (validRows / 16) * (groupedCols_valid / 2) + 16) * sizeof(uint16_t);
constexpr int tmpBufSizeAligned = PTO_CEIL(tmpBufSize, 32);
using TmpTile = Tile<TileType::Vec, uint8_t, 1, tmpBufSizeAligned, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
constexpr int virtualRow = PTO_CEIL(validRows, FRACTAL_NZ_ROW) + 1;
using DstNZ_int8 = Tile<TileType::Vec, int8_t, virtualRow, paddedCols, BLayout::ColMajor, validRows, paddedCols,
SLayout::RowMajor, 512, PadValue::Null, CompactMode::RowPlusOne>;
DstNZ_int8 fp8TileNZ;
E8ZzTile e8ZzTile(validRows, groupedCols_valid);
E8StoreTile e8StoreTile(1, groupedCols_flattened);
TmpTile tmpTile(1, tmpBufSizeAligned);
TASSIGN(fp8TileNZ, 0x0);
TASSIGN(e8ZzTile, 0x20100);
TASSIGN(e8StoreTile, 0x20100);
TASSIGN(tmpTile, 0x30100);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
TQUANT<pto::QuantType::MXFP8, DstFP8Tile, SrcTile, DstE8Tile, MaxTile, ScalingTile, scaleAlg>(
fp8Tile, srcTile, &e8Tile, &maxPerGpTile, &scalingTile);
TMOV(e8ZzTile, e8Tile, tmpTile);
TMOV(fp8TileNZ, fp8Tile);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
TSTORE(e8Global, e8StoreTile);
using DstFp8GlobalNZ = GlobalTensor<int8_t, TileShape2D<int8_t, validRows, paddedCols, Layout::NZ>,
BaseShape2D<int8_t, validRows, paddedCols, Layout::NZ>, Layout::NZ>;
DstFp8GlobalNZ fp8GlobalNZ((__gm__ int8_t *)out_fp8);
TSTORE(fp8GlobalNZ, fp8TileNZ);
}
}
template <int validRows, int validCols, int mode>
__global__ AICORE void runTQuantInt8Sym(__gm__ int8_t __out__ *out_s8, __gm__ float __in__ *src,
__gm__ float __in__ *scale)
{
constexpr int paddedCols_b32 = PTO_CEIL(validCols, BLOCK_BYTE_SIZE / sizeof(float));
constexpr int paddedCols_b8 = PTO_CEIL(validCols, BLOCK_BYTE_SIZE / sizeof(int8_t));
using SrcGlobal = GlobalTensor<float, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstGlobal = GlobalTensor<int8_t, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using ParaGlobal = GlobalTensor<float, Shape<1, 1, 1, validRows, 1>, pto::Stride<1, 1, 1, 1, 1>, pto::Layout::DN>;
using SrcTile = Tile<TileType::Vec, float, validRows, paddedCols_b32, BLayout::RowMajor, -1, -1>;
using DstTile = Tile<TileType::Vec, int8_t, validRows, paddedCols_b8, BLayout::RowMajor, -1, -1>;
using ParaTile = Tile<TileType::Vec, float, validRows, 1, BLayout::ColMajor, -1, -1>;
using TmpTile = Tile<TileType::Vec, float, validRows, BLOCK_BYTE_SIZE / sizeof(float), BLayout::ColMajor, -1, -1>;
SrcTile srcTile(validRows, validCols);
DstTile dstS8Tile(validRows, validCols);
ParaTile scaleTile(validRows, 1);
TmpTile tmpTile(validRows, BLOCK_BYTE_SIZE / sizeof(float));
SrcGlobal srcGlobal(src);
DstGlobal dstGlobal(out_s8);
ParaGlobal scaleGlobal(scale);
TASSIGN(srcTile, 0x0);
TASSIGN(dstS8Tile, 0x0);
TASSIGN(scaleTile, 0x20100);
TASSIGN(tmpTile, 0x28000);
TLOAD(srcTile, srcGlobal);
TLOAD(scaleTile, scaleGlobal);
#ifndef __PTO_AUTO__
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
#endif
TQUANT<pto::QuantType::INT8_SYM, DstTile, SrcTile, ParaTile>(dstS8Tile, srcTile, scaleTile, tmpTile);
#ifndef __PTO_AUTO__
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
#endif
TSTORE(dstGlobal, dstS8Tile);
}
template <int validRows, int validCols, int mode>
__global__ AICORE void runTQuantInt8Asym(__gm__ uint8_t __out__ *out_u8, __gm__ float __in__ *src,
__gm__ float __in__ *scale, __gm__ float __in__ *offset)
{
constexpr int paddedCols_b32 = PTO_CEIL(validCols, BLOCK_BYTE_SIZE / sizeof(float));
constexpr int paddedCols_b8 = PTO_CEIL(validCols, BLOCK_BYTE_SIZE / sizeof(uint8_t));
using SrcGlobal = GlobalTensor<float, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstGlobal = GlobalTensor<uint8_t, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using ParaGlobal = GlobalTensor<float, Shape<1, 1, 1, validRows, 1>, pto::Stride<1, 1, 1, 1, 1>, pto::Layout::DN>;
using SrcTile = Tile<TileType::Vec, float, validRows, paddedCols_b32, BLayout::RowMajor, -1, -1>;
using DstTile = Tile<TileType::Vec, uint8_t, validRows, paddedCols_b8, BLayout::RowMajor, -1, -1>;
using ParaTile = Tile<TileType::Vec, float, validRows, 1, BLayout::ColMajor, -1, -1>;
using TmpTile = Tile<TileType::Vec, float, validRows, BLOCK_BYTE_SIZE / sizeof(float), BLayout::ColMajor, -1, -1>;
SrcTile srcTile(validRows, validCols);
DstTile dstU8Tile(validRows, validCols);
ParaTile scaleTile(validRows, 1);
ParaTile offsetTile(validRows, 1);
TmpTile tmpTile(validRows, BLOCK_BYTE_SIZE / sizeof(float));
SrcGlobal srcGlobal(src);
DstGlobal dstGlobal(out_u8);
ParaGlobal scaleGlobal(scale);
ParaGlobal offsetGlobal(offset);
TASSIGN(srcTile, 0x0);
TASSIGN(dstU8Tile, 0x20100);
TASSIGN(scaleTile, 0x30100);
TASSIGN(offsetTile, 0x32500);
TASSIGN(tmpTile, 0x34000);
TLOAD(srcTile, srcGlobal);
TLOAD(scaleTile, scaleGlobal);
TLOAD(offsetTile, offsetGlobal);
#ifndef __PTO_AUTO__
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
#endif
TQUANT<pto::QuantType::INT8_ASYM, DstTile, SrcTile, ParaTile>(dstU8Tile, srcTile, scaleTile, tmpTile, &offsetTile);
#ifndef __PTO_AUTO__
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
#endif
TSTORE(dstGlobal, dstU8Tile);
}
template <int validRows, int validCols, int mode, pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP8(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream)
{
runTQuant<validRows, validCols, mode, scaleAlg><<<1, nullptr, stream>>>(dst_exp, dst, src);
}
template <int validRows, int validCols, int mode, pto::QuantType quantType>
void LaunchTQuantInt8(std::conditional_t<quantType == pto::QuantType::INT8_SYM, int8_t, uint8_t> *dst, float *src,
float *scale, void *stream, float *offset = nullptr)
{
if constexpr (quantType == pto::QuantType::INT8_SYM) {
runTQuantInt8Sym<validRows, validCols, mode><<<1, nullptr, stream>>>(dst, src, scale);
} else {
runTQuantInt8Asym<validRows, validCols, mode><<<1, nullptr, stream>>>(dst, src, scale, offset);
}
}
template <int validRows, int validCols, int mode, pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
__global__ AICORE void runTQuantBF16(__gm__ uint8_t __out__ *out_e8m0, __gm__ uint8_t __out__ *out_fp8,
__gm__ bfloat16_t __in__ *src)
{
constexpr int paddedCols = PTO_CEIL(validCols, 32);
constexpr int groupedCols_flattened = validRows * (paddedCols / 32);
constexpr int groupedCols_valid = paddedCols / 32;
constexpr int groupedCols_flat_u8 = PTO_CEIL(groupedCols_flattened, 32);
constexpr int groupedCols_flat_b16 = PTO_CEIL(groupedCols_flattened, 16);
using SrcGlobal =
GlobalTensor<bfloat16_t, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstE8Global =
GlobalTensor<uint8_t, Shape<1, 1, 1, 1, groupedCols_flattened>, pto::Stride<1, 1, 1, groupedCols_flattened, 1>>;
using DstFP8Global = GlobalTensor<int8_t, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using SrcTile = Tile<TileType::Vec, bfloat16_t, validRows, paddedCols, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
using DstE8Tile = Tile<TileType::Vec, uint8_t, 1, groupedCols_flat_u8, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
using DstFP8Tile = Tile<TileType::Vec, int8_t, validRows, paddedCols, BLayout::RowMajor, validRows, paddedCols,
SLayout::NoneBox, 512, PadValue::Zero>;
using MaxTile = Tile<TileType::Vec, bfloat16_t, 1, groupedCols_flat_b16, BLayout::RowMajor, -1, -1>;
using ScalingTile = Tile<TileType::Vec, bfloat16_t, 1, groupedCols_flat_b16, BLayout::RowMajor, -1, -1>;
SrcTile srcTile(validRows, validCols);
ScalingTile scalingTile(1, groupedCols_flattened);
DstFP8Tile fp8Tile;
DstE8Tile e8Tile(1, groupedCols_flattened);
MaxTile maxPerGpTile(1, groupedCols_flattened);
SrcGlobal srcGlobal(src);
DstE8Global e8Global(out_e8m0);
DstFP8Global fp8Global((__gm__ int8_t *)out_fp8);
TASSIGN(srcTile, 0x0);
TASSIGN(maxPerGpTile, 0x8100);
TASSIGN(scalingTile, 0x9200);
TASSIGN(e8Tile, 0xA300);
TASSIGN(fp8Tile, 0xB400);
TLOAD(srcTile, srcGlobal);
if constexpr (mode == 0) {
#ifndef __PTO_AUTO__
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
#endif
TQUANT<pto::QuantType::MXFP8, DstFP8Tile, SrcTile, DstE8Tile, MaxTile, ScalingTile, scaleAlg>(
fp8Tile, srcTile, &e8Tile, &maxPerGpTile, &scalingTile);
#ifndef __PTO_AUTO__
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
#endif
TSTORE(e8Global, e8Tile);
TSTORE(fp8Global, fp8Tile);
} else {
constexpr int groupedCols_valid = paddedCols / 32;
using E8ZzTile = Tile<TileType::Vec, uint8_t, validRows, groupedCols_valid, BLayout::RowMajor, -1, -1,
SLayout::RowMajor, 32, PadValue::Zero>;
using E8StoreTile = Tile<TileType::Vec, uint8_t, 1, groupedCols_flattened, BLayout::RowMajor, -1, -1,
SLayout::NoneBox, 512, PadValue::Zero>;
constexpr int tmpBufSize = (16 + (validRows / 16) * (groupedCols_valid / 2) + 16) * sizeof(uint16_t);
constexpr int tmpBufSizeAligned = PTO_CEIL(tmpBufSize, 32);
using TmpTile = Tile<TileType::Vec, uint8_t, 1, tmpBufSizeAligned, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
constexpr int virtualRow = PTO_CEIL(validRows, FRACTAL_NZ_ROW) + 1;
using DstNZ_int8 = Tile<TileType::Vec, int8_t, virtualRow, paddedCols, BLayout::ColMajor, validRows, paddedCols,
SLayout::RowMajor, 512, PadValue::Null, CompactMode::RowPlusOne>;
DstNZ_int8 fp8TileNZ;
E8ZzTile e8ZzTile(validRows, groupedCols_valid);
E8StoreTile e8StoreTile(1, groupedCols_flattened);
TmpTile tmpTile(1, tmpBufSizeAligned);
TASSIGN(fp8TileNZ, 0x0);
TASSIGN(e8ZzTile, 0x8100);
TASSIGN(e8StoreTile, 0x8100);
TASSIGN(tmpTile, 0x9200);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
TQUANT<pto::QuantType::MXFP8, DstFP8Tile, SrcTile, DstE8Tile, MaxTile, ScalingTile, scaleAlg>(
fp8Tile, srcTile, &e8Tile, &maxPerGpTile, &scalingTile);
TMOV(e8ZzTile, e8Tile, tmpTile);
TMOV(fp8TileNZ, fp8Tile);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
TSTORE(e8Global, e8StoreTile);
using DstFp8GlobalNZ = GlobalTensor<int8_t, TileShape2D<int8_t, validRows, paddedCols, Layout::NZ>,
BaseShape2D<int8_t, validRows, paddedCols, Layout::NZ>, Layout::NZ>;
DstFp8GlobalNZ fp8GlobalNZ((__gm__ int8_t *)out_fp8);
TSTORE(fp8GlobalNZ, fp8TileNZ);
}
}
template <int validRows, int validCols, int mode, pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP8_BF16(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream)
{
runTQuantBF16<validRows, validCols, mode, scaleAlg><<<1, nullptr, stream>>>(dst_exp, dst, (bfloat16_t *)src);
}
template <int validRows, int validCols, int mode, pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
__global__ AICORE void runTQuantFP16(__gm__ uint8_t __out__ *out_e8m0, __gm__ uint8_t __out__ *out_fp8,
__gm__ half __in__ *src)
{
constexpr int paddedCols = PTO_CEIL(validCols, 32);
constexpr int groupedCols_flattened = validRows * (paddedCols / 32);
constexpr int groupedCols_valid = paddedCols / 32;
constexpr int groupedCols_flat_u8 = PTO_CEIL(groupedCols_flattened, 32);
constexpr int groupedCols_flat_b16 = PTO_CEIL(groupedCols_flattened, 16);
using SrcGlobal = GlobalTensor<half, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstE8Global =
GlobalTensor<uint8_t, Shape<1, 1, 1, 1, groupedCols_flattened>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstFP8Global = GlobalTensor<int8_t, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using SrcTile = Tile<TileType::Vec, half, validRows, paddedCols, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512,
PadValue::Zero>;
using DstE8Tile = Tile<TileType::Vec, uint8_t, 1, groupedCols_flat_u8, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
using DstFP8Tile = Tile<TileType::Vec, int8_t, validRows, paddedCols, BLayout::RowMajor, validRows, paddedCols,
SLayout::NoneBox, 512, PadValue::Zero>;
using MaxTile = Tile<TileType::Vec, half, 1, groupedCols_flat_b16, BLayout::RowMajor, -1, -1>;
using ScalingTile = Tile<TileType::Vec, half, 1, groupedCols_flat_b16, BLayout::RowMajor, -1, -1>;
SrcTile srcTile(validRows, validCols);
ScalingTile scalingTile(1, groupedCols_flattened);
DstFP8Tile fp8Tile;
DstE8Tile e8Tile(1, groupedCols_flattened);
MaxTile maxPerGpTile(1, groupedCols_flattened);
SrcGlobal srcGlobal(src);
DstE8Global e8Global(out_e8m0);
DstFP8Global fp8Global((__gm__ int8_t *)out_fp8);
TASSIGN(srcTile, 0x0);
TASSIGN(maxPerGpTile, 0x10100);
TASSIGN(scalingTile, 0x10600);
TASSIGN(e8Tile, 0x20700);
TASSIGN(fp8Tile, 0x20C00);
TLOAD(srcTile, srcGlobal);
if constexpr (mode == 0) {
#ifndef __PTO_AUTO__
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
#endif
TQUANT<pto::QuantType::MXFP8, DstFP8Tile, SrcTile, DstE8Tile, MaxTile, ScalingTile, scaleAlg>(
fp8Tile, srcTile, &e8Tile, &maxPerGpTile, &scalingTile);
#ifndef __PTO_AUTO__
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
#endif
TSTORE(e8Global, e8Tile);
TSTORE(fp8Global, fp8Tile);
} else {
constexpr int groupedCols_valid = paddedCols / 32;
using E8ZzTile = Tile<TileType::Vec, uint8_t, validRows, groupedCols_valid, BLayout::RowMajor, -1, -1,
SLayout::RowMajor, 32, PadValue::Zero>;
using E8StoreTile = Tile<TileType::Vec, uint8_t, 1, groupedCols_flattened, BLayout::RowMajor, -1, -1,
SLayout::NoneBox, 512, PadValue::Zero>;
constexpr int tmpBufSize = (16 + (validRows / 16) * (groupedCols_valid / 2) + 16) * sizeof(uint16_t);
constexpr int tmpBufSizeAligned = PTO_CEIL(tmpBufSize, 32);
using TmpTile = Tile<TileType::Vec, uint8_t, 1, tmpBufSizeAligned, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
constexpr int virtualRow = PTO_CEIL(validRows, FRACTAL_NZ_ROW) + 1;
using DstNZ_int8 = Tile<TileType::Vec, int8_t, virtualRow, paddedCols, BLayout::ColMajor, validRows, paddedCols,
SLayout::RowMajor, 512, PadValue::Null, CompactMode::RowPlusOne>;
DstNZ_int8 fp8TileNZ;
E8ZzTile e8ZzTile(validRows, groupedCols_valid);
E8StoreTile e8StoreTile(1, groupedCols_flattened);
TmpTile tmpTile(1, tmpBufSizeAligned);
TASSIGN(fp8TileNZ, 0x0);
TASSIGN(e8ZzTile, 0x10100);
TASSIGN(e8StoreTile, 0x10100);
TASSIGN(tmpTile, 0x10600);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
TQUANT<pto::QuantType::MXFP8, DstFP8Tile, SrcTile, DstE8Tile, MaxTile, ScalingTile, scaleAlg>(
fp8Tile, srcTile, &e8Tile, &maxPerGpTile, &scalingTile);
TMOV(e8ZzTile, e8Tile, tmpTile);
TMOV(fp8TileNZ, fp8Tile);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
TSTORE(e8Global, e8StoreTile);
using DstFp8GlobalNZ = GlobalTensor<int8_t, TileShape2D<int8_t, validRows, paddedCols, Layout::NZ>,
BaseShape2D<int8_t, validRows, paddedCols, Layout::NZ>, Layout::NZ>;
DstFp8GlobalNZ fp8GlobalNZ((__gm__ int8_t *)out_fp8);
TSTORE(fp8GlobalNZ, fp8TileNZ);
}
}
template <int validRows, int validCols, int mode, pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP8_FP16(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream)
{
runTQuantFP16<validRows, validCols, mode, scaleAlg><<<1, nullptr, stream>>>(dst_exp, dst, (half *)src);
}
template <typename SrcT, int staticRows, int staticCols, int validRows, int validCols,
pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
struct MxFp8Exp2DSpec {
static_assert(staticCols % 64 == 0, "static tail axis must be 64-aligned");
static_assert(validCols % 64 == 0, "valid tail axis must be 64-aligned");
static_assert(validRows <= staticRows && validCols <= staticCols, "valid shape must fit in static tile shape");
static constexpr int rows = validRows;
static constexpr int cols = validCols;
static constexpr pto::QuantScaleAlg scaleAlgValue = scaleAlg;
static constexpr int validGroups = validCols / 32;
static constexpr int expCols = PTO_CEIL(staticCols / 32, 32);
static constexpr int packedGroups = validRows * validGroups;
static constexpr int packedGroupsAligned = PTO_CEIL(packedGroups, 16);
static constexpr int ubAlign = TQUANT_A5_UB_ALIGN_BYTES;
static constexpr int srcBytes = staticRows * staticCols * sizeof(SrcT);
static constexpr int maxBytes = packedGroupsAligned * sizeof(SrcT);
static constexpr int scalingBytes = maxBytes;
static constexpr int e8Bytes = staticRows * expCols * sizeof(uint8_t);
static constexpr int fp8Bytes = staticRows * staticCols * sizeof(uint8_t);
static constexpr int maxOffset = PTO_CEIL(srcBytes, ubAlign);
static constexpr int scalingOffset = PTO_CEIL(maxOffset + maxBytes, ubAlign);
static constexpr int e8Offset = PTO_CEIL(scalingOffset + scalingBytes, ubAlign);
static constexpr int fp8Offset = PTO_CEIL(e8Offset + e8Bytes, ubAlign);
static_assert(fp8Offset + fp8Bytes < TQUANT_A5_UB_SIZE_BYTES, "TQuant MXFP8 2D test UB layout exceeds UB size");
using SrcGlobal = GlobalTensor<SrcT, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstE8Global =
GlobalTensor<uint8_t, Shape<1, 1, 1, validRows, validGroups>, pto::Stride<1, 1, 1, validGroups, 1>>;
using DstFP8Global = GlobalTensor<int8_t, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using SrcTile = Tile<TileType::Vec, SrcT, staticRows, staticCols, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512,
PadValue::Zero>;
using DstE8Tile = Tile<TileType::Vec, uint8_t, staticRows, expCols, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
using DstFP8Tile = Tile<TileType::Vec, int8_t, staticRows, staticCols, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
using MaxTile = Tile<TileType::Vec, SrcT, 1, packedGroupsAligned, BLayout::RowMajor, -1, -1>;
using ScalingTile = Tile<TileType::Vec, SrcT, 1, packedGroupsAligned, BLayout::RowMajor, -1, -1>;
};
template <typename Spec>
__global__ AICORE void runTQuantExp2D(__gm__ uint8_t __out__ *out_e8m0, __gm__ uint8_t __out__ *out_fp8,
__gm__ typename Spec::SrcGlobal::RawDType __in__ *src)
{
typename Spec::SrcTile srcTile(Spec::rows, Spec::cols);
typename Spec::ScalingTile scalingTile(1, Spec::packedGroups);
typename Spec::DstFP8Tile fp8Tile(Spec::rows, Spec::cols);
typename Spec::DstE8Tile e8Tile(Spec::rows, Spec::validGroups);
typename Spec::MaxTile maxPerGpTile(1, Spec::packedGroups);
typename Spec::SrcGlobal srcGlobal(src);
typename Spec::DstE8Global e8Global(out_e8m0);
typename Spec::DstFP8Global fp8Global((__gm__ int8_t *)out_fp8);
TASSIGN(srcTile, 0x0);
TASSIGN(maxPerGpTile, Spec::maxOffset);
TASSIGN(scalingTile, Spec::scalingOffset);
TASSIGN(e8Tile, Spec::e8Offset);
TASSIGN(fp8Tile, Spec::fp8Offset);
TLOAD(srcTile, srcGlobal);
#ifndef __PTO_AUTO__
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
#endif
TQUANT<pto::QuantType::MXFP8, typename Spec::DstFP8Tile, typename Spec::SrcTile, typename Spec::DstE8Tile,
typename Spec::MaxTile, typename Spec::ScalingTile, Spec::scaleAlgValue>(fp8Tile, srcTile, &e8Tile,
&maxPerGpTile, &scalingTile);
#ifndef __PTO_AUTO__
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
#endif
TSTORE(e8Global, e8Tile);
TSTORE(fp8Global, fp8Tile);
}
template <int staticRows, int staticCols, int validRows, int validCols,
pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP8_FP32_Exp2D(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream)
{
using Spec = MxFp8Exp2DSpec<float, staticRows, staticCols, validRows, validCols, scaleAlg>;
runTQuantExp2D<Spec><<<1, nullptr, stream>>>(dst_exp, dst, src);
}
template <int staticRows, int staticCols, int validRows, int validCols,
pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP8_BF16_Exp2D(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream)
{
using Spec = MxFp8Exp2DSpec<bfloat16_t, staticRows, staticCols, validRows, validCols, scaleAlg>;
runTQuantExp2D<Spec><<<1, nullptr, stream>>>(dst_exp, dst, (bfloat16_t *)src);
}
template <int staticRows, int staticCols, int validRows, int validCols,
pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP8_FP16_Exp2D(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream)
{
using Spec = MxFp8Exp2DSpec<half, staticRows, staticCols, validRows, validCols, scaleAlg>;
runTQuantExp2D<Spec><<<1, nullptr, stream>>>(dst_exp, dst, (half *)src);
}
PTO_INTERNAL void CompactFp4PackedRows(__ubuf__ uint8_t *dstPtr, __ubuf__ uint8_t *srcPtr, uint32_t rows,
uint32_t validPackedCols, uint32_t srcStride, uint32_t dstStride)
{
constexpr uint32_t elementsPerRepeat = REPEAT_BYTE / sizeof(uint8_t);
RegTensor<uint8_t> vreg;
UnalignReg ureg;
uint16_t repeatTimes = CeilDivision(validPackedCols, elementsPerRepeat);
for (uint16_t row = 0; row < (uint16_t)rows; ++row) {
uint32_t remaining = validPackedCols;
__ubuf__ uint8_t *srcRow = srcPtr + row * srcStride;
__ubuf__ uint8_t *dstRow = dstPtr + row * dstStride;
for (uint16_t repeat = 0; repeat < repeatTimes; ++repeat) {
uint32_t cols = remaining > elementsPerRepeat ? elementsPerRepeat : remaining;
MaskReg preg = CreatePredicate<uint8_t>(cols);
uint32_t offset = repeat * elementsPerRepeat;
vldas(ureg, srcRow + offset);
vldus(vreg, ureg, srcRow + offset);
vsts(vreg, dstRow, offset, NORM_B8, preg);
remaining -= cols;
}
}
}
template <typename SrcT, int validRows, int validCols>
struct MxFp4B16Spec {
static constexpr int rows = validRows;
static constexpr int paddedCols = PTO_CEIL(validCols, 32);
static constexpr int packedCols = paddedCols / 2;
static constexpr int validPackedCols = PTO_DIV_ROUNDUP(validCols, 2);
static constexpr int groupedColsFlattened = validRows * (paddedCols / 32);
static constexpr int groupedColsValid = paddedCols / 32;
static constexpr int groupedColsE8Static = PTO_CEIL(groupedColsValid, 32);
static constexpr int groupedColsB16Static = PTO_CEIL(groupedColsValid, 16);
static constexpr int groupedColsFlatAligned = PTO_CEIL(groupedColsFlattened, 32);
static constexpr int groupedColsB16FlatAligned = PTO_CEIL(groupedColsFlattened, 16);
static constexpr int fp4FlatAligned = PTO_CEIL(validRows * packedCols, 32);
static constexpr int validPackedColsAligned = PTO_CEIL(validPackedCols, 32);
static constexpr int ubAlign = TQUANT_A5_UB_ALIGN_BYTES;
static constexpr int srcBytes = validRows * paddedCols * sizeof(SrcT);
static constexpr int maxBytes = groupedColsB16FlatAligned * sizeof(SrcT);
static constexpr int scalingBytes = maxBytes;
static constexpr int e8Bytes = validRows * groupedColsE8Static * sizeof(uint8_t);
static constexpr int fp4Bytes = validRows * packedCols * sizeof(uint8_t);
static constexpr int fp4StoreBytes = validRows * validPackedColsAligned * sizeof(uint8_t);
static constexpr int maxOffset = PTO_CEIL(srcBytes, ubAlign);
static constexpr int scalingOffset = PTO_CEIL(maxOffset + maxBytes, ubAlign);
static constexpr int e8Offset = PTO_CEIL(scalingOffset + scalingBytes, ubAlign);
static constexpr int fp4Offset = PTO_CEIL(e8Offset + e8Bytes, ubAlign);
static constexpr int fp4StoreOffset = PTO_CEIL(fp4Offset + fp4Bytes, ubAlign);
using SrcGlobal = GlobalTensor<SrcT, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstE8Global =
GlobalTensor<uint8_t, Shape<1, 1, 1, validRows, groupedColsValid>, pto::Stride<1, 1, 1, groupedColsValid, 1>>;
using DstFP4Global =
GlobalTensor<uint8_t, Shape<1, 1, 1, validRows, validPackedCols>, pto::Stride<1, 1, 1, validPackedCols, 1>>;
using SrcTile = Tile<TileType::Vec, SrcT, validRows, paddedCols, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512,
PadValue::Zero>;
using DstE8Tile = Tile<TileType::Vec, uint8_t, validRows, groupedColsE8Static, BLayout::RowMajor, -1, -1,
SLayout::NoneBox, 512, PadValue::Zero>;
using DstFP4Tile = Tile<TileType::Vec, float4_e2m1x2_t, 1, fp4FlatAligned, BLayout::RowMajor, -1, -1,
SLayout::NoneBox, 512, PadValue::Zero>;
using DstBytesTile = Tile<TileType::Vec, uint8_t, validRows, validPackedColsAligned, BLayout::RowMajor, -1, -1,
SLayout::NoneBox, 512, PadValue::Zero>;
using MaxTile = Tile<TileType::Vec, SrcT, 1, groupedColsB16FlatAligned, BLayout::RowMajor, -1, -1>;
using ScalingTile = Tile<TileType::Vec, SrcT, 1, groupedColsB16FlatAligned, BLayout::RowMajor, -1, -1>;
};
template <typename Spec>
PTO_INTERNAL void AssignMxFp4B16Tiles(typename Spec::SrcTile &srcTile, typename Spec::MaxTile &maxPerGpTile,
typename Spec::ScalingTile &scalingTile, typename Spec::DstE8Tile &e8Tile,
typename Spec::DstFP4Tile &fp4Tile, typename Spec::DstBytesTile &fp4BytesTile)
{
TASSIGN(srcTile, 0x0);
TASSIGN(maxPerGpTile, Spec::maxOffset);
TASSIGN(scalingTile, Spec::scalingOffset);
TASSIGN(e8Tile, Spec::e8Offset);
TASSIGN(fp4Tile, Spec::fp4Offset);
TASSIGN(fp4BytesTile, Spec::fp4StoreOffset);
}
template <typename Spec>
PTO_INTERNAL void StoreMxFp4B16Result(typename Spec::DstFP4Tile &fp4Tile, typename Spec::DstBytesTile &fp4BytesTile,
typename Spec::DstE8Tile &e8Tile, typename Spec::DstE8Global &e8Global,
typename Spec::DstFP4Global &fp4Global)
{
__VEC_SCOPE__
{
mem_bar(VST_VLD);
CompactFp4PackedRows((__ubuf__ uint8_t *)fp4BytesTile.data(), (__ubuf__ uint8_t *)fp4Tile.data(), Spec::rows,
Spec::validPackedCols, Spec::packedCols, Spec::validPackedColsAligned);
mem_bar(VST_VST);
}
#ifndef __PTO_AUTO__
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
#endif
TSTORE(e8Global, e8Tile);
TSTORE(fp4Global, fp4BytesTile);
}
template <typename SrcT, int validRows, int validCols, pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
__global__ AICORE void runTQuantMXFP4E2M1B16(__gm__ uint8_t __out__ *out_e8m0, __gm__ uint8_t __out__ *out_fp4,
__gm__ SrcT __in__ *src)
{
using Spec = MxFp4B16Spec<SrcT, validRows, validCols>;
static_assert(Spec::fp4StoreOffset + Spec::fp4StoreBytes <= TQUANT_A5_UB_SIZE_BYTES,
"MXFP4 E2M1 test kernel UB layout exceeds UB capacity.");
typename Spec::SrcTile srcTile(validRows, validCols);
typename Spec::DstFP4Tile fp4Tile(1, Spec::fp4Bytes);
typename Spec::DstBytesTile fp4BytesTile(validRows, Spec::validPackedCols);
typename Spec::DstE8Tile e8Tile(validRows, Spec::groupedColsValid);
typename Spec::MaxTile maxPerGpTile(1, Spec::groupedColsFlattened);
typename Spec::ScalingTile scalingTile(1, Spec::groupedColsFlattened);
typename Spec::SrcGlobal srcGlobal(src);
typename Spec::DstE8Global e8Global(out_e8m0);
typename Spec::DstFP4Global fp4Global(out_fp4);
AssignMxFp4B16Tiles<Spec>(srcTile, maxPerGpTile, scalingTile, e8Tile, fp4Tile, fp4BytesTile);
TLOAD(srcTile, srcGlobal);
#ifndef __PTO_AUTO__
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
#endif
if constexpr (scaleAlg == pto::QuantScaleAlg::OCP) {
TQUANT<pto::QuantType::MXFP4_E2M1, typename Spec::DstFP4Tile, typename Spec::SrcTile, typename Spec::DstE8Tile,
typename Spec::MaxTile>(fp4Tile, srcTile, &e8Tile, &maxPerGpTile, &scalingTile);
} else {
TQUANT<pto::QuantType::MXFP4_E2M1, typename Spec::DstFP4Tile, typename Spec::SrcTile, typename Spec::DstE8Tile,
typename Spec::MaxTile, typename Spec::ScalingTile, scaleAlg>(fp4Tile, srcTile, &e8Tile, &maxPerGpTile,
&scalingTile);
}
StoreMxFp4B16Result<Spec>(fp4Tile, fp4BytesTile, e8Tile, e8Global, fp4Global);
}
template <int validRows, int validCols, pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP4_E2M1_FP16(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream)
{
runTQuantMXFP4E2M1B16<half, validRows, validCols, scaleAlg><<<1, nullptr, stream>>>(dst_exp, dst, (half *)src);
}
template <int validRows, int validCols, pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP4_E2M1_BF16(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream)
{
runTQuantMXFP4E2M1B16<bfloat16_t, validRows, validCols, scaleAlg>
<<<1, nullptr, stream>>>(dst_exp, dst, (bfloat16_t *)src);
}
template <typename SrcT, int staticRows, int staticCols, int validRows, int validCols,
pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
struct MxFp4B16Exp2DSpec {
static_assert(staticCols % 32 == 0, "static tail axis must be 32-aligned");
static_assert(validCols % 32 == 0, "valid tail axis must be 32-aligned");
static_assert(validRows <= staticRows && validCols <= staticCols, "valid shape must fit in static tile shape");
static constexpr int rows = validRows;
static constexpr int cols = validCols;
static constexpr pto::QuantScaleAlg scaleAlgValue = scaleAlg;
static constexpr int validGroups = validCols / 32;
static constexpr int expCols = PTO_CEIL(staticCols / 32, 32);
static constexpr int packedGroups = validRows * validGroups;
static constexpr int packedGroupsAligned = PTO_CEIL(packedGroups, 16);
static constexpr int staticPackedCols = staticCols / 2;
static constexpr int packedCols = staticPackedCols;
static constexpr int validPackedCols = PTO_DIV_ROUNDUP(validCols, 2);
static constexpr int validPackedColsAligned = PTO_CEIL(validPackedCols, 32);
static constexpr int ubAlign = TQUANT_A5_UB_ALIGN_BYTES;
static constexpr int srcBytes = staticRows * staticCols * sizeof(SrcT);
static constexpr int maxBytes = packedGroupsAligned * sizeof(SrcT);
static constexpr int scalingBytes = maxBytes;
static constexpr int e8Bytes = staticRows * expCols * sizeof(uint8_t);
static constexpr int fp4Bytes = staticRows * staticPackedCols * sizeof(uint8_t);
static constexpr int fp4StoreBytes = validRows * validPackedColsAligned * sizeof(uint8_t);
static constexpr int maxOffset = PTO_CEIL(srcBytes, ubAlign);
static constexpr int scalingOffset = PTO_CEIL(maxOffset + maxBytes, ubAlign);
static constexpr int e8Offset = PTO_CEIL(scalingOffset + scalingBytes, ubAlign);
static constexpr int fp4Offset = PTO_CEIL(e8Offset + e8Bytes, ubAlign);
static constexpr int fp4StoreOffset = PTO_CEIL(fp4Offset + fp4Bytes, ubAlign);
using SrcGlobal = GlobalTensor<SrcT, Shape<1, 1, 1, validRows, validCols>, pto::Stride<1, 1, 1, validCols, 1>>;
using DstE8Global =
GlobalTensor<uint8_t, Shape<1, 1, 1, validRows, validGroups>, pto::Stride<1, 1, 1, validGroups, 1>>;
using DstFP4Global =
GlobalTensor<uint8_t, Shape<1, 1, 1, validRows, validPackedCols>, pto::Stride<1, 1, 1, validPackedCols, 1>>;
using SrcTile = Tile<TileType::Vec, SrcT, staticRows, staticCols, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512,
PadValue::Zero>;
using DstE8Tile = Tile<TileType::Vec, uint8_t, staticRows, expCols, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
using DstFP4Tile = Tile<TileType::Vec, float4_e2m1x2_t, 1, fp4Bytes, BLayout::RowMajor, -1, -1, SLayout::NoneBox,
512, PadValue::Zero>;
using DstBytesTile = Tile<TileType::Vec, uint8_t, validRows, validPackedColsAligned, BLayout::RowMajor, -1, -1,
SLayout::NoneBox, 512, PadValue::Zero>;
using MaxTile = Tile<TileType::Vec, SrcT, 1, packedGroupsAligned, BLayout::RowMajor, -1, -1>;
using ScalingTile = Tile<TileType::Vec, SrcT, 1, packedGroupsAligned, BLayout::RowMajor, -1, -1>;
};
template <typename Spec>
__global__ AICORE void runTQuantMXFP4E2M1B16Exp2D(__gm__ uint8_t __out__ *out_e8m0, __gm__ uint8_t __out__ *out_fp4,
__gm__ typename Spec::SrcGlobal::RawDType __in__ *src)
{
static_assert(Spec::fp4StoreOffset + Spec::fp4StoreBytes <= TQUANT_A5_UB_SIZE_BYTES,
"MXFP4 E2M1 2D test kernel UB layout exceeds UB capacity.");
typename Spec::SrcTile srcTile(Spec::rows, Spec::cols);
typename Spec::DstFP4Tile fp4Tile(1, Spec::fp4Bytes);
typename Spec::DstBytesTile fp4BytesTile(Spec::rows, Spec::validPackedCols);
typename Spec::DstE8Tile e8Tile(Spec::rows, Spec::validGroups);
typename Spec::MaxTile maxPerGpTile(1, Spec::packedGroups);
typename Spec::ScalingTile scalingTile(1, Spec::packedGroups);
typename Spec::SrcGlobal srcGlobal(src);
typename Spec::DstE8Global e8Global(out_e8m0);
typename Spec::DstFP4Global fp4Global(out_fp4);
TASSIGN(srcTile, 0x0);
TASSIGN(maxPerGpTile, Spec::maxOffset);
TASSIGN(scalingTile, Spec::scalingOffset);
TASSIGN(e8Tile, Spec::e8Offset);
TASSIGN(fp4Tile, Spec::fp4Offset);
TASSIGN(fp4BytesTile, Spec::fp4StoreOffset);
TLOAD(srcTile, srcGlobal);
#ifndef __PTO_AUTO__
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
#endif
TQUANT<pto::QuantType::MXFP4_E2M1, typename Spec::DstFP4Tile, typename Spec::SrcTile, typename Spec::DstE8Tile,
typename Spec::MaxTile, typename Spec::ScalingTile, Spec::scaleAlgValue>(fp4Tile, srcTile, &e8Tile,
&maxPerGpTile, &scalingTile);
StoreMxFp4B16Result<Spec>(fp4Tile, fp4BytesTile, e8Tile, e8Global, fp4Global);
}
template <int staticRows, int staticCols, int validRows, int validCols,
pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP4_E2M1_BF16_Exp2D(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream)
{
using Spec = MxFp4B16Exp2DSpec<bfloat16_t, staticRows, staticCols, validRows, validCols, scaleAlg>;
runTQuantMXFP4E2M1B16Exp2D<Spec><<<1, nullptr, stream>>>(dst_exp, dst, (bfloat16_t *)src);
}
template <int staticRows, int staticCols, int validRows, int validCols,
pto::QuantScaleAlg scaleAlg = pto::QuantScaleAlg::OCP>
void LaunchTQuantMXFP4_E2M1_FP16_Exp2D(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream)
{
using Spec = MxFp4B16Exp2DSpec<half, staticRows, staticCols, validRows, validCols, scaleAlg>;
runTQuantMXFP4E2M1B16Exp2D<Spec><<<1, nullptr, stream>>>(dst_exp, dst, (half *)src);
}
}
template void TQuantTest::LaunchTQuantMXFP8<32, 32, 0>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<32, 64, 0>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<64, 128, 0>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<128, 128, 0>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<32, 64, 1>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<64, 128, 1>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<64, 256, 1>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<64, 512, 1>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<128, 128, 1>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<15, 32, 0>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<7, 64, 0>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<33, 64, 0>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<13, 192, 0>(uint8_t *dst, float *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<32, 128, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, float *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8<2, 256, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, float *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantInt8<64, 128, 0, pto::QuantType::INT8_SYM>(int8_t *dst, float *src, float *scale,
void *stream, float *offset);
template void TQuantTest::LaunchTQuantInt8<128, 128, 0, pto::QuantType::INT8_SYM>(int8_t *dst, float *src, float *scale,
void *stream, float *offset);
template void TQuantTest::LaunchTQuantInt8<256, 128, 0, pto::QuantType::INT8_SYM>(int8_t *dst, float *src, float *scale,
void *stream, float *offset);
template void TQuantTest::LaunchTQuantInt8<64, 128, 0, pto::QuantType::INT8_ASYM>(uint8_t *dst, float *src,
float *scale, void *stream,
float *offset);
template void TQuantTest::LaunchTQuantInt8<128, 128, 0, pto::QuantType::INT8_ASYM>(uint8_t *dst, float *src,
float *scale, void *stream,
float *offset);
template void TQuantTest::LaunchTQuantInt8<256, 128, 0, pto::QuantType::INT8_ASYM>(uint8_t *dst, float *src,
float *scale, void *stream,
float *offset);
template void TQuantTest::LaunchTQuantInt8<32, 72, 0, pto::QuantType::INT8_ASYM>(uint8_t *dst, float *src, float *scale,
void *stream, float *offset);
template void TQuantTest::LaunchTQuantMXFP8_BF16<32, 128, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<64, 128, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<128, 128, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<14, 16, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<7, 48, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<18, 136, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<1, 32, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<2, 16, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<4, 16, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<8, 16, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<16, 16, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<3, 32, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<5, 96, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<1, 16, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<4, 256, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<4, 512, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<3, 256, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<5, 256, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<18, 138, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<1, 192, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<1, 198, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<32, 128, 1>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<64, 128, 1>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<128, 128, 1>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<32, 128, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<64, 128, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<128, 128, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<7, 48, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16<2, 256, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<32, 128, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<64, 128, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<128, 128, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<4, 256, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<11, 640, 0>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<32, 128, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<64, 128, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<128, 128, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<2, 256, 0, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP8_BF16_Exp2D<3, 128, 1, 64>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP4_E2M1_FP16<2, 128>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP4_E2M1_FP16<32, 1024>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP4_E2M1_FP16<2, 256, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP4_E2M1_BF16<2, 128>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP4_E2M1_BF16<32, 1024>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP4_E2M1_BF16<2, 256, pto::QuantScaleAlg::NV>(uint8_t *dst, uint16_t *src,
uint8_t *dst_exp, void *stream);
template void TQuantTest::LaunchTQuantMXFP4_E2M1_BF16_Exp2D<4, 128, 2, 128, pto::QuantScaleAlg::NV>(uint8_t *dst,
uint16_t *src,
uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<32, 128, 1>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<64, 128, 1>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);
template void TQuantTest::LaunchTQuantMXFP8_FP16<128, 128, 1>(uint8_t *dst, uint16_t *src, uint8_t *dst_exp,
void *stream);