import unittest
from unittest.mock import patch
import os
from copy import deepcopy
from mskl.optune.kernel_modifier import Replacer
class TestReplacer(unittest.TestCase):
def setUp(self):
self.kernel_file_content = [
"using ArchTag = actlass::arch::AscendV220;\n",
" using ElementA = half;\n",
" using LayoutA = actlass::layout::RowMajor; // tunable\n",
" using ElementB = half;\n",
" using LayoutB = actlass::layout::RowMajor;\n",
" using ElementC = half;\n",
" using LayoutC = actlass::layout::RowMajor;\n",
" using ElementAccumulator = float;\n",
"\n",
" using StoreOpClass = actlass::epilogue::process::StoreOp<ArchTag, ElementAccumulator, ElementC, LayoutC,\n",
" actlass::epilogue::process::QuantGranularity::NO_QUANT, false>;\n",
" using GemmKernel = typename actlass::gemm::kernel::DefaultGemm<\n",
" ElementA,\n",
" LayoutA,\n",
" ElementB,\n",
" LayoutB,\n",
" ElementC,\n",
" LayoutC,\n",
" ElementAccumulator,\n",
" ArchTag,\n",
" actlass::arch::OpClassCube,\n",
" actlass::arch::OpMultiplyAdd,\n",
" actlass::gemm::GemmShape<128, 256, 256>, // tunable: L0C_Tile_Shape\n",
" actlass::gemm::GemmShape<128, 256, 64>,\n",
" void,\n",
" void,\n",
" StoreOpClass,\n",
" actlass::epilogue::block::InterimTargetType::GM_DESTINATION,\n",
" void,\n",
" actlass::epilogue::block::InterimTargetType::UNDEFINED,\n",
" void,\n",
" void,\n",
" typename actlass::gemm::block::GemmIdentityBlockSwizzle<>\n",
" >::GemmKernel;\n",
]
self.kernel_file_path = "test_kernel.cpp"
self.output_file_path = "test_output.cpp"
self.output_file_path_case_2 = "test_output_2.cpp"
def tearDown(self):
if os.path.exists(self.output_file_path):
os.remove(self.output_file_path)
@patch("mskl.utils.autotune_utils.get_file_lines")
def test_replace_config(self, mock_get_file_lines):
mock_get_file_lines.return_value = deepcopy(self.kernel_file_content)
replacer = Replacer(self.kernel_file_path)
node = {
"LayoutA": "actlass::xxx",
"L0C_Tile_Shape": "<128, 128, 128>,",
}
replacer.replace_config(node, self.output_file_path)
with open(self.output_file_path, "r", encoding="utf-8") as file:
lines = file.readlines()
self.assertEqual(lines[2], ' using LayoutA = actlass::xxx;\n')
self.assertEqual(lines[22], ' <128, 128, 128>,\n')
self.assertEqual(os.stat(self.output_file_path).st_mode & 0o777, 0o640)
@patch("mskl.utils.autotune_utils.get_file_lines")
def test_replace_src_with_config(self, mock_get_file_lines):
mock_get_file_lines.return_value = deepcopy(self.kernel_file_content)
config = {
"LayoutA": "actlass::xxx",
"L0C_Tile_Shape": "<128, 128, 128>,",
}
Replacer.replace_src_with_config(self.kernel_file_path, self.output_file_path_case_2, config)
with open(self.output_file_path_case_2, "r", encoding="utf-8") as file:
lines = file.readlines()
self.assertEqual(lines[2], ' using LayoutA = actlass::xxx;\n')
self.assertEqual(lines[22], ' <128, 128, 128>,\n')
self.assertEqual(os.stat(self.output_file_path_case_2).st_mode & 0o777, 0o640)
@patch("mskl.utils.autotune_utils.get_file_lines")
def test_write_to_file(self, mock_get_file_lines):
mock_get_file_lines.return_value = deepcopy(self.kernel_file_content)
replacer = Replacer(self.kernel_file_path)
lines = deepcopy(self.kernel_file_content)
replacer._write_to_file(lines, self.output_file_path)
with open(self.output_file_path, "r", encoding="utf-8") as file:
self.assertEqual(file.readlines(), lines)
self.assertEqual(os.stat(self.output_file_path).st_mode & 0o777, 0o640)
if __name__ == "__main__":
unittest.main()