#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ----------------------------------------------------------------------------
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------

import os
import shutil
import subprocess
import logging
from setuptools import setup, find_packages, Distribution, Command
from wheel.bdist_wheel import bdist_wheel

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
PACKAGE_NAME = "ascend_ops"
VERSION = "1.0.0"
DESCRIPTION = "Example of PyTorch C++ and Ascend extensions"


class CleanCommand(Command):
    """
    usage: python setup.py clean
    """
    description = "Clean build artifacts from the source tree"
    user_options = []

    def initialize_options(self):
        pass

    def finalize_options(self):
        pass

    def run(self):
        folders_to_remove = ['build', 'dist', f'{PACKAGE_NAME}.egg-info']
        for folder in folders_to_remove:
            if os.path.exists(folder):
                shutil.rmtree(folder)
                logging.info(f"Removed folder: {folder}")
        for root, _, files in os.walk('.'):
            for file in files:
                if file.endswith(('.pyc', '.pyo')):
                    file_path = os.path.join(root, file)
                    os.remove(file_path)
                    logging.info(f"Removed file: {file_path}")
        logging.info("Cleaned build artifacts.")


class BinaryDistribution(Distribution):
    """
    Make this wheel not a pure python package
    """
    def is_pure(self):
        return False

    def has_ext_modules(self):
        return True


class ABI3Wheel(bdist_wheel):
    """
    Force to use abi3 tag for wheel, this wheel supports multiple python versions >= 3.8
    """
    def get_tag(self):
        python, abi, plat = super().get_tag()
        python = "cp38"
        abi = "abi3"
        return python, abi, plat

    def run(self):
        self.run_command('cmake_build')
        super().run()


class CMakeBuildCommand(Command):
    """
    Custom command to build CMake extensions
    """
    description = "Build CMake extensions"
    user_options = []

    def initialize_options(self):
        pass

    def finalize_options(self):
        pass

    def run(self):
        """
        This file `setup.py` and the CMakeLists.txt are in the same directory.
        Use multi-core to speed up compilation.
        """
        cpu_count = os.cpu_count() or 2
        num_jobs = str(cpu_count)
        # Get Torch and Torch NPU paths
        import torch
        TORCH_CMAKE_PATH = torch.utils.cmake_prefix_path
        Torch_DIR = os.path.join(TORCH_CMAKE_PATH, "Torch")
        logging.info(f"Using Torch path: {Torch_DIR}")
        import torch_npu
        TORCH_NPU_PATH = os.path.dirname(torch_npu.__file__)
        logging.info(f"Using Torch NPU path: {TORCH_NPU_PATH}")

        # Build the CMake project
        build_temp = os.path.join(os.getcwd(), 'build')
        cmake_generator = 'Ninja' if shutil.which('ninja') else 'Unix Makefiles'
        logging.info(f"Using CMake generator: {cmake_generator}")
        cmake_config_command = ['cmake', '-S', os.getcwd(), '-B', build_temp,
                                '-G', cmake_generator,
                                '-DCMAKE_BUILD_TYPE=Release',
                                f'-DTorch_DIR={Torch_DIR}',
                                f'-DTORCH_NPU_PATH={TORCH_NPU_PATH}'
                                ]
        subprocess.check_call(cmake_config_command, cwd=os.getcwd())
        subprocess.check_call(['cmake', '--build', build_temp, '--config', 'Release', '--parallel', num_jobs], cwd=os.getcwd())
        logging.info("CMake extensions built successfully.")


cmdclass = {
    'clean': CleanCommand,
    'bdist_wheel': ABI3Wheel,
    'cmake_build': CMakeBuildCommand,
}


setup(
    name=PACKAGE_NAME,
    version=VERSION,
    description=DESCRIPTION,
    packages=find_packages(),
    package_data={PACKAGE_NAME: ['*.so']},
    distclass=BinaryDistribution,
    cmdclass=cmdclass,
    zip_safe=False,
    install_requires=[
        "torch",
        "torch_npu"
    ],
    classifiers=[
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Operating System :: POSIX :: Linux",
    ],
    python_requires='>=3.8',
)