"""Smoke tests for tools/perf_data_collection/op_replay/ scripts."""
import importlib
import importlib.util
import subprocess
import sys
from contextlib import contextmanager
from types import SimpleNamespace
import types
from pathlib import Path
import pytest
from tools.perf_data_collection.op_replay import common
OP_REPLAY_DIR = Path(__file__).resolve().parents[3] / "tools" / "perf_data_collection" / "op_replay"
if str(OP_REPLAY_DIR) not in sys.path:
sys.path.insert(0, str(OP_REPLAY_DIR))
dispatch_ffn = importlib.import_module("DispatchFFNCombine_run")
split_qkv = importlib.import_module("split_qkv_rmsnorm_rope_kernel_run")
op_common = importlib.import_module("common")
run_all_op = importlib.import_module("run_all_op")
@contextmanager
def op_replay_import_path():
path = str(OP_REPLAY_DIR)
inserted = path not in sys.path
if inserted:
sys.path.insert(0, path)
try:
yield
finally:
if inserted:
sys.path.remove(path)
def import_op_replay_script(script: str):
module_name = f"_test_op_replay_{Path(script).stem}"
spec = importlib.util.spec_from_file_location(module_name, OP_REPLAY_DIR / script)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
with op_replay_import_path():
spec.loader.exec_module(module)
return module
class TestOpReplayScriptsExist:
EXPECTED_SCRIPTS = [
"common.py",
"replay_framework.py",
"run_all_op.py",
"MatMulV2_run.py",
"MatMulV3_run.py",
"RmsNorm_run.py",
"SwiGlu_run.py",
"QuantBatchMatmulV3_run.py",
"BatchMatMulV2_run.py",
"GroupedMatmul_run.py",
"GroupedMatmulSwigluQuant_run.py",
"LightningIndexer_run.py",
"MoeTokenPermute_run.py",
"MoeTokenUnpermute_run.py",
"ScatterNdUpdate_run.py",
"SparseFlashAttention_run.py",
"TransposeBatchMatMul_run.py",
"DispatchFFNCombine_run.py",
]
@pytest.mark.parametrize("script", EXPECTED_SCRIPTS)
def test_script_exists(self, script):
assert (OP_REPLAY_DIR / script).is_file()
class TestOpReplayImportMap:
NEW_REPLAY_SCRIPTS = [
"BatchMatMulV2_run.py",
"GroupedMatmul_run.py",
"GroupedMatmulSwigluQuant_run.py",
"LightningIndexer_run.py",
"MoeTokenPermute_run.py",
"MoeTokenUnpermute_run.py",
"ScatterNdUpdate_run.py",
"SparseFlashAttention_run.py",
"TransposeBatchMatMul_run.py",
]
def test_new_replay_script_mains_are_coverage_visible(self, monkeypatch):
calls = []
for script in self.NEW_REPLAY_SCRIPTS:
module = import_op_replay_script(script)
monkeypatch.setattr(module.op, "main", lambda name=script: calls.append(name))
module.main()
assert calls == self.NEW_REPLAY_SCRIPTS
class TestOpReplayArgparse:
"""Verify scripts accept --help without crashing (no NPU required)."""
SCRIPTS_WITH_HELP = [
"run_all_op.py",
"MatMulV2_run.py",
"BatchMatMulV2_run.py",
"GroupedMatmul_run.py",
"GroupedMatmulSwigluQuant_run.py",
"LightningIndexer_run.py",
"MoeTokenPermute_run.py",
"MoeTokenUnpermute_run.py",
"ScatterNdUpdate_run.py",
"SparseFlashAttention_run.py",
"TransposeBatchMatMul_run.py",
"DispatchFFNCombine_run.py",
]
@pytest.mark.parametrize("script", SCRIPTS_WITH_HELP)
def test_help_flag(self, script):
result = subprocess.run(
[sys.executable, str(OP_REPLAY_DIR / script), "--help"],
capture_output=True,
text=True,
timeout=10,
)
assert result.returncode == 0, f"--help failed for {script}: {result.stderr}"
assert "--device" in result.stdout
class FakeTensor:
def __init__(self, shape=(1,), dtype="float32", device="npu"):
self.shape = shape
self.dtype = dtype
self.device = device
self.ndim = len(shape)
def npu(self):
return self
def to(self, dtype):
return FakeTensor(self.shape, dtype=dtype, device=self.device)
def unsqueeze(self, _dim):
return FakeTensor((1, *self.shape), dtype=self.dtype, device=self.device)
class FakeTorch:
int32 = "int32"
float32 = "float32"
class Npu:
@staticmethod
def synchronize():
return None
npu = Npu()
class Ops:
class Ascend:
@staticmethod
def dispatch_ffn_combine(**kwargs):
return kwargs["out"], kwargs["expert_token_nums"]
_C_ascend = Ascend()
ops = Ops()
@staticmethod
def arange(*_args, **_kwargs):
return FakeTensor((4,), dtype="int32")
@staticmethod
def full(shape, _fill_value, dtype=None):
return FakeTensor(tuple(shape), dtype=dtype)
class TestDispatchFFNCombineReplayHelpers:
def test_argparser_and_simple_helpers(self, monkeypatch, capsys):
module = import_op_replay_script("DispatchFFNCombine_run.py")
parser = module.build_argparser()
args = parser.parse_args(["--ep-size", "8", "--no-balanced", "--max-output-size", "123"])
assert args.ep_size == 8
assert args.balanced is False
assert args.max_output_size == 123
monkeypatch.setattr(module, "MAX_OUTPUT_SIZE", None)
assert module.infer_max_output_size((2, 4), 2) == module.DEFAULT_DFC_MAX_OUTPUT_SIZE
monkeypatch.setattr(module, "MAX_OUTPUT_SIZE", 256)
assert module.infer_max_output_size((2, 4), 2) == 256
monkeypatch.setattr(module, "EP_SIZE", 16)
assert module.should_skip_row_for_ep_size(Path("DispatchFFNCombine.csv"), 1, {"EP Size": "8"})
assert not module.should_skip_row_for_ep_size(Path("DispatchFFNCombine.csv"), 1, {"EP Size": "16"})
assert not module.should_skip_row_for_ep_size(Path("DispatchFFNCombine.csv"), 1, {"EP Size": ""})
assert "does not match replay" in capsys.readouterr().out
def test_shape_builders_validate_without_npu(self, monkeypatch):
module = import_op_replay_script("DispatchFFNCombine_run.py")
monkeypatch.setattr(module, "get_runtime_modules", lambda: (FakeTorch, object()))
monkeypatch.setattr(module, "resolve_runtime_dtype", lambda name: name)
with pytest.raises(ValueError, match="num_experts must be positive"):
module.build_balanced_expert_idx_tensor((2, 2), 0)
with pytest.raises(ValueError, match="scale shape mismatch"):
module.build_scale_tensor((2, 3), (2, 4), "FLOAT")
def test_debug_and_extension_paths_are_non_fatal(self, monkeypatch, capsys):
module = import_op_replay_script("DispatchFFNCombine_run.py")
monkeypatch.setenv("DFC_DEBUG_DEVICES", "1")
monkeypatch.setattr(module, "_PRINTED_DFC_DEVICE_DEBUG", False)
case = {
"x": FakeTensor((2, 4)),
"weight1_list": [FakeTensor((1, 4, 8))],
"weight2_list": [FakeTensor((1, 8, 4))],
"expert_idx": FakeTensor((2, 1), dtype="int32"),
"scale1_list": [FakeTensor((1, 8))],
"scale2_list": [FakeTensor((1, 4))],
"probs": FakeTensor((2, 1)),
"out": FakeTensor((2, 4)),
"expert_token_nums": FakeTensor((1,), dtype="int32"),
}
module.debug_dfc_tensor_devices(case)
assert "[DFC debug]" in capsys.readouterr().out
monkeypatch.setattr(module.importlib.util, "find_spec", lambda _name: None)
module.ensure_vllm_ascend_extension_loaded()
assert "DispatchFFNCombine replay may fail" in capsys.readouterr().err
def test_launch_torchrun_builds_command(self, monkeypatch):
module = import_op_replay_script("DispatchFFNCombine_run.py")
calls = []
monkeypatch.setattr(module, "find_free_port", lambda: 23456)
monkeypatch.setattr(module.subprocess, "run", lambda command, **kwargs: calls.append((command, kwargs)))
module.launch_torchrun_and_wait(
2,
["--database-path", "db"],
nproc_per_node=2,
nnodes=1,
node_rank=0,
master_addr="127.0.0.1",
master_port=None,
)
command, kwargs = calls[0]
assert "torch.distributed.run" in command
assert "--master_port=23456" in command
assert kwargs["env"]["_DFC_AUTO_TORCHRUN"] == "1"
def test_row_and_operator_paths_can_be_stubbed(self, monkeypatch, tmp_path):
module = import_op_replay_script("DispatchFFNCombine_run.py")
monkeypatch.setattr(module, "get_runtime_modules", lambda: (FakeTorch, object()))
monkeypatch.setattr(module, "ensure_vllm_ascend_extension_loaded", lambda: None)
case = {
"x": FakeTensor((2, 4)),
"weight1_list": [FakeTensor((1, 4, 8))],
"weight2_list": [FakeTensor((1, 8, 4))],
"expert_idx": FakeTensor((2, 1), dtype="int32"),
"scale1_list": [FakeTensor((1, 8))],
"scale2_list": [FakeTensor((1, 4))],
"probs": FakeTensor((2, 1)),
"group": "hccl",
"max_output_size": 64,
"out": FakeTensor((2, 4)),
"expert_token_nums": FakeTensor((1,), dtype="int32"),
"expected_output_shapes": [(2, 4), (1,)],
"weight_kind": "BF16",
"num_experts": 1,
"global_num_experts": 1,
"topk": 1,
}
out, expert_token_nums, used_fallback = module.execute_dfc_op(case)
assert (out, expert_token_nums, used_fallback) == (case["out"], case["expert_token_nums"], False)
monkeypatch.setattr(module, "build_row_case", lambda row, balanced: case)
module.run_row(tmp_path / "DispatchFFNCombine.csv", 1, {}, balanced=True)
def test_build_row_case_rejects_bad_metadata(self, monkeypatch):
module = import_op_replay_script("DispatchFFNCombine_run.py")
monkeypatch.setattr(module, "init_runtime", lambda: None)
row = {
"Input Shapes": "2,4;1,4,8",
"Input Data Types": "BF16;BF16",
"Input Formats": "ND;ND",
"Output Shapes": "2,4;1",
"Output Data Types": "BF16;INT32",
"Output Formats": "ND;ND",
}
with pytest.raises(ValueError, match="seven input metadata slots"):
module.build_row_case(row)
def test_main_reports_missing_csv_before_npu_setup(self, monkeypatch, tmp_path):
module = import_op_replay_script("DispatchFFNCombine_run.py")
args = SimpleNamespace(
repeat_count=1,
ep_size=1,
balanced=True,
max_output_size=None,
device="ATLAS_800_A3_752T_128G_DIE",
vllm_version="0.18.0",
database_path=tmp_path,
torch_version=None,
cann_version=None,
update_mode="all",
nproc_per_node=None,
nnodes=1,
node_rank=0,
master_addr="127.0.0.1",
master_port=None,
)
monkeypatch.setattr(module, "build_argparser", lambda: SimpleNamespace(parse_args=lambda: args))
monkeypatch.setattr(module, "get_replay_repeat_count", lambda value: value)
monkeypatch.setattr(module, "get_target_data_dir", lambda **_kwargs: tmp_path)
with pytest.raises(FileNotFoundError, match="No DispatchFFNCombine.csv"):
module.main()
class TestRunAllOpHelpers:
def test_argparser_and_dispatch_args(self):
module = import_op_replay_script("run_all_op.py")
args = module.build_argparser().parse_args(
[
"--execution-mode",
"subprocess",
"--dispatch-ffn-combine-ep-size",
"32",
]
)
assert args.execution_mode == "subprocess"
assert args.dispatch_ffn_combine_ep_size == 32
command = ["python", "DispatchFFNCombine_run.py"]
module.append_dispatch_ffn_combine_args(
command,
Path("DispatchFFNCombine_run.py"),
dispatch_ffn_combine_ep_size=32,
dispatch_ffn_combine_nproc_per_node=16,
dispatch_ffn_combine_nnodes=2,
dispatch_ffn_combine_node_rank=1,
dispatch_ffn_combine_master_addr="host0",
dispatch_ffn_combine_master_port=29501,
)
assert command[-12:] == [
"--ep-size",
"32",
"--nproc-per-node",
"16",
"--nnodes",
"2",
"--node-rank",
"1",
"--master-addr",
"host0",
"--master-port",
"29501",
]
def test_run_script_modes_build_expected_invocations(self, monkeypatch, tmp_path):
module = import_op_replay_script("run_all_op.py")
script_path = tmp_path / "Add_run.py"
script_path.write_text("print('ok')\n", encoding="utf-8")
monkeypatch.setattr(module, "SCRIPT_DIR", tmp_path)
monkeypatch.setattr(module, "build_database_cli_args", lambda **_kwargs: ["--database-path", "db"])
calls = []
monkeypatch.setattr(module.subprocess, "run", lambda command, **kwargs: calls.append((command, kwargs)))
module.run_script_subprocess(
script_path,
database_path=Path("db"),
device="ATLAS_800_A3_752T_128G_DIE",
vllm_ascend_version=None,
torch_version=None,
cann_version=None,
repeat_count=2,
update_mode="all",
dispatch_ffn_combine_ep_size=None,
dispatch_ffn_combine_nproc_per_node=None,
dispatch_ffn_combine_nnodes=1,
dispatch_ffn_combine_node_rank=0,
dispatch_ffn_combine_master_addr="127.0.0.1",
dispatch_ffn_combine_master_port=None,
)
assert calls[0][0][1] == str(script_path)
assert "--repeat-count" in calls[0][0]
runpy_calls = []
monkeypatch.setattr(module.runpy, "run_path", lambda path, **kwargs: runpy_calls.append((path, kwargs)))
module.run_script_inprocess(
script_path,
database_path=Path("db"),
device="ATLAS_800_A3_752T_128G_DIE",
vllm_ascend_version=None,
torch_version=None,
cann_version=None,
repeat_count=None,
update_mode="missing-only",
dispatch_ffn_combine_ep_size=None,
dispatch_ffn_combine_nproc_per_node=None,
dispatch_ffn_combine_nnodes=1,
dispatch_ffn_combine_node_rank=0,
dispatch_ffn_combine_master_addr="127.0.0.1",
dispatch_ffn_combine_master_port=None,
)
assert runpy_calls == [(str(script_path), {"run_name": "__main__"})]
def test_run_script_dispatches_and_main_summarizes(self, monkeypatch, tmp_path):
module = import_op_replay_script("run_all_op.py")
script_path = tmp_path / "Add_run.py"
script_path.write_text("print('ok')\n", encoding="utf-8")
mode_calls = []
monkeypatch.setattr(module, "run_script_subprocess", lambda *args, **kwargs: mode_calls.append("subprocess"))
monkeypatch.setattr(module, "run_script_inprocess", lambda *args, **kwargs: mode_calls.append("inprocess"))
module.run_script(
script_path,
database_path=Path("db"),
device="ATLAS_800_A3_752T_128G_DIE",
vllm_ascend_version=None,
torch_version=None,
cann_version=None,
repeat_count=None,
update_mode="all",
dispatch_ffn_combine_ep_size=None,
dispatch_ffn_combine_nproc_per_node=None,
dispatch_ffn_combine_nnodes=1,
dispatch_ffn_combine_node_rank=0,
dispatch_ffn_combine_master_addr="127.0.0.1",
dispatch_ffn_combine_master_port=None,
execution_mode="subprocess",
)
assert mode_calls == ["subprocess"]
args = SimpleNamespace(
execution_mode="inprocess",
op=None,
device="ATLAS_800_A3_752T_128G_DIE",
vllm_version=None,
database_path=tmp_path,
torch_version=None,
cann_version=None,
repeat_count=None,
update_mode="all",
dispatch_ffn_combine_ep_size=None,
dispatch_ffn_combine_nproc_per_node=None,
dispatch_ffn_combine_nnodes=1,
dispatch_ffn_combine_node_rank=0,
dispatch_ffn_combine_master_addr="127.0.0.1",
dispatch_ffn_combine_master_port=None,
continue_on_error=False,
)
monkeypatch.setattr(module, "build_argparser", lambda: SimpleNamespace(parse_args=lambda: args))
monkeypatch.setattr(module, "reset_invalid_replay_rows", lambda: None)
monkeypatch.setattr(module, "discover_run_scripts", lambda: [script_path])
monkeypatch.setattr(module, "get_target_data_dir", lambda **_kwargs: tmp_path)
monkeypatch.setattr(module, "has_operator_csv", lambda *_args: True)
monkeypatch.setattr(module, "run_script", lambda **_kwargs: mode_calls.append("main"))
monkeypatch.setattr(module, "get_invalid_replay_rows", lambda: [])
monkeypatch.setattr(module, "print_invalid_replay_summary", lambda *_args, **_kwargs: None)
monkeypatch.setattr(module, "SCRIPT_DIR", tmp_path)
module.main()
assert mode_calls[-1] == "main"
assert (tmp_path / "run_all_op_status.json").is_file()
class TestCommonModule:
def test_module_imports_without_npu(self):
"""common.py imports without NPU; torch is lazy-loaded (stays None until init_runtime)."""
assert common.torch is None
assert common.torch_npu is None
def test_data_dir_points_to_profiling_database(self):
"""DATA_DIR resolves to the profiling_database/data/ tree."""
assert common.DATA_DIR.parts[-2:] == ("profiling_database", "data")
def test_build_host_tensor_uses_empty_for_float_dtypes(self, monkeypatch):
class FakeTorch:
bool = object()
int32 = object()
int64 = object()
float16 = object()
bfloat16 = object()
float32 = object()
float64 = object()
def __init__(self):
self.empty_calls = []
self.randint_calls = []
def empty(self, shape, dtype):
self.empty_calls.append((shape, dtype))
return ("empty", shape, dtype)
def randint(self, *args, **kwargs):
self.randint_calls.append((args, kwargs))
return ("randint", args, kwargs)
fake_torch = FakeTorch()
monkeypatch.setattr(op_common, "get_runtime_modules", lambda: (fake_torch, None))
tensor = op_common.build_host_tensor((2, 3), fake_torch.bfloat16)
assert tensor == ("empty", (2, 3), fake_torch.bfloat16)
assert fake_torch.empty_calls == [((2, 3), fake_torch.bfloat16)]
assert fake_torch.randint_calls == []
class TestSplitQkvReplay:
def test_build_case_accepts_legacy_two_output_rows(self, monkeypatch):
monkeypatch.setattr(split_qkv.op, "resolve_api", lambda: "fake_api")
monkeypatch.setattr(
split_qkv,
"build_input_tensor",
lambda shape, tensor_format, dtype_name: {
"shape": shape,
"format": tensor_format,
"dtype": dtype_name,
},
)
monkeypatch.setattr(
split_qkv,
"build_positions_tensor",
lambda shape, max_position_embeddings: {
"shape": shape,
"max_position_embeddings": max_position_embeddings,
},
)
monkeypatch.setattr(
split_qkv,
"build_weight_tensor",
lambda length, dtype_name: (length, dtype_name),
)
case = split_qkv.build_case(
{
"Input Shapes": "128,1152;64",
"Input Formats": "ND;ND",
"Input Data Types": "DT_BF16;DT_FLOAT",
"Output Shapes": "128,1024;128,64",
}
)
assert case["kwargs"]["q_hidden_size"] == 1024
assert case["kwargs"]["kv_hidden_size"] == 64
assert case["kwargs"]["cos_sin_cache"]["shape"] == (2048, 64)
assert case["kwargs"]["positions"]["shape"] == (128,)
class TestDispatchFfnReplay:
def test_multinode_requires_explicit_master_port(self):
with pytest.raises(ValueError, match="--master-port"):
dispatch_ffn.launch_torchrun_and_wait(
32,
[],
nproc_per_node=16,
nnodes=2,
node_rank=0,
master_addr="127.0.0.1",
master_port=None,
)
def test_single_node_auto_port_still_launches(self, monkeypatch):
calls = []
monkeypatch.setattr(dispatch_ffn, "find_free_port", lambda: 12345)
monkeypatch.setattr(
dispatch_ffn.subprocess,
"run",
lambda cmd, env, check: (calls.append((cmd, env, check)) or SimpleNamespace(returncode=0)),
)
dispatch_ffn.launch_torchrun_and_wait(
16,
["--repeat-count", "1"],
nproc_per_node=16,
nnodes=1,
node_rank=0,
master_addr="127.0.0.1",
master_port=None,
)
cmd, env, check = calls[0]
assert "--master_port=12345" in cmd
assert env["_DFC_AUTO_TORCHRUN"] == "1"
assert check is True
def test_extension_load_success_is_cached(self, monkeypatch):
calls = []
utils_mod = types.ModuleType("vllm_ascend.utils")
utils_mod.enable_custom_op = lambda: calls.append("enable")
package_mod = types.ModuleType("vllm_ascend")
monkeypatch.setattr(dispatch_ffn, "_EXTENSION_LOAD_STATE", [None])
monkeypatch.setitem(sys.modules, "vllm_ascend", package_mod)
monkeypatch.setitem(sys.modules, "vllm_ascend.utils", utils_mod)
dispatch_ffn.ensure_vllm_ascend_extension_loaded()
dispatch_ffn.ensure_vllm_ascend_extension_loaded()
assert calls == ["enable"]
assert dispatch_ffn._EXTENSION_LOAD_STATE[0] is True
def test_extension_load_failure_is_cached(self, monkeypatch):
warnings = []
imports = []
utils_mod = types.ModuleType("vllm_ascend.utils")
def fail_enable_custom_op():
raise RuntimeError("missing extension")
def fail_import_module(name):
imports.append(name)
raise ImportError(name)
utils_mod.enable_custom_op = fail_enable_custom_op
package_mod = types.ModuleType("vllm_ascend")
package_mod.__file__ = __file__
monkeypatch.setattr(dispatch_ffn, "_EXTENSION_LOAD_STATE", [None])
monkeypatch.setattr(
dispatch_ffn,
"warn_vllm_ascend_extension_load_failure",
lambda context, exc: warnings.append((context, type(exc).__name__)),
)
monkeypatch.setattr(dispatch_ffn.importlib, "import_module", fail_import_module)
monkeypatch.setattr(dispatch_ffn.importlib.util, "find_spec", lambda name: None)
monkeypatch.setitem(sys.modules, "vllm_ascend", package_mod)
monkeypatch.setitem(sys.modules, "vllm_ascend.utils", utils_mod)
dispatch_ffn.ensure_vllm_ascend_extension_loaded()
dispatch_ffn.ensure_vllm_ascend_extension_loaded()
assert warnings == [("enable_custom_op", "RuntimeError")]
assert imports == ["vllm_ascend.vllm_ascend_C"]
assert dispatch_ffn._EXTENSION_LOAD_STATE[0] is False
def test_extension_load_warning_mentions_context(self, capsys):
dispatch_ffn.warn_vllm_ascend_extension_load_failure("unit-test", RuntimeError("missing"))
captured = capsys.readouterr()
assert "unit-test" in captured.err
assert "RuntimeError" in captured.err
class TestRunAllOp:
def test_argparser_parses_replay_options(self, tmp_path):
parser = run_all_op.build_argparser()
args = parser.parse_args(
[
"--database-path",
str(tmp_path),
"--device",
"TEST_DEVICE",
"--update-mode",
"missing-only",
"--execution-mode",
"subprocess",
"--op",
"MatMulV2",
"PadV3",
"--continue-on-error",
]
)
assert args.database_path == tmp_path
assert args.device == "TEST_DEVICE"
assert args.update_mode == "missing-only"
assert args.execution_mode == "subprocess"
assert args.op == ["MatMulV2", "PadV3"]
assert args.continue_on_error is True
def test_discover_run_scripts(self):
scripts = run_all_op.discover_run_scripts()
assert len(scripts) > 0
assert run_all_op.SELF_NAME not in [s.name for s in scripts]
def test_filter_run_scripts_exact_match(self):
scripts = [
Path("MatMulV2_run.py"),
Path("PadV3_run.py"),
Path("RmsNorm_run.py"),
]
filtered = run_all_op.filter_run_scripts(scripts, {"MatMulV2"})
names = [s.name for s in filtered]
assert names == ["MatMulV2_run.py"]
def test_filter_run_scripts_none_returns_all(self):
scripts = [Path("MatMulV2_run.py"), Path("PadV3_run.py")]
filtered = run_all_op.filter_run_scripts(scripts, None)
assert len(filtered) == 2
def test_get_csv_name(self):
assert run_all_op.get_csv_name(Path("MatMulV2_run.py")) == "MatMulV2.csv"
assert run_all_op.get_csv_name(Path("PadV3_run.py")) == "PadV3.csv"
def test_has_operator_csv(self, tmp_path):
datadir = Path(tmp_path)
sub = datadir / "sub"
sub.mkdir(parents=True)
(sub / "MatMulV2.csv").write_text("x")
assert run_all_op.has_operator_csv(datadir, "MatMulV2.csv")
assert not run_all_op.has_operator_csv(datadir, "Nonexistent.csv")
class TestDispatchFfnConstants:
def test_default_ep_size(self):
from DispatchFFNCombine_run import DEFAULT_EP_SIZE, DEFAULT_DFC_REPEAT_COUNT
assert DEFAULT_EP_SIZE == 16
assert DEFAULT_DFC_REPEAT_COUNT > 0
def test_default_max_output_size(self):
from DispatchFFNCombine_run import DEFAULT_DFC_MAX_OUTPUT_SIZE
assert DEFAULT_DFC_MAX_OUTPUT_SIZE > 0
def test_build_argparser(self):
parser = dispatch_ffn.build_standard_argparser(
description="test",
usage_examples=["python test.py"],
version_help="test",
)
args = parser.parse_args(["--database-path", "test_dir"])
assert args.database_path == Path("test_dir")