#!/usr/bin/env python3
# coding: utf-8
# Copyright 2024 Huawei Technologies Co., Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# ===========================================================================
import glob
import re
import os.path
import shlex
import subprocess

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils import common_info
from ansible.module_utils.common_utils import generate_table, McuMultiProcess
from ansible.module_utils.safe_file_handler import SafeFileHandler
from ansible.module_utils.utils import check_safe_command_arg

OK = "OK"
ERROR = "ERROR"
_CANN = "cann"
_SET_ENV_SCRIPT_PATH = "{}/{}/set_env.sh"
_SET_ENV_SCRIPT_CMD = ". {}/{}/set_env.sh"
_ASCEND_TOOLKIT = "ascend-toolkit"
_NNAE = "nnae"
_NNRT = "nnrt"
_NOT_INSTALLED = "not installed"

messages = []


def run_command(command, custom_env=None):
    try:
        env = os.environ.copy()
        if custom_env:
            env.update(custom_env)
        process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=env, executable="/bin/bash")
        stdout, stderr = process.communicate()
        if not isinstance(stdout, str):
            stdout = str(stdout, encoding='utf-8')
        if not isinstance(stderr, str):
            stderr = str(stderr, encoding='utf-8')
        return process.returncode == 0, stdout + stderr
    except (OSError, subprocess.SubprocessError) as e:
        # 捕获与子进程执行相关的具体异常
        return False, "Command execution failed: {}".format(str(e))
    except UnicodeDecodeError as e:
        # 捕获与字符编码相关的具体异常
        return False, "String decoding failed: {}".format(str(e))
    except Exception as e:
        # 捕获其他未预期的异常并记录详细信息
        return False, "Unexpected error occurred during command execution: {}".format(str(e))


def find_files(dir_path, file_name):
    targets = set()
    if not os.path.isdir(dir_path):
        return targets
    for root, _, files in os.walk(dir_path):
        if file_name in files:
            targets.add(os.path.realpath(os.path.join(root, file_name)))
    return targets


def info_to_dict(file_path):
    """
    load info file into json. e.g.
    A = B   => {"A": "B"}
    """
    info_dict = dict()
    if not os.path.isfile(file_path):
        return info_dict

    try:
        lines = SafeFileHandler.safe_read(file_path, mode='r').splitlines()
        for line in lines:
            split_line = line.split("=")
            if len(split_line) == 2:
                info_dict[split_line[0].strip()] = split_line[1].strip()
    except Exception as e:
        raise Exception("Failed to read file {}".format(file_path)) from e
    return info_dict


def get_npu_ids():
    ok, outputs = run_command("npu-smi info -m")
    if not ok:
        messages.append(outputs)
        return []
    npu_ids = []
    for line in outputs.splitlines():
        split_line = line.split()
        if not split_line:
            continue
        if '310' in split_line[-1] or '910' in split_line[-1] or '710' in split_line[-1]:
            npu_ids.append(split_line[0].strip())
    return npu_ids


def test_driver(cus_npu_info):
    if not os.path.exists("/usr/local/Ascend/driver/version.info"):
        return "not installed", ""
    ok, output = run_command(
        "npu-smi info",
        {"LD_LIBRARY_PATH": "/usr/local/Ascend/driver/lib64:/usr/local/Ascend/driver/lib64/common:"
                            "/usr/local/Ascend/driver/lib64/driver:"})
    if not ok:
        return ERROR, ""
    pattern = r"Version:\s*(\S+)"
    version = re.findall(pattern, output)[0].upper()

    if cus_npu_info == "300i-pro":
        checking_words = "300i"
    elif cus_npu_info == "300v-pro":
        checking_words = "300v"
    else:
        return OK, version

    for npu_id in get_npu_ids():
        ok, outputs = run_command("npu-smi info -t product -i {}".format(npu_id))
        if not ok:
            return ERROR, ""
        if checking_words not in outputs.lower():
            messages.append("you are installing driver of {} on hardware of {}".format(
                cus_npu_info, "300i-pro" if cus_npu_info == "300v-pro" else "300v-pro"))
            return ERROR, ""
    return OK, version


