from enum import Enum
from typing import List
import numpy as np
class DataFormat(Enum):
mk = '(m, k)'
kn = '(k, n)'
k4m4 = '(m/4, k/4, m4, k4)'
kn4 = '(n/4, k, n4)'
class MatrixTiling:
axis: int = None
tile_size: int = None
reverse: bool = None
def __init__(self, axis: int, tile_size: int, reverse: bool = False):
self.axis = axis
self.tile_size = tile_size
self.reverse = reverse
class FormageChange:
__align_completion__ = False
def __init__(self, vector_len: int = 4):
self.vector_len = vector_len
self.call_packing_by_format_ = {
(DataFormat.mk, DataFormat.k4m4): self.mk_to_k4m4,
(DataFormat.kn, DataFormat.kn4): self.nk_to_kn4,
}
def run(self, data: np.ndarray, src_format: DataFormat,
dst_format: DataFormat,
tiling_info: List[MatrixTiling]) -> np.ndarray:
blocks = self.do_tiling(data, tiling_info)
result = self.do_packing(blocks, src_format, dst_format)
return result.reshape(data.shape)
def do_tiling(self, data: np.ndarray,
tiling_info: List[MatrixTiling]) -> List[np.ndarray]:
blocks: List[np.ndarray] = [data]
def block_tiling(block_: np.ndarray, axis_: int, tiling_num_: int,
reverse_: bool, new_blocks_: List[np.ndarray]):
for i in range(tiling_num_):
start = i * tiling.tile_size
end = (i + 1) * tiling.tile_size
dim_ = block_.shape[axis_]
if dim_ < tiling.tile_size:
end = dim_
elif dim_ % tiling.tile_size:
if i == tiling_num_ - 2:
res = dim_ - start
half = res // 2
docker_len = half // self.vector_len * self.vector_len
if res - docker_len > docker_len + 1:
docker_len += self.vector_len
end = start + docker_len
elif i == tiling_num_ - 1:
res = dim_ - start + tiling.tile_size
half = res // 2
docker_len = half // self.vector_len * self.vector_len
if res - docker_len > docker_len + 1:
docker_len += self.vector_len
start = dim_ - res + docker_len
end = dim_
if reverse_:
start = dim_ - (i + 1) * tiling.tile_size
end = dim_ - i * tiling.tile_size
if axis_ == 0:
new_blocks_.append(block_[start:end])
elif axis_ == 1:
new_blocks_.append(block_[:, start:end])
else:
raise TypeError(f'Unsupport tiling axis[{tiling.axis}]')
for tiling in tiling_info:
new_blocks = []
for block in blocks:
tiling_num = int(block.shape[tiling.axis] / tiling.tile_size)
if block.shape[tiling.axis] % tiling.tile_size:
tiling_num += 1
block_tiling(block, tiling.axis, tiling_num, tiling.reverse,
new_blocks)
blocks = new_blocks
return blocks
def do_packing(self, blocks: List[np.ndarray], src_format: DataFormat,
dst_format: DataFormat) -> np.ndarray:
pack_func = self.call_packing_by_format_[(src_format, dst_format)]
result = np.array([], dtype=blocks[0].dtype)
for block in blocks:
result = np.concatenate([result, pack_func(block)], 0)
return result
def mk_to_k4m4(self, block: np.ndarray) -> np.ndarray:
m_slices = self.generate_slices(block.shape[0], self.vector_len)
k_slices = self.generate_slices(block.shape[1], self.vector_len)
result = np.array([], dtype=block.dtype)
m_loc = 0
for m_slice in m_slices:
k_loc = 0
for k_slice in k_slices:
slice_data = block[m_loc:m_loc + m_slice,
k_loc:k_loc + k_slice]
result = np.concatenate([result, slice_data.flatten()])
k_loc += k_slice
m_loc += m_slice
return result
def nk_to_kn4(self, block: np.ndarray) -> np.ndarray:
n_slices = self.generate_slices(block.shape[1], self.vector_len)
result = np.array([], dtype=block.dtype)
n_loc = 0
for n_slice in n_slices:
slice_data = block[:, n_loc:n_loc + n_slice]
result = np.concatenate([result, slice_data.flatten()])
n_loc += n_slice
return result
def generate_slices(self, dim: int, vec_len: int):
slices = [vec_len] * int(dim / vec_len)
boundary = dim % self.vector_len
if boundary == 0:
return slices
if vec_len == 12:
small_vec_len = 4
while small_vec_len:
while boundary >= small_vec_len:
slices.append(small_vec_len)
boundary -= small_vec_len
small_vec_len = int(small_vec_len / 2)
else:
small_vec_len = int(vec_len / 2)
while small_vec_len:
if boundary >= small_vec_len:
slices.append(small_vec_len)
boundary -= small_vec_len
small_vec_len = int(small_vec_len / 2)
return slices