import dataclasses
import os
import contextlib
import hashlib
import json
import logging
import subprocess
import sys
import sysconfig
from time import time, time_ns
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
List,
NoReturn,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
import torch
from torch._inductor import config
from torch._inductor.exc import CppCompileError
from torch._inductor.codecache import (
CacheBase,
get_lock_dir,
write,
LOCK_TIMEOUT,
DLLWrapper,
)
from torch._inductor.graph import GraphLowering
from torch._inductor.utils import (
clear_on_fresh_inductor_cache,
is_linux,
is_windows,
)
import torch_npu
from torch_npu.utils._error_code import ErrCode, pta_error
from .cpp_builder import library_paths
from . import config as npu_config
from .codegen.catlass.catlass_utils import get_npu_arch, _normalize_npu_arch
empty_json = "{}"
log = logging.getLogger("torch._inductor")
@contextlib.contextmanager
def lock_context(key):
from filelock import FileLock
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
yield
def _catlass_include_paths() -> List[str]:
from .cpp_builder import get_ascend_home
ASCEND_HOME = get_ascend_home()
catlass_path = npu_config.catlass.catlass_dir
return [
os.path.realpath(os.path.join(ASCEND_HOME, "compiler/tikcpp")),
os.path.realpath(os.path.join(ASCEND_HOME, "compiler/tikcpp/tikcfw")),
os.path.realpath(os.path.join(ASCEND_HOME, "compiler/tikcpp/tikcfw/impl")),
os.path.realpath(os.path.join(ASCEND_HOME, "compiler/tikcpp/tikcfw/interface")),
os.path.realpath(os.path.join(ASCEND_HOME, "include")),
os.path.realpath(os.path.join(ASCEND_HOME, "include/experiment/runtime")),
os.path.realpath(os.path.join(ASCEND_HOME, "include/experiment/msprof")),
os.path.realpath(os.path.join(ASCEND_HOME, "pkg_inc")),
os.path.realpath(os.path.join(catlass_path, "include")),
]
def _ascend_lib_options() -> List[str]:
lpaths = library_paths(npu=True) + [sysconfig.get_config_var("LIBDIR")]
extra_ldflags: List[str] = []
if is_linux():
for path in lpaths:
extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"])
extra_ldflags.append("-lruntime")
extra_ldflags.append("-lstdc++")
extra_ldflags.append("-lascendcl")
extra_ldflags.append("-lm")
extra_ldflags.append("-ltiling_api")
extra_ldflags.append("-lplatform")
extra_ldflags.append("-lc_sec")
extra_ldflags.append("-ldl")
extra_ldflags.append("-lnnopbase")
else:
raise NotImplementedError(
"Unsupported env, failed to find ascend libs! Currently only Linux is supported."
)
return extra_ldflags
def _bisheng_host_compiler_options() -> List[str]:
return [
"-fPIC",
"-fno-strict-aliasing",
"-fvisibility=hidden",
"-Wconversion",
]
def _bisheng_compiler_options(is_mix: bool = False) -> List[str]:
npu_arch = _normalize_npu_arch(get_npu_arch())
if npu_arch == "910B":
arch = "dav-c220"
elif npu_arch == "910D":
arch = "dav-c310"
else:
raise ValueError(f"Unrecognized NPU arch: {npu_arch}")
if not is_mix:
arch += '-cube'
options = [
f"--cce-aicore-arch={arch}",
"-O2",
"-std=c++17",
"-xcce",
"-DL2_CACHE_HINT",
]
if npu_arch == "910D":
options.append("-DCATLASS_ARCH_A5_ENABLED")
elif npu_arch == "910B":
options.append("-DCATLASS_ARCH_A2_ENABLED")
if npu_config.catlass.enable_debug_info:
options.extend(["--lineinfo", "-g"])
return options
def _bisheng_compiler() -> Optional[str]:
if os.path.exists(os.getenv("ASCEND_HOME_PATH")):
return os.path.realpath(
os.path.join(
os.getenv("ASCEND_HOME_PATH", ""), "tools/ccec_compiler/bin/bisheng"
)
)
return "bisheng"
def catlass_compile_command(
src_files: List[str],
dst_file: str,
dst_file_ext: str,
extra_args: Optional[List[str]] = None,
is_mix: bool = False,
) -> str:
if extra_args is None:
extra_args = []
include_paths = _catlass_include_paths()
ascend_lib_options = _ascend_lib_options()
bisheng_host_compiler_options = _bisheng_host_compiler_options()
bisheng_compiler_options = _bisheng_compiler_options(is_mix)
options = (
bisheng_compiler_options
+ extra_args
+ [
f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
for opt in bisheng_host_compiler_options
]
+ ["-I" + path for path in include_paths]
+ ascend_lib_options
)
src_file = " ".join(src_files)
res = ""
if dst_file_ext == "o":
res = f"{_bisheng_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
elif dst_file_ext == "so":
options.append("-shared")
res = f"{_bisheng_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
elif dst_file_ext == "exe":
res = f"{_bisheng_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
else:
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
log.debug("Bisheng command: %s", res)
return res
class NPUCompileError(CppCompileError):
pass
@clear_on_fresh_inductor_cache
class CATLASSCodeCache:
@dataclasses.dataclass
class CacheEntry:
input_path: str
output_path: str
cache: Dict[str, CacheEntry] = {}
cache_clear = staticmethod(cache.clear)
_SOURCE_CODE_SUFFIX = "cpp"
@classmethod
def write(cls, source_code: str, dst_file_ext: str, is_mix: bool) -> Tuple[str, str]:
"""
Writes source code into a file with dst_file_ext as the file extension.
Returns the hash key of source code, and the path to the file.
"""
catlass_command = repr(
catlass_compile_command(["dummy_input"], "dummy_output", dst_file_ext, is_mix=is_mix)
)
key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=catlass_command)
return key, input_path
@classmethod
def compile(
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None, is_mix: bool = False
) -> Tuple[str, str, str]:
"""
Compiles CATLASS source_code into a file with dst_file_ext extension.
Returns a tuple of dst_file_path, hash_key, source_code_path
"""
key, input_path = cls.write(source_code, dst_file_ext, is_mix)
if key not in cls.cache:
from filelock import FileLock
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
if not os.path.exists(output_path):
cmd = catlass_compile_command(
[input_path], output_path, dst_file_ext, extra_args, is_mix
)
start_time = time()
log.debug("CATLASS Compilation: %s", cmd)
cmd_parts = cmd.split(" ")
try:
subprocess.check_output(
cmd_parts, stderr=subprocess.STDOUT, env=os.environ
)
except subprocess.CalledProcessError as error:
raise NPUCompileError(cmd_parts, error.output) from error
end_time = time()
log_duration_msg = f"CATLASS Compilation took {end_time - start_time} seconds. Compile command: {cmd}"
log.info(log_duration_msg)
else:
log.debug(
"CATLASS Compilation skipped: %s since output already exists",
input_path,
)
cls.cache[key] = CATLASSCodeCache.CacheEntry(input_path, output_path)
return (cls.cache[key].output_path, key, input_path)
@classmethod
def load(cls, source_code: str, dst_file_ext: str, is_mix: bool = False) -> Tuple[DLLWrapper, str, str]:
"""
Compiles source code and loads the generated .so file.
Returns a tuple of DLLWrapper, hash_key, source_code_path
"""
if dst_file_ext != "so":
raise RuntimeError(
f"Only support loading a .so file for now. "
f"Requested file extension: {dst_file_ext}. Source code: {source_code}"
)
dst_file_path, hash_key, source_code_path = cls.compile(
source_code, dst_file_ext, is_mix=is_mix
)
return (DLLWrapper(dst_file_path), hash_key, source_code_path)