def test_ubengine():
    if not os.path.exists("/usr/local/Ascend/driver/ube_mgmt/"):
        return "not installed", ""
    ok, _ = run_command(
        "hccn_tool -g -dev_info -i 0",
        {"LD_LIBRARY_PATH": "/usr/local/Ascend/driver/lib64:/usr/local/Ascend/driver/lib64/common:" 
                            "/usr/local/Ascend/driver/lib64/driver:"})
    if not ok:
        return ERROR, ""
    version_info = "/usr/local/Ascend/driver/version.info"
    if os.path.exists("/usr/local/Ascend/driver/ube_mgmt/version.info"):
        version_info = "/usr/local/Ascend/driver/ube_mgmt/version.info"

    # Read mami_version from version.info file
    mami_version = ""
    if os.path.exists(version_info):
        with open(version_info, 'r') as f:
            for line in f:
                if '=' in line:
                    mami_version = line.strip().split('=')[1]
                    break
    if not mami_version:
        return "not installed", ""
    return OK, mami_version


def test_mcu(module):
    """
    Test mcu and get the mcu version dict.

    This function obtains npu id information through 'npu-smi info -l',
    and then concurrently executes 'npu-smi upgrade -b mcu -i NPU_ID' to query the mcu version
    and generate the return value version_dict

    Args:
        module (AnsibleModule): Ansible module instance.

    Returns:
        version_dict: The mcu version dict
                       eg: {'npu_id_1': '24.2.1', 'npu_id_2': '24.2.1','npu_id_4':'24.2.1'}
    """
    if not module.get_bin_path('npu-smi'):
        return {}
    rc, out, _ = module.run_command(shlex.split("npu-smi info -l"))
    """     
    eg:
            Total Count                    : 1
    
            NPU ID                         : 8
            Product Name                   : IT21PDDA01
            Serial Number                  : 033VNY10MB000071
            Chip Count                     : 1
        A3:
            Total Count                    : 8
    
            NPU ID                         : 0
            Chip Count                     : 2
            ...
    """
    if rc != 0:
        return {}
    npu_id_list = []
    for line in out.splitlines():
        if 'NPU ID' in line:
            npu_id = line.split(':')[-1].strip()
            if npu_id.isdigit():
                npu_id_list.append(int(npu_id))

    mcu_multiprocess = McuMultiProcess(npu_id_list, module)
    results = mcu_multiprocess.multi_run_command()
    version_dict = {}
    for device_id in sorted(results.keys()):
        result = results[device_id]
        if not result.get('success'):
            if 'This device does not support querying version' in result.get('error'):
                version_dict.update({'npu_id_' + str(device_id): 'not support'})
            else:
                version_dict.update({'npu_id_' + str(device_id): 'ERROR'})
        for line in result.get('output').splitlines():
            # Version  : 24.15.15
            if 'Version' in line:
                version = line.split(':')[-1].strip()
                version_dict.update({'npu_id_' + str(device_id): version})
    return version_dict


def test_firmware():
    if not os.path.exists("/usr/local/Ascend/firmware/version.info"):
        return "not installed", ""

    if not os.path.exists("/usr/local/Ascend/driver/tools/upgrade-tool"):
        return "not installed", ""
    ok, output = run_command(
        "/usr/local/Ascend/driver/tools/upgrade-tool --device_index -1 --system_version",
        {"LD_LIBRARY_PATH": "/usr/local/Ascend/driver/lib64:/usr/local/Ascend/driver/lib64/common:"
                            "/usr/local/Ascend/driver/lib64/driver:"})
    if ok and "succeed" in output:
        file_path = '/usr/local/Ascend/firmware/version.info'
        try:
            content = SafeFileHandler.safe_read(file_path, mode='r')
            pattern = r"Version=(\S+)"
            match = re.search(pattern, content)
            version = match.group(1) if match else ""
            return OK, version
        except Exception as e:
            raise Exception("Failed to read firmware version file {}".format(file_path)) from e
    return ERROR, ""


def test_toolbox():
    ascend_install_path = common_info.get_ascend_install_path(os.getuid(), os.path.expanduser("~"))
    bin_path = os.path.join(ascend_install_path, "toolbox/latest/Ascend-DMI/bin/ascend-dmi")
    if not os.path.exists(bin_path):
        return "not installed", ""
    commands = [". {}/toolbox/set_env.sh".format(ascend_install_path), "ascend-dmi -v"]
    ok, output = run_command(" && ".join(commands))
    if ok:
        return OK, get_cann_version('toolbox')
    return ERROR, ""


