from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from tools.perf_data_collection.op_replay import replay_framework
if TYPE_CHECKING:
from pathlib import Path
def test_build_inputs_honors_dtype_override(monkeypatch: pytest.MonkeyPatch):
recorded_calls: list[tuple[tuple[int, ...], str, str]] = []
monkeypatch.setattr(replay_framework, "init_runtime", lambda: None)
monkeypatch.setattr(
replay_framework,
"get_runtime_modules",
lambda: ("torch", "torch_npu"),
)
def fake_build_input_tensor(*, shape, input_format, dtype_name):
recorded_calls.append((shape, input_format, dtype_name))
return {
"shape": shape,
"input_format": input_format,
"dtype_name": dtype_name,
}
monkeypatch.setattr(replay_framework, "build_input_tensor", fake_build_input_tensor)
op = replay_framework.OpReplay(
kernel_type="MaskedFill",
description="test",
usage_examples=["python test.py"],
version_help="test",
input_count=2,
input_dtype_overrides={1: "DT_BOOL"},
)
tensors = op.build_inputs(
{
"Input Shapes": "2,3;2,3",
"Input Formats": "ND;ND",
"Input Data Types": "FLOAT16;INT64",
}
)
assert [tensor["dtype_name"] for tensor in tensors] == ["DT_FLOAT16", "DT_BOOL"]
assert recorded_calls == [
((2, 3), "ND", "DT_FLOAT16"),
((2, 3), "ND", "DT_BOOL"),
]
def test_resolve_api_supports_nested_torch_paths(monkeypatch: pytest.MonkeyPatch):
class FakeFunctional:
@staticmethod
def softmax():
return "softmax"
class FakeNN:
functional = FakeFunctional()
class FakeTorch:
nn = FakeNN()
monkeypatch.setattr(
replay_framework,
"get_runtime_modules",
lambda: (FakeTorch(), object()),
)
op = replay_framework.OpReplay(
kernel_type="SoftmaxV2",
api_path="torch.nn.functional.softmax",
description="test",
usage_examples=["python test.py"],
version_help="test",
)
resolved = op.resolve_api()
assert resolved() == "softmax"
def test_main_replays_each_row_repeat_count_times(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
capsys: pytest.CaptureFixture[str],
):
csv_path = tmp_path / "Add.csv"
csv_path.write_text(
"Input Shapes,Input Formats,Input Data Types\n1,ND,FLOAT16\n",
encoding="utf-8",
)
calls: list[int] = []
monkeypatch.setattr(replay_framework, "ensure_npu_available", lambda: None)
monkeypatch.setattr(replay_framework, "get_target_data_dir", lambda **_: tmp_path)
monkeypatch.setattr(replay_framework, "get_replay_repeat_count", lambda _: 3)
monkeypatch.setattr(replay_framework.OpReplay, "synchronize", lambda self: None)
def build_case(_row):
return {"inputs": [], "kwargs": {}, "api": None}
def run_case(_case):
calls.append(1)
return "ok"
op = replay_framework.OpReplay(
kernel_type="Add",
description="test",
usage_examples=["python test.py"],
version_help="test",
build_case=build_case,
run_case=run_case,
)
monkeypatch.setattr(
"sys.argv",
["Add_run.py", "--database-path", str(tmp_path)],
)
op.main()
output = capsys.readouterr().out
assert len(calls) == 3
assert "Processed 3 Add rows" in output