#!/usr/bin/env python3
# coding: utf-8
# Copyright 2024 Huawei Technologies Co., Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# ===========================================================================
import os

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.check_output_manager import check_event, wait_for_finish
from ansible.module_utils.common_info import get_os_and_arch, get_npu_info, need_skip_sys_package, UNKNOWN
from ansible.module_utils.common_utils import run_command
from ansible.module_utils.compatibility_config import EOL_CARD, EOL_MODEL, CARD_OS_COMPONENTS_MAP, \
    MODEL_TAGS_NOT_SUPPORT
from ansible.module_utils.check_library_utils.dl_checks import DLCheck
from ansible.module_utils.check_library_utils.frame_checks import FrameCheck
from ansible.module_utils.check_library_utils.npu_checks import NPUCheck
from ansible.module_utils.check_library_utils.check_user import UserCheck
from ansible.module_utils.check_library_utils.localhost_check import LocalhostCheck
from ansible.module_utils.check_library_utils.check_k8s_device_ip import K8sDeviceIpCheck
from ansible.module_utils.check_library_utils.cann_checks import CANNCheck
from ansible.module_utils.check_utils import CheckUtil as util
from ansible.module_utils.check_library_utils.mindie_image_check import MindIEImageCheck
from ansible.module_utils.check_library_utils.deepseek_pd_check import DeepseekDpCheck
from ansible.module_utils.check_library_utils.deepseek_cntr_check import DeepseekCntrCheck
from ansible.module_utils.check_output_manager import CHECK_OUTPUT_MANAGER

EOL_MSG = "[ASCEND] The lifecycle of {} is over and is no longer supported"
SUPPORT_MSG = "[ASCEND] {} has no support for {} on this device"

MAX_CIRCLES = 8
DL_TAGS = {'dl', 'ascend-docker-runtime', 'clusterd', 'volcano', 'ascend-operator',
           'ascend-device-plugin', 'noded', 'npu-exporter', 'resilience-controller'}


