"""
Reshape Add Operator for PyPTO
Process: reshape 2D -> 1D -> add 0.5 -> reshape 1D -> 2D
Loop tile block: [64, 64], Vec tile shape: [64, 64]
"""
import logging
import os
import pypto
import torch
import torch_npu
from numpy.testing import assert_allclose
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
def get_device_id() -> int:
env_val = os.environ.get("TILE_FWK_DEVICE_ID", "0")
try:
return int(env_val)
except ValueError:
logger.error("TILE_FWK_DEVICE_ID must be an integer, got: %s", env_val)
raise
@pypto.frontend.jit(debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0})
def reshape_add_kernel(
x: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_FP32),
y: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_FP32),
tile_m: int,
tile_n: int,
m_size: int,
n_size: int):
m_dyn = x.shape[0]
n_dyn = x.shape[1]
pypto.set_vec_tile_shapes(tile_m, tile_n)
m_loop = (m_dyn + tile_m - 1) // tile_m
n_loop = (n_dyn + tile_n - 1) // tile_n
for m_idx in pypto.loop(m_loop, name="LOOP_M", idx_name="m_idx"):
m_offset = m_idx * tile_m
m_offset_end = pypto.min(m_offset + tile_m, m_dyn)
valid_m = m_offset_end - m_offset
for n_idx in pypto.loop(n_loop, name="LOOP_N", idx_name="n_idx"):
n_offset = n_idx * tile_n
n_offset_end = pypto.min(n_offset + tile_n, n_dyn)
valid_n = n_offset_end - n_offset
x_view = pypto.view(x, [tile_m, tile_n], [m_offset, n_offset], valid_shape=[valid_m, valid_n])
x_1d = pypto.reshape(x_view, [tile_m * tile_n], valid_shape=[valid_m * valid_n])
pypto.set_vec_tile_shapes(tile_m * tile_n)
result_1d = pypto.add(x_1d, 0.5)
result = pypto.reshape(result_1d, [tile_m, tile_n], valid_shape=[valid_m, valid_n])
pypto.set_vec_tile_shapes(tile_m, tile_n)
pypto.assemble(result, [m_offset, n_offset], y)
TEST_SHAPES = [
(100, 100),
(128, 100),
(100, 128),
(128, 128),
(64, 64),
]
def main():
device_id = get_device_id()
torch.npu.set_device(device_id)
logger.info("Running on NPU device %d", device_id)
tile_m, tile_n = 64, 64
for m, n in TEST_SHAPES:
x = torch.randn(m, n, dtype=torch.float32, device=f"npu:{device_id}")
y = torch.zeros(m, n, dtype=torch.float32, device=f"npu:{device_id}")
reshape_add_kernel(x, y, tile_m, tile_n, m, n)
torch.npu.synchronize()
golden = x + 0.5
max_diff = (y - golden).abs().max().item()
logger.info("shape [%d, %d] max_diff=%.6f", m, n, max_diff)
assert_allclose(y.cpu().numpy(), golden.cpu().numpy(), rtol=1e-3, atol=1e-3)
logger.info("All tests passed")
if __name__ == "__main__":
main()