def set_cann_env():
    """
    description: 设置环境变量,优先级: toolkit > nnae
    """
    ascend_install_path = common_info.get_ascend_install_path(os.getuid(), os.path.expanduser("~"))
    commands = []
    if os.path.exists("{}/ascend-toolkit/set_env.sh".format(ascend_install_path)):
        commands.append(". {}/ascend-toolkit/set_env.sh".format(ascend_install_path))
    if os.path.exists("{}/cann/set_env.sh".format(ascend_install_path)):
        commands.append(". {}/cann/set_env.sh".format(ascend_install_path))
    if os.path.exists("{}/nnae/set_env.sh".format(ascend_install_path)):
        commands.append(". {}/nnae/set_env.sh".format(ascend_install_path))
    return commands


def test_python_package(package_name, python_version):
    local_path = common_info.get_local_path(os.getuid(), os.path.expanduser("~"))
    ascend_install_path = common_info.get_ascend_install_path(os.getuid(), os.path.expanduser("~"))
    paths = os.environ.get("PATH", "")
    paths = "{}/{}/bin:".format(local_path, python_version) + paths
    ld_paths = "{}/{}/lib:".format(local_path, python_version)
    ok, output = run_command("python3 -m pip list | grep {} | grep -v torch-mindio".format(package_name), custom_env={
        "PATH": paths, "LD_LIBRARY_PATH": ld_paths
    })
    if not ok:
        return "not installed", ""
    # 新版本在/usr/local/Ascend/cann/tools/ccec_compiler
    paths = "{}/ascend-toolkit/latest/atc/ccec_compiler/bin/:".format(ascend_install_path) + paths
    ld_paths = ("{}/gcc7.3.0/lib64;{}/{}/lib/{}/site-packages/{}/lib:{}/add-ons/:".
                format(local_path, local_path, python_version, python_version[:9],
                       package_name, ascend_install_path) + ld_paths)
    commands = set_cann_env()
    if package_name == "torch":
        if not commands:
            return ERROR, ""
        commands.append('python3 -c "import torch; import torch_npu; a = torch.randn(3, 4).npu(); print(a + a)"')
    if package_name == "mindspore":
        commands.append('python3 -c "import mindspore;mindspore.set_context(device_target=\'Ascend\');\
        mindspore.run_check()"')
    ok, output = run_command(" && ".join(commands), custom_env={
        "PATH": paths, "LD_LIBRARY_PATH": ld_paths
    })
    if not ok:
        return ERROR, ""
    if package_name == "torch":
        package_name = "torch_npu"
    commands = set_cann_env()
    commands.append('python3 -c "import {}; print({}.__version__)"'.format(package_name, package_name))
    ok, output = run_command(" && ".join(commands), custom_env={"PATH": paths, "LD_LIBRARY_PATH": ld_paths})
    version = output.split('\n')[0]
    return OK, version


def test_tensorflow(python_version):
    local_path = common_info.get_local_path(os.getuid(), os.path.expanduser("~"))
    ascend_install_path = common_info.get_ascend_install_path(os.getuid(), os.path.expanduser("~"))
    paths = os.environ.get("PATH", "")
    paths = "{}/{}/bin:".format(local_path, python_version) + paths
    ld_paths = "{}/{}/lib:{}/add-ons/:".format(local_path, python_version, ascend_install_path)
    ok, output = run_command('python3 -m pip list | grep -E "tensorflow |tensorflow-cpu"', custom_env={
        "PATH": paths, "LD_LIBRARY_PATH": ld_paths
    })
    if not ok:
        return "not installed", ""
    if "1.15.0" in output:
        version = "1.15.0"
    elif "2.6.5" in output:
        version = "2.6.5"
    else:
        return ERROR, ""

    commands = []
    if os.path.exists("{}/nnae/set_env.sh".format(ascend_install_path)):
        commands.append(". {}/nnae/set_env.sh".format(ascend_install_path))
    if os.path.exists("{}/ascend-toolkit/set_env.sh".format(ascend_install_path)):
        commands.append(". {}/ascend-toolkit/set_env.sh".format(ascend_install_path))
    if os.path.exists("{}/cann/set_env.sh".format(ascend_install_path)):
        commands.append(". {}/cann/set_env.sh".format(ascend_install_path))
    if version == "1.15.0":
        commands.append('python3 -c "import npu_bridge.estimator; import npu_bridge.hccl;'
                        ' from tensorflow.core.protobuf import rewriter_config_pb2"')
    if version == "2.6.5":
        commands.append('python3 -c "import npu_device; from tensorflow.core.protobuf import rewriter_config_pb2"')

    ok, output = run_command(" && ".join(commands), custom_env={
        "PATH": paths, "LD_LIBRARY_PATH": ld_paths
    })
    messages.append(output)
    if not ok:
        return ERROR, ""

    return OK, version


