"""
"""
from dataclasses import dataclass, field
from typing import Optional
import os
import numpy as np
import torch
import torch_npu
import pypto
import pytest
from numpy.testing import assert_allclose
import torch.nn.functional as F
FP32 = pypto.DT_FP32
FP16 = pypto.DT_FP16
INT32 = pypto.DT_INT32
INT8 = pypto.DT_INT8
UINT64 = pypto.DT_UINT64
UINT32 = pypto.DT_UINT32
@dataclass
class ShapeConfig:
ori_shape: list
m_tile_shape: list
k_tile_shape: list
n_tile_shape: list
view_shape: list
in_dtype: pypto.DataType
out_dtype: pypto.DataType
a_trans: bool = False
b_trans: bool = False
a_format_nz: bool = False
b_format_nz: bool = False
c_format_nz: bool = False
gm_acc: bool = False
@dataclass
class ExtendParams:
bias_shape: list = field(default_factory=list)
bias_dtype: np.dtype = None
scale_shape: list = field(default_factory=list)
scale_dtype: np.dtype = None
scale: int = None
relu_type: int = None
def trans_nd_to_fractal_nz(data: torch.Tensor, keep_m_dim=False):
def _gen_axes_for_transpose(offset, base):
return [x for x in range(offset)] + [x + offset for x in base]
def _ceil_div(a, b):
return (a + b - 1) // b
ori_shape = data.shape
m_ori, n_ori = ori_shape[-2:]
batch_ori = ori_shape[:-2]
batch_num = len(batch_ori)
batch_padding = (0,) * batch_num
m0 = 16
n0 = 32 // data.dtype.itemsize
if data.dtype == torch.int32:
n0 = 16
m1, n1 = _ceil_div(m_ori, m0), _ceil_div(n_ori, n0)
padding_m = m1 * m0 - m_ori
padding_n = n1 * n0 - n_ori
if not keep_m_dim:
data = F.pad(data, (batch_padding + (0, padding_n, 0, padding_m)), "constant")
array_trans = _gen_axes_for_transpose(len(data.shape) - 2, [2, 0, 1, 3])
data = data.reshape(batch_ori + (m1, m0, n1, n0)).permute(*array_trans).contiguous()
else:
data = F.pad(data, (batch_padding + (0, padding_n, 0, 0)), "constant")
array_trans = _gen_axes_for_transpose(len(data.shape) - 2, [1, 0, 2])
data = data.reshape(batch_ori + (m_ori, n1, n0)).permute(*array_trans).contiguous()
return data
@pypto.frontend.jit(
debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0}
)
def matmul_kernel_with_mn_split(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
out_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
shape_info: ShapeConfig,
):
m = shape_info.ori_shape[0]
n = shape_info.ori_shape[2]
m_view = shape_info.view_shape[0]
n_view = shape_info.view_shape[1]
pypto.set_cube_tile_shapes(shape_info.m_tile_shape, shape_info.k_tile_shape, shape_info.n_tile_shape,
enable_split_k=shape_info.gm_acc)
m_loop = (m + m_view - 1) // m_view
n_loop = (n + n_view - 1) // n_view
for m_idx in pypto.loop(0, m_loop, 1, name="LOOP_L0_mIdx", idx_name="m_idx"):
for n_idx in pypto.loop(0, n_loop, 1, name="LOOP_L0_nIdx", idx_name="n_idx"):
if shape_info.a_trans:
a_view = a_tensor[:, m_idx * m_view: m_idx * m_view + m_view]
else:
a_view = a_tensor[m_idx * m_view: m_idx * m_view + m_view, :]
if shape_info.b_trans:
b_view = b_tensor[n_idx * n_view: n_idx * n_view + n_view, :]
else:
b_view = b_tensor[:, n_idx * n_view: n_idx * n_view + n_view]
out_view = pypto.matmul(a_view, b_view, a_trans=shape_info.a_trans, b_trans=shape_info.b_trans,
out_dtype=shape_info.out_dtype)
out_tensor[m_idx * m_view: m_idx * m_view + m_view, n_idx * n_view: n_idx * n_view + n_view] = out_view
@pypto.frontend.jit(
debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0}
)
def bmm_kernel_with_no_mn_split(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
out_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
shape_info: ShapeConfig,
):
pypto.set_cube_tile_shapes(shape_info.m_tile_shape, shape_info.k_tile_shape, shape_info.n_tile_shape,
enable_split_k=shape_info.gm_acc)
result = pypto.matmul(a_tensor, b_tensor, a_trans=shape_info.a_trans, b_trans=shape_info.b_trans,
out_dtype=shape_info.out_dtype)
out_tensor.move(result)
@pytest.mark.soc("950", "910")
def test_mm_with_mn_split():
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
m = 69
k = 99
n = 129
tile_m = 64
tile_k = 64
tile_n = 64
m_view = 128
n_view = 256
shape_info = ShapeConfig([m, k, n], [tile_m, tile_m], [tile_k, tile_k], [tile_n, tile_n], [m_view, n_view], FP16,
FP32, True, True, False, False, False, False)
a1_tensor = torch.rand([k, m], dtype=torch.float16, device=f"npu:{device_id}")
b1_tensor = torch.rand([n, k], dtype=torch.float16, device=f"npu:{device_id}")
c1_tensor = torch.zeros([m, n], dtype=torch.float32, device=f"npu:{device_id}")
golden = torch.matmul(a1_tensor.to(torch.float32).T, b1_tensor.to(torch.float32).T)
matmul_kernel_with_mn_split(
a1_tensor, b1_tensor, c1_tensor,
shape_info
)
assert torch.allclose(c1_tensor.cpu().to(torch.float32), golden.cpu().to(torch.float32), atol=1e-3, rtol=1e-3)
@pytest.mark.soc("950", "910")
def test_mm_with_mn_split_nz():
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch_npu.npu.config.allow_internal_format = True
torch.npu.set_device(int(device_id))
m = 64
k = 128
n = 128
tile_m = 64
tile_k = 64
tile_n = 64
m_view = 128
n_view = 256
shape_info = ShapeConfig([m, k, n], [tile_m, tile_m], [tile_k, tile_k], [tile_n, tile_n], [m_view, n_view], FP16,
FP32, True, True, True, True, False, False)
a1_tensor = torch.rand([k, m], dtype=torch.float16, device=f'npu:{device_id}')
b1_tensor = torch.rand([n, k], dtype=torch.float16, device=f'npu:{device_id}')
c1_tensor = torch.zeros([m, n], dtype=torch.float32, device=f'npu:{device_id}')
a1_tensor_nz = torch_npu.npu_format_cast(a1_tensor, 29) if shape_info.a_format_nz else a1_tensor
b1_tensor_nz = torch_npu.npu_format_cast(b1_tensor, 29) if shape_info.b_format_nz else b1_tensor
golden = torch.matmul(a1_tensor.to(torch.float32).T, b1_tensor.to(torch.float32).T)
matmul_kernel_with_mn_split(
a1_tensor_nz, b1_tensor_nz, c1_tensor,
shape_info
)
assert torch.allclose(c1_tensor.cpu().to(torch.float32), golden.cpu().to(torch.float32), atol=1e-3, rtol=1e-3)
torch.npu.synchronize()
@pytest.mark.soc("950", "910")
def test_bmm_with_mn_split():
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
b = 3
m = 63
k = 127
n = 129
tile_m = 64
tile_k = 64
tile_n = 64
shape_info = ShapeConfig([b, m, k, n], [tile_m, tile_m], [tile_k, tile_k], [tile_n, tile_n], [-1, -1], FP16, FP32,
True, False, False, False, False, False)
a1_tensor = torch.rand([b, k, m], dtype=torch.float16, device=f'npu:{device_id}')
b1_tensor = torch.rand([b, k, n], dtype=torch.float16, device=f'npu:{device_id}')
c1_tensor = torch.zeros([b, m, n], dtype=torch.float32, device=f'npu:{device_id}')
golden = torch.matmul(a1_tensor.to(torch.float32).transpose(-2, -1), b1_tensor.to(torch.float32))
bmm_kernel_with_no_mn_split(
a1_tensor, b1_tensor, c1_tensor,
shape_info
)
assert torch.allclose(c1_tensor.cpu().to(torch.float32), golden.cpu().to(torch.float32), atol=1e-3, rtol=1e-3)