"""
Verify that a packed FP4x2 uint8 input is exposed as DT_FP4_E2M1 with logical K * 2 shape in kernel.
"""
import os
import pypto
import torch
M = 64
PACKED_K = 32
LOGICAL_K = PACKED_K * 2
@pypto.frontend.jit(
runtime_options={"run_mode": pypto.RunMode.SIM},
host_options={"compile_stage": pypto.CompStage.TENSOR_GRAPH})
def fp4_kernel(
x: pypto.Tensor([], pypto.DT_FP4_E2M1),
out: pypto.Tensor([], pypto.DT_INT32),
):
assert x.dtype == pypto.DT_FP4_E2M1
assert x.shape[-1] == PACKED_K * 2
pypto.set_vec_tile_shapes(32, 32)
out.move(out + 1)
def test_fp4x2_to_fp4_logical_shape():
packed_x = torch.zeros((M, PACKED_K), dtype=torch.uint8)
out = torch.zeros((1,), dtype=torch.int32)
fp4_kernel(packed_x, out)
@pypto.frontend.jit(
runtime_options={"run_mode": pypto.RunMode.SIM},
host_options={"compile_stage": pypto.CompStage.TENSOR_GRAPH})
def fp4x2_kernel(
x: pypto.Tensor([], pypto.DT_FP4_E2M1X2),
out: pypto.Tensor([], pypto.DT_INT32),
):
assert x.dtype == pypto.DT_FP4_E2M1X2
assert x.shape[-1] == PACKED_K
pypto.set_vec_tile_shapes(32, 32)
out.move(out + 1)
def test_fp4x2_to_fp4x2_logical_shape():
packed_x = torch.zeros((M, PACKED_K), dtype=torch.uint8)
out = torch.zeros((1,), dtype=torch.int32)
fp4x2_kernel(packed_x, out)
if __name__ == "__main__":
test_fp4x2_to_fp4_logical_shape()
test_fp4x2_to_fp4x2_logical_shape()