import argparse
import os
import re
import shutil
import subprocess
import sys
import site
import time
import logging
import platform
from pathlib import Path
from typing import Dict, List, Optional, Tuple
def _format_cmd(command: List[str]) -> str:
return " ".join(map(str, command))
def _print_if_pattern(completed, patterns):
for line in (completed.stdout or "").splitlines():
if any(p.search(line) for p in patterns):
logging.info(line.rstrip())
for line in (completed.stderr or "").splitlines():
if any(p.search(line) for p in patterns):
logging.info(line.rstrip())
def _run_command_verbose(command: List[str], cwd_str: Optional[str]) -> None:
proc = subprocess.Popen(
[str(x) for x in command],
cwd=cwd_str,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
errors="replace",
bufsize=1,
)
if proc.stdout is None:
proc.kill()
raise RuntimeError("failed to capture process stdout in verbose mode")
merged_output_lines: List[str] = []
for line in proc.stdout:
merged_output_lines.append(line)
logging.info(line.rstrip())
return_code = proc.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, command, output="".join(merged_output_lines), stderr=None)
def _run_command_quiet(command: List[str], cwd_str: Optional[str], always_print_patterns: Optional[List[str]]) -> None:
completed = subprocess.run(
[str(x) for x in command],
cwd=cwd_str,
capture_output=True,
text=True,
encoding="utf-8",
errors="replace",
)
if completed.returncode != 0:
if completed.stdout:
logging.info(completed.stdout.rstrip())
if completed.stderr:
logging.info(completed.stderr.rstrip())
raise subprocess.CalledProcessError(
completed.returncode, command, output=completed.stdout, stderr=completed.stderr
)
if always_print_patterns:
patterns = [re.compile(p) for p in always_print_patterns]
_print_if_pattern(completed, patterns)
def run_command(
command: List[str],
cwd: Optional[Path] = None,
*,
title: Optional[str] = None,
verbose: bool = False,
always_print_patterns: Optional[List[str]] = None,
) -> float:
cwd_str = str(cwd) if cwd is not None else None
start = time.perf_counter()
if title:
logging.info(f"{title}")
if verbose:
logging.info(f" $ {_format_cmd(command)}" + (f"\n cwd: {cwd_str}" if cwd_str else ""))
try:
if verbose:
_run_command_verbose(command, cwd_str)
else:
_run_command_quiet(command, cwd_str, always_print_patterns)
except FileNotFoundError as e:
raise RuntimeError(f"command not found: {command[0]}") from e
return time.perf_counter() - start
def add_user_scripts_to_path() -> None:
if os.name == "nt":
appdata = os.environ.get("APPDATA")
if not appdata:
return
major, minor = sys.version_info[:2]
scripts_dir = Path(appdata) / "Python" / f"Python{major}{minor}" / "Scripts"
else:
scripts_dir = Path(site.getuserbase()) / "bin"
if scripts_dir.exists():
scripts_str = str(scripts_dir)
current_path = os.environ.get("PATH", "")
if scripts_str not in current_path.split(os.pathsep):
os.environ["PATH"] = scripts_str + os.pathsep + current_path
def ensure_cmake_tools() -> None:
if shutil.which("cmake") and shutil.which("ctest"):
return
logging.info("cmake/ctest not found, installing via pip...")
run_command([sys.executable, "-m", "pip", "install", "--user", "cmake>=3.16"])
add_user_scripts_to_path()
if not (shutil.which("cmake") and shutil.which("ctest")):
raise RuntimeError(
"cmake/ctest still not found after installation; please add user Scripts/bin directory to PATH"
)
def is_windows() -> bool:
if os.name == "nt" or platform.system().lower() == "windows":
return True
return False
def cmake_friendly_path(p: Optional[str]) -> Optional[str]:
if not p:
return None
if is_windows():
p = p.replace("\\", "/")
return p
def get_compiler_major_version(compiler_path: str) -> int:
"""Get the major version number of the compiler."""
if not compiler_path:
return 0
try:
logging.debug("Checking version for compiler: %s", compiler_path)
result = subprocess.run(
[compiler_path, "--version"],
capture_output=True,
text=True,
check=False
)
if result.returncode != 0:
logging.warning("Failed to run --version on: %s", compiler_path)
return 0
match = re.search(r'(\d+)\.', result.stdout)
if match:
version = int(match.group(1))
logging.debug("Parsed version for %s: %d", compiler_path, version)
return version
except Exception as e:
logging.warning("Exception occurred while checking compiler version: %s", e)
return 0
return 0
def _try_find_compiler(cxx_name: str, cc_name: str, min_ver: int) -> Tuple[Optional[str], Optional[str]]:
"""
Try to find a specific C++ compiler and check if the version meets the requirements.
"""
cxx_path = shutil.which(cxx_name)
if not cxx_path:
return None, None
ver = get_compiler_major_version(cxx_path)
logging.debug("Found candidate %s, version: %d (required: %d)", cxx_path, ver, min_ver)
if ver >= min_ver:
cc_path = shutil.which(cc_name)
logging.info("Selected compiler pair: %s / %s (Version >= %d)", cxx_path, cc_path, min_ver)
return cxx_path, cc_path
return None, None
def _auto_detect_compilers() -> Tuple[str, Optional[str]]:
logging.info("CXX not specified, starting automatic detection...")
cxx, cc = _try_find_compiler("clang++", "clang", 15)
if cxx:
return cxx, cc
cxx, cc = _try_find_compiler("g++", "gcc", 13)
if cxx:
return cxx, cc
error_msg = (
"Could not find a suitable compiler.\n"
"Requirements:\n"
" - clang++ >= 15\n"
" - OR g++ >= 13"
)
logging.error(error_msg)
raise RuntimeError(error_msg)
def _derive_cc_from_cxx(cxx_path: str) -> Optional[str]:
"""
Guess the corresponding CC based on the path name of CXX.
"""
if not cxx_path:
return None
logging.debug("Attempting to derive CC from CXX: %s", cxx_path)
name = Path(cxx_path).name
if "clang" in name:
logging.info("Derived CC as clang")
return shutil.which("clang")
if "g++" in name:
logging.info("Derived CC as gcc")
return shutil.which("gcc")
return None
def detect_compilers(cxx_arg: Optional[str], cc_arg: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
cxx = cxx_arg or os.environ.get("CXX")
cc = cc_arg or os.environ.get("CC")
if cxx:
logging.info("Using explicit CXX: %s", cxx)
if not cxx:
cxx, auto_cc = _auto_detect_compilers()
if not cc:
cc = auto_cc
elif not Path(cxx).is_absolute():
resolved_cxx = shutil.which(cxx)
if resolved_cxx:
logging.debug("Resolved relative path '%s' to '%s'", cxx, resolved_cxx)
cxx = resolved_cxx
if not cc:
cc = _derive_cc_from_cxx(cxx)
elif not Path(cc).is_absolute():
resolved_cc = shutil.which(cc) or cc
if resolved_cc != cc:
logging.debug("Resolved relative path '%s' to '%s'", cc, resolved_cc)
cc = resolved_cc
if cxx:
cxx = cmake_friendly_path(cxx)
if cc:
cc = cmake_friendly_path(cc)
logging.info("Final Compiler Selection -> CXX: %s, CC: %s", cxx, cc)
return cxx, cc
def cmake_build(build_dir: Path, build_type: str) -> None:
cmd: List[str] = ["cmake", "--build", str(build_dir), "--parallel"]
cmd.extend(["--config", build_type])
run_command(cmd)
def generate_golden(build_dir: Path, gen_script: Path) -> None:
dst = build_dir / "gen_data.py"
shutil.copyfile(gen_script, dst)
run_command([sys.executable, str(dst.name)], cwd=build_dir)
def read_cmake_cache_var(build_dir: Path, var_name: str) -> Optional[str]:
cache = build_dir / "CMakeCache.txt"
if not cache.exists():
return None
try:
for line in cache.read_text(encoding="utf-8", errors="replace").splitlines():
if not line or line.startswith(("//", "#")):
continue
if line.startswith(f"{var_name}:"):
_, _, value = line.partition("=")
return value
except OSError:
return None
return None
def find_binaries(build_dir: Path, build_type: str) -> Dict[str, Path]:
bin_dir = build_dir / "bin"
if os.name == "nt":
config_dir = bin_dir / build_type
if config_dir.exists():
bin_dir = config_dir
if not bin_dir.exists():
return {}
binaries: Dict[str, Path] = {}
for p in bin_dir.iterdir():
if not p.is_file():
continue
if os.name == "nt":
if p.suffix.lower() != ".exe":
continue
binaries[p.stem] = p
else:
binaries[p.name] = p
return binaries
def run_gtest_binary(binary: Path, gtest_filter: Optional[str], build_type: str,
verbose: bool) -> None:
cmd: List[str] = [str(binary)]
if gtest_filter:
cmd.append(f"--gtest_filter={gtest_filter}")
run_cwd = binary.parent
if os.name == "nt" and binary.parent.name.lower() == build_type.lower():
run_cwd = binary.parent.parent
run_command(cmd, cwd=run_cwd, verbose=verbose)
def run_binary(binary: Path, build_type: str, cwd: Optional[Path] = None) -> None:
run_cwd = cwd or binary.parent
if os.name == "nt" and binary.parent.name.lower() == build_type.lower():
run_cwd = binary.parent.parent
run_command([str(binary)], cwd=run_cwd)
def build_and_run_demo(demo_name: str,
repo_root: Path,
build_type: str,
cxx: Optional[str],
cc: Optional[str], *,
verbose: bool) -> None:
demos_root = repo_root / ".." / "demos" / "costmodel"
demo_map: dict[str, tuple[Path, str]] = {
"gemm": (demos_root / "gemm_demo", "gemm_demo"),
"flash_attn": (demos_root / "flash_attention_demo", "flash_attention_demo"),
"mla": (demos_root / "mla_attention_demo", "mla_attention_demo"),
}
if demo_name not in demo_map:
raise RuntimeError(f"unknown demo: {demo_name}")
demo_src, exe_stem = demo_map[demo_name]
legacy_demo_src = repo_root / "demo"
if demo_name == "gemm" and not demo_src.exists() and legacy_demo_src.exists():
demo_src = legacy_demo_src
exe_stem = "gemm_demo"
if not demo_src.exists():
raise RuntimeError(f"demo dir not found: {demo_src}")
demo_build = demo_src / "build"
if demo_build.exists():
shutil.rmtree(demo_build)
demo_build.mkdir(parents=True, exist_ok=True)
run_command(
[
"cmake",
"-S",
str(demo_src),
"-B",
str(demo_build),
f"-DCMAKE_BUILD_TYPE={build_type}",
*([f"-DCMAKE_C_COMPILER={cc}"] if cc else []),
*([f"-DCMAKE_CXX_COMPILER={cxx}"] if cxx else []),
],
title="[STEP] demo: cmake configure",
verbose=verbose,
)
run_command(
["cmake", "--build", str(demo_build), "--parallel", "--config", build_type],
title="[STEP] demo: cmake build",
verbose=verbose,
)
exe_name = f"{exe_stem}.exe" if os.name == "nt" else exe_stem
exe = demo_build / exe_name
if os.name == "nt":
exe = demo_build / build_type / exe_name
if not exe.exists():
raise RuntimeError(f"demo binary not found: {exe}")
run_command([str(exe)],
cwd=(exe.parent.parent if (os.name == "nt" and exe.parent.name.lower() == build_type.lower())
else exe.parent),
title=f"[STEP] demo: run {exe_stem}",
verbose=verbose,
always_print_patterns=[r"^perf:"])
def _format_seconds(seconds: float) -> str:
if seconds < 1:
return f"{seconds*1000:.0f}ms"
return f"{seconds:.2f}s"
def _render_table(headers: List[str], rows: List[List[str]]) -> str:
widths = [len(h) for h in headers]
for row in rows:
for i, cell in enumerate(row):
widths[i] = max(widths[i], len(cell))
def fmt_row(cols: List[str]) -> str:
return "| " + " | ".join(c.ljust(widths[i]) for i, c in enumerate(cols)) + " |"
sep = "+-" + "-+-".join("-" * w for w in widths) + "-+"
out = [sep, fmt_row(headers), sep]
out.extend(fmt_row(r) for r in rows)
out.append(sep)
return "\n".join(out)
def _parse_duration_seconds(s: str) -> float:
if s.endswith("ms"):
return float(s[:-2]) / 1000.0
if s.endswith("s"):
return float(s[:-1])
return 0.0
def parse_arguments():
parser = argparse.ArgumentParser(
description="Build & run costmodel simulator ST unit tests (tests/costmodel/st, st_fit, and st_a5_fit)",
epilog=("Examples:\n python run_costmodel.py --build-type Release\n"
" python run_costmodel.py --testcase tadd --build-type Release\n"
" python run_costmodel.py --testcase tadds_fit --suite st_fit --build-type Release\n"
" python run_costmodel.py --testcase tadd_fit --suite st_a5_fit --build-type Release\n"
" python run_costmodel.py --no-build --gtest_filter TADDTest.*\n"
" python run_costmodel.py --demo gemm\n"
" python run_costmodel.py --demo flash_attn\n"
" python run_costmodel.py --demo mla\n"
" python run_costmodel.py --demo all\n"
),
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--verbose", action="store_true", help="Show full output from cmake/msbuild/gtest (default: \
quiet, only structured logs).",)
parser.add_argument("-t", "--testcase", help="Run a single testcase (e.g. tadd). Default: run all built bin.",)
parser.add_argument(
"--suite",
choices=["st", "st_fit", "st_a5_fit"],
default="st",
help="Test suite to run.",
)
parser.add_argument("-g", "--gtest_filter", help="Optional gtest filter (e.g. 'TADDTest.case1').",)
parser.add_argument("--cxx", help="C++ compiler (e.g. clang++). Default: $CXX or auto-detect.")
parser.add_argument("--cc", help="C compiler (e.g. clang). Default: $CC or auto-detect.")
parser.add_argument("--build-type", default="Release", choices=["Release", "Debug", "RelWithDebInfo", "MinSizeRel"],
help="CMake build type.",)
parser.add_argument("--build-dir", default=None, help="Build directory. Default: tests/costmodel/<suite>/build",)
parser.add_argument("--no-clean", action="store_true", help="(Deprecated) No-op; kept for backward compatibility.")
parser.add_argument("--clean", action="store_true", help="Delete build dir and rebuild.")
parser.add_argument("--rebuild", action="store_true", help="Force re-configure and rebuild .")
parser.add_argument("--no-build", action="store_true", help="Skip cmake configure/build, only run existing bin")
parser.add_argument("--no-gen", action="store_true", help="Skip running testcase gen_data.py.")
parser.add_argument("--xml-dir", default=None, help="If set, write gtest xml reports under this directory")
parser.add_argument("--no-install", action="store_true", help="Do not auto-install missing tools/deps (numpy).")
parser.add_argument("--demo", choices=["gemm", "flash_attn", "mla", "all"], default=None, help="Build & run demo \
program (e.g. 'gemm', 'flash_attn'). \
Note: demo runs alone (does not run costmodel ST).")
parser.add_argument("--demo-only", action="store_true", help="Same as --demo (demo runs without costmodel ST).")
parser.add_argument("--generator", default=None, help="CMake generator(Windows required: 'MinGW Makefiles' etc..)")
parser.add_argument("--cmake_prefix_path", default=None, help="-DCMAKE_PREFIX_PATH=<path> e.g. D:\\gtest")
args = parser.parse_args()
return args
def setup_environment(args) -> None:
if not args.no_install:
add_user_scripts_to_path()
ensure_cmake_tools()
def log_build_info(args, cxx, cc) -> None:
logging.info(f"[INFO] build_type={args.build_type}")
if cxx:
logging.info(f"[INFO] cxx={cxx}")
if cc:
logging.info(f"[INFO] cc={cc}")
def run_demo_mode(args, repo_root, cxx, cc) -> int:
if args.demo_only and not args.demo:
logging.error("error: --demo-only requires --demo")
return 2
demo_name = args.demo or "gemm"
if not args.demo:
pass
logging.info("\n== DEMO ==")
demos = ["gemm", "flash_attn", "mla"] if demo_name == "all" else [demo_name]
t0 = time.perf_counter()
for name in demos:
build_and_run_demo(
demo_name=name, repo_root=repo_root, build_type=args.build_type, cxx=cxx, cc=cc, verbose=args.verbose
)
demo_time = time.perf_counter() - t0
logging.info(f"[PASS] demo: {demo_name} ({_format_seconds(demo_time)})")
return 0
def maybe_generate_formula_params(source_dir: Path, repo_root: Path, verbose: bool) -> None:
if source_dir.name not in {"st", "st_fit", "st_a5_fit", "perf_sim_st"}:
return
for costmodel_dir in ("a2a3", "a5"):
gen_script = (
repo_root.parent
/ "include"
/ "pto"
/ "costmodel"
/ costmodel_dir
/ "formula_costmodel"
/ "gen_formula_params_header.py"
)
if not gen_script.exists():
raise RuntimeError(f"formula params generator not found: {gen_script}")
run_command(
[sys.executable, str(gen_script)],
cwd=repo_root.parent,
title=f"[STEP] {source_dir.name}: generate {costmodel_dir} formula_params_generated.hpp",
verbose=verbose,
)
def run_test_mode(args, repo_root, cxx, cc) -> int:
source_dir = repo_root / "costmodel" / args.suite
if not source_dir.exists():
logging.error(f"error: not found costmodel dir: {source_dir}")
return 2
maybe_generate_formula_params(source_dir, repo_root, args.verbose)
build_dir = Path(args.build_dir) if args.build_dir else (source_dir / "build")
if not build_dir.is_absolute():
build_dir = (repo_root / build_dir).resolve()
if args.clean:
if build_dir.exists():
shutil.rmtree(build_dir)
need_build = determine_need_build(args, source_dir, build_dir)
if need_build:
if not perform_build(args, source_dir, build_dir, cxx, cc):
return 2
else:
logging.info("\n== BUILD ==")
logging.info("[SKIP] build (already built)")
return execute_tests(args, source_dir, build_dir)
def parse_expected_testcases(source_dir: Path) -> Optional[set[str]]:
cmake_list = source_dir / "testcase" / "CMakeLists.txt"
if not cmake_list.exists():
return None
text = cmake_list.read_text(encoding="utf-8", errors="replace")
m = re.search(r"set\(ALL_TESTCASES\s*(.*?)\)", text, flags=re.DOTALL)
if not m:
return None
body = m.group(1)
cases: list[str] = []
for raw_line in body.splitlines():
line = raw_line.strip()
if not line or line.startswith("#"):
continue
line = line.split("#", 1)[0].strip()
if not line:
continue
cases.extend(line.split())
return set(cases)
def determine_need_build(args, source_dir: Path, build_dir: Path) -> bool:
binaries_before = find_binaries(build_dir, args.build_type) if build_dir.exists() else {}
configured_testcase = read_cmake_cache_var(build_dir, "TEST_CASE") if build_dir.exists() else None
config_mismatch = False
if args.testcase:
if configured_testcase != args.testcase:
config_mismatch = True
else:
if configured_testcase:
config_mismatch = True
have_requested_binary = True
if args.testcase:
have_requested_binary = args.testcase in binaries_before
else:
have_requested_binary = bool(binaries_before)
expected = parse_expected_testcases(source_dir)
if expected:
missing = expected.difference(binaries_before.keys())
if missing:
have_requested_binary = False
need_build = (
(not args.no_build)
and (
config_mismatch
or args.rebuild
or args.clean
or not (build_dir / "CMakeCache.txt").exists()
or not have_requested_binary
)
)
return need_build
def perform_build(args, source_dir, build_dir, cxx, cc) -> bool:
build_dir.mkdir(parents=True, exist_ok=True)
logging.info("\n== BUILD ==")
if is_windows() and not args.generator:
logging.error("On Windows, must specify --generator (\"MinGW Makefiles\" or \"Ninja\", etc..)")
return False
cfg_time = run_command(
[
"cmake",
*([] if args.testcase else ["-UTEST_CASE"]),
"-S",
str(source_dir),
"-B",
str(build_dir),
f"-DCMAKE_BUILD_TYPE={args.build_type}",
*([f"-DTEST_CASE={args.testcase}"] if args.testcase else []),
*([f"-DCMAKE_C_COMPILER={cc}"] if cc else []),
*([f"-DCMAKE_CXX_COMPILER={cxx}"] if cxx else []),
*(["-G", args.generator] if args.generator else []),
*([f"-DCMAKE_PREFIX_PATH={args.cmake_prefix_path}"] if args.cmake_prefix_path else []),
],
title="[STEP] cmake configure",
verbose=args.verbose,
)
build_time = run_command(
["cmake", "--build", str(build_dir), "--parallel", "--config", args.build_type],
title="[STEP] cmake build",
verbose=args.verbose,
)
logging.info(f"[PASS] build ({_format_seconds(cfg_time + build_time)})")
return True
def execute_tests(args, source_dir, build_dir) -> int:
binaries = find_binaries(build_dir, args.build_type)
if not binaries:
logging.error(f"error: no binaries found under {build_dir / 'bin'} (did build succeed?)")
return 2
selected: list[tuple[str, Path]]
if args.testcase:
if args.testcase not in binaries:
known = ", ".join(sorted(binaries.keys()))
logging.error(f"error: unknown testcase '{args.testcase}'. Built binaries: {known}")
return 2
selected = [(args.testcase, binaries[args.testcase])]
else:
selected = sorted(binaries.items(), key=lambda x: x[0])
results = run_selected_tests(args, source_dir, build_dir, selected)
if results:
print_test_summary(results)
return 0
def run_selected_tests(args, source_dir, build_dir, selected) -> List[List[str]]:
logging.info("\n== TESTS ==")
results: List[List[str]] = []
for testcase, binary in selected:
t0 = time.perf_counter()
try:
run_gtest_binary(
binary=binary,
gtest_filter=args.gtest_filter,
build_type=args.build_type,
verbose=args.verbose
)
status = "PASS"
except Exception:
status = "FAIL"
raise
finally:
elapsed = time.perf_counter() - t0
results.append([testcase, status, _format_seconds(elapsed)])
logging.info(f"[{status}] {testcase} ({_format_seconds(elapsed)})")
return results
def print_test_summary(results) -> None:
logging.info("\n== SUMMARY ==")
total_time_s = sum(_parse_duration_seconds(r[2]) for r in results)
results.append(["TOTAL", "", _format_seconds(total_time_s)])
logging.info(_render_table(["Target", "Status", "Time"], results))
def main() -> int:
args = parse_arguments()
setup_environment(args)
repo_root = Path(__file__).resolve().parent
cxx, cc = detect_compilers(args.cxx, args.cc)
log_build_info(args, cxx, cc)
if args.demo or args.demo_only:
return run_demo_mode(args, repo_root, cxx, cc)
return run_test_mode(args, repo_root, cxx, cc)
if __name__ == "__main__":
logging.basicConfig(format='%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s', level=logging.INFO)
raise SystemExit(main())