import os
import pypto
import numpy as np
import torch
import torch_npu
CONDITION_THRESHOLD = 8
@pypto.frontend.jit()
def cust_hidden_loop_func(t0: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
t1: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
out: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32)
):
"""
实现隐藏循环带条件分支的逻辑
参考C++案例: DynamicBasicTest.HiddenLoopWithIf
"""
pypto.set_vec_tile_shapes(32, 32)
k = 0
out.move(pypto.add(t0, t1))
if k < CONDITION_THRESHOLD:
for _ in pypto.loop(2, name="L02", idx_name="idx3"):
t0.move(pypto.add(t0, t1))
out.move(pypto.add(t0, out))
else:
for _ in pypto.loop(2, name="L03", idx_name="idx4"):
t0 = t0 + t1
out = t0 + out
def test_hidden_loop_with_if_jit_function():
"""
使用jit装饰器测试
"""
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch.npu.set_device(device_id)
tiling = 32
n, m = 1, 1
s = 32
shape = (n * s, m * s)
t0_data = torch.full(shape, 11.0, dtype=torch.float32, device=f"npu:{device_id}")
t1_data = torch.full(shape, 20.0, dtype=torch.float32, device=f"npu:{device_id}")
res = torch.zeros(shape, dtype=torch.float32, device=f"npu:{device_id}")
cust_hidden_loop_func(t0_data, t1_data, res)
torch_npu.npu.synchronize()
golden = torch.full(shape, 113.0, dtype=torch.float32)
assert torch.allclose(golden, res.cpu(), atol=1e-5)
def test_hidden_loop_with_if_multiple_shapes():
"""
测试不同形状
"""
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch.npu.set_device(device_id)
test_cases = [
{"tiling": 16, "n": 1, "m": 1, "s": 32},
{"tiling": 8, "n": 2, "m": 2, "s": 16},
{"tiling": 32, "n": 1, "m": 2, "s": 16},
]
all_passed = True
for _, config in enumerate(test_cases):
tiling = config["tiling"]
n, m, s = config["n"], config["m"], config["s"]
shape = (n * s, m * s)
t0_data = torch.full(
shape, 11.0, dtype=torch.float32, device=f"npu:{device_id}"
)
t1_data = torch.full(
shape, 20.0, dtype=torch.float32, device=f"npu:{device_id}"
)
out_data = torch.zeros(shape, dtype=torch.float32, device=f"npu:{device_id}")
cust_hidden_loop_func(t0_data, t1_data, out_data)
torch_npu.npu.synchronize()
out_cpu = out_data.cpu()
golden = torch.full(shape, 113.0, dtype=torch.float32)
passed = torch.allclose(golden, out_cpu, atol=1e-5)
all_passed = all_passed and passed
return all_passed
@pypto.frontend.jit()
def op_hidden_loop_mix_loops(t1: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
t2: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
t3: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
out: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32)
):
pypto.set_vec_tile_shapes(32, 32)
k0 = 0
if k0 < CONDITION_THRESHOLD:
t0_temp = pypto.add(t1, t1)
else:
t0_temp = pypto.add(t2, t2)
t0_temp.move(pypto.add(t0_temp, 1.0))
for _ in pypto.loop(2, name="L02", idx_name="k"):
t3.move(pypto.mul(t3, t2))
out.move(pypto.sub(t3, t0_temp))
for _ in pypto.loop(2, name="L04", idx_name="h"):
t0_temp.move(pypto.mul(t0_temp, t2))
out.move(pypto.add(out, t0_temp))
def test_hidden_loop_mix_loops_jit_function():
"""
使用jit装饰器测试
"""
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch.npu.set_device(device_id)
tiling = 32
n, m = 1, 1
s = 32
shape = (n * s, m * s)
t1_data = torch.full(shape, 20.0, dtype=torch.float32, device=f"npu:{device_id}")
t2_data = torch.full(shape, 30.0, dtype=torch.float32, device=f"npu:{device_id}")
t3_data = torch.full(shape, 40.0, dtype=torch.float32, device=f"npu:{device_id}")
out_data = torch.zeros(shape, dtype=torch.float32, device=f"npu:{device_id}")
op_hidden_loop_mix_loops(t1_data, t2_data, t3_data, out_data)
torch_npu.npu.synchronize()
out_cpu = out_data.cpu()
golden = torch.full(shape, 72859.0, dtype=torch.float32)
if torch.allclose(golden, out_cpu, atol=1e-5):
return True
else:
return False