import importlib.util
import os
import sys
import types
import pytest
BASE_DIR = os.path.dirname(
os.path.dirname(
os.path.dirname(
os.path.dirname(
os.path.dirname(os.path.realpath(__file__)))))
)
PYTHON_DIR = os.path.join(BASE_DIR, "autofuse/compiler/python")
MODULE_NAME = "autofuse.compiler.python.asc_codegen_compile"
MODULE_PATH = os.path.join(PYTHON_DIR, "asc_codegen_compile.py")
class _SimpleNamespace(object):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class _DummyLogger(object):
@staticmethod
def info(*args, **kwargs):
return None
@staticmethod
def error(*args, **kwargs):
return None
@staticmethod
def warning(*args, **kwargs):
return None
def _stub_module(name, **attrs):
module = types.ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
sys.modules[name] = module
return module
@pytest.fixture(scope="module")
def asc_codegen_compile_module():
_stub_module("tbe")
_stub_module("tbe.common")
_stub_module("tbe.common.buildcfg", get_current_build_config=lambda: {})
_stub_module("tbe.common.utils")
_stub_module("tbe.common.utils.log", info=_DummyLogger.info, error=_DummyLogger.error, warning=_DummyLogger.warning)
_stub_module("tbe.common.utils.op_tiling", do_op_tiling=lambda *args, **kwargs: None)
_stub_module("tbe.common.context", get_context=lambda: _SimpleNamespace(get_compile_info=lambda: {}))
_stub_module("tbe.tikcpp")
_stub_module("tbe.tikcpp.compile_op",
CommonUtility=_SimpleNamespace(print_compile_log=lambda *args, **kwargs: None),
AscendCLogLevel=_SimpleNamespace(LOG_ERROR="error", LOG_DEBUG="debug", LOG_WARNING="warning"))
_stub_module("tbe.tikcpp.get_op_tiling", TilingInfo=object, _change_param_name_to_name=lambda *args, **kwargs: None,
gen_static_shape_v2=lambda *args, **kwargs: None)
sys.modules["tbe.tikcpp"].OpInfo = object
_stub_module("asc_op_compile_base")
_stub_module("asc_op_compile_base.common")
_stub_module("asc_op_compile_base.common.platform")
_stub_module("asc_op_compile_base.common.platform.platform_info", get_soc_spec=lambda *args, **kwargs: None)
package = _stub_module("autofuse")
compiler_pkg = _stub_module("autofuse.compiler")
python_pkg = _stub_module("autofuse.compiler.python")
package.compiler = compiler_pkg
compiler_pkg.python = python_pkg
package_prefix = MODULE_NAME.rsplit('.', 1)[0]
_stub_module(package_prefix + ".pyautofuse", Schedule=object, CodeGen=object, ascir=_SimpleNamespace())
_stub_module(package_prefix + ".ascbc_kernel_compile",
ascbc_kernel_compile=lambda *args, **kwargs: ("kernel.o", "kernel.json"),
camel_to_snake=lambda value: value)
_stub_module(package_prefix + ".compile_adapter", get_pgo_env_flag=lambda: False, get_pgo_topn=lambda: 5)
spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
sys.modules[MODULE_NAME] = module
return module
def test_get_replace_kernel_root_returns_none_when_env_is_missing(monkeypatch, asc_codegen_compile_module):
monkeypatch.delenv("AUTOFUSE_DFX_FLAGS", raising=False)
assert asc_codegen_compile_module.get_replace_kernel_root() is None
def test_replace_device_kernel_replaces_exact_kernel_file(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
device_build_dir = tmpdir.ensure("build", "device", dir=True)
replace_root.join("demo_graph_op_kernel.cpp").write("new device kernel")
dst = device_build_dir.join("demo_graph_op_kernel.cpp")
dst.write("old device kernel")
asc_codegen_compile_module.replace_device_kernel(str(replace_root), str(device_build_dir), "demo_graph")
assert dst.read() == "new device kernel"
def test_replace_host_files_replaces_host_tiling_group_and_removes_stale_files(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
source_dir = replace_root.ensure("a", "b", "whatever", dir=True)
host_build_dir = tmpdir.ensure("build", "host", dir=True)
source_dir.join("demo_graph_tiling_func_0.cpp").write("new0")
source_dir.join("demo_graph_tiling_func_1.cpp").write("new1")
host_build_dir.join("demo_graph_tiling_func_old.cpp").write("stale")
asc_codegen_compile_module.replace_host_files(str(replace_root), str(host_build_dir), "demo_graph")
host_files = sorted([path.basename for path in host_build_dir.visit(fil="demo_graph*tiling_func*.cpp")])
assert host_files == ["demo_graph_tiling_func_0.cpp", "demo_graph_tiling_func_1.cpp"]
def test_find_host_replace_source_dir_raises_when_multiple_dirs_match(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
dir1 = replace_root.mkdir("x")
dir2 = replace_root.mkdir("y")
dir1.join("demo_graph_tiling_func_a.cpp").write("a")
dir2.join("demo_graph_tiling_func_b.cpp").write("b")
with pytest.raises(RuntimeError):
asc_codegen_compile_module.find_host_replace_source_dir(str(replace_root), "demo_graph")
def test_replace_host_files_allows_missing_optional_headers(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
source_dir = replace_root.mkdir("nested")
host_build_dir = tmpdir.ensure("build", "host", dir=True)
source_dir.join("demo_graph_tiling_func_0.cpp").write("new0")
asc_codegen_compile_module.replace_host_files(str(replace_root), str(host_build_dir), "demo_graph")
assert host_build_dir.join("demo_graph_tiling_func_0.cpp").read() == "new0"
assert not host_build_dir.join("autofuse_tiling_data.h").check()
assert not host_build_dir.join("autofuse_tiling_func_common.h").check()
def test_replace_host_files_skips_same_file_copy_when_source_is_host_dir(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
host_build_dir = replace_root.ensure("host", dir=True)
host_build_dir.join("demo_graph_tiling_func_0.cpp").write("same cpp")
host_build_dir.join("autofuse_tiling_data.h").write("same header")
asc_codegen_compile_module.replace_host_files(str(replace_root), str(host_build_dir), "demo_graph")
assert host_build_dir.join("demo_graph_tiling_func_0.cpp").read() == "same cpp"
assert host_build_dir.join("autofuse_tiling_data.h").read() == "same header"
def test_replace_host_files_copies_autofuse_cube_tiling_data_when_present(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
source_dir = replace_root.mkdir("nested")
host_build_dir = tmpdir.ensure("build", "host", dir=True)
source_dir.join("demo_graph_tiling_func_0.cpp").write("new0")
source_dir.join("autofuse_cube_tiling_data.h").write("cube")
asc_codegen_compile_module.replace_host_files(str(replace_root), str(host_build_dir), "demo_graph")
assert host_build_dir.join("autofuse_cube_tiling_data.h").read() == "cube"
def test_host_miss_device_hit_keeps_device_replacement_independent(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
device_build_dir = tmpdir.ensure("build", "device", dir=True)
host_build_dir = tmpdir.ensure("build", "host", dir=True)
replace_root.join("demo_graph_op_kernel.cpp").write("new device kernel")
device_build_dir.join("demo_graph_op_kernel.cpp").write("old device kernel")
asc_codegen_compile_module.replace_device_kernel(str(replace_root), str(device_build_dir), "demo_graph")
asc_codegen_compile_module.replace_host_files(str(replace_root), str(host_build_dir), "demo_graph")
assert device_build_dir.join("demo_graph_op_kernel.cpp").read() == "new device kernel"
assert host_build_dir.listdir() == []
def test_device_miss_host_hit_keeps_host_replacement_independent(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
source_dir = replace_root.mkdir("nested")
device_build_dir = tmpdir.ensure("build", "device", dir=True)
host_build_dir = tmpdir.ensure("build", "host", dir=True)
source_dir.join("demo_graph_tiling_func_0.cpp").write("new0")
device_build_dir.join("demo_graph_op_kernel.cpp").write("old device kernel")
asc_codegen_compile_module.replace_device_kernel(str(replace_root), str(device_build_dir), "demo_graph")
asc_codegen_compile_module.replace_host_files(str(replace_root), str(host_build_dir), "demo_graph")
assert device_build_dir.join("demo_graph_op_kernel.cpp").read() == "old device kernel"
assert host_build_dir.join("demo_graph_tiling_func_0.cpp").read() == "new0"
def test_host_replacement_happens_once_before_static_shape_compile_and_uses_final_host_dir(monkeypatch,
tmpdir,
asc_codegen_compile_module):
calls = []
tmp_path = str(tmpdir)
host_cv_common_dir = tmpdir.ensure("host", "cv_common", dir=True)
monkeypatch.setattr(asc_codegen_compile_module,
"get_graph_basic_info",
lambda params, args: ("demo_graph", 1, 1, True, False, {}))
monkeypatch.setattr(asc_codegen_compile_module,
"create_compile_dirs",
lambda temp_dir: (str(tmpdir.ensure("device", dir=True)), str(tmpdir.ensure("host", dir=True))))
monkeypatch.setattr(asc_codegen_compile_module,
"generate_device_and_host_code",
lambda **kwargs: ("kernel", {"k": "v"}))
monkeypatch.setattr(asc_codegen_compile_module,
"is_static_compile",
lambda params, tiling_func_srcs: True)
monkeypatch.setattr(asc_codegen_compile_module,
"ascbc_matmul_kernel_tiling_pro",
lambda *args, **kwargs: kwargs["use_cv_common"].__setitem__(0, True))
def fake_static_shape_compile(**kwargs):
if "use_cv_common" in kwargs:
kwargs["use_cv_common"][0] = True
calls.append(("static_shape_compile", kwargs["graph_name"]))
monkeypatch.setattr(asc_codegen_compile_module, "static_shape_compile", fake_static_shape_compile)
monkeypatch.setattr(asc_codegen_compile_module,
"get_replace_kernel_root",
lambda: tmpdir.ensure("replace_root", dir=True))
monkeypatch.setattr(asc_codegen_compile_module,
"replace_host_files",
lambda root, host_dir, graph_name: calls.append(("replace_host", host_dir, graph_name)))
monkeypatch.setattr(asc_codegen_compile_module,
"replace_kernel",
lambda *args, **kwargs: calls.append(("replace_device",)))
monkeypatch.setattr(asc_codegen_compile_module,
"ascbc_kernel_compile",
lambda *args, **kwargs: ("kernel.o", "kernel.json"))
monkeypatch.setattr(asc_codegen_compile_module,
"asc_graph_compile_post",
lambda *args, **kwargs: calls.append(("post", args[0])))
monkeypatch.setattr(asc_codegen_compile_module, "timestamp_set", lambda *args, **kwargs: None)
asc_codegen_compile_module.asc_graph_compile("arg0",
"kernel_name",
temp_dir=tmp_path,
params={
"schedule_results": object(),
"vector_core_num": 1,
"impl_mode": None,
})
replace_calls = [item for item in calls if item[0] == "replace_host"]
assert len(replace_calls) == 1
assert replace_calls[0][1] == str(host_cv_common_dir)
order = [item[0] for item in calls]
assert order.index("replace_host") < order.index("static_shape_compile")
class _LogCapture(object):
def __init__(self):
self.messages = []
def info(self, message, *args):
if args:
message = message % args
self.messages.append(message)
def test_replace_host_files_logs_source_target_removed_and_copied(monkeypatch, tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
source_dir = replace_root.ensure("nested", dir=True)
host_build_dir = tmpdir.ensure("build", "host", dir=True)
log_capture = _LogCapture()
source_dir.join("demo_graph_tiling_func_0.cpp").write("new0")
source_dir.join("autofuse_tiling_data.h").write("tiling")
host_build_dir.join("demo_graph_tiling_func_old.cpp").write("stale")
monkeypatch.setattr(asc_codegen_compile_module, "logger", log_capture)
asc_codegen_compile_module.replace_host_files(str(replace_root), str(host_build_dir), "demo_graph")
assert any("replace host source dir:" in message for message in log_capture.messages)
assert any("replace host target dir:" in message for message in log_capture.messages)
assert any("cleanup stale host tiling files:" in message for message in log_capture.messages)
assert any("copied host files:" in message for message in log_capture.messages)
def test_find_host_replace_source_dir_error_message_contains_conflict_dirs(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
dir1 = replace_root.mkdir("x")
dir2 = replace_root.mkdir("y")
dir1.join("demo_graph_tiling_func_a.cpp").write("a")
dir2.join("demo_graph_tiling_func_b.cpp").write("b")
with pytest.raises(RuntimeError) as exc_info:
asc_codegen_compile_module.find_host_replace_source_dir(str(replace_root), "demo_graph")
error_message = str(exc_info.value)
assert str(dir1) in error_message
assert str(dir2) in error_message
def test_replace_host_files_does_not_touch_infershape_cpp(tmpdir, asc_codegen_compile_module):
replace_root = tmpdir.mkdir("replace_root")
source_dir = replace_root.mkdir("nested")
host_build_dir = tmpdir.ensure("build", "host", dir=True)
source_dir.join("demo_graph_tiling_func_0.cpp").write("new0")
source_dir.join("demo_graph_infershape.cpp").write("new infershape")
host_build_dir.join("demo_graph_infershape.cpp").write("old infershape")
asc_codegen_compile_module.replace_host_files(str(replace_root), str(host_build_dir), "demo_graph")
assert host_build_dir.join("demo_graph_tiling_func_0.cpp").read() == "new0"
assert host_build_dir.join("demo_graph_infershape.cpp").read() == "old infershape"