#!/usr/bin/env python3
# coding: utf-8
# Copyright 2024 Huawei Technologies Co., Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# ===========================================================================
import os
import grp
import pwd
import csv
import json
from collections import defaultdict
from ansible.module_utils.check_output_manager import check_event
from ansible.module_utils.check_utils import CheckUtil as util
from ansible.module_utils.safe_file_handler import SafeFileHandler

GB = 1024 * 1024 * 1024


class DLCheck:
    user = 'hwMindX'
    user_id = 9000
    group = 'hwMindX'
    group_id = 9000
    k8s_extra_space = 6 * GB
    dl_extra_space = 12 * GB
    throttle = 0.70

    def __init__(self, module, error_messages):
        self.module = module
        self.master_group = module.params.get("master_groups")
        self.cluster_info = module.params.get("cluster_info")
        self.ip = module.params.get('current_hostname')
        self.node_name = module.params.get("node_name")
        self.error_messages = error_messages
        self.master0_arch = module.params.get("master0_arch")
        self.master_arch = module.params.get("master_arch")
        self.master_groups = self.module.params['master_groups']
        self.worker_groups = self.module.params['worker_groups']
        self.current_hostname = self.module.params['current_hostname']

    def _get_k8s_pods(self):
        _, out, _ = self.module.run_command("kubectl get pods -A", check_rc=True)
        dl_pods_info = defaultdict(list)

        # dl中的组件
        match_list = ['device-plugin', 'volcano', 'noded', 'clusterd',
                      'ascend-operator', 'npu-exporter', 'resilience-controller']
        # 解析回显信息
        lines = out.splitlines()
        reader = csv.DictReader(lines, delimiter=' ', skipinitialspace=True)
        for row in reader:
            name = row['NAME']
            for match in match_list:
                if match in name:
                    dl_pods_info[match].append(row)
        return dl_pods_info

    @check_event
    def check_docker_runtime(self):
        # step1.查看default runtime 字段是否为ascend
        _, out, _ = self.module.run_command('docker info', check_rc=True)
        if 'Default Runtime: ascend' not in out:
            util.record_error("Please install docker runtime firstly", self.error_messages)
            return
        # step2.查看daemon.json文件
        try:
            docker_daemon = SafeFileHandler.safe_read_json('/etc/docker/daemon.json')
            if docker_daemon.get('default-runtime') != 'ascend' or 'ascend' not in docker_daemon.get('runtimes'):
                util.record_error("Please install docker runtime firstly", self.error_messages)
                return
            content = SafeFileHandler.safe_read(
                '/usr/local/Ascend/Ascend-Docker-Runtime/ascend_docker_runtime_install.info', 'r')
            for line in content.splitlines():
                if "version" in line:
                    return
        except FileNotFoundError:
            pass
        except Exception as e:
            util.record_error("check docker runtime error: {}".format(str(e)), self.error_messages)
        util.record_error("Please install docker runtime firstly", self.error_messages)
        return

    @check_event
    def check_volcano(self):
        msg = "Please install volcano firstly"
        all_k8s_pods_info = self._get_k8s_pods()
        volcano_controllers = {}
        volcano_scheduler = {}
        for pod in all_k8s_pods_info.get('volcano', []):
            # 验证命名空间
            if pod.get('NAMESPACE') == 'volcano-system':
                if 'volcano-controllers' in pod.get('NAME'):
                    volcano_controllers = pod
                elif 'volcano-scheduler' in pod.get('NAME'):
                    volcano_scheduler = pod

        if not volcano_controllers or not volcano_scheduler:
            util.record_error(msg, self.error_messages)
        return

    def check_master_arch(self):
        if not self.master_arch or not self.master0_arch:
            return
        if self.master0_arch != self.master_arch:
            util.record_error("The master nodes in the cluster have different architectures "
                              "and installation of DL in this cluster is not supported.", self.error_messages)

    def check_space_available(self):
        sv = os.statvfs('/')
        total = (sv.f_blocks * sv.f_frsize)
        used = (sv.f_blocks - sv.f_bfree) * sv.f_frsize
        usage = float(used + self.k8s_extra_space + self.dl_extra_space) / (total or 1)
        if usage > self.throttle:
            total_gb = "{:.2f}".format(total / GB)
            used_gb = "{:.2f}".format(used / GB)
            usage = "{:.2f}".format(usage)
            msg = 'Insufficient available remaining disk space for Docker containers, filesystems, or root ' \
                  'directories. Total disk space: {} GB, used disk space: {} GB. After installation, the disk ' \
                  'usage: {}, should be below {}'.format(total_gb, used_gb, usage, self.throttle)
            util.record_error(msg, self.error_messages)

    def check_inventory(self):
        master_cnt = len(self.master_group)
        if master_cnt == 0:
            util.record_error(
                "[ASCEND][ERROR] The master node configuration information is missing,"
                " please configure the master node info. For details about the master node configuration,"
                " see the user guide.", self.error_messages)
            return

        if master_cnt % 2 == 0:
            util.record_error("[ASCEND][ERROR] the number of Master nodes must be an odd number, "
                              "for example, 1, 3, 5 ,7. Please modify the master nodes configuration.",
                              self.error_messages)
            return

    def check_node_ready(self):
        """
        check the node on inventory whether in k8s cluster
        k8s cluster info like:
            NAME STATUS ROLES AGE VERSION INTERNAL-IP  ...
            name Ready  worker 1d  v1.2.0  0.0.0.0
        """
        name_pos = 0
        status_pos = 1
        ip_pos = 5
        for line in self.cluster_info.get('stdout_lines', []):
            contents = line.split()
            if not contents or len(contents) < 9:
                msg = "k8s cluster info not illegal. Please check either k8s is already installed or the cluster" \
                      "has been built"
                util.record_error(msg, self.error_messages)
                return
            if contents[name_pos] == self.node_name and contents[status_pos] == 'Ready' and contents[ip_pos] == self.ip:
                return
        msg = "current host {} not in k8s cluster. Please check your k8s " \
              "cluster info and join the node".format(self.node_name)
        util.record_error(msg, self.error_messages)
        return

    def check_container_runtime_consistency(self):
        """
        检查k8s集群中所有节点的容器运行时是否一致
        """
        runtime_list = []
        node_names = []

        for line in self.cluster_info.get('stdout_lines', [])[1:]:  # 跳过标题行
            line = line.strip()
            if line:
                parts = line.split()
                if len(parts) >= 8:  # 确保有足够的列
                    node_name = parts[0]
                    container_runtime = parts[-1]  # CONTAINER-RUNTIME 列
                    node_names.append(node_name)
                    runtime_list.append(container_runtime)

        if not runtime_list:
            return

        # 提取运行时类型(去除版本号)
        runtime_types = []
        for runtime in runtime_list:
            if 'docker' in runtime:
                runtime_types.append('docker')
            elif 'containerd' in runtime:
                runtime_types.append('containerd')
            else:
                runtime_types.append(runtime)

        # 检查是否所有节点使用相同的运行时
        if len(set(runtime_types)) > 1:
            # 找出不同的运行时类型
            unique_runtimes = list(set(runtime_types))
            msg = "All nodes in the k8s cluster must use the same container runtime. Found mixed " \
                  "runtimes: {}. ".format(' '.join(unique_runtimes))
            util.record_error(msg, self.error_messages)

    def check_user_and_group(self, uid, gid, username, groupname):
        errors = []

        def safe_check(check_func):
            try:
                check_func()
            except KeyError:
                pass
            except ValueError as e:
                errors.append(str(e))

        def raise_error(err):
            raise ValueError(err)

        # 检查 UID
        safe_check(
            lambda: pwd.getpwuid(uid).pw_name == username or raise_error(
                "UID {} exists,but username is {},instead of the expected {}".format(uid, pwd.getpwuid(uid).pw_name,
                                                                                     username))
        )

        # 检查 GID 对应的组
        safe_check(
            lambda: grp.getgrgid(gid).gr_name == groupname or raise_error(
                "GID {} exists,but group name is {},instead of the expected {}".format(gid, grp.getgrgid(gid).gr_name,
                                                                                       groupname))
        )

        # 检查用户名对应的 UID
        safe_check(
            lambda: pwd.getpwnam(username).pw_uid == uid or raise_error(
                "user {} exists,but UID is {},instead of the expected {}".format(username,
                                                                                 pwd.getpwnam(username).pw_uid, uid))
        )

        # 检查组名对应的 GID
        safe_check(
            lambda: grp.getgrnam(groupname).gr_gid == gid or raise_error(
                "group {} exists,but GID is {},instead of the expected {}".format(groupname,
                                                                                  grp.getgrnam(groupname).gr_gid, gid))
        )

        if errors:
            for error in errors:
                util.record_error(error, self.error_messages)

    @check_event
    def check_dl_basic(self):
        self.check_space_available()
        self.check_master_arch()
        self.check_node_ready()
        self.check_container_runtime_consistency()
        self.check_inventory()
        self.check_user_and_group(self.user_id, self.group_id, self.user, self.group)
        self.check_npu_installed()

    @check_event
    def check_dns(self):
        dns_file = "/etc/resolv.conf"
        if not os.path.exists(dns_file):
            util.record_error("[ASCEND][ERROR] Please config the DNS before install DL", self.error_messages)
            return

        try:
            lines = SafeFileHandler.safe_read(dns_file, "r").splitlines(True)
        except Exception as e:
            error_msg = "[ASCEND][ERROR] Failed to read DNS config file {}: {}".format(dns_file, str(e))
            util.record_error(error_msg, self.error_messages)
            return

        for line in lines:
            if "nameserver" in line:
                return
        util.record_error("[ASCEND][ERROR] Please config the DNS before install DL", self.error_messages)

    @check_event
    def check_resilience_controller_support(self):
        """
        resilience_controller installation only supports 910A1
        """
        card = util.get_card()
        if card != "910":
            util.record_error(
                "[ASCEND][ERROR] Your device does not support resilience-controller.",
                self.error_messages
            )

    def check_npu_installed(self):
        if self.current_hostname in self.worker_groups:
            driver_info = "/usr/local/Ascend/driver/version.info"
            firmware_info = "/usr/local/Ascend/firmware/version.info"

            if not os.path.exists(driver_info):
                util.record_error(
                    "[ASCEND][ERROR] Please install NPU driver firstly.", self.error_messages)

            if not os.path.exists(firmware_info):
                util.record_error(
                    "[ASCEND][ERROR] Please install NPU firmware firstly.", self.error_messages)