from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import shutil
import tarfile
import setuptools
tarfile.TarFile.format = tarfile.GNU_FORMAT
CUR_DIR = os.path.split(os.path.realpath(__file__))[0]
os.environ['SOURCE_DATE_EPOCH'] = \
str(int(os.path.getctime(os.path.realpath(__file__))))
class SetupTool():
""" tool for setup"""
def __init__(self):
self.set_packages()
self.set_version()
self.set_platform()
self.setup_args = dict()
def set_packages(self):
""" set packages based on build mode"""
enable_experimental = os.getenv('AMCT_EXPERIMENTAL', '').upper() == 'TRUE'
if enable_experimental:
self.packages = setuptools.find_packages(
include=['amct_pytorch', 'amct_pytorch.*'])
else:
self.packages = setuptools.find_packages(
include=['amct_pytorch', 'amct_pytorch.*'],
exclude=['amct_pytorch.experimental',
'amct_pytorch.experimental.*'])
def set_version(self):
""" set version"""
version_file = os.path.join(CUR_DIR, 'amct_pytorch', '.version')
with open(version_file) as fid:
version = fid.readlines()[0].strip()
self.version = version
def set_platform(self):
""" set platform"""
if 'sdist' in sys.argv:
platform = os.getenv('AMCT_PYTORCH_PLATFORM').replace("\n", "")
self.platform = platform
setup_tools = SetupTool()
def get_package_data():
""" get package data"""
return {
'': ['.version'],
'amct_pytorch.classic.graph_based': ['amct_pytorch/proto/*.proto',
'amct_pytorch/capacity/*.csv',
'lib/*.so',
],
}
setuptools.setup(
name='amct_pytorch',
version=setup_tools.version,
description='Ascend Model Compression Toolkit for PyTorch',
url='https://gitcode.com/cann/amct',
packages=setup_tools.packages,
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Programming Language :: C++',
'Programming Language :: Python :: 3'
],
author='Huawei Technologies Co., Ltd.',
license='Apache 2.0',
extras_require={
"pytorch": ["2.1"]
},
package_data=get_package_data(),
zip_safe=False,
**setup_tools.setup_args
)
if 'sdist' in sys.argv:
shutil.move(
os.path.join(
CUR_DIR,
'dist/amct_pytorch-{}.tar.gz'.format(setup_tools.version)),
os.path.join(
CUR_DIR,
'dist/amct_pytorch-{}-py3-none-{}.tar.gz'.format(
setup_tools.version, setup_tools.platform)))