"""
"""
import os
import pypto
import pytest
import torch
import torch_npu
import numpy as np
from numpy.testing import assert_allclose
@pypto.frontend.jit()
def cust_dyn_func(
a: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT32),
b: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT32),
c: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT32),
tiling=32
):
pypto.set_vec_tile_shapes(tiling, tiling)
for _ in pypto.loop(1, name="s0", idx_name="k"):
c.move(pypto.add(a, b))
class Network(torch.nn.Module):
def forward(self, data1, data2, shape, tiling=32):
add_01 = torch.add(data1, data2)
c1 = torch.zeros(shape, dtype=torch.int32, device=data2.device)
cust_dyn_func(add_01, data2, c1, tiling=tiling)
data2 = c1
data2 = torch.sub(data2, add_01)
data2 = torch.add(data2, add_01)
c2 = torch.zeros(shape, dtype=torch.int32, device=data2.device)
cust_dyn_func(data2, data1, c2, tiling=tiling)
data2 = c2
return data2
def compute_golden(data1, data2):
add_01 = torch.add(data1, data2)
data2 = torch.add(add_01, data2)
data2 = torch.sub(data2, add_01)
data2 = torch.add(data2, add_01)
data2 = torch.add(data2, data1)
return data2
def test_select_experts():
shape = (256, 256)
input0 = torch.from_numpy(np.random.uniform(-5, 5, size=(256, 256))).to(torch.int32)
input1 = torch.from_numpy(np.random.uniform(-5, 5, size=(256, 256))).to(torch.int32)
golden_out = compute_golden(input0, input1)
torch_npu.npu.set_device(int(os.environ.get('TILE_FWK_DEVICE_ID', 0)))
input0 = input0.npu()
input1 = input1.npu()
npu_mode = Network().npu()
assert not torch_npu.npu.is_current_stream_capturing()
s = torch.npu.Stream()
with torch.npu.stream(s):
g = torch_npu.npu.NPUGraph()
torch_npu.npu.empty_cache()
assert not torch_npu.npu.is_current_stream_capturing()
g.capture_begin()
for _ in range(1):
npu_out = npu_mode(input0, input1, shape)
assert torch_npu.npu.is_current_stream_capturing()
g.capture_end()
torch_npu.npu.current_stream().wait_stream(s)
g.replay()
stream = torch_npu.npu.current_stream()
stream.synchronize()
g.reset()
npu_out = npu_out.cpu().detach().numpy()
golden_out = golden_out.cpu().detach().numpy()
assert_allclose(npu_out, golden_out, rtol=5e-3, atol=5e-3)