import os
import sys
import unittest
from unittest import mock
THIS_FILE_NAME = __file__
FILE_PATH = os.path.dirname(os.path.realpath(THIS_FILE_NAME))
TOP_PATH = os.path.join(FILE_PATH, "../../../")
API_ROOT_PATH = os.path.join(TOP_PATH, "build/adapter_ut")
FRAMEWORK_PATH = os.path.join(TOP_PATH, "tools/build/")
sys.path.insert(0, FRAMEWORK_PATH)
import asc_op_compile_base
from asc_op_compile_base.asc_op_compiler import template_tiling
from asc_op_compile_base.asc_op_compiler.template_tiling import *
class TestCompileOp(unittest.TestCase):
def setUp(self):
print(f"-------------------SetUp----------------")
def tearDown(self):
print(f"-------------------TearDown-------------")
def test_template_tiling(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
extract_template_tiling_info(declare_param_str, select_param_str)
result = decode_tiling()
self.assertEqual(result.get(17176852).get("paramArgs"), ['20', '25', '6', '1'])
def test_template_tiling_err_tpl(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args,('Cannot find flag in ASCENDC_TPL_UINT_SEL Z! Value should be in [ASCENDC_TPL_UI_RANGE, ASCENDC_TPL_UI_LIST, ASCENDC_TPL_UI_MIX].',))
def test_template_tiling_no_flag(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {10, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('Cannot find flag in ASCENDC_TPL_UINT_SEL Z! Value should be in [ASCENDC_TPL_UI_RANGE, ASCENDC_TPL_UI_LIST, ASCENDC_TPL_UI_MIX].',))
def test_template_tiling_null(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
extract_template_tiling_info(declare_param_str, select_param_str)
result = decode_tiling()
self.assertIsNone(result.get(17176851112))
def test_template_tiling_single(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
extract_template_tiling_info(declare_param_str, select_param_str)
result = decode_tiling(17176852)
self.assertEqual(result.get(17176852).get("paramArgs"), ['20', '25', '6', '1'])
def test_template_tiling_single_null(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
extract_template_tiling_info(declare_param_str, select_param_str)
result = decode_tiling(17176851112)
self.assertIsNone(result.get(17176851112))
def test_template_tiling_ui_range(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 20, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ("UI_RANGE declare parse failed, because the announced length of ASCENDC_TPL_UINT_DECL Z is greater than actual values' length.",))
def test_template_tiling_duplicated(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 3, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('There is duplicated number in ASCENDC_TPL_DECL_UINT Z! Duplicated List: [0, 1, 2, 3, 4, 3, 6].',))
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = \
"@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},"\
"@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
result = decode_tiling()
self.assertEqual(e.exception.args, ("ASCENDC_TPL_SELECT has duplicated definitions!",))
def test_template_tiling_name(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S111@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('There is missing marco define: S in ASCENDC_TPL_BOOL_SEL.',))
def test_template_invalid_params(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = \
"@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},"\
"@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('Length of ASCENDC_TPL_UINT_SEL Z is too short.',))
select_param_str = \
"@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},"\
"@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = {},@@ASCENDC_TPL_UINT_SEL_Z@@ = {4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('values of ASCENDC_TPL_FORMAT_SEL Y is empty!',))
select_param_str = \
"@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},"\
"@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = {15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1,4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {3, 1}}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('There is invalid number in ASCENDC_TPL_BOOL_SEL S! Value should only be in [0, 1].',))
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {1, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('Bit width:1 in ASCENDC_TPL_UINT_DECL Z is not enough to represent all values: [0, 1, 2, 3, 4, 5, 6]! Please make sure 2^bitWidth is greater than or equal to the [number of values].',))
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {0, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('Bit width in ASCENDC_TPL_UINT_DECL Z cannot be less than zero!',))
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = \
"@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},"\
"@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_FORMAT_SEL_S@@ = {0, 1}}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('S has different type in ASCENDC_TPL_FORMAT_SEL and ASCENDC_TPL_BOOL_DECL.',))
select_param_str = \
"@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@ASCENDC_TPL_BOOL_SEL_XXX@@ = {0, 1}},"\
"@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 1, 6},@@ASCENDC_TPL_BOOL_SEL_XXX@@ = {0, 1}}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('Cannot find ASCENDC_TPL_BOOL_SEL name: XXX in ASCENDC_TPL_BOOL_DECL.',))
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {64, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(e.exception.args, ('name:Z, type:UINT, Total bit width cannot be greater than 64!',))
def test_template_tiling_struct(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1},@@ASCENDC_TPL_TILING_STRUCT_SEL_tilingDataStruct@@ = {}},};"
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
extract_template_tiling_info(declare_param_str, select_param_str)
result = decode_tiling()
self.assertEqual(result.get(17176852).get("paramArgs"), ['20', '25', '6', '1'])
def test_template_tiling_kernel(self):
declare_param_str = "@@ASCENDC_TPL_ARGS_DECL_AddTemplateCustom@@ = {@@ASCENDC_TPL_DTYPE_DECL_D_T_X@@ = {10, 20},@@ASCENDC_TPL_DTYPE_DECL_D_T_Y@@ = {10, 20},@@ASCENDC_TPL_DTYPE_DECL_D_T_Z@@ = {10, 20},@@ASCENDC_TPL_UINT_DECL_TILE_NUM@@ = {8, 2, 2, 0, 2, 3, 5, 10, 12, 13, 9, 8},@@ASCENDC_TPL_BOOL_DECL_IS_SPLIT@@ = {0, 1},};"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_KERNEL_TYPE_SEL@@ = {2}, @@ASCENDC_TPL_DTYPE_SEL_D_T_X@@ = {10}, @@ASCENDC_TPL_DTYPE_SEL_D_T_Y@@ = {10}, @@ASCENDC_TPL_DTYPE_SEL_D_T_Z@@ = {10}, @@ASCENDC_TPL_UINT_SEL_TILE_NUM@@ = {1, 1, 8}, @@ASCENDC_TPL_BOOL_SEL_IS_SPLIT@@ = {0, 1},}, @@{@@ASCENDC_TPL_KERNEL_TYPE_SEL@@ = {0}, @@ASCENDC_TPL_DTYPE_SEL_D_T_X@@ = {20}, @@ASCENDC_TPL_DTYPE_SEL_D_T_Y@@ = {20}, @@ASCENDC_TPL_DTYPE_SEL_D_T_Z@@ = {20}, @@ASCENDC_TPL_UINT_SEL_TILE_NUM@@ = {1, 1, 8}, @@ASCENDC_TPL_BOOL_SEL_IS_SPLIT@@ = {0, 1},},};"
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
extract_template_tiling_info(declare_param_str, select_param_str)
result = decode_tiling()
self.assertEqual(result.get(17435146).get("paramArgs"), ['10', '10', '10', '1', '0'])
self.assertEqual(result.get(17435146).get("kernelType"), 2)
def test_template_tiling_datatype(self):
input0 = int(template_tiling.ASCENDC_TPL_INPUT_BIAS)
output0 = int(template_tiling.ASCENDC_TPL_OUTPUT_BIAS)
input0_expr = f"({input0} + (3 * 7 - 21))"
output0_expr = f"({output0} + ((8 // 4) - 2))"
d_t_x_sel_expr = "(13 * 2 + 1)"
d_t_y_sel_expr = "((5 - 2) * 1)"
declare_param_str = \
f"@@ASCENDC_TPL_ARGS_DECL_AddTemplateCustom@@ = {{@@ASCENDC_TPL_DATATYPE_DECL_D_T_X@@ = {{C_DT_FLOAT16, C_DT_BF16, {input0_expr}}},"\
f"@@ASCENDC_TPL_DATATYPE_DECL_D_T_Y@@ = {{C_DT_FLOAT, C_DT_INT32, {output0_expr}}},"\
"@@ASCENDC_TPL_BOOL_DECL_IS_SPLIT@@ = {0, 1},};"
select_param_str = \
f"@@ASCENDC_TPL_LISTS@@ = {{@@{{@@ASCENDC_TPL_DATATYPE_SEL_D_T_X@@ = {{{d_t_x_sel_expr}}},"\
f"@@ASCENDC_TPL_DATATYPE_SEL_D_T_Y@@ = {{{d_t_y_sel_expr}}},@@ASCENDC_TPL_BOOL_SEL_IS_SPLIT@@ = {{1}},}},}};"
op_info = OpInfo(
inputs=[{"dtype": ["float16", "bfloat16"]}],
outputs=[{"dtype": ["float32", "int32"]}]
)
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
with mock.patch.object(
template_tiling, 'safe_parse_value', wraps=template_tiling.safe_parse_value
) as mock_safe_parse_value:
extract_template_tiling_info(declare_param_str, select_param_str)
dtype_options, dtype_select_indexes = extract_decl_param_options(op_info, "dtype")
result = decode_tiling()
self.assertIn(mock.call("C_DT_BF16"), mock_safe_parse_value.call_args_list)
self.assertIn(mock.call("C_DT_INT32"), mock_safe_parse_value.call_args_list)
self.assertIn(mock.call(input0_expr), mock_safe_parse_value.call_args_list)
self.assertIn(mock.call(output0_expr), mock_safe_parse_value.call_args_list)
self.assertIn(mock.call(d_t_x_sel_expr), mock_safe_parse_value.call_args_list)
self.assertIn(mock.call(d_t_y_sel_expr), mock_safe_parse_value.call_args_list)
self.assertEqual(dtype_options, [["float16", "bfloat16"], ["float32", "int32"]])
self.assertEqual(dtype_select_indexes, [True, True])
decoded_info = next(iter(result.values()))
self.assertEqual(decoded_info.get("dtypeParams"), ["bfloat16", "int32"])
self.assertEqual(decoded_info.get("paramArgs"), ["TypeFromId<27>::type", "TypeFromId<3>::type", "1"])
def test_template_tiling_datatype_parse_values(self):
self.assertEqual(safe_parse_value("27"), 27)
self.assertEqual(safe_parse_value("(13 * 2 + 1)"), 27)
self.assertEqual(safe_parse_value("C_DT_BF16"), 27)
self.assertEqual(safe_parse_value("C_FORMAT_ND"), 2)
self.assertEqual(safe_parse_value("C_DT_UNKNOWN"), -1)
self.assertEqual(extract_expr("{C_DT_FLOAT16, (13 * 2 + 1), C_DT_UNKNOWN, C_FORMAT_ND}"), [1, 27, 2])
def test_template_tiling_datatype_invalid_select_value(self):
declare_param_str = \
"@@ASCENDC_TPL_ARGS_DECL_AddTemplateCustom@@ = {@@ASCENDC_TPL_DATATYPE_DECL_D_T_X@@ = {C_DT_FLOAT16, C_DT_BF16},"\
"@@ASCENDC_TPL_BOOL_DECL_IS_SPLIT@@ = {0, 1},};"
select_param_str = \
"@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DATATYPE_SEL_D_T_X@@ = {(2 * 3 + 1)},"\
"@@ASCENDC_TPL_BOOL_SEL_IS_SPLIT@@ = {1},},};"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
self.assertEqual(
e.exception.args,
("Cannot find value {7} from ASCENDC_TPL_DATATYPE_SEL D_T_X in ASCENDC_TPL_DATATYPE_DECL,"
" please check your macro define.",)
)
def test_template_tiling_datatype_invalid_parse_value(self):
declare_param_str = \
"@@ASCENDC_TPL_ARGS_DECL_AddTemplateCustom@@ = {@@ASCENDC_TPL_DATATYPE_DECL_D_T_X@@ = {C_DT_UNKNOWN},};"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, "")
self.assertEqual(e.exception.args, ("values of ASCENDC_TPL_DATATYPE_DECL D_T_X is empty!",))
def test_template_tiling_deterministic(self):
declare_param_str = "@@structFlashAttentionScore@@ =@@ASCENDC_TPL_ARGS_DECL_FlashAttentionScore@@ = {@@ASCENDC_TPL_DTYPE_DECL_X@@ = { 10, 30, 20},@@ASCENDC_TPL_FORMAT_DECL_Y@@ = {15, 25},@@ASCENDC_TPL_UINT_DECL_Z@@ = {8, 2, 2, 0, 2, 3, 4, 5, 6},@@ASCENDC_TPL_BOOL_DECL_S@@ = {0, 1}, };"
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_DETERMINISTIC_SEL@@ = {false},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1},@@ASCENDC_TPL_TILING_STRUCT_SEL_tilingDataStruct@@ = {}},};"
with asc_op_compile_base.common.context.op_context.OpContext():
with mock.patch.object(asc_op_compile_base.common.context.get_context(), 'get_addition', return_value = ''):
extract_template_tiling_info(declare_param_str, select_param_str)
result = decode_tiling()
self.assertEqual(result.get(17176852).get("paramArgs"), ['20', '25', '6', '1'])
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_DETERMINISTIC_SEL@@ = {0},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1},@@ASCENDC_TPL_TILING_STRUCT_SEL_tilingDataStruct@@ = {}},};"
extract_template_tiling_info(declare_param_str, select_param_str)
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_DETERMINISTIC_SEL@@ = {1},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1},@@ASCENDC_TPL_TILING_STRUCT_SEL_tilingDataStruct@@ = {}},};"
extract_template_tiling_info(declare_param_str, select_param_str)
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_DETERMINISTIC_SEL@@ = {3},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1},@@ASCENDC_TPL_TILING_STRUCT_SEL_tilingDataStruct@@ = {}},};"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
select_param_str = "@@ASCENDC_TPL_LISTS@@ = {@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 10, 30},@@ASCENDC_TPL_DETERMINISTIC_SEL@@ = {0, 1},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15},@@ASCENDC_TPL_UINT_SEL_Z@@ = {1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1}},@@{@@ASCENDC_TPL_DTYPE_SEL_X@@ = { 20},@@ASCENDC_TPL_FORMAT_SEL_Y@@ = { 15, 25},@@ASCENDC_TPL_UINT_SEL_Z@@ = { 1, 4, 6},@@ASCENDC_TPL_BOOL_SEL_S@@ = {0, 1},@@ASCENDC_TPL_TILING_STRUCT_SEL_tilingDataStruct@@ = {}},};"
with self.assertRaises(RuntimeError) as e:
extract_template_tiling_info(declare_param_str, select_param_str)
if __name__ == "__main__":
unittest.main()