"""
"""
import os
import pypto
import pytest
import torch
import numpy as np
from numpy.testing import assert_allclose
import torch_npu
def test_vector_operation_round():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
dtype = pypto.DT_FP32
tiling = 32
n, m = tiling * 1, tiling * 1
shape = (n, m)
view_shape = (16, 16)
tile_shape = (8, 8)
decimals = 2
pypto.runtime._device_init()
a = pypto.tensor(shape, dtype, "ROUND_TENSOR_a")
b = pypto.tensor(shape, dtype, "ROUND_TENSOR_b")
with pypto.function("ROUND", a, b):
for b_idx in pypto.loop(int(np.ceil(n / view_shape[0])), name="LOOP_ROUND_L0", idx_name="b_idx"):
for s_idx in pypto.loop(int(np.ceil(m / view_shape[1])), name="LOOP_ROUND_L1", idx_name="s_idx"):
tile_a = pypto.view(a, view_shape,
[b_idx * view_shape[0], s_idx * view_shape[1]],
valid_shape=[(pypto.symbolic_scalar(n) -
b_idx * view_shape[0]).min(pypto.symbolic_scalar(view_shape[0])),
(pypto.symbolic_scalar(m) - s_idx * view_shape[1]).min(
pypto.symbolic_scalar(view_shape[1]))])
pypto.set_vec_tile_shapes(tile_shape[0], tile_shape[1])
tile_a.move(pypto.round(tile_a, decimals=decimals))
pypto.assemble(tile_a, [b_idx * view_shape[0], s_idx * view_shape[1]], b)
a_tensor = (torch.rand(n, m, dtype=torch.float32) * 100 - 50) * 0.123
b_tensor = torch.zeros(n, m, dtype=torch.float32)
pto_a_tensor = pypto.from_torch(a_tensor, "a_tensor")
pto_b_tensor = pypto.from_torch(b_tensor, "b_tensor")
pypto.runtime._device_run_once_data_from_host(pto_a_tensor, pto_b_tensor)
expected = torch.round(a_tensor, decimals=decimals)
assert_allclose(b_tensor.flatten(), expected.flatten(), rtol=1e-3, atol=1e-3)
pypto.runtime._device_fini()