class CompatibilityCheck(object):

    _WORKER_NODES = ["worker"]

    def __init__(self):
        self.module = AnsibleModule(
            argument_spec=dict(
                tags=dict(type='list'),
                python_version=dict(type='str', required=False),
                force_upgrade_npu=dict(type="bool", required=False),
                group_names=dict(type="list"),
                hosts_name=dict(type="str", required=False),
                master_groups=dict(type="list", required=False),
                worker_groups=dict(type="list", required=False),
                host_info=dict(type="dict", required=False),
                current_hostname=dict(type="str", required=False),
                ascend_deployer_work_dir=dict(type='str', required=True),
                npu_num=dict(type='int', required=False),
                packages=dict(type='dict', required=True),
                node_name=dict(type="str", required=True),
                resource_dir=dict(type="str", required=True),
                cluster_info=dict(type='dict', required=True),
                davinci_devices=dict(type="list", required=False),
                weights_path=dict(type="str", required=False),
                ip_address=dict(type="str", required=True),
                mtos_sys_dir=dict(typy="str", required=True),
                use_k8s_version=dict(type="str", required=False),
                master0_arch=dict(type="str", required=True),
                worker0_arch=dict(typy="str", required=True),
                master_arch=dict(typy="str", required=True),
                groups=dict(type="dict", required=True),
                other_build_image_arch=dict(typy="str", required=True),
                model_name=dict(type="str", required=False),
                weight_mount_path=dict(type="str", required=False),
                model_weight_path=dict(type="str", required=False),
                mindie_image_name=dict(type="str", required=False),
                mindie_image_file=dict(type="str", required=False),
                expert_map_file=dict(type="str", required=False),
                job_id=dict(type="str", required=False),
                p_instances_num=dict(type="int", required=False),
                d_instances_num=dict(type="int", required=False),
                single_p_instance_pod_num=dict(type="int", required=False),
                single_d_instance_pod_num=dict(type="int", required=False),
                max_seq_len=dict(type="int", required=False),
                mindie_host_log_path=dict(type="str", required=False),
                tls_config=dict(type="dict", required=False),
                npu_info=dict(type="dict", required=False),
                cntr_mnt_path=dict(type="str", required=False),
                worker_num=dict(type="int", required=False),
                master_ip=dict(type="str", required=False),
                container_runtime_type=dict(type='dict', required=False),
                HARDWARE_TYPE=dict(type="str", required=False, default="")
            ))
        self.tags = list(filter(bool, self.module.params.get('tags', [])))
        if 'all' in self.tags:
            self.tags.remove('all')
        # "group_names": [ "worker"]
        self.group_names = self.module.params.get("group_names")
        self.worker_groups = self.module.params.get("worker_groups")
        # "inventory_hostname": "localhost" "inventory_hostname": "192.168.1.1"
        self.current_hostname = self.module.params.get("current_hostname")
        # "hosts_name": "master,worker"
        self.hosts_name = self.module.params.get("hosts_name").split(",")
        self.packages = self.module.params.get("packages")
        self.npu_info = get_npu_info(hardware_type=self.module.params.get("HARDWARE_TYPE", ""))
        self.os_and_arch = get_os_and_arch()
        self.card = util.get_card(hardware_type=self.module.params.get("HARDWARE_TYPE", ""))
        self.error_messages = []
        self.dl_check = DLCheck(self.module, self.error_messages)
        self.frame_checks = FrameCheck(self.module, self.npu_info, self.error_messages)
        self.npu_check = NPUCheck(self.module, self.error_messages)
        self.user_check = UserCheck(self.module, self.error_messages)
        self.k8s_device_ip_check = K8sDeviceIpCheck(self.module, self.error_messages)
        self.cann_check = CANNCheck(self.module, self.npu_info, self.error_messages)
        self.mindie_image_check = MindIEImageCheck(self.module, self.error_messages)
        self.deepseek_dp_check = DeepseekDpCheck(self.module, self.error_messages)
        self.deepseek_cntr_check = DeepseekCntrCheck(self.module, self.error_messages)
        self.localhost_check = LocalhostCheck(self.module, self.error_messages)
        self.tags_config = self.init_config()
        self.resources_dir = os.path.join(self.module.params.get("ascend_deployer_work_dir"), "resources")
        self.resource_dir = os.path.expanduser(self.module.params["resource_dir"])
        self.ip_address = self.module.params.get("ip_address")
        self.os_components_dict = dict()

    def base_check(self):
        if self.current_hostname in self.worker_groups:
            self.check_os_and_card_compatibility()

    def feature_check(self):
        all_checks = set()
        for tag in self.tags:
            tag_config = self.tags_config.get(tag)
            if not tag_config:
                continue
            if not set(self.group_names).intersection(set(tag_config.get("nodes"))):
                continue
            checks = self.tags_config.get(tag).get("checks")
            if tag in DL_TAGS and self.current_hostname not in self.worker_groups:
                checks.remove(self.npu_check.check_npu)
            if tag == 'ascend-device-plugin' and set(self.tags).intersection({'dl', 'ascend-docker-runtime'}):
                checks.remove(self.dl_check.check_docker_runtime)
            if tag in {'ascend-operator', 'clusterd'} and set(self.tags).intersection({'dl', 'volcano'}):
                checks.remove(self.dl_check.check_volcano)
            all_checks.update(set(checks))

        for check_handler in all_checks:
            check_handler()

    def _localhost_check(self):
        if "localhost" in self.current_hostname:
            if 'MTOS' in self.os_and_arch and 'sys_pkg' in self.tags:
                self.localhost_check.check_mtos_kernel_devel_pkg()
            if set(self.tags).intersection(DL_TAGS):
                self.localhost_check.check_dl_executor()
                self.localhost_check.check_dl_diff_arch()

    def run(self):
        self.base_check()
        self._localhost_check()
        self.user_check.check_root()
        self.user_check.check_user_privilege_escalation()
        self.feature_check()
        wait_for_finish()
        check_result = {self.ip_address: CHECK_OUTPUT_MANAGER.generate_check_output()}
        if self.error_messages:
            self.error_messages.append("For check details, please see ~/.ascend_deployer/"
                                       "deploy_info/check_res_output.json.")
            return self.module.exit_json(stdout='\n'.join(self.error_messages), fail_flag=True,
                                         check_result_json=check_result)
        self.module.exit_json(stdout='\n'.join(self.error_messages), changed=True, rc=0, fail_flag=False,
                              check_result_json=check_result)

    @check_event
    def check_os_and_card_compatibility(self):
        self.check_card()
        self.check_model()
        if self.os_and_arch in self.os_components_dict:
            self.check_components()

    def check_card(self):
        card = self.npu_info.get('card')
        if card == UNKNOWN:
            # No card or could not recognise npu card
            return
        if card in EOL_CARD:
            util.record_error(EOL_MSG.format(card), self.error_messages)
            return
        self.os_components_dict = CARD_OS_COMPONENTS_MAP.get(card, dict())
        if not self.os_components_dict:
            util.record_error(
                "Check card failed: cannot find card support dict, ascend-deployer not support on this device",
                self.error_messages)
            return
        if self.os_and_arch not in self.os_components_dict:
            util.record_error("Check device supports os failed: {} is not supported on this device"
                              .format(self.os_and_arch), self.error_messages)

    def check_model(self):
        model = self.npu_info.get('model')
        if model in EOL_MODEL:
            util.record_error(EOL_MSG.format(model), self.error_messages)
        unsupported_tags = MODEL_TAGS_NOT_SUPPORT.get(model, [])
        for tag in self.tags:
            if tag in unsupported_tags:
                util.record_error("Check model failed: " + SUPPORT_MSG.format(tag, model), self.error_messages)

    def check_components(self):
        infer_devices = ('A300i-pro', 'A300i-duo', 'A200i-a2', 'Atlas 800I A2')
        card = self.npu_info.get('card')
        not_support_components = []
        supported_tags = self.os_components_dict.get(self.os_and_arch, [])
        for tag in self.tags:
            if tag not in supported_tags or (card in infer_devices and 'mindspore' in tag):
                # infer devices do not support mindspore anymore.
                not_support_components.append(tag)
        if not_support_components:
            util.record_error(
                "Check os failed: " + SUPPORT_MSG.format(','.join(not_support_components), self.os_and_arch),
                self.error_messages)

    def filter_cann_check(self):
        plugins = ["nnrt", "nnae", "toolbox", "toolkit", "kernels"]
        for plugin in plugins:
            if self.packages.get(plugin):
                return True

        return False


    def init_config(self):
        return {
            'resilience-controller': {
                "checks": [
                    self.npu_check.check_npu,
                    self.dl_check.check_dl_basic,
                    self.dl_check.check_dns,
                    self.dl_check.check_resilience_controller_support,
                    self.k8s_device_ip_check.k8s_device_ip_check],
                "nodes": self.hosts_name
            },
            'npu': {"checks": [self.npu_check.check_npu], "nodes": self._WORKER_NODES},
            'mcu': {"checks": [self.npu_check.check_npu, self.npu_check.check_mcu], "nodes": self.hosts_name},
            'firmware': {"checks": [self.npu_check.check_npu_health, self.npu_check.check_firmware],
                         "nodes": self._WORKER_NODES},
            'driver': {"checks": [self.npu_check.check_npu_health, self.npu_check.check_driver],
                       "nodes": self._WORKER_NODES},
            'pytorch_dev': {"checks": [self.npu_check.check_npu, self.cann_check.check_kernels,
                                       self.frame_checks.check_torch],
                            "nodes": self.hosts_name},
            'pytorch_run': {"checks": [self.npu_check.check_npu, self.cann_check.check_kernels,
                                       self.frame_checks.check_torch],
                            "nodes": self.hosts_name},
            'tensorflow_dev': {"checks": [self.npu_check.check_npu, self.frame_checks.check_tensorflow,
                                          self.cann_check.check_kernels],
                               "nodes": self.hosts_name},
            'tensorflow_run': {"checks": [self.npu_check.check_npu, self.frame_checks.check_tensorflow,
                                          self.cann_check.check_kernels],
                               "nodes": self.hosts_name},
            'npu-exporter': {"checks": [self.npu_check.check_npu, self.dl_check.check_dl_basic, self.dl_check.check_dns,
                                        self.mindie_image_check.check_npu_installed,
                                        self.k8s_device_ip_check.k8s_device_ip_check],
                             "nodes": self._WORKER_NODES},
            'noded': {"checks": [self.npu_check.check_npu, self.dl_check.check_dl_basic, self.dl_check.check_dns,
                                 self.mindie_image_check.check_npu_installed,
                                 self.k8s_device_ip_check.k8s_device_ip_check],
                      "nodes": self._WORKER_NODES},
            'volcano': {"checks": [self.npu_check.check_npu, self.dl_check.check_dl_basic, self.dl_check.check_dns,
                                   self.k8s_device_ip_check.k8s_device_ip_check],
                        "nodes": self.hosts_name},
            'ascend-operator': {"checks": [self.npu_check.check_npu, self.dl_check.check_dl_basic,
                                           self.dl_check.check_dns, self.dl_check.check_volcano,
                                           self.k8s_device_ip_check.k8s_device_ip_check],
                                "nodes": ["master"]},
            'clusterd': {"checks": [self.npu_check.check_npu, self.dl_check.check_dl_basic, self.dl_check.check_dns,
                                    self.dl_check.check_volcano, self.k8s_device_ip_check.k8s_device_ip_check],
                         "nodes": ["master"]},
            'kernels': {"checks": [self.cann_check.check_kernels, self.cann_check.check_cann_basic],
                        "nodes": self.hosts_name},
            'dl': {"checks": [self.npu_check.check_npu, self.dl_check.check_dl_basic, self.dl_check.check_dns,
                              self.k8s_device_ip_check.k8s_device_ip_check],
                   "nodes": self.hosts_name},
            'toolkit': {"checks": [self.cann_check.check_cann_basic], "nodes": self.hosts_name},
            'nnrt': {"checks": [self.cann_check.check_cann_basic], "nodes": self.hosts_name},
            'nnae': {"checks": [self.cann_check.check_cann_basic], "nodes": self.hosts_name},
            'toolbox': {"checks": [self.cann_check.check_cann_basic], "nodes": self.hosts_name},
            'mindspore': {"checks": [self.frame_checks.check_mindspore], "nodes": self.hosts_name},
            'mindspore_scene': {"checks": [self.npu_check.check_npu, self.frame_checks.check_mindspore,
                                           self.cann_check.check_cann_basic, self.cann_check.check_kernels],
                                "nodes": self.hosts_name},
            'ascend-device-plugin': {"checks": [self.npu_check.check_npu, self.dl_check.check_dl_basic,
                                                self.dl_check.check_dns, self.mindie_image_check.check_npu_installed,
                                                self.dl_check.check_docker_runtime,
                                                self.k8s_device_ip_check.k8s_device_ip_check],
                                     "nodes": self._WORKER_NODES},
            'ascend-docker-runtime': {"checks": [self.npu_check.check_npu, self.dl_check.check_dns,
                                                 self.mindie_image_check.check_npu_installed],
                                      "nodes": self._WORKER_NODES},
            'offline_dev': {"checks": [self.npu_check.check_npu, self.cann_check.check_kernels],
                            "nodes": self.hosts_name},
            'offline_run': {"checks": [self.npu_check.check_npu], "nodes": self.hosts_name},
            'tensorflow': {"checks": [self.frame_checks.check_tensorflow], "nodes": self.hosts_name},
            'pytorch': {"checks": [self.frame_checks.check_torch], "nodes": self.hosts_name},
            'mindie_image': {"checks": [self.mindie_image_check.check_mindie_image], "nodes": self.hosts_name},
            'deepseek_pd': {"checks": [self.deepseek_dp_check.check_deepseek_pd], "nodes": self.hosts_name},
            'deepseek_cntr': {
                "checks": [self.deepseek_cntr_check.check_deepseek_cntr, self.mindie_image_check.check_npu_installed],
                "nodes": self.hosts_name},
            'ubengine': {"checks": [self.npu_check.check_npu, self.mindie_image_check.check_npu_installed],
                         "nodes": self.hosts_name},
        }


if __name__ == '__main__':
    CompatibilityCheck().run()