#!/usr/bin/env python3
# coding: utf-8
# Copyright 2024 Huawei Technologies Co., Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# ===========================================================================
import glob
import os

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils import common_info, common_utils, venv_installer


class UBEngineInstaller(object):

    def __init__(self):
        self.module = AnsibleModule(
            argument_spec=dict(
                resources_dir=dict(type="str", required=True),
                npu_info=dict(type="dict", required=True),
            )
        )
        self.resources_dir = os.path.expanduser(self.module.params["resources_dir"])
        self.npu_info = self.module.params["npu_info"]
        self.arch = common_info.ARCH
        self.local_path = common_info.get_local_path(os.getuid(), os.path.expanduser("~"))
        self.messages = []

    def module_failed(self):
        return self.module.fail_json(msg="\n".join(self.messages), rc=1, changed=False)

    def module_success(self):
        return self.module.exit_json(msg="Install UBEngine success. Reboot needed for installation to take effect", rc=0, changed=True)

    def find_files(self, path, pattern):
        self.messages.append("try to find {} for {}".format(path, pattern))
        matched_files = glob.glob(os.path.join(path, pattern))
        self.messages.append("find files: " + ",".join(matched_files))
        if len(matched_files) > 0:
            return matched_files[0]
        return ""

    def find_ubengine_pkg(self):
        npu_scene = self.npu_info.get("scene")
        if npu_scene != "a910_95":
            self.messages.append("[ASCEND][ERROR] not support install UBEngine on this device. ")
            return self.module_failed()
        package_path = common_info.get_scene_dict(os.path.expanduser(self.resources_dir)).get(npu_scene)
        ubengine_file = self.find_files(package_path, r"*npu-UBEngine-mgmt*linux*.run")
        if not ubengine_file:
            return self.module.fail_json(
                changed=False,
                rc=1,
                msg=(
                    "[ASCEND][ERROR]failed to find ubengine package from {}".format(self.resources_dir)
                ))
        return ubengine_file

    def install_ubengine(self):
        pkg = self.find_ubengine_pkg()
        command = "bash {} --full".format(pkg)
        _, messages = common_utils.run_command(self.module, command)
        self.messages.extend(messages)

    def run(self):
        self.install_ubengine()
        self.module_success()


def main():
    UBEngineInstaller().run()


if __name__ == "__main__":
    main()