def get_value_on_prefix_ignore_case(_dict, _key, default=None):
    for key, value in _dict.items():
        if key.lower().startswith(_key.lower()):
            return value
    return default


def get_cann_version(item):
    root_path = '/usr/local/Ascend'
    _item = item
    if item == 'ascend-toolkit':
        _item = 'toolkit'
    item_info_dir = os.path.join(root_path, item, "latest")
    target_paths = find_files(item_info_dir, "ascend_" + _item + "_install.info")
    version = ""
    for info_path in target_paths:
        item_info = info_to_dict(info_path)
        version = get_value_on_prefix_ignore_case(item_info, "version", "")
    return version


def _get_new_cann_version(item):
    root_path = '/usr/local/Ascend'
    if item not in ['ascend-toolkit', 'nnrt', 'nnae']:
        return ""
    if not os.path.exists(os.path.join(root_path, _CANN)):
        return ""
    if "not installed" in _check_version_in_whitelist(item):
        return ""
    item_info_dir = os.path.join(root_path, _CANN)
    _item = 'toolkit'
    target_paths = find_files(item_info_dir, "ascend_" + _item + "_install.info")
    version = ""

    for info_path in target_paths:
        item_info = info_to_dict(info_path)
        version = get_value_on_prefix_ignore_case(item_info, "version", "")
    return version


def _check_version_in_whitelist(package_name):
    """
    为了兼容新版本的cann包按照toolkit>nnae>nnrt的包含关系显示版本信息
    优先显示新版本的信息
    | 测试参数 | `whitelist=nnrt` | `whitelist=nnae` | `whitelist=no` |
    |----------|------------------|------------------|----------------|
    | `--test=nnrt` | 显示nnrt的版本信息 | 显示nnrt的版本信息 | 显示nnrt的版本信息 |
    | `--test=nnae` | nnae not installed | 显示nnae的版本信息 | 显示nnae的版本信息 |
    | `--test=toolkit` | toolkit not installed | toolkit not installed | 显示toolkit的版本信息 |
    """
    whitelist_conf = _get_whitelist_conf()

    whitelist_dict = {"no": {_ASCEND_TOOLKIT: "",
                             _NNAE: "",
                             _NNRT: ""},
                      _NNAE: {_ASCEND_TOOLKIT: _NOT_INSTALLED,
                               _NNAE: "",
                               _NNRT: ""},
                      _NNRT: {_ASCEND_TOOLKIT: _NOT_INSTALLED,
                               _NNAE: _NOT_INSTALLED,
                               _NNRT: ""},
                      }
    version = whitelist_dict.get(whitelist_conf, {}).get(package_name, "")
    if _NOT_INSTALLED in version:
        return _NOT_INSTALLED
    return ""


def _get_whitelist_conf():
    whitelist_conf = ""
    install_conf_pattern = "/usr/local/Ascend/cann/*/install.conf"
    install_conf_list = glob.glob(install_conf_pattern, recursive=True)
    if not install_conf_list:
        return ""
    install_conf = install_conf_list[0]
    if os.path.exists(install_conf):
        content = SafeFileHandler.safe_read(install_conf).splitlines()
        for line in content:
            if "whitelist" in line.lower():
                whitelist_conf = line.split("=")[-1]
    return whitelist_conf


def test_cann_packages(package_name, python_version):
    ascend_install_path = common_info.get_ascend_install_path(os.getuid(), os.path.expanduser("~"))
    cann_path = os.path.join(ascend_install_path, "{}/latest".format(package_name))

    if not os.path.exists(os.path.join(ascend_install_path, _CANN)) and not os.path.exists(cann_path):
        return _NOT_INSTALLED, ""

    commands = []
    check_safe_command_arg(ascend_install_path)
    check_safe_command_arg(package_name)
    if package_name in ["ascend-toolkit", "nnae", "nnrt"]:
        # 新版本的cann会生成一个软连接ascend-toolkit路径的软连接
        if _NOT_INSTALLED in _check_version_in_whitelist(package_name) and (
                not os.path.exists(cann_path) or package_name == "ascend-toolkit"):
            return _NOT_INSTALLED, ""

        if os.path.exists(_SET_ENV_SCRIPT_PATH.format(ascend_install_path, _CANN)):
            commands.append(_SET_ENV_SCRIPT_CMD.format(ascend_install_path, _CANN))
        elif os.path.exists(_SET_ENV_SCRIPT_PATH.format(ascend_install_path, package_name)):
            commands.append(_SET_ENV_SCRIPT_CMD.format(ascend_install_path, package_name))
    elif os.path.exists(_SET_ENV_SCRIPT_PATH.format(ascend_install_path, package_name)):
        commands.append(_SET_ENV_SCRIPT_CMD.format(ascend_install_path, package_name))
    else:
        return ERROR, ""
    commands.append('python3 -c "import acl"')
    paths = os.environ.get("PATH", "")
    local_path = common_info.get_local_path(os.getuid(), os.path.expanduser("~"))
    paths = "{}/{}/bin:".format(local_path, python_version) + paths
    ld_paths = "{}/{}/lib:".format(local_path, python_version)
    ok, output = run_command(" && ".join(commands), custom_env={
        "PATH": paths, "LD_LIBRARY_PATH": ld_paths
    })
    if not ok:
        return ERROR, ""
    return OK, _get_new_cann_version(package_name) or get_cann_version(package_name)


