"""
"""
import logging
from math import ceil, prod
from pathlib import Path
import sys
import numpy as np
import torch
import pypto
helper_path: Path = Path(
Path(__file__).parent.parent.parent.parent.parent.parent,
"framework/tests/cmake/scripts/helper",
).resolve()
if str(helper_path) not in sys.path:
sys.path.append(str(helper_path))
from test_case_desc import TensorDesc
from test_case_runner import TestCaseRunner
from test_case_tools import get_dtype_by_name
def get_pto_dtype_by_name(name: str):
str_to_dtype = {
"int4": pypto.DT_INT4,
"int8": pypto.DT_INT8,
"int16": pypto.DT_INT16,
"int32": pypto.DT_INT32,
"int64": pypto.DT_INT64,
"fp8": pypto.DT_FP8,
"fp16": pypto.DT_FP16,
"fp32": pypto.DT_FP32,
"hf4": pypto.DT_HF4,
"hf8": pypto.DT_HF8,
"fp8e4m3": pypto.DT_FP8E4M3,
"fp8e5m2": pypto.DT_FP8E5M2,
"fp8e8m0": pypto.DT_FP8E8M0,
"uint8": pypto.DT_UINT8,
"uint16": pypto.DT_UINT16,
"uint32": pypto.DT_UINT32,
"uint64": pypto.DT_UINT64,
"bool": pypto.DT_BOOL,
"double": pypto.DT_DOUBLE,
"bf16": pypto.DT_BF16,
}
return str_to_dtype.get(name, pypto.DT_FP32)
class PTOTestCaseRunner(TestCaseRunner):
def __init__(
self,
operation: str,
input_tensors: list,
output_tensors: list,
view_shape: tuple,
tile_shape: tuple,
params: dict,
):
super().__init__(view_shape, tile_shape, params)
self._operation = operation
self._input_tensors = [
TensorDesc.from_dict(tensor) if isinstance(tensor, dict) else tensor
for tensor in input_tensors
]
self._output_tensors = [
TensorDesc.from_dict(tensor) if isinstance(tensor, dict) else tensor
for tensor in output_tensors
]
def gen_loop_range_tuple(self):
if len(self._input_tensors[0].shape) != len(self._view_shape):
raise ValueError(
"The lengths of input tensors and view shape are not same."
)
return tuple(
[
ceil(self._input_tensors[0].shape[index] / self._view_shape[index])
for index in list(range(len(self._view_shape)))
]
)
def input_tensors(self):
return [
pypto.tensor(
input_tensor.shape,
get_pto_dtype_by_name(input_tensor.dtype),
input_tensor.name,
)
for input_tensor in self._input_tensors
]
def input_data(self):
input_data = []
for input_tensor in self._input_tensors:
min_value = input_tensor.data_range.min
max_value = input_tensor.data_range.max
data = None
if min_value != max_value:
data = np.random.uniform(
min_value, max_value, prod(input_tensor.shape)
).astype(get_dtype_by_name(input_tensor.dtype))
else:
data = np.full(
prod(input_tensor.shape),
max_value,
dtype=get_dtype_by_name(input_tensor.dtype),
)
input_data.append(torch.from_numpy(data))
return input_data
def output_tensors(self):
return [
pypto.tensor(
output_tensor.shape,
get_pto_dtype_by_name(output_tensor.dtype),
output_tensor.name,
)
for output_tensor in self._output_tensors
]
def output_data(self):
return [
torch.from_numpy(
np.full(
prod(output_tensor.shape),
0,
dtype=get_dtype_by_name(output_tensor.dtype),
)
)
for output_tensor in self._output_tensors
]
def exec_dyn_func(self, input_tensors: list, output_tensors: list):
loop_range_tuple = self.gen_loop_range_tuple()
function = "import pypto\n"
function += """import os\n"""
function += """import torch\n"""
function += """import torch_npu\n"""
function += "\n"
function += """device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))\n"""
function += """torch.npu.set_device(device_id)\n"""
function += "\n"
function += f"with pypto.function('{self._operation}', *input_tensors, *output_tensors):\n"
tab = " "
prefix = tab
for index, value in enumerate(loop_range_tuple):
function += prefix + (tab * index)
function += f"for index_{index} in pypto.loop({value}):\n"
loop_range_len = len(loop_range_tuple)
prefix = tab * (loop_range_len + 1)
function += prefix + "input_data = []\n"
view_offset = [
f"index_{index} * {self._view_shape[index]}"
for index, _ in enumerate(loop_range_tuple)
]
for index, tensor in enumerate(input_tensors):
function += prefix
view_shape = [
min(dim, view_dim)
for dim, view_dim in zip(tensor.shape, self._view_shape)
]
view_offset = [
"0" if dim == 1 else offset
for dim, offset in zip(tensor.shape, view_offset)
]
function += (
f"input_{index} = pypto.view(input_tensors[{index}], {view_shape}, ["
)
for offset in view_offset:
function += offset + ", "
function += "])\n"
function += prefix + f"input_data.append(input_{index})\n"
function += prefix + f"res = []\n"
function += prefix + f"for _ in enumerate(output_tensors):\n"
function += prefix + f" res.append(pypto.tensor())\n"
function += prefix + f"if len(res) == 1:\n"
function += prefix + f" res[0].move(op_func(input_data, params))\n"
function += prefix + f"else:\n"
function += (
prefix + f" for dst_, src_ in zip(res, op_func(input_data, params)):\n"
)
function += prefix + f" dst_.move(src_)\n"
if self._operation == "Transpose":
(
view_offset[self._params["first_dim"]],
view_offset[self._params["second_dim"]],
) = (
view_offset[self._params["second_dim"]],
view_offset[self._params["first_dim"]],
)
function += prefix + "for dst_, src_ in zip(output_tensors, res):\n"
function += prefix + f" pypto.assemble(src_, ["
for offset in view_offset:
function += offset + ", "
function += "], dst_)\n"
function += prefix + "for input in input_data:\n"
function += prefix + " del input\n"
function += prefix + "for tmp in res:\n"
function += prefix + " del tmp\n"
logging.info(function)
pypto.set_vec_tile_shapes(*self.tile_shape)
exec(
function,
{
"input_tensors": input_tensors,
"output_tensors": output_tensors,
"op_func": self._op_func,
"params": self._params,
},
)
def tear_up(self):
pypto.runtime._device_init()
def tear_down(self):
pypto.runtime._device_fini()
def run_on_device(self, inputs: list) -> list:
output = self.output_data()
pto_inputs_tensor = [pypto.from_torch(tensor, f"IN_{idx}") for idx, tensor in enumerate(inputs)]
pto_output_tensor = [pypto.from_torch(tensor, f"IN_{idx}") for idx, tensor in enumerate(output)]
pypto.runtime._device_run_once_data_from_host(*pto_inputs_tensor, *pto_output_tensor)
return [
torch.tensor(
output[index],
dtype=get_dtype_by_name(self._output_tensors[index].dtype, True),
).reshape(self._output_tensors[index].shape)
for index in list(range(len(output)))
]