"""
"""
from typing import List
import os
import pytest
import pypto
import torch
import torch_npu
from pypto import Tensor as PTensor, loop, ceildiv, SymInt
from pypto.frontend import jit, dynamic
@jit
def isfinite_2d(
x: PTensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_FP16),
out: PTensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_BOOL),
view_shape: List[SymInt],
tile_shape: List[int],
):
b, s = x.shape
pypto.set_vec_tile_shapes(*tile_shape)
for i in loop(ceildiv(b, view_shape[0])):
for j in loop(ceildiv(s, view_shape[1])):
tile = pypto.view(x, view_shape, [i * view_shape[0], j * view_shape[1]])
result = pypto.isfinite(tile)
pypto.assemble(result, [i * view_shape[0], j * view_shape[1]], out)
del tile
def test_is_finite():
view_shape = [32, 128]
tile_shape = [32, 32]
x_pt = torch.rand(32, 128, dtype=torch.float16)
ids = torch.randint(32 * 128, (30,))
x_pt.view(-1, 1)[ids] = torch.nan
ids = torch.randint(32 * 128, (30,))
x_pt.view(-1, 1)[ids] = torch.inf
ids = torch.randint(32 * 128, (30,))
x_pt.view(-1, 1)[ids] = -torch.inf
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch_npu.npu.set_device(device_id)
x = x_pt.npu()
golden = torch.isfinite(x_pt)
out = torch.zeros((32, 128), dtype=torch.bool, device=f"npu:{device_id}")
isfinite_2d(x, out, view_shape, tile_shape)
assert torch.allclose(golden, out.cpu())
if __name__ == "__main__":
test_is_finite()