"""
"""
from typing import List
import os
import pytest
import pypto
import torch
import torch_npu
from pypto import Tensor as PTensor, loop, ceildiv, SymInt
from pypto.frontend import jit, dynamic
@jit
def ceil_div_2d(
x: PTensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_INT32),
y: PTensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_INT32),
out: PTensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_INT32),
view_shape: List[SymInt],
tile_shape: List[int],
):
b, s = x.shape
pypto.set_vec_tile_shapes(*tile_shape)
for i in loop(ceildiv(b, view_shape[0])):
for j in loop(ceildiv(s, view_shape[1])):
tile_x = pypto.view(x, view_shape, [i * view_shape[0], j * view_shape[1]])
tile_y = pypto.view(y, view_shape, [i * view_shape[0], j * view_shape[1]])
result = pypto.ceil_div(tile_x, tile_y)
pypto.assemble(result, [i * view_shape[0], j * view_shape[1]], out)
del tile_x, tile_y
def test_ceil_div():
view_shape = [32, 128]
tile_shape = [32, 32]
x_pt = torch.randint(0, 100, (32, 128), dtype=torch.int32)
y_pt = torch.randint(1, 100, (32, 128), dtype=torch.int32)
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch_npu.npu.set_device(device_id)
out = torch.zeros((32, 128), dtype=torch.int32, device=f"npu:{device_id}")
ceil_div_2d(x_pt.npu(), y_pt.npu(), out.npu(), view_shape, tile_shape)
assert out.shape == (32, 128)
golden = torch.ceil(torch.div(x_pt, y_pt)).to(torch.int32)
assert torch.allclose(golden, out.cpu())
if __name__ == "__main__":
test_ceil_div()