#!/usr/bin/env python3
# coding: utf-8
# Copyright 2024 Huawei Technologies Co., Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# ===========================================================================
import collections
import os.path
import json
import platform
import shutil
import subprocess
import shlex
import re

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common_utils import McuMultiProcess
from ansible.module_utils.safe_file_handler import SafeFileHandler

# 定义最大文件大小限制为100MB
MAX_FILE_SIZE = 100 * 1024 * 1024


def run_command(command):
    try:
        output = subprocess.check_output(shlex.split(command))
        if not isinstance(output, str):
            output = str(output, encoding='utf-8')
        return 0, output
    except (subprocess.CalledProcessError, OSError) as e:
        # 捕获子进程执行相关的具体异常
        return -1, "Command execution failed: {}".format(str(e))
    except UnicodeDecodeError as e:
        # 捕获与字符编码相关的具体异常
        return -1, "String decoding failed: {}".format(str(e))
    except Exception as e:
        # 捕获其他未预期的异常并记录详细信息
        return -1, "Unexpected error occurred during command execution: {}".format(str(e))


def info_to_dict(file_path):
    """
    load info file into json. e.g.
    A = B   => {"A": "B"}
    """
    info_dict = dict()
    if not os.path.isfile(file_path):
        return info_dict
    
    try:
        lines = SafeFileHandler.safe_read(file_path, mode='r').splitlines()
        for line in lines:
            split_line = line.split("=")
            if len(split_line) == 2:
                info_dict[split_line[0].strip()] = split_line[1].strip()
    except Exception as e:
        raise Exception("Failed to read file {}".format(file_path)) from e
    return info_dict


def get_value_on_prefix_ignore_case(_dict, _key, default=None):
    for key, value in _dict.items():
        if key.lower().startswith(_key.lower()):
            return value
    return default


def find_files(dir_path, file_name):
    targets = set()
    if not os.path.isdir(dir_path):
        return targets
    for root, _, files in os.walk(dir_path):
        if file_name in files:
            targets.add(os.path.realpath(os.path.join(root, file_name)))
    return targets


def getinfo_from_xml(file_path, root_path):
    if not os.path.exists(file_path):
        return {}
    
    try:
        lines = SafeFileHandler.safe_read(file_path, mode='r').splitlines()
    except Exception as e:
        raise Exception("Failed to read file {}".format(file_path)) from e
    
    arches = {'ARM': 'aarch64', 'x86': 'x86_64'}
    info_dict = {}
    keyword_pattern = re.compile('>(.*)<')
    for line in lines:
        keyword = ""
        if keyword_pattern.findall(line):
            keyword = keyword_pattern.findall(line)[0]
        if 'OutterName' in line and keyword:
            info_dict['name'] = keyword
        if 'ProcessorArchitecture' in line and keyword:
            arch = keyword
            info_dict['install_arch'] = arches.get(arch, arch)
        if 'Version' in line and keyword:
            info_dict['version'] = keyword
    info_dict['install_path'] = root_path
    return info_dict


def get_item_info(info_dict, item):
    item_info = {}
    for key, value in info_dict.items():
        if item + "_install_path" in key.lower():
            item_version_info_path = os.path.join(value, item, "version.info")
            version_info = info_to_dict(item_version_info_path).get("Version", "")
            item_info = {"name": item, "install_arch": platform.machine(),
                         "install_path": value, "version": version_info}
    return item_info


def collect_app_info():
    info_dict = info_to_dict("/etc/ascend_install.info")
    driver_info = get_item_info(info_dict, "driver")
    firmware_info = get_item_info(info_dict, "firmware")
    apps_info = [item for item in (firmware_info, driver_info) if item]
    root_path = '/usr/local/Ascend'
    if os.getuid() != 0:
        root_path = os.path.expanduser('~/Ascend')
    for item in ['nnrt', 'toolkit', 'nnae', 'toolbox']:
        _item = item
        if item == 'toolkit':
            _item = 'ascend-toolkit'
        item_info_dir = os.path.join(root_path, _item, "latest")
        target_paths = find_files(item_info_dir, "ascend_" + item + "_install.info")
        for info_path in target_paths:
            item_info = info_to_dict(info_path)
            info_dict = {"name": item,
                         'install_path': get_value_on_prefix_ignore_case(item_info, "path", os.path.dirname(info_path)),
                         'install_arch': get_value_on_prefix_ignore_case(item_info, "arch", platform.machine()),
                         'version': get_value_on_prefix_ignore_case(item_info, "version", "")}
            apps_info.append(info_dict)
    other_packages = ["mindie_image"]
    for item in other_packages:
        info_dict = {"name": item, "version": ""}
        if item == 'mindie_image':
            version = get_mindie_image_version()
            info_dict["version"] = version

        if info_dict["version"]:
            apps_info.append(info_dict)

    ret = {
        "progress": "1.0",
        "operation": "app display",
        "result": apps_info
    }
    return ret


def get_hccn_info():
    ret, outputs = run_command("npu-smi info -l")
    hccn_info = {}
    if not ret:
        npu_ids = []
        for line in outputs.split('\n'):
            if "NPU ID" in line:
                npu_ids.append(line.split(":")[-1].strip())
        for npu_id in npu_ids:
            hccn_lines = ""
            status, outputs = run_command("hccn_tool -i {} -ip -g".format(npu_id))
            if not status:
                hccn_lines += outputs.strip()

            status, outputs = run_command("hccn_tool -i {} -ip -inet6 -g".format(npu_id))
            if not status:
                hccn_lines += outputs.strip()

            _, outputs = run_command("hccn_tool -i {} -net_health -g".format(npu_id))
            hccn_lines += outputs.strip()

            if hccn_lines:
                hccn_info[npu_id] = hccn_lines
    return hccn_info


