from pathlib import Path
import pytest
from impl.ops_math.dynamic import pows
from asc_op_compile_base.common.buildcfg.buildcfg import build_config
from asc_op_compile_base.common.buildcfg.buildcfg_mapping import kernel_meta_parent_dir, \
op_debug_config, tbe_debug_level
from asc_op_compile_base.common.ccec import current_build_config
from asc_op_compile_base.common.context import get_context, op_info
from asc_op_compile_base.common.context.op_context import OpContext
from asc_op_compile_base.common.platform.platform_info import set_current_compile_soc_info
def compile_sub_kernel(kernel_meta_dir, op_name, op_type, func, extend_op_info: dict = None):
current_build_config()[kernel_meta_parent_dir] = kernel_meta_dir
current_build_config()[tbe_debug_level] = 0
set_current_compile_soc_info("Ascend910_9391")
current_build_config()[op_debug_config] = ["dump_cce", ]
current_build_config()['enable_deterministic_mode'] = 0
current_build_config()[kernel_meta_parent_dir] = kernel_meta_dir
current_build_config()['enable_super_kernel'] = 1
sp_info = {}
sp_info['super_kernel_sub_loc'] = 'middle'
sp_info['super_kernel_options'] = 'early-start=0'
sp_info['super_kernel_count'] = 0
sp_info['super_kernel_sub_id'] = 0
if extend_op_info:
sp_info.update(extend_op_info)
with OpContext('static'):
opinfo = op_info.OpInfo(op_name, op_type)
get_context().set_graph_op_info(opinfo)
get_context().add_addition('super_kernel_sub_info', sp_info)
func()
class SubkernelPath:
def __init__(self, path, name):
self.root = path
self.name = name
def o(self):
return self.root / "kernel_meta" / (self.name + ".o")
def json(self):
return self.name + ".json"
@pytest.fixture(scope="function")
def subkernel_is_inf(tmp_dir):
kernel_meta_dir = Path(tmp_dir) / "subkernel_is_inf"
from impl.dynamic import is_inf
x = {}
x["shape"] = [1024]
x["ori_shape"] = [1024]
x["format"] = "ND"
x["ori_format"] = "ND"
x["dtype"] = "float16"
y = {}
y["shape"] = [1024]
y["ori_shape"] = [1024]
y["format"] = "ND"
y["ori_format"] = "ND"
y["dtype"] = "float16"
def make_1_in_1_out_subkernel_fixture(
impl_module_name,
func_name,
op_name,
op_type,
extend_op_info=None
):
"""生成子内核 fixture 的工厂函数"""
@pytest.fixture(scope="function")
def fixture_func(tmp_dir, request):
kernel_meta_dir = Path(tmp_dir) / f"subkernel_{op_name}"
module = __import__(f"impl.ops_math.dynamic.{impl_module_name}", fromlist=[func_name])
func = getattr(module, func_name)
x = {
"shape": [1024],
"ori_shape": [1024],
"format": "ND",
"ori_format": "ND",
"dtype": "float16"
}
y = {
"shape": [1024],
"ori_shape": [1024],
"format": "ND",
"ori_format": "ND",
"dtype": "float16"
}
with build_config():
compile_sub_kernel(
str(kernel_meta_dir),
op_name,
op_type,
extend_op_info=extend_op_info,
func=lambda: func(x, y)
)
return SubkernelPath(kernel_meta_dir, impl_module_name)
return fixture_func
NEW_EXTEND_OP_INFO = {
"super_kernel_options": "split-mode=1:early-start=1",
}
subkernel_is_inf_default = make_1_in_1_out_subkernel_fixture(
impl_module_name="is_inf",
func_name="is_inf",
op_name="IsInf_Default",
op_type="IsInf",
extend_op_info=None
)
subkernel_is_finite_default = make_1_in_1_out_subkernel_fixture(
impl_module_name="is_finite",
func_name="is_finite",
op_name="IsFinite_Default",
op_type="IsFinite",
extend_op_info=None
)
subkernel_is_inf_split_mode1 = make_1_in_1_out_subkernel_fixture(
impl_module_name="is_inf",
func_name="is_inf",
op_name="IsInf_SplitMode1",
op_type="IsInf",
extend_op_info=NEW_EXTEND_OP_INFO
)
subkernel_is_finite_split_mode1 = make_1_in_1_out_subkernel_fixture(
impl_module_name="is_finite",
func_name="is_finite",
op_name="IsFinite_SplitMode1",
op_type="IsFinite",
extend_op_info=NEW_EXTEND_OP_INFO
)
@pytest.fixture
def subkernel_inf(request):
fixture_name = request.param
return request.getfixturevalue(fixture_name)
@pytest.fixture
def subkernel_finite(request):
fixture_name = request.param
return request.getfixturevalue(fixture_name)
@pytest.fixture(scope="function")
def subkernel_pows_default(tmp_dir):
kernel_meta_dir = Path(tmp_dir) / "subkernel_pows"
x = {}
x["shape"] = [1024]
x["ori_shape"] = [1024]
x["format"] = "ND"
x["ori_format"] = "ND"
x["dtype"] = "float16"
x1 = {}
x1["shape"] = [1024]
x1["ori_shape"] = [1024]
x1["format"] = "ND"
x1["ori_format"] = "ND"
x1["dtype"] = "float16"
y = {}
y["shape"] = [1024]
y["ori_shape"] = [1024]
y["format"] = "ND"
y["ori_format"] = "ND"
y["dtype"] = "float16"
with build_config():
compile_sub_kernel(str(kernel_meta_dir), "Pows", "Pows", lambda: pows.pows(x, x1, y))
return SubkernelPath(kernel_meta_dir, "pows")
@pytest.fixture
def subkernel_pows(request):
fixture_name = request.param
return request.getfixturevalue(fixture_name)