import pytest
import importlib.util
import os
import sys
import types
_BASE_DIR = os.path.dirname(
os.path.dirname(
os.path.dirname(
os.path.dirname(
os.path.dirname(os.path.realpath(__file__)))))
)
_PYTHON_DIR = os.path.join(_BASE_DIR, "autofuse/compiler/python")
MODULE_NAME = "autofuse.compiler.python.asc_codegen_compile"
MODULE_PATH = os.path.join(_PYTHON_DIR, "asc_codegen_compile.py")
class SimpleNamespace(object):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class DummyLogger(object):
@staticmethod
def info(*args, **kwargs):
return None
@staticmethod
def error(*args, **kwargs):
return None
@staticmethod
def warning(*args, **kwargs):
return None
def stub_module(name, **attrs):
module = types.ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
sys.modules[name] = module
return module
@pytest.fixture(scope="module")
def asc_codegen_compile_module():
"""Fixture to load asc_codegen_compile module with mocked dependencies"""
stub_module("tbe")
stub_module("tbe.common")
stub_module("tbe.common.buildcfg", get_current_build_config=lambda: {})
stub_module("tbe.common.utils")
stub_module("tbe.common.utils.log",
info=DummyLogger.info,
error=DummyLogger.error,
warning=DummyLogger.warning)
stub_module("tbe.common.utils.op_tiling",
do_op_tiling=lambda *args, **kwargs: {
"tiling_data": {},
"tiling_key": 12345,
"block_dim": 4
})
stub_module("tbe.common.context",
get_context=lambda: SimpleNamespace(get_compile_info=lambda: {}))
stub_module("tbe.tikcpp")
stub_module("tbe.tikcpp.compile_op",
CommonUtility=SimpleNamespace(print_compile_log=lambda *args, **kwargs: None),
AscendCLogLevel=SimpleNamespace(LOG_ERROR="error", LOG_DEBUG="debug", LOG_WARNING="warning",
LOG_INFO="info"))
class MockTilingInfo:
def __init__(self):
self.tiling_data = None
self.tiling_key = None
self.file_content = None
stub_module("tbe.tikcpp.get_op_tiling",
TilingInfo=MockTilingInfo,
_change_param_name_to_name=lambda *args, **kwargs: None,
gen_static_shape_v2=lambda *args, **kwargs: "mock_tiling_content")
sys.modules["tbe.tikcpp"].OpInfo = object
stub_module("asc_op_compile_base")
stub_module("asc_op_compile_base.common")
stub_module("asc_op_compile_base.common.platform")
stub_module("asc_op_compile_base.common.platform.platform_info",
get_soc_spec=lambda key: 8 if key == "vector_core_cnt" else None)
package = stub_module("autofuse")
compiler_pkg = stub_module("autofuse.compiler")
python_pkg = stub_module("autofuse.compiler.python")
package.compiler = compiler_pkg
compiler_pkg.python = python_pkg
package_prefix = MODULE_NAME.rsplit('.', 1)[0]
stub_module(package_prefix + ".pyautofuse",
Schedule=object,
CodeGen=object,
ascir=SimpleNamespace())
stub_module(package_prefix + ".ascbc_kernel_compile",
ascbc_kernel_compile=lambda *args, **kwargs: ("kernel.o", "kernel.json"),
camel_to_snake=lambda value: ''.join(['_' +
c.lower() if c.isupper() else c for c in value]).lstrip('_'))
stub_module(package_prefix + ".compile_adapter",
get_pgo_env_flag=lambda: False,
get_pgo_topn=lambda: 5)
spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
if spec is not None and spec.loader is not None:
spec.loader.exec_module(module)
sys.modules[MODULE_NAME] = module
return module
class TestBuildConvArgs:
"""测试 _build_conv_args 函数"""
@staticmethod
def test_build_conv_args_basic(asc_codegen_compile_module):
"""测试基本的 Conv2D 参数构建 - 验证格式转换功能"""
args_list = [
{"shape": [1, 64, 224, 224], "format": "NCHW", "dtype": "float16"},
{"shape": [7, 7, 3, 64], "format": "HWCN", "dtype": "float16"},
{"shape": [1, 64, 112, 112], "format": "NCHW", "dtype": "float16"},
{"shape": [1, 64, 224, 224], "format": "NCHW", "dtype": "float16"},
]
input_num = 2
data_format = "NCHW"
origin_inputs, origin_outputs, inputs = asc_codegen_compile_module.build_conv_args(
args_list, input_num, data_format
)
assert len(origin_inputs) == 2
assert len(inputs) == 2
assert len(origin_outputs) == 1
assert inputs[0]["param_name"] == "input0"
assert inputs[0]["format"] == "NCHW"
assert inputs[1]["param_name"] == "input1"
assert inputs[1]["format"] == "NCHW"
assert list(inputs[1]["shape"]) == [64, 3, 7, 7]
assert origin_outputs[0]["format"] == "NCHW"
assert origin_outputs[0]["param_name"] == "output0"
@staticmethod
def test_build_conv_args_nhwc_to_nchw(asc_codegen_compile_module):
"""测试 NHWC -> NCHW 格式转换"""
args_list = [
{"shape": [1, 224, 224, 64], "format": "NHWC", "dtype": "float16"},
{"shape": [1, 64, 224, 224], "format": "NCHW", "dtype": "float16"},
]
input_num = 1
data_format = "NCHW"
origin_inputs, origin_outputs, inputs = asc_codegen_compile_module.build_conv_args(
args_list, input_num, data_format
)
assert inputs[0]["format"] == "NCHW"
assert list(inputs[0]["shape"]) == [1, 64, 224, 224]
assert inputs[0]["ori_format"] == "NCHW"
@staticmethod
def test_build_conv_args_same_format_no_conversion(asc_codegen_compile_module):
"""测试格式相同时不转换"""
args_list = [
{"shape": [1, 64, 224, 224], "format": "NCHW", "dtype": "float16"},
{"shape": [1, 64, 224, 224], "format": "NCHW", "dtype": "float16"},
]
input_num = 1
data_format = "NCHW"
origin_inputs, origin_outputs, inputs = asc_codegen_compile_module.build_conv_args(
args_list, input_num, data_format
)
assert inputs[0]["format"] == "NCHW"
assert inputs[0]["shape"] == [1, 64, 224, 224]
class TestGetGraphBasicInfo:
"""测试 get_graph_basic_info 函数新增的 is_conv 返回值"""
@staticmethod
def test_get_graph_basic_info_returns_is_conv_for_conv2d(asc_codegen_compile_module):
"""测试 Conv2D 场景返回 is_conv=True"""
params = {'vector_core_num': 8}
class MockScheduleResults:
def get_name(self):
return "Conv2DFusionGraph"
def get_input_num(self):
return 3
def get_output_num(self):
return 1
def is_cube_type(self):
return True
def get_cube_attributes(self):
return {"cube_attributes": {"input_num": 3}}
def is_conv_type(self):
return True
args = [
{"shape": [1, 64, 224, 224]},
{"shape": [64, 3, 7, 7]},
{"shape": [1, 64, 224, 224]},
"conv2d_kernel"
]
params['schedule_results'] = MockScheduleResults()
graph_name, input_num, output_num, is_cube, is_conv, cube_attrs = \
asc_codegen_compile_module.get_graph_basic_info(params, args)
assert "conv2" in graph_name.lower()
assert input_num == 3
assert output_num == 1
assert is_cube == True
assert is_conv == True
assert cube_attrs is not None
@staticmethod
def test_get_graph_basic_info_returns_is_conv_false_for_matmul(asc_codegen_compile_module):
"""测试 MatMul 场景返回 is_conv=False"""
params = {'vector_core_num': 8}
class MockScheduleResults:
def get_name(self):
return "MatMulFusionGraph"
def get_input_num(self):
return 2
def get_output_num(self):
return 1
def is_cube_type(self):
return True
def get_cube_attributes(self):
return {"cube_attributes": {"is_batch": False}}
def is_conv_type(self):
return False
args = [
{"shape": [1024, 1024]},
{"shape": [1024, 1024]},
{"shape": [1024, 1024]},
"matmul_kernel"
]
params['schedule_results'] = MockScheduleResults()
graph_name, input_num, output_num, is_cube, is_conv, cube_attrs = \
asc_codegen_compile_module.get_graph_basic_info(params, args)
assert "matmul" in graph_name.lower() or "mat" in graph_name.lower()
assert is_cube == True
assert is_conv == False
class TestGenerateCmakeLists:
"""测试 generate_cmake_lists 函数包含 Conv2D 编译路径"""
@staticmethod
def test_generate_cmake_lists_includes_conv2d_paths(asc_codegen_compile_module, tmpdir):
"""测试生成的 CMakeLists.txt 包含 Conv2D 头文件路径"""
host_build_dir = str(tmpdir)
asc_codegen_compile_module.generate_cmake_lists(
"conv2d_graph",
"conv2d_kernel",
host_build_dir,
is_last_compile=True,
is_static_shape=False,
is_cube=True
)
cmake_file = os.path.join(host_build_dir, "CMakeLists.txt")
assert os.path.exists(cmake_file)
with open(cmake_file, 'r') as f:
content = f.read()
assert "conv2d_v2" in content
assert "ops_nn/ascendc/conv2d_v2" in content
assert "cube_kernel_tiling_wrapper.cpp" in content
@staticmethod
def test_generate_cmake_lists_includes_matmul_paths(asc_codegen_compile_module, tmpdir):
"""测试生成的 CMakeLists.txt 包含 MatMul 头文件路径"""
host_build_dir = str(tmpdir)
asc_codegen_compile_module.generate_cmake_lists(
"matmul_graph",
"matmul_kernel",
host_build_dir,
is_last_compile=True,
is_static_shape=True,
is_cube=True
)
cmake_file = os.path.join(host_build_dir, "CMakeLists.txt")
assert os.path.exists(cmake_file)
with open(cmake_file, 'r') as f:
content = f.read()
assert "mat_mul_v3" in content
assert "ops_nn/ascendc/mat_mul_v3" in content
class TestStaticShapeCompileHasattrCheck:
"""测试 static_shape_compile 函数的 hasattr 检查"""
@staticmethod
def test_static_shape_compile_uses_hasattr_for_gen_const_tiling_data(asc_codegen_compile_module, tmpdir):
"""验证 static_shape_compile 使用 hasattr 检查 GenConstTilingData"""
temp_dir = str(tmpdir)
host_dir = os.path.join(temp_dir, "host")
device_dir = os.path.join(temp_dir, "device")
os.makedirs(host_dir, exist_ok=True)
os.makedirs(device_dir, exist_ok=True)
original_ascendc_clean = getattr(asc_codegen_compile_module, 'ascendc_clean', None)
original_ascbc_host_compile = getattr(asc_codegen_compile_module, 'ascbc_host_compile', None)
original_static_shape_kernel_proc = getattr(asc_codegen_compile_module, 'static_shape_kernel_proc', None)
def mock_ascendc_clean(*args, **kwargs):
pass
def mock_ascbc_host_compile(*args, **kwargs):
pass
def mock_static_shape_kernel_proc(*args, **kwargs):
pass
asc_codegen_compile_module.ascendc_clean = mock_ascendc_clean
asc_codegen_compile_module.ascbc_host_compile = mock_ascbc_host_compile
asc_codegen_compile_module.static_shape_kernel_proc = mock_static_shape_kernel_proc
try:
pass
finally:
if original_ascendc_clean:
asc_codegen_compile_module.ascendc_clean = original_ascendc_clean
if original_ascbc_host_compile:
asc_codegen_compile_module.ascbc_host_compile = original_ascbc_host_compile
if original_static_shape_kernel_proc:
asc_codegen_compile_module.static_shape_kernel_proc = original_static_shape_kernel_proc
class TestDynamicShapeCompile:
"""测试新增的 dynamic_shape_compile 函数"""
@staticmethod
def test_dynamic_shape_compile_exists(asc_codegen_compile_module):
"""验证 dynamic_shape_compile 函数存在"""
assert hasattr(asc_codegen_compile_module, 'dynamic_shape_compile')
@staticmethod
def test_dynamic_shape_compile_signature(asc_codegen_compile_module):
"""验证 dynamic_shape_compile 函数签名"""
import inspect
sig = inspect.signature(asc_codegen_compile_module.dynamic_shape_compile)
params = list(sig.parameters.keys())
assert 'kernel_name' in params
assert 'temp_dir' in params
assert 'graph_name' in params
assert 'use_cv_common' in params
assert 'is_cube' in params
class TestAscbcConvKernelTilingPro:
"""测试新增的 ascbc_conv_kernel_tiling_pro 函数"""
@staticmethod
def test_ascbc_conv_kernel_tiling_pro_exists(asc_codegen_compile_module):
"""验证 ascbc_conv_kernel_tiling_pro 函数存在"""
assert hasattr(asc_codegen_compile_module, 'ascbc_conv_kernel_tiling_pro')
@staticmethod
def test_ascbc_conv_kernel_tiling_pro_signature(asc_codegen_compile_module):
"""验证函数签名"""
import inspect
sig = inspect.signature(asc_codegen_compile_module.ascbc_conv_kernel_tiling_pro)
params = list(sig.parameters.keys())
assert 'temp_dir' in params
assert 'graph_name' in params
assert 'kernel_name' in params
assert 'cube_attrs' in params
assert 'tiling_key_list' in params
class TestAscbcMatmulKernelDynamicTilingPro:
"""测试新增的 ascbc_matmul_kernel_dynamic_tiling_pro 函数"""
@staticmethod
def test_ascbc_matmul_kernel_dynamic_tiling_pro_exists(asc_codegen_compile_module):
"""验证函数存在"""
assert hasattr(asc_codegen_compile_module, 'ascbc_matmul_kernel_dynamic_tiling_pro')
@staticmethod
def test_ascbc_matmul_kernel_dynamic_tiling_pro_signature(asc_codegen_compile_module):
"""验证函数签名"""
import inspect
sig = inspect.signature(asc_codegen_compile_module.ascbc_matmul_kernel_dynamic_tiling_pro)
params = list(sig.parameters.keys())
assert 'temp_dir' in params
assert 'graph_name' in params
assert 'kernel_name' in params
assert 'cube_attrs' in params