import glob
import os
import re
import shlex
import shutil
import site
import subprocess
import sys
import torch
from setuptools import Command, find_packages, setup
from torch.utils import cpp_extension
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
YES_VALUES = ("1", "true", "yes", "on")
MIN_TORCH_VERSION = (2, 5, 0)
def _version_text(version):
return ".".join(str(part) for part in version)
def _parse_numeric_version(value):
if not value:
return None
match = re.match(r"^\s*(\d+(?:\.\d+)*)", value)
if not match:
return None
return tuple(int(part) for part in match.group(1).split("."))
def _version_at_least(value, minimum):
parsed = _parse_numeric_version(value)
if not parsed:
return False
size = max(len(parsed), len(minimum))
padded = parsed + (0,) * (size - len(parsed))
padded_minimum = minimum + (0,) * (size - len(minimum))
return padded >= padded_minimum
def _version_less_than(value, maximum):
parsed = _parse_numeric_version(value)
if not parsed:
return False
size = max(len(parsed), len(maximum))
padded = parsed + (0,) * (size - len(parsed))
padded_maximum = maximum + (0,) * (size - len(maximum))
return padded < padded_maximum
def _env_enabled(name):
return os.environ.get(name, "").strip().lower() in YES_VALUES
def _ensure_torch_version():
version = getattr(torch, "__version__", None)
if _version_at_least(version, MIN_TORCH_VERSION):
return
raise RuntimeError(
"PyTorch is too old for bundled torchcsprng.\n"
f"Required: torch >= {_version_text(MIN_TORCH_VERSION)}.\n"
f"Current torch: {version or '<unknown>'}."
)
def _nvcc_release(cuda_home):
if not cuda_home:
return None
nvcc = os.path.join(cuda_home, "bin", "nvcc")
if not os.path.isfile(nvcc):
return None
try:
out = subprocess.check_output(
[nvcc, "--version"],
stderr=subprocess.STDOUT,
).decode("utf-8", "ignore")
except (OSError, subprocess.CalledProcessError):
return None
match = re.search(r"release\s+(\d+\.\d+)", out)
return match.group(1) if match else None
def _dedupe_existing_paths(paths):
result = []
seen = set()
for path in paths:
if not path:
continue
path = os.path.abspath(path)
if path in seen:
continue
if os.path.isdir(path):
seen.add(path)
result.append(path)
return result
def _cuda_home_candidates():
"""Return possible CUDA roots.
Supports:
- system CUDA: /usr/local/cuda, /usr/local/cuda-*
- conda CUDA: $CONDA_PREFIX
- nvcc discovered on PATH
- manually configured CUDA_HOME / CUDA_PATH
"""
candidates = []
for env_name in ("CUDA_HOME", "CUDA_PATH", "CONDA_PREFIX"):
value = os.environ.get(env_name)
if value:
candidates.append(value)
if cpp_extension.CUDA_HOME:
candidates.append(cpp_extension.CUDA_HOME)
path_nvcc = shutil.which("nvcc")
if path_nvcc:
candidates.append(os.path.dirname(os.path.dirname(path_nvcc)))
candidates.append("/usr/local/cuda")
candidates.extend(sorted(glob.glob("/usr/local/cuda-*"), reverse=True))
result = []
seen = set()
for candidate in candidates:
if not candidate:
continue
candidate = os.path.abspath(candidate)
if candidate in seen:
continue
seen.add(candidate)
result.append(candidate)
return result
def _python_site_dirs():
dirs = []
try:
dirs.extend(site.getsitepackages())
except Exception:
pass
try:
user_site = site.getusersitepackages()
if user_site:
dirs.append(user_site)
except Exception:
pass
conda_prefix = os.environ.get("CONDA_PREFIX")
if conda_prefix:
pattern = os.path.join(conda_prefix, "lib", "python*", "site-packages")
dirs.extend(glob.glob(pattern))
return _dedupe_existing_paths(dirs)
def _cuda_include_dirs():
"""Return CUDA include directories.
Handles:
- /usr/local/cuda*/include
- $CONDA_PREFIX/include
- $CONDA_PREFIX/targets/x86_64-linux/include
- pip NVIDIA wheels: site-packages/nvidia/*/include
"""
candidates = []
for root in _cuda_home_candidates():
candidates.extend(
[
os.path.join(root, "include"),
os.path.join(root, "targets", "x86_64-linux", "include"),
]
)
for site_dir in _python_site_dirs():
candidates.extend(glob.glob(os.path.join(site_dir, "nvidia", "*", "include")))
return _dedupe_existing_paths(candidates)
def _cuda_library_dirs():
"""Return CUDA library directories.
Handles:
- /usr/local/cuda*/lib64
- $CONDA_PREFIX/lib
- $CONDA_PREFIX/targets/x86_64-linux/lib
- pip NVIDIA wheels: site-packages/nvidia/*/lib
"""
candidates = []
for root in _cuda_home_candidates():
candidates.extend(
[
os.path.join(root, "lib64"),
os.path.join(root, "lib"),
os.path.join(root, "targets", "x86_64-linux", "lib"),
os.path.join(root, "targets", "x86_64-linux", "lib", "stubs"),
]
)
for site_dir in _python_site_dirs():
candidates.extend(glob.glob(os.path.join(site_dir, "nvidia", "*", "lib")))
return _dedupe_existing_paths(candidates)
def _find_header(header_name, include_dirs):
for include_dir in include_dirs:
candidate = os.path.join(include_dir, header_name)
if os.path.exists(candidate):
return candidate
return None
def _ensure_cuda_headers(include_dirs):
required_headers = [
"cuda_runtime.h",
"cublas_v2.h",
]
missing = []
found = {}
for header in required_headers:
path = _find_header(header, include_dirs)
if path:
found[header] = path
else:
missing.append(header)
if missing:
lines = [
"Error: CUDA development headers are missing.",
"",
"Missing headers:",
]
for header in missing:
lines.append(f" - {header}")
lines.extend(
[
"",
"Checked include directories:",
]
)
if include_dirs:
for include_dir in include_dirs:
lines.append(f" - {include_dir}")
else:
lines.append(" <none>")
lines.extend(
[
"",
"This usually means nvcc is available, but CUDA development headers",
"are not installed or are not visible to the build system.",
"",
"For pip/PyTorch CUDA wheels, headers may live under:",
" site-packages/nvidia/*/include",
"",
"For conda/miniforge CUDA environments, headers may live under:",
" $CONDA_PREFIX/targets/x86_64-linux/include",
"",
"To skip CUDA support for torchcsprng, set:",
" NSSMPC_SKIP_CSPRNG_CUDA=1",
]
)
raise RuntimeError("\n".join(lines))
print("Detected CUDA headers:")
for header, path in found.items():
print(f" {header}: {path}")
def _cxx_compiler_command():
env_cxx = os.environ.get("CXX")
if env_cxx:
return shlex.split(env_cxx)
if sys.platform.startswith("win"):
return ["cl"] if shutil.which("cl") else None
for name in ("g++", "c++", "clang++"):
if shutil.which(name):
return [name]
return None
def _compiler_version(command):
for args in (command + ["-dumpfullversion", "-dumpversion"], command + ["--version"]):
try:
out = subprocess.check_output(args, stderr=subprocess.STDOUT).decode("utf-8", "ignore")
except (OSError, subprocess.CalledProcessError):
continue
parsed = _parse_numeric_version(out)
if parsed:
return _version_text(parsed)
return None
def _compiler_kind(command):
name = os.path.basename(command[0]).lower()
if "clang" in name:
return "clang"
return "gcc"
def _cuda_compiler_issue(torch_cuda):
if not torch_cuda or not sys.platform.startswith("linux"):
return None
if os.environ.get("TORCH_DONT_CHECK_COMPILER_ABI", "").upper() in (
"ON",
"1",
"YES",
"TRUE",
"Y",
):
return None
command = _cxx_compiler_command()
if not command:
return "No C++ compiler command was found for CUDA extension builds."
version = _compiler_version(command)
if not version:
return f"Could not determine C++ compiler version for {' '.join(command)}."
kind = _compiler_kind(command)
bounds_map = (
getattr(cpp_extension, "CUDA_CLANG_VERSIONS", {})
if kind == "clang"
else getattr(cpp_extension, "CUDA_GCC_VERSIONS", {})
)
bounds = bounds_map.get(torch_cuda)
if not bounds:
return None
min_version, max_exclusive_version = tuple(bounds[0]), tuple(bounds[1])
if _version_at_least(version, min_version) and _version_less_than(version, max_exclusive_version):
return None
compiler_name = "clang++" if kind == "clang" else "g++"
return (
f"Detected {compiler_name}-compatible compiler {' '.join(command)} {version}, "
f"but CUDA {torch_cuda} requires {compiler_name} "
f">= {_version_text(min_version)}, < {_version_text(max_exclusive_version)} "
"for PyTorch CUDA extension builds."
)
def _ensure_cuda_compiler_compatible(torch_cuda):
issue = _cuda_compiler_issue(torch_cuda)
if issue:
raise RuntimeError(
"C++ compiler version is incompatible with this CUDA/PyTorch build.\n"
f"Reason: {issue}\n"
"Required: use a host C++ compiler version within PyTorch's CUDA compiler bounds, "
"or set NSSMPC_SKIP_CSPRNG_CUDA=1 for an intentional CPU-only torchcsprng build."
)
def _auto_set_cuda_home(torch_cuda):
"""Align CUDA_HOME to torch.version.cuda when possible.
This supports both system CUDA and conda/miniforge CUDA layouts.
"""
if not torch_cuda:
return True
current = os.environ.get("CUDA_HOME")
if current and _nvcc_release(current) == torch_cuda:
os.environ.setdefault("CUDA_PATH", current)
return True
candidates = _cuda_home_candidates()
for candidate in candidates:
if _nvcc_release(candidate) == torch_cuda:
os.environ["CUDA_HOME"] = candidate
os.environ.setdefault("CUDA_PATH", candidate)
print(
f"Notice: auto-set CUDA_HOME={candidate} "
f"(matches torch.version.cuda={torch_cuda})"
)
return True
return False
_ensure_torch_version()
version = open("version.txt", "r").read().strip()
sha = "Unknown"
package_name = "torchcsprng"
cwd = os.path.dirname(os.path.abspath(__file__))
try:
sha = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd)
.decode("ascii")
.strip()
)
except Exception:
pass
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
elif sha != "Unknown":
version += "+" + sha[:7]
print(f"Building wheel {package_name}-{version}")
def write_version_file():
version_path = os.path.join(cwd, "torchcsprng", "version.py")
with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version))
f.write("git_version = {}\n".format(repr(sha)))
write_version_file()
with open("README.md", "r") as fh:
long_description = fh.read()
def append_flags(flags, flags_to_append):
for flag in flags_to_append:
if flag not in flags:
flags.append(flag)
return flags
def get_extensions():
skip_cuda = _env_enabled("NSSMPC_SKIP_CSPRNG_CUDA")
build_cuda = not skip_cuda and (
torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1"
)
if skip_cuda:
print(
"Notice: NSSMPC_SKIP_CSPRNG_CUDA is set; "
"building torchcsprng without CUDA support."
)
module_name = "torchcsprng"
extensions_dir = os.path.join(cwd, module_name, "csrc")
openmp = "ATen parallel backend: OpenMP" in torch.__config__.parallel_info()
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
sources = main_file + source_cpu
extension = CppExtension
define_macros = []
cxx_flags = os.getenv("CXX_FLAGS", "")
if cxx_flags == "":
cxx_flags = []
else:
cxx_flags = cxx_flags.split(" ")
if openmp:
if sys.platform == "linux":
cxx_flags = append_flags(cxx_flags, ["-fopenmp"])
elif sys.platform == "win32":
cxx_flags = append_flags(cxx_flags, ["/openmp"])
include_dirs = []
library_dirs = []
if build_cuda:
if not _auto_set_cuda_home(torch.version.cuda):
raise RuntimeError(
"CUDA PyTorch is installed, but no matching CUDA Toolkit / nvcc "
"was found for bundled torchcsprng.\n"
f"Required: CUDA Toolkit / nvcc {torch.version.cuda}, "
"matching torch.version.cuda."
)
if os.environ.get("CUDA_HOME"):
cpp_extension.CUDA_HOME = os.environ["CUDA_HOME"]
extension = CUDAExtension
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
include_dirs = _cuda_include_dirs()
library_dirs = _cuda_library_dirs()
_ensure_cuda_headers(include_dirs)
_ensure_cuda_compiler_compatible(torch.version.cuda)
nvcc_flags = os.getenv("NVCC_FLAGS", "")
if nvcc_flags == "":
nvcc_flags = []
else:
nvcc_flags = nvcc_flags.split(" ")
for include_dir in include_dirs:
nvcc_flags = append_flags(nvcc_flags, [f"-I{include_dir}"])
nvcc_flags = append_flags(
nvcc_flags,
[
"--expt-extended-lambda",
"-Xcompiler",
],
)
extra_compile_args = {
"cxx": cxx_flags,
"nvcc": nvcc_flags,
}
print("Building torchcsprng with CUDA support.")
print("CUDA_HOME:", os.environ.get("CUDA_HOME") or "<unset>")
print("CUDA include dirs:")
for include_dir in include_dirs:
print(f" - {include_dir}")
print("CUDA library dirs:")
for library_dir in library_dirs:
print(f" - {library_dir}")
else:
extra_compile_args = {
"cxx": cxx_flags,
}
ext_modules = [
extension(
module_name + "._C",
sources,
define_macros=define_macros,
include_dirs=include_dirs,
library_dirs=library_dirs,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
class fast_install(Command):
description = "Custom install command that cleans project and installs wheel"
user_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
os.system("python setup.py clean")
os.system("python setup.py bdist_wheel")
os.system(f"pip install {glob.glob('./dist/*.whl')[0]} --force-reinstall --no-deps")
class clean(Command):
description = "Custom clean command that cleans project based on .gitignore rules"
user_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
with open(".gitignore", "r") as f:
ignores = f.read()
start_deleting = False
for wildcard in filter(None, ignores.split("\n")):
if (
wildcard
== "# do not change or delete this comment - `python setup.py clean` deletes everything after this line"
):
start_deleting = True
if not start_deleting:
continue
for filename in glob.glob(wildcard, recursive=True):
try:
os.remove(filename)
print(f"Removed file: {filename}")
except OSError:
shutil.rmtree(filename, ignore_errors=True)
print(f"Removed directory: {filename}")
setup(
name=package_name,
version=version,
author="Pavel Belevich",
author_email="pbelevich@fb.com",
url="https://github.com/pytorch/csprng",
description="Cryptographically secure pseudorandom number generators for PyTorch",
long_description=long_description,
long_description_content_type="text/markdown",
license="BSD-3",
packages=find_packages(exclude=("test",)),
package_data={"": ["*.pyi"]},
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
python_requires=">=3.10",
install_requires="torch>=2.5.0",
ext_modules=get_extensions(),
test_suite="test",
cmdclass={
"fast_install": fast_install,
"build_ext": BuildExtension,
"clean": clean,
},
)