def test_fault_diag():
    try:
        env = os.environ.copy()
        process = subprocess.Popen("ascend-fd version", stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True,
                                   env=env, executable="/bin/bash")
        stdout, stderr = process.communicate()
        if not isinstance(stdout, str):
            stdout = str(stdout, encoding='utf-8')
        if not isinstance(stderr, str):
            stderr = str(stderr, encoding='utf-8')
        if process.returncode == 0:
            pattern = r"ascend-fd v(\d+[+.]\d+[+.]\w+)"
            match = re.search(pattern, stdout + stderr)
            version = match.group(1) if match else ""
            return OK, version
        elif process.returncode == 127:
            return "not installed", ""
    except Exception:
        return ERROR, ""
    return ERROR, ""


def test_mindie_image():
    _, output = run_command("docker ps --filter name=MindIE --format {{.Names}}")
    container_names = output.splitlines()
    if 'MindIE' not in container_names:
        return "not installed", ""
    command = ["docker exec MindIE "]
    version_path = "/usr/local/Ascend/mindie/latest/mindie-service/version.info"
    command.append("cat {}".format(version_path))
    ok, output = run_command(" ".join(command))
    if not ok:
        return ERROR, ""
    for line in output.splitlines():
        if "Ascend-mindie :" in line:
            return OK, line.split(":")[1].strip()
    return ERROR, ""


def main():
    module = AnsibleModule(argument_spec=dict(
        ansible_run_tags=dict(type="list", required=True),
        cus_npu_info=dict(type="str", required=True),
        ip=dict(type="str", required=True),
        python_version=dict(type="str", required=True)
    )
    )
    ansible_run_tags = set(module.params["ansible_run_tags"])
    cus_npu_info = module.params.get("cus_npu_info", "")
    python_version = module.params.get("python_version", "")
    if 'whole' in ansible_run_tags:
        ansible_run_tags = ["driver", "firmware", "toolbox", "mindspore",
                            "pytorch", "tensorflow", "nnae", "nnrt", 
                            "toolkit", "fault-diag", "mindie_image", "mcu", "ubengine"]
    result = {}
    if "driver" in ansible_run_tags:
        result["driver"] = test_driver(cus_npu_info)
    if "ubengine" in ansible_run_tags:
        result["ubengine"] = test_ubengine()
    if "firmware" in ansible_run_tags:
        result["firmware"] = test_firmware()
    if 'mcu' in ansible_run_tags:
        result["mcu"] = test_mcu(module)
    if "toolbox" in ansible_run_tags:
        result["toolbox"] = test_toolbox()
    if "mindspore" in ansible_run_tags:
        result["mindspore"] = test_python_package("mindspore", python_version)
    if "pytorch" in ansible_run_tags:
        result["pytorch"] = test_python_package("torch", python_version)
    if "tensorflow" in ansible_run_tags:
        result["tensorflow"] = test_tensorflow(python_version)
    if "nnae" in ansible_run_tags:
        result["nnae"] = test_cann_packages("nnae", python_version)
    if "nnrt" in ansible_run_tags:
        result["nnrt"] = test_cann_packages("nnrt", python_version)
    if "toolkit" in ansible_run_tags:
        result["toolkit"] = test_cann_packages("ascend-toolkit", python_version)
    if "fault-diag" in ansible_run_tags:
        result["fault-diag"] = test_fault_diag()
    if "mindie_image" in ansible_run_tags:
        result["mindie_image"] = test_mindie_image()

    formatted_result = {module.params.get("ip", ""): result}

    module.exit_json(changed=True, rc=0, msg="\n".join(messages), result=formatted_result)


if __name__ == "__main__":
    main()