#!/usr/bin/env python3
# coding: utf-8
# Copyright 2024 Huawei Technologies Co., Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# ===========================================================================
import glob
import os

from ansible.module_utils.check_output_manager import check_event
from ansible.module_utils.check_utils import CheckUtil as util
from ansible.module_utils.common_info import SceneName
from ansible.module_utils.safe_file_handler import SafeFileHandler

GB = 1024 * 1024 * 1024

class CANNCheck:
    def __init__(self, module, npu_info, error_messages):
        self.module = module
        self.tags = module.params.get("tags")
        self.resource_dir = os.path.join(module.params.get("ascend_deployer_work_dir"), "resources")
        self.python_version = module.params.get("python_version")
        self.packages = module.params.get("packages")
        self.npu_info = npu_info
        self.error_messages = error_messages

    @check_event
    def check_kernels(self):
        if self.npu_info.get("scene") == SceneName.Infer:
            util.record_error("[ASCEND][ERROR] kernels not support infer scene", self.error_messages)
            return

    def check_driver_installation(self):
        ascend_info_path = "/etc/ascend_install.info"
        if not os.path.isfile(ascend_info_path):
            return
        try:
            lines = SafeFileHandler.safe_read(ascend_info_path, "r").splitlines(True)
        except Exception as e:
            error_msg = "[ASCEND][ERROR] Failed to read driver install info file {}: {}".format(ascend_info_path,
                                                                                                str(e))
            util.record_error(error_msg, self.error_messages)
            return

        for line in lines:
            if "Driver_Install_Path_Param" not in line:
                continue
            driver_install_path = line.split("=")[-1].strip()
            if not os.path.isfile(os.path.join(driver_install_path, "driver/version.info")):
                util.record_error("[ASCEND][ERROR] The /etc/ascend_install.info file exists in the environment, "
                                  "and the file records the driver installation path. However, "
                                  "the driver/version.info does not exist in the installation path. "
                                  "Please check the driver is correctly installed.", self.error_messages)
                return

    def check_cann_install_path_permission(self):
        install_path = "/usr/local/Ascend"
        if not os.path.isdir(install_path):
            return
        if os.stat(install_path).st_uid != 0:
            util.record_error("[ASCEND][ERROR] The owner of the cann installation dir "
                              "'/usr/local/Ascend' must be root, change the owner to root", self.error_messages)
            return

        mode = os.stat(install_path).st_mode
        permissions = oct(mode)[-3:]
        if int(permissions) != 755:
            util.record_error("[ASCEND][ERROR] When installing cann, the user and group of the installation path "
                              "must be root, and the permission must be 755. ", self.error_messages)
        return

    @check_event
    def check_cann_basic(self):
        self.check_driver_installation()
        self.check_cann_install_path_permission()
        self.check_if_nnrt_upgrade()
        self.check_disk_space()

    def check_disk_space(self):
        """
        Check if available disk space is greater than 10GB
        Skip when tags is toolbox
        """
        required_tags = {"toolkit", "nnae", "nnrt", "kernels","mindspore_scene"}
        tags_set = set(self.tags)
        intersection = tags_set & required_tags
        if not intersection:
            return

        required_space = 10 * GB

        try:
            sv = os.statvfs('/')
            available = sv.f_bavail * sv.f_frsize  # 可用空间(字节)

            if available < required_space:
                available_gb = "{:.2f}".format(available / GB)
                required_gb = "{:.2f}".format(required_space / GB)
                msg = ('[ASCEND][ERROR] Insufficient available disk space. Available: {} GB, Required: {} GB. '
                       'Please ensure at least {} GB of free space is available.'.format(
                    available_gb, required_gb, required_gb))
                util.record_error(msg, self.error_messages)

        except Exception as e:
            msg = '[ASCEND][ERROR] Failed to check disk space: {}'.format(str(e))
            util.record_error(msg, self.error_messages)

    def check_if_nnrt_upgrade(self):
        whitelist_conf = ""
        install_conf_pattern = "/usr/local/Ascend/cann/*/install.conf"
        install_conf_list = glob.glob(install_conf_pattern, recursive=True)
        if not install_conf_list:
            return
        install_conf = install_conf_list[0]
        if os.path.exists(install_conf):
            content = SafeFileHandler.safe_read(install_conf).splitlines()
            for line in content:
                if "whitelist" in line.lower():
                    whitelist_conf = line.split("=")[-1]
        if whitelist_conf == "nnrt" and set(self.tags) & {"toolkit", "nnae"}:
            util.record_error(
                "[ASCEND][ERROR] NNRT is already installed. Toolkit or NNAE cannot be used for upgrading."
                "Please use only the nnrt command for the upgrade.",
                self.error_messages)