def get_npu_info(outputs):
    check_next_line = False
    npus = collections.defaultdict(lambda: 0)
    for line in outputs.splitlines():
        if "====" in line:
            check_next_line = True
            continue
        if check_next_line:
            words = line.split()
            if len(words) > 11:
                npus[words[2]] += 1
            check_next_line = False
    return npus


def get_mcu_version(module):
    """
    Get the mcu version dict.

    This function obtains npu id information through 'npu-smi info -l',
    and then concurrently executes 'npu-smi upgrade -b mcu -i NPU_ID' to query the mcu version
    and generate the return value version_dict

    Args:
        module (AnsibleModule): Ansible module instance.

    Returns:
        version_dict: The mcu version dict
                       eg: {'npu_id_1': '24.2.1', 'npu_id_2': '24.2.1','npu_id_4':'24.2.1'}
    """
    if not module.get_bin_path('npu-smi'):
        return {}
    rc, out, _ = module.run_command(shlex.split("npu-smi info -l"))
    if rc != 0:
        return {}
    npu_id_list = []
    for line in out.splitlines():
        if 'NPU ID' in line:
            npu_id = line.split(':')[-1].strip()
            if npu_id.isdigit():
                npu_id_list.append(int(npu_id))

    mcu_multiprocess = McuMultiProcess(npu_id_list, module)
    results = mcu_multiprocess.multi_run_command()
    version_dict = {}
    for device_id in sorted(results.keys()):
        result = results[device_id]
        if not result.get('success'):
            if 'This device does not support querying version' in result.get('error'):
                version_dict.update({'npu_id_' + str(device_id): 'not support'})
            else:
                version_dict.update({'npu_id_' + str(device_id): 'ERROR'})
        for line in result.get('output').splitlines():
            # Version  : 24.15.15
            if 'Version' in line:
                version = line.split(':')[-1].strip()
                version_dict.update({'npu_id_' + str(device_id): version})
    return version_dict


def get_mindie_image_version():
    _, output = run_command("docker ps --filter name=MindIE --format {{.Names}}")
    container_names = output.splitlines()
    if 'MindIE' not in container_names:
        return ""
    version_path = "/usr/local/Ascend/mindie/latest/mindie-service/version.info"
    command = "docker exec MindIE cat {}".format(version_path)
    rc, output = run_command(command)
    if rc != 0:
        return ""
    try:
        for line in output.splitlines():
            if "Ascend-mindie :" in line:
                return line.split(":")[1].strip()
    except Exception:
        return ""
    return ""


def main():
    module = AnsibleModule(argument_spec=dict(
        ip=dict(type="str", required=True),
        only_package=dict(type="bool", required=True),
    )
    )
    ip = module.params["ip"]
    only_package = module.params["only_package"]
    if os.path.exists(os.path.expanduser("~/smartkit/reports/")):
        shutil.rmtree(os.path.expanduser("~/smartkit/reports/"))
    if not os.path.exists(os.path.expanduser("~/smartkit/reports")):
        os.makedirs(os.path.expanduser("~/smartkit/reports"), mode=0o750)
    app_info = collect_app_info()

    if only_package:
        outputs = ["[ASCEND]{:<16} {:<16}".format("Package", "Version"), ]
        outputs.append('-' * len(outputs[-1]))
        for app in app_info.get("result", []):
            outputs.append("{:<16} {:<16}".format(app['name'], app['version']))
        return module.exit_json(changed=True, rc=0, msg="\n".join(outputs))
    display_json_path = os.path.expanduser("~/smartkit/display.json")
    try:
        SafeFileHandler.safe_write_json(display_json_path, app_info)
    except Exception as e:
        module.fail_json(msg="Failed to write file {}: {}".format(display_json_path, str(e)))

    local_info = {"packages": app_info.get("result", [])}

    _, outputs = run_command("npu-smi info")
    driver_info_path = os.path.expanduser("~/smartkit/reports/driver_info.txt")
    try:
        SafeFileHandler.safe_write(driver_info_path, outputs)
    except Exception as e:
        module.fail_json(msg="Failed to write driver info to {}: {}".format(driver_info_path, str(e)))

    npus = get_npu_info(outputs)
    if npus:
        local_info['npu'] = ",".join(["{}:{}".format(npu_type, num) for npu_type, num in npus.items()])
        local_info['mcu'] = ",".join(["{}:{}".format(key, value) for key, value in get_mcu_version(module).items()])
    if "910" in local_info.get('npu', ''):
        hccn_info = get_hccn_info()
        if hccn_info:
            local_info['hccn'] = [value for value in hccn_info.values()]
    
    local_info_json_path = os.path.expanduser("~/smartkit/reports/local_info.json")
    try:
        SafeFileHandler.safe_write_json(local_info_json_path, {ip: local_info})
    except Exception as e:
        module.fail_json(msg="Failed to write file {}: {}".format(local_info_json_path, str(e)))

    module.exit_json(changed=True, rc=0)


if __name__ == "__main__":
    main()