import ctypes
import hashlib
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
from asc.runtime.cache import get_cache_manager
from ..utils import get_ascend_path
PRINT_INTERFACE_NAME = "print_interface"
def build_print_utils(obj_name: str, src_file: str, src_dir: str) -> str:
so_path = os.path.join(src_dir, f"{obj_name}.so")
if not os.path.exists(src_dir):
os.makedirs(src_dir, 0o750, exist_ok=True)
cxx = os.getenv("CC")
if cxx is None:
cpp = shutil.which("c++")
gxx = shutil.which("g++")
cxx = cpp if cpp is not None else gxx
if cxx is None:
raise RuntimeError("Failed to find C++ compiler")
cc_cmd = [cxx, src_file]
cc_cmd += ["-w"]
cc_cmd += [f"-L{os.path.join(get_ascend_path(), 'lib64')}", "-lascend_dump", "-lc_sec"]
cc_cmd += ["-shared", "-fPIC", "-o", so_path]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so_path
else:
raise RuntimeError("Failed to compile " + src_file)
class PrintInterface(object):
def __init__(self):
dir_name = os.path.dirname(os.path.realpath(__file__))
src_path = os.path.join(dir_name, "print_utils.cpp")
src = Path(src_path).read_text()
version_cfg_info = ""
version_cfg = get_ascend_path() / "version.cfg"
if version_cfg.exists():
version_cfg_info += version_cfg.read_text()
key = hashlib.sha256((src + version_cfg_info).encode("utf-8")).hexdigest()
cache_manager = get_cache_manager(key)
rt_lib = cache_manager.get_file(f"{PRINT_INTERFACE_NAME}.so")
if rt_lib is None:
with tempfile.TemporaryDirectory() as tmpdir:
so_file = build_print_utils(PRINT_INTERFACE_NAME, src_path, tmpdir)
with open(so_file, "rb") as f:
rt_lib = cache_manager.put(f.read(), f"{PRINT_INTERFACE_NAME}.so", binary=True)
self.lib: ctypes.CDLL = ctypes.CDLL(rt_lib, ctypes.RTLD_LOCAL)
def call(self, *args):
fn_name = "PrintWorkSpace"
fn = getattr(self.lib, fn_name)
fn(*args)
print_interface: Optional[PrintInterface] = None
def load_print_interface():
global print_interface
print_interface = PrintInterface()
def call_print_interface(*args):
if print_interface is None:
load_print_interface()
print_interface.call(*args)