import re
import os
from abc import ABC, abstractmethod
from typing import List, Union
from torch.utils.cpp_extension import load
from torch.library import Library
import torch_npu
import mindspeed

ASCEND_HOME_PATH = "ASCEND_HOME_PATH"
AS_LIBRARY = Library("mindspeed", "DEF")


class MindSpeedOpBuilder(ABC):
    _cann_path = None
    _torch_npu_path = None
    _cann_version = None
    _loaded_ops = {}

    def __init__(self, name):
        self.name = name
        self._cann_path = self.get_cann_path()
        self._torch_npu_path = os.path.dirname(os.path.abspath(torch_npu.__file__))

    def get_cann_path(self):
        if ASCEND_HOME_PATH in os.environ and os.path.exists(os.environ[ASCEND_HOME_PATH]):
            return os.environ[ASCEND_HOME_PATH]
        return None

    def get_absolute_paths(self, paths):
        mindspeed_path = os.path.abspath(os.path.dirname(mindspeed.__file__))
        return [os.path.join(mindspeed_path, path) for path in paths]

    def register_op_proto(self, op_proto: Union[str, List[str]]):
        if isinstance(op_proto, str):
            op_proto = [op_proto]
        for proto in op_proto:
            AS_LIBRARY.define(proto)

    @abstractmethod
    def sources(self):
        ...

    def include_paths(self):
        paths = [
            os.path.join(self._torch_npu_path, 'include'),
            os.path.join(self._torch_npu_path, 'include/third_party/hccl/inc'),
            os.path.join(self._torch_npu_path, 'include/third_party/acl/inc'),
            os.path.join(self._cann_path, 'include'),
        ]
        return paths

    def cxx_args(self):
        args = ['-fstack-protector-all', '-Wl,-z,relro,-z,now,-z,noexecstack', '-fPIC', '-pie',
                '-s', '-fvisibility=hidden', '-D_FORTIFY_SOURCE=2', '-O2']
        return args

    def extra_ldflags(self):
        flags = [
            '-L' + os.path.join(self._cann_path, 'lib64'), '-lascendcl',
            '-L' + os.path.join(self._torch_npu_path, 'lib'), '-ltorch_npu'
        ]
        return flags

    def load(self, verbose=True):
        if self.name in __class__._loaded_ops:
            return __class__._loaded_ops[self.name]

        op_module = load(name=self.name,
                         sources=self.get_absolute_paths(self.sources()),
                         extra_include_paths=self.get_absolute_paths(self.include_paths()),
                         extra_cflags=self.cxx_args(),
                         extra_ldflags=self.extra_ldflags(),
                         verbose=verbose)
        __class__._loaded_ops[self.name] = op_module

        return op_module