"""
"""
import os
import logging
import torch
import torch_npu
import pypto
def test_attention_residuals_pypto():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch_npu.npu.set_device(device_id)
t = 4
l = 8
d = 512
k = torch.randn((t, l, d), dtype=torch.bfloat16).npu()
k_out = torch.zeros((t, l, d), dtype=torch.float32).npu()
k_golden = k.to(torch.float32).cpu()
k_golden = k_golden ** 2
pypto.set_verify_golden_data(goldens=[None, k_golden])
attention_residuals(k, k_out)
k_out_cpu = k_out.cpu()
not_close = (~torch.isclose(k_out_cpu, k_golden, rtol=3e-3, atol=3e-3)).sum()
logging.info(f"不相等数量:{not_close} / {k_out_cpu.numel()}")
assert torch.allclose(k_out_cpu, k_golden, rtol=3e-3, atol=3e-3)
@pypto.frontend.jit(debug_options={"compile_debug_mode": 0})
def attention_residuals(
k_in: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
k_out: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC])):
t, l, d = k_in.shape
l_max = 32
unroll_list = [4]
for t_idx, unroll_length in pypto.loop_unroll(0, t, 1, name="Loop_t", idx_name="tIdx", unroll_list=unroll_list):
t_tile = unroll_length
k = pypto.view(k_in, [t_tile, l_max, d], [t_idx, 0, 0], valid_shape=[t_tile, l, d])
pypto.set_vec_tile_shapes(8, 16, 128)
k_2d = pypto.reshape(k, [t_tile * l_max, d], valid_shape=[t_tile * l, d])
pypto.set_vec_tile_shapes(128, 128)
k_fp32 = pypto.cast(k_2d, pypto.DT_FP32)
k = pypto.mul(k_fp32, k_fp32)
pypto.set_vec_tile_shapes(8, 16, 128)
k_3d = pypto.reshape(k, [t_tile, l_max, d], valid_shape=[t_tile, l, d])
pypto.assemble(k_3d, [t_idx, 0, 0], k_out)
if __name__ == "__main__":
test_attention_residuals_pypto()