#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c) 2020 Huawei Technologies Co.,Ltd.
#
# openGauss is licensed under Mulan PSL v2.
# You can use this software according to the terms
# and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS,
# WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# ----------------------------------------------------------------------------
# Description  : gs_ssh is a utility to execute one command on all nodes.
#############################################################################
import os
import sys
import socket
from unittest import expectedFailure
package_path = os.path.dirname(os.path.realpath(__file__))
ld_path = package_path + "/gspylib/clib"
if 'LD_LIBRARY_PATH' not in os.environ:
    os.environ['LD_LIBRARY_PATH'] = ld_path
    os.execve(os.path.realpath(__file__), sys.argv, os.environ)
if not os.environ.get('LD_LIBRARY_PATH').startswith(ld_path):
    os.environ['LD_LIBRARY_PATH'] = \
        ld_path + ":" + os.environ['LD_LIBRARY_PATH']
    os.execve(os.path.realpath(__file__), sys.argv, os.environ)

from gspylib.common.GaussLog import GaussLog
from gspylib.common.Common import DefaultValue
from gspylib.common.ErrorCode import ErrorCode
from gspylib.common.DbClusterInfo import dbClusterInfo
from gspylib.threads.SshTool import SshTool
from gspylib.common.ParameterParsecheck import Parameter
from gspylib.common.ParallelBaseOM import ParallelBaseOM
from gspylib.os.gsfile import g_file
from base_utils.os.env_util import EnvUtil
from base_utils.os.file_util import FileUtil
from domain_utils.cluster_file.version_info import VersionInfo
from base_utils.os.user_util import UserUtil


class ParallelSsh(ParallelBaseOM):
    """
    The class is used to execute one command on all nodes.
    """

    def __init__(self):
        """
        function: initialize the parameters
        input : NA
        output: NA
        """
        ParallelBaseOM.__init__(self)
        self.userInfo = ""
        self.cmd = ""

    def usage(self):
        """
gs_ssh is a utility to execute one command on all %s cluster nodes.

Usage:
  gs_ssh -? | --help
  gs_ssh -V | --version
  gs_ssh -c COMMAND

General options:
  -c                             Command to be executed in cluster.
  -?, --help                     Show help information for this utility,
                                 and exit the command line mode.
  -V, --version                  Show version information.
        """
        print(self.usage.__doc__ % VersionInfo.PRODUCT_NAME)

    def parseCommandLine(self):
        """
        function: parse command line
        input : NA
        output: NA
        """
        ##Parse command
        ParaObj = Parameter()
        ParaDict = ParaObj.ParameterCommandLine("ssh")
        # If help is included in the parameter,
        # the help message is printed and exited
        if (ParaDict.__contains__("helpFlag")):
            self.usage()
            sys.exit(0)
        # Gets the cmd parameter
        if (ParaDict.__contains__("cmd")):
            self.cmd = ParaDict.get("cmd")
            # The cmd parameter is required
        if (self.cmd == ""):
            GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50001"]
                                   % 'c' + ".")

    def check_nodename_recognized(self, node_names):
        """
        function: Check if nodename can be recognized.
        input : NA
        output: NA
        """
        try:
            for name in node_names:
                addr_info = socket.getaddrinfo(name, None)
        except Exception as e:
            GaussLog.printMessage(
                "[Warning] Name or service not known on node [%s]." % name)
            return False
        return True
        
    def initGlobal(self):
        """
        function: Init global parameter
        input : NA
        output: NA
        """
        try:
            # Get user information
            self.user = UserUtil.getUserInfo()["name"]
            self.clusterInfo = dbClusterInfo()
            self.clusterInfo.initFromStaticConfig(self.user)

            nodeNames = self.clusterInfo.getClusterNodeNames()
            if self.check_nodename_recognized(nodeNames):
                self.sshTool = SshTool(nodeNames)
            else:
                nodeSshIps = self.clusterInfo.getClusterSshIps()[0]
                self.sshTool = SshTool(nodeSshIps)
            
        except Exception as e:
            GaussLog.exitWithError(str(e))

    def executeCommand(self):
        """
        function: execute command
        input : NA
        output: NA
        """
        failedNodes = ""
        succeedNodes = ""
        # Create a temporary shell file
        cmdFile = "%s/ClusterCall_%d.sh"\
                      % (EnvUtil.getTmpDirFromEnv(), os.getpid())
        try:
            executeCmd = self.cmd

            FileUtil.createFile(cmdFile, True, DefaultValue.FILE_MODE)

            # Writes the cmd command to the shell
            with open(cmdFile, "a") as fp:
                fp.write("#!/bin/sh")
                fp.write(os.linesep)
                fp.write("%s" % executeCmd)
                fp.write(os.linesep)
                fp.flush()

            ##############################################################
            cmdDir = EnvUtil.getTmpDirFromEnv() + '/'
            # Distribute the shell file to the temporary directory
            # for each node
            self.sshTool.scpFiles(cmdFile, cmdDir)
            # Execute the shell file on all nodes
            cmdExecute = g_file.SHELL_CMD_DICT["execShellFile"] % cmdFile
            (status, output) = self.sshTool.getSshStatusOutput(cmdExecute)
            # Resolve the execution results of all nodes
            for node in status.keys():
                if (status[node] != DefaultValue.SUCCESS):
                    failedNodes += "%s " % node
                else:
                    succeedNodes += "%s " % node
            # Some nodes fail to execute
            if (failedNodes != "" and succeedNodes != ""):
                GaussLog.printMessage(
                    "Failed to execute command on %s." % failedNodes)
                GaussLog.printMessage(
                    "Successfully execute command on %s.\n" % succeedNodes)
            # All nodes execute successfully
            elif (failedNodes == ""):
                GaussLog.printMessage(
                    "Successfully execute command on all nodes.\n")
            # All nodes fail to execute
            elif (succeedNodes == ""):
                GaussLog.printMessage(
                    "Failed to execute command on all nodes.\n")
            # Output Execution result
            GaussLog.printMessage("Output:\n%s" % output)
            # Delete the temporary shell file at all nodes
            cmdFileRm = g_file.SHELL_CMD_DICT["deleteFile"]\
                        % (cmdFile, cmdFile)
            self.sshTool.executeCommand(cmdFileRm)

        except Exception as e:
            cmdFileRm = g_file.SHELL_CMD_DICT["deleteFile"]\
                        % (cmdFile, cmdFile)
            self.sshTool.executeCommand(cmdFileRm)
            GaussLog.exitWithError(str(e))

    def run(self):
        """
        function: Perform the whole process
        input : NA
        output: NA
        """
        # parse cmd lines
        self.parseCommandLine()
        # init globals
        self.initGlobal()
        # execute command
        self.executeCommand()


if __name__ == '__main__':
    # main function
    # Can not run as root
    if (os.getuid() == 0):
        GaussLog.exitWithError(ErrorCode.GAUSS_501["GAUSS_50105"])

    try:
        parallelSsh = ParallelSsh()
        parallelSsh.run()
    except Exception as e:
        GaussLog.exitWithError(str(e))

    sys.exit(0)