import os
import sys
import logging
import runpy
import subprocess
import shutil
from setuptools import setup, find_packages
from setuptools.command.build_py import build_py as _build_py
from setuptools.dist import Distribution
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
os.environ["SOURCE_DATE_EPOCH"] = "315532800"
VERSION_FILE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "version.py")
WHEEL_MODE_ENV = "MINDIESD_WHEEL_MODE"
MULTI_TORCH_PLUGIN_DIR_ENV = "MINDIESD_MULTI_TORCH_PLUGIN_DIR"
SKIP_OPS_BUILD_ENV = "MINDIESD_SKIP_OPS_BUILD"
FIXED_WHEEL_MODE = "fixed"
MULTI_TORCH_WHEEL_MODE = "multi_torch"
SUPPORTED_TORCH_PLUGIN_VARIANTS = ("torch26", "torch27", "torch28", "torch29", "torch210")
def get_mindiesd_version():
version_ns = runpy.run_path(VERSION_FILE)
version = version_ns.get("__version__")
if not version:
raise RuntimeError("Failed to get version from %s" % VERSION_FILE)
logging.info("Build version is: %s", version)
return version
def get_python_version():
"""获取 Python 版本字符串,如 py310"""
try:
major = sys.version_info.major
minor = sys.version_info.minor
if major is None or minor is None:
raise RuntimeError("Cannot get Python version: version info is None")
python_version = f"py{major}{minor}"
logging.info("Python version is: %s", python_version)
return python_version
except Exception as e:
logging.error("Failed to get Python version: %s", e)
raise RuntimeError("Cannot get Python version. Please ensure Python is properly installed.") from e
def get_wheel_mode():
mode = os.environ.get(WHEEL_MODE_ENV, FIXED_WHEEL_MODE).strip().lower()
if mode not in (FIXED_WHEEL_MODE, MULTI_TORCH_WHEEL_MODE):
raise RuntimeError(
f"Unsupported {WHEEL_MODE_ENV}={mode}. Expected one of: {FIXED_WHEEL_MODE}, {MULTI_TORCH_WHEEL_MODE}."
)
logging.info("Wheel build mode is: %s", mode)
return mode
def is_env_enabled(env_name):
return os.environ.get(env_name, "").strip().lower() in ("1", "true", "yes", "on")
def copy_so_files(src_dir, dest_dir):
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
so_files = [f for f in os.listdir(src_dir) if f.endswith('.so')]
if not so_files:
logging.warning("No .so files found in %s", src_dir)
return
for so_file in so_files:
src_file = os.path.join(src_dir, so_file)
dest_file = os.path.join(dest_dir, so_file)
shutil.copy2(src_file, dest_file)
logging.info("Copied %s to %s", src_file, dest_file)
def copy_multi_torch_plugin_files(proj_root):
src_root = os.environ.get(
MULTI_TORCH_PLUGIN_DIR_ENV,
os.path.join(proj_root, "build", "torch_plugin_variants"),
)
dest_root = os.path.join(proj_root, "mindiesd", "plugin")
logging.info("Using multi torch plugin source directory: %s", src_root)
missing_variants = []
for variant in SUPPORTED_TORCH_PLUGIN_VARIANTS:
variant_src_dir = os.path.join(src_root, variant)
variant_dest_dir = os.path.join(dest_root, variant)
so_file = os.path.join(variant_src_dir, "libPTAExtensionOPS.so")
if not os.path.isfile(so_file):
missing_variants.append(variant)
continue
copy_so_files(variant_src_dir, variant_dest_dir)
if missing_variants:
raise RuntimeError(
"Missing multi torch plugin .so files for variants: %s. "
"Expected files under %s/<variant>/libPTAExtensionOPS.so." % (", ".join(missing_variants), src_root)
)
def ensure_plugin_init():
plugin_dir = os.path.join(os.getcwd(), 'mindiesd/plugin')
init_file = os.path.join(plugin_dir, '__init__.py')
os.makedirs(plugin_dir, exist_ok=True)
with open(init_file, "a", encoding="utf-8"):
pass
def run_script(script_path, args=None, cwd=None):
"""执行 shell 脚本"""
cmd = ['bash', script_path]
if args:
cmd.extend(args)
logging.info(">>> Running script: %s", ' '.join(cmd))
try:
subprocess.check_call(cmd, cwd=cwd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
logging.error("Script failed with return code %s", e.returncode)
raise RuntimeError("Script execution failed: %s" % script_path) from e
def merge_compile_commands(proj_root, build_dir):
"""Merge all compile_commands.json from different build stages into one."""
import json
sources = [
("AscendC ops", os.path.join(build_dir, "compile_commands_ascendc.json")),
("PyTorch plugin", os.path.join(build_dir, "plugin_build", "compile_commands.json")),
("TIK ops", os.path.join(build_dir, "compile_commands_tik.json")),
]
merged = []
seen = set()
for stage_name, path in sources:
if not os.path.isfile(path):
logging.info("compile_commands.json not found for %s: %s", stage_name, path)
continue
try:
with open(path, 'r', encoding="utf-8") as f:
entries = json.load(f)
except json.JSONDecodeError as e:
logging.warning("Failed to parse %s: %s", path, e)
continue
if not isinstance(entries, list):
logging.warning("Unexpected format in %s, expected list", path)
continue
added = 0
for entry in entries:
key = (
entry.get("directory", ""),
entry.get("file", ""),
entry.get("command", ""),
)
if key not in seen:
seen.add(key)
merged.append(entry)
added += 1
logging.info("Merged %s entries from %s (%s total)", added, stage_name, len(entries))
if merged:
output_path = os.path.join(proj_root, "compile_commands.json")
with open(output_path, 'w', encoding="utf-8") as f:
json.dump(merged, f, indent=2)
logging.info("Merged compile_commands.json written to %s (%s total entries)", output_path, len(merged))
else:
logging.info("No compile_commands.json entries found to merge")
class CustomBuildPy(_build_py):
def run(self):
proj_root = os.path.abspath(os.getcwd())
build_dir = os.path.join(proj_root, 'build')
wheel_mode = get_wheel_mode()
logging.info("%s", "=" * 60)
logging.info("Starting MindIE-SD Build Process")
logging.info("Project root: %s", proj_root)
logging.info("Build directory: %s", build_dir)
logging.info("%s", "=" * 60)
get_python_version()
for script in os.listdir(build_dir):
script_path = os.path.join(build_dir, script)
if os.path.isfile(script_path):
os.chmod(script_path, 0o444)
try:
ops_dir = os.path.join(proj_root, 'csrc', 'ops')
if is_env_enabled(SKIP_OPS_BUILD_ENV):
logging.info("Skipping Ascend operators build because %s is enabled.", SKIP_OPS_BUILD_ENV)
elif os.path.isdir(ops_dir):
logging.info("%s", "=" * 60)
logging.info("Building Ascend operators...")
logging.info("%s", "=" * 60)
build_ops_script = os.path.join(build_dir, 'build_ops.sh')
run_script(build_ops_script, args=[build_dir], cwd=build_dir)
else:
logging.warning("The path of custom op operators %s does not exist.", ops_dir)
if wheel_mode == FIXED_WHEEL_MODE:
plugin_dir = os.path.join(proj_root, 'csrc', 'plugin')
if os.path.isdir(plugin_dir):
logging.info("%s", "=" * 60)
logging.info("Building PyTorch plugins...")
logging.info("%s", "=" * 60)
build_plugin_script = os.path.join(build_dir, 'build_plugin.sh')
run_script(build_plugin_script, args=[build_dir], cwd=build_dir)
else:
logging.warning("The path of op plugins %s does not exist.", plugin_dir)
else:
logging.info("%s", "=" * 60)
logging.info("Packaging prebuilt PyTorch plugin variants...")
logging.info("%s", "=" * 60)
copy_multi_torch_plugin_files(proj_root)
merge_compile_commands(proj_root, build_dir)
if wheel_mode == FIXED_WHEEL_MODE:
source_dir = os.path.join(build_dir, 'plugin_build')
destination_dir = os.path.join(proj_root, 'mindiesd', 'plugin')
copy_so_files(source_dir, destination_dir)
logging.info("%s", "=" * 60)
logging.info("Build completed successfully!")
logging.info("%s", "=" * 60)
except Exception as e:
logging.error("Build failed: %s", e)
raise
super().run()
class BDistWheel(_bdist_wheel):
def finalize_options(self):
super().finalize_options()
self.root_is_pure = False
class BinaryDistribution(Distribution):
def has_ext_modules(self):
return True
if __name__ == "__main__":
requirements = ["torch", "torch_npu"]
mindie_sd_version = get_mindiesd_version()
build_wheel_mode = get_wheel_mode()
ensure_plugin_init()
package_data = {"mindiesd": ["ops/**/*"]}
if build_wheel_mode == MULTI_TORCH_WHEEL_MODE:
package_data["mindiesd"].append("plugin/**/*.so")
else:
package_data["mindiesd"].append("plugin/*.so")
setup(
name="mindiesd",
version=mindie_sd_version,
author="ascend",
description="build wheel for mindie sd",
setup_requires=[],
install_requires=requirements,
zip_safe=False,
python_requires=">=3.10",
include_package_data=True,
packages=find_packages(),
package_data=package_data,
cmdclass={"build_py": CustomBuildPy, "bdist_wheel": BDistWheel},
distclass=BinaryDistribution,
)