"""
"""
import abc
import os
import logging
from collections.abc import Iterable
import inspect
import torch
import torch_npu
import numpy as np
from numpy.testing import assert_allclose
import pypto
class TestBuilder(abc.ABC):
def __init__(self, params: tuple, kernel, kernel_golden, tiling: int):
super().__init__()
self.params = params
self.kernel = kernel
self.kernel_golden = kernel_golden
self.tiling = tiling
self.input_pto_list = []
self.output_pto_list = []
self.input_data_list = []
self.output_data_list = []
self.tensor_list = []
self.input_dyn_axes = []
self.output_dyn_axes = []
self.rtol_value = 1e-3
self.atol_value = 1e-3
self.device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
def __call__(self, on_board: bool = True, jit: bool = False):
self.device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
self.run(on_board, jit)
def set_tol(self, rtol=1e-3, atol=1e-3):
self.rtol_value = rtol
self.atol_value = atol
def set_dyn_axes(self, input_dyn_axes, output_dyn_axes):
self.input_dyn_axes = input_dyn_axes
self.output_dyn_axes = output_dyn_axes
def setup_inputs(self, *args):
for idx, item in enumerate(args):
dtype = self.dtype_conversion(str(item.dtype))
pto_tensor = pypto.tensor(item.shape, dtype, f"PTO_TENSOR_{idx}")
self.input_pto_list.append(pto_tensor)
self.input_data_list.append(item)
def setup_inputs_jit(self, *args):
input_list = []
for item in args:
if isinstance(item, tuple):
input_list.extend(item)
else:
input_list.append(item)
for item in input_list:
self.input_pto_list.append(item)
def get_input_list(self):
return self.input_pto_list
def get_output_list(self):
return self.output_pto_list
def get_input_data_list(self):
return self.input_data_list
def get_output_data_list(self):
return self.output_data_list
def init_output(self, goldens):
for idx, golden in enumerate(goldens):
dtype = self.dtype_conversion(str(golden.dtype))
pto_tensor = pypto.tensor(golden.shape, dtype, f"PTO_TENSOR_out_{idx}")
output_data = torch.zeros(golden.shape, dtype=golden.dtype)
self.output_pto_list.append(pto_tensor)
self.output_data_list.append(output_data)
def init_output_jit(self, goldens):
for golden in goldens:
pto_tensor = torch.zeros_like(golden, dtype=golden.dtype, device=f'npu:{self.device_id}')
self.output_pto_list.append(pto_tensor)
@abc.abstractmethod
def get_input_from_param(self):
pass
def run_pto(self, kernel, tiling, on_board: bool = True):
if on_board:
torch.npu.set_device(self.device_id)
logging.info("Function compile ...")
pypto.set_vec_tile_shapes(tiling, tiling)
with pypto.function("MAIN", *self.input_pto_list, *self.output_pto_list) as rlf:
for _ in rlf:
kernel(self.params, *self.input_pto_list, *self.output_pto_list)
del rlf
assert all(isinstance(x, pypto.tensor) for x in self.output_pto_list)
logging.info("Function compile done.")
if on_board:
logging.info("Kernel Launch ...")
pto_input_data = [pypto.from_torch(tensor, f"IN_{idx}")
for idx, tensor in enumerate(self.input_data_list)]
pto_output_data = [pypto.from_torch(tensor, f"OUT_{idx}")
for idx, tensor in enumerate(self.output_data_list)]
pypto.runtime._device_run_once_data_from_host(*pto_input_data, *pto_output_data)
logging.info("Kernel run finish.")
result_len = len(self.golden_output)
for idx in range(result_len):
assert_allclose(self.golden_output[idx].cpu().flatten().tolist(),
self.output_data_list[idx].cpu().flatten().tolist(),
rtol=self.rtol_value, atol=self.atol_value)
def run_pto_jit(self):
torch.npu.set_device(self.device_id)
self.inputs = self.get_input_from_param()
output_count = len(inspect.signature(self.kernel_golden).parameters) - 1 - len(self.inputs)
goldens = self.kernel_golden(self.params, *self.inputs, *[None] * output_count)
self.init_output_jit(goldens)
pto_inputs = self._convert_torch_to_pto(self.input_pto_list, self.input_dyn_axes)
pto_outputs = self._convert_torch_to_pto(self.output_pto_list, self.output_dyn_axes)
self.kernel(*pto_inputs, *pto_outputs, self.params)
torch_npu.npu.synchronize()
result_len = len(goldens)
for idx in range(result_len):
assert_allclose(np.array(self.output_pto_list[idx].cpu().flatten().tolist()),
np.array(goldens[idx].flatten().tolist()),
rtol=self.rtol_value, atol=self.atol_value)
def _convert_torch_to_pto(self, tensors, dynamic_axes):
if len(tensors) == len(dynamic_axes):
pto_tensors = [pypto.from_torch(
tensor, dynamic_axis=axis) for tensor, axis in zip(tensors, dynamic_axes)]
elif len(dynamic_axes) == 0:
pto_tensors = [pypto.from_torch(tensor) for tensor in tensors]
else:
raise RuntimeError("The lengths of tensors and dynamic_axes must be identical.")
return pto_tensors
def run(self, on_board: bool = True, jit: bool = False):
if jit:
self.run_pto_jit()
return
if on_board:
pypto.runtime._device_init()
self.inputs = self.get_input_from_param()
output_count = len(inspect.signature(self.kernel_golden).parameters) - 1 - len(self.inputs)
self.golden_output = self.torch_convert(self.kernel_golden(self.params, *self.inputs, *[None] * output_count))
self.init_output(self.golden_output)
logging.info("PyPTO run is called.")
self.run_pto(self.kernel, self.tiling, on_board)
logging.info("PyPTO run finished.")
if on_board:
pypto.runtime._device_fini()
def torch_convert(self, data_tuple: tuple):
if not isinstance(data_tuple, tuple):
data_tuple = (data_tuple, )
def _convert(item):
if isinstance(item, np.ndarray):
return torch.from_numpy(item)
elif torch.is_tensor(item):
return item
elif isinstance(item, (int, float, bool)):
return torch.tensor(item)
elif isinstance(item, Iterable) and not isinstance(item, str):
return type(item)(_convert(subitem) for subitem in item)
else:
return item
result = tuple(_convert(item) for item in data_tuple)
if not isinstance(data_tuple, tuple) and len(result) == 1:
return result[0]
return result
def dtype_conversion(self, str_dtype):
if str_dtype in ['int4']:
return pypto.DT_INT4
elif str_dtype in ['int8', 'torch.int8', 'np.int8']:
return pypto.DT_INT8
elif str_dtype in ['int16', 'torch.int16', 'np.int16']:
return pypto.DT_INT16
elif str_dtype in ['int32', 'torch.int32', 'np.int32']:
return pypto.DT_INT32
elif str_dtype in ['int64', 'torch.int64', 'np.int64']:
return pypto.DT_INT64
elif str_dtype in ['float8', 'torch.float8', 'np.float8']:
return pypto.DT_FP8
elif str_dtype in ['float16', 'half', 'torch.float16', 'np.float16']:
return pypto.DT_FP16
elif str_dtype in ['float32', 'torch.float32', 'np.float32']:
return pypto.DT_FP32
elif str_dtype in ['bfloat16', 'torch.bfloat16']:
return pypto.DT_BF16
elif str_dtype in ['uint8', 'torch.uint8', 'np.uint8']:
return pypto.DT_UINT8
elif str_dtype in ['uint16', 'torch.uint16', 'np.uint16']:
return pypto.DT_UINT16
elif str_dtype in ['uint32', 'torch.uint32', 'np.uint32']:
return pypto.DT_UINT32
elif str_dtype in ['uint64', 'torch.uint64', 'np.uint64']:
return pypto.DT_UINT64
elif str_dtype in ['bool', 'torch.bool', 'np.bool_']:
return pypto.DT_BOOL
elif str_dtype in ['torch.double', 'np.double']:
return pypto.DT_DOUBLE
else:
raise ValueError("undefined dtype")