#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c) 2022 Huawei Technologies Co.,Ltd.
#
# openGauss is licensed under Mulan PSL v2.
# You can use this software according to the terms
# and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS,
# WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# ----------------------------------------------------------------------------
# Description  : cm_install is a utility to deploy CM tool to openGauss database cluster.
#############################################################################

import cmd
import getopt
import os
import sys
import re
import shlex
import subprocess
import xml.etree.cElementTree as ETree
from CMLog import CMLog
from Common import *
from ErrorCode import ErrorCode
from InstallImpl import InstallImpl

sys.path.append(sys.path[0] + "/../../")

class Install:
    """
    The class is used to do perform installation
    """

    def __init__(self):
        self.envFile = ""
        self.xmlFile = ""
        self.gaussHome = ""
        self.gaussLog = ""
        self.toolPath = ""
        self.tmpPath = ""
        self.cmDirs = []
        self.hostnames = []
        self.localhostName = ""
        self.cmpkg = ""
        self.nodesInfo = dict()
        self.clusterStopped = False
        self.maxTerm = 0
        self.primaryTermAbnormal = False

    def getLocalhostName(self):
        import socket
        self.localhostName = socket.gethostname()

    def getEnvParams(self):
        self.gaussHome = getEnvParam(self.envFile, "GAUSSHOME")
        self.gaussLog = getEnvParam(self.envFile, "GAUSSLOG")
        self.toolPath = getEnvParam(self.envFile, "GPHOME")
        self.tmpPath = getEnvParam(self.envFile, "PGHOST")

    def checkExeUser(self):
        if os.getuid() == 0:
            CMLog.exitWithError(ErrorCode.GAUSS_501["GAUSS_50105"])

    def usage(self):
        """
cm_install is a utility to deploy CM tool to openGauss database cluster.

Usage:
    cm_install -? | --help
    cm_install -X XMLFILE [-e envFile] --cmpkg=cmpkgPath
General options:
    -X                                 Path of the XML configuration file.
    -e                                 Path of env file.
                                       Default value "~/.bashrc".
    --cmpkg                            Path of CM pacakage.
    -?, --help                         Show help information for this
                                       utility, and exit the command line mode.
        """
        print(self.usage.__doc__)

    def parseCommandLine(self):
        if len(sys.argv) == 1:
            self.usage()
            sys.exit(1)

        try:
            opts, args = getopt.getopt(sys.argv[1:], "?X:e:", ["help", "cmpkg="])
        except getopt.GetoptError as e:
            CMLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50000"] % str(e))

        for opt, value in opts:
            if opt in ("-?", "--help"):
                self.usage()
                sys.exit(0)
            elif opt in ("-X"):
                self.xmlFile = value
            elif opt in ("-e"):
                self.envFile = value
            elif opt in ("--cmpkg"):
                self.cmpkg = value

    def checkParam(self):
        if self.xmlFile == "":
            CMLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50001"] % 'X' + ".")
        checkXMLFile(self.xmlFile)

        if self.cmpkg == "":
            CMLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50001"] % '-cmpkg' + ".")
        if not os.path.exists(self.cmpkg):
            CMLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50201"] % self.cmpkg)
        if not os.path.isfile(self.cmpkg):
            CMLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50210"] % ("cmpkg " + self.cmpkg))

        if self.envFile == "":
            self.envFile = os.path.join(os.environ['HOME'], ".bashrc")
        if not os.path.exists(self.envFile):
            CMLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50201"] % ("envFile " + self.envFile))
        if not os.path.isfile(self.envFile):
            CMLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50210"] % ("envFile " + self.envFile))
        mppdbEnv = getEnvParam(self.envFile, "MPPDB_ENV_SEPARATE_PATH")
        if mppdbEnv != "":
            self.envFile = mppdbEnv
        if self.envFile == "" or not os.path.exists(self.envFile) or not os.path.isfile(self.envFile):
            CMLog.exitWithError(ErrorCode.GAUSS_518["GAUSS_51802"] % 'MPPDB_ENV_SEPARATE_PATH' + ".")

    def checkOm(self):
        """
        check whether there is om tool
        """
        # Use common_execute_cmd to safely execute command, avoiding injection risk
        safe_env_file = shlex.quote(self.envFile)
        cmd_str = "source %s; gs_om --version" % safe_env_file
        cmd_args = ['sh', '-c', cmd_str]
        status, output = common_execute_cmd(cmd_args)
        if status != 0:
            errorDetail = "\nCommand: %s\nStatus: %s\nOutput: %s\n" % (
                str(cmd_args), status, output)
            self.logger.logExit("OM tool is required." + errorDetail)

    def checkXMLFileSecurity(self):
        """
        function : check XML contain DTDs
        input : String
        output : NA
        """
        # Check xml for security requirements
        # if it have "<!DOCTYPE" or it have "<!ENTITY",
        # exit and print "File have security risks."
        try:
            with open(self.xmlFile, "r", encoding='utf-8') as fb:
                lines = fb.readlines()
            for line in lines:
                if re.findall("<!DOCTYPE", line) or re.findall("<!ENTITY", line):
                    raise Exception("File have security risks.")
        except Exception as e:
            raise Exception(str(e))

    def initParserXMLFile(self):
        """
        function : Init parser xml file
        input : String
        output : Object
        """
        try:
            # check xml for security requirements
            self.checkXMLFileSecurity()
            dom_tree = ETree.parse(self.xmlFile)
            rootNode = dom_tree.getroot()
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_512["GAUSS_51236"] + " Error: \n%s." % str(e))
        return rootNode
    
    def getInfoListOfAllNodes(self):
        """
        get hostname and cmDir list of all nodes
        check other CM infos in xml
        TODO: check the consistence of xml and installed cluster.
        """
        self.localhostName = getLocalhostName()

        # get hostnames and port from static file
        # Use common_execute_cmd to safely execute command, avoiding injection risk
        safe_env_file = shlex.quote(self.envFile)
        cmd_str = "source %s; gs_om -t view" % safe_env_file
        cmd_args = ['sh', '-c', cmd_str]
        status, output = common_execute_cmd(cmd_args)
        if status != 0:
            self.logger.logExit((ErrorCode.GAUSS_514["GAUSS_51400"] % str(cmd_args)) + \
                "\nStatus:%d\nOutput:%s" % (status, output))
        nodesStaticInfoStr = re.split("azName.*:.*", output)
        if len(nodesStaticInfoStr) == 0:
            self.logger.logExit("Failed to get cluster info from static file.")
        if len(nodesStaticInfoStr) < 2:
            self.logger.logExit("CM is not supported in single instance.")
        nodesStaticInfo = nodesStaticInfoStr[1:]
        for nodeInfo in nodesStaticInfo:
            if nodeInfo == "":
                continue
            nodename = re.findall("nodeName:(.*)", nodeInfo)[0]
            self.hostnames.append(nodename)
            dataPath = re.findall("datanodeLocalDataPath.*:(.*)", nodeInfo)[0]
            port = re.findall("datanodePort.*:(.*)", nodeInfo)[0]
            self.nodesInfo[nodename] = {"dataPath": dataPath, "port": port}

        # get node info from XML
        hostnamesInXML = []
        rootNode = self.initParserXMLFile()
        elementName = 'DEVICELIST'
        if not rootNode.findall('DEVICELIST'):
            raise Exception(ErrorCode.GAUSS_512["GAUSS_51200"] % elementName)
        deviceArray = rootNode.findall('DEVICELIST')[0]
        deviceNodes = deviceArray.findall('DEVICE')
        cmDict = {"cmsNum": "", "cmServerPortBase": "", "cmServerPortStandby": "",
             "cmServerlevel": "", "cmServerListenIp1": "", "cmServerRelation": ""}
        for dev in deviceNodes:
            paramList = dev.findall('PARAM')
            for param in paramList:
                paraName = param.attrib['name']
                paraValue = param.attrib['value']
                if paraName == 'name':
                    hostnamesInXML.append(paraValue)
                elif paraName == 'cmDir':
                    self.cmDirs.append(paraValue)
                elif paraName == 'cmServerLevel':
                    cmDict['cmServerlevel'] = paraValue
                elif paraName in cmDict.keys():
                    cmDict[paraName] = paraValue
        # check whether XML contains all nodes info
        if self.hostnames != hostnamesInXML:
            self.logger.logExit("XML info is not consistent with static file.")
        # check params in xml
        for item in cmDict:
            if item == 'cmServerPortStandby':
                continue
            if cmDict[item] == "":
                self.logger.logExit(ErrorCode.GAUSS_512["GAUSS_51200"] % item)
        if cmDict['cmsNum'] != '1':
            self.logger.logExit(ErrorCode.GAUSS_500["GAUSS_50024"] % 'cmsNum')
        if cmDict['cmServerlevel'] != '1':
            self.logger.logExit(ErrorCode.GAUSS_500["GAUSS_50024"] % 'cmServerlevel')
        if not cmDict['cmServerPortBase'].isdigit():
            self.logger.logExit(ErrorCode.GAUSS_500["GAUSS_50024"] % 'cmServerPortBase')
        if cmDict['cmServerPortStandby'] != "" and not cmDict['cmServerPortStandby'].isdigit():
            self.logger.logExit(ErrorCode.GAUSS_500["GAUSS_50024"] % 'cmServerPortStandby')
        if len(self.hostnames) != len(self.cmDirs):
            self.logger.logExit("\"cmDir\" of all nodes must be provided.")

    def checkHostTrust(self):
        checkHostsTrust(self.hostnames)

    def initLogger(self):
        logPath = os.path.join(self.gaussLog, "cm", "cm_tool")
        if not os.path.exists(logPath):
            os.makedirs(logPath)
        self.logger = CMLog(logPath, "cm_install", "cm_install")

    def checkCM(self):
        """
        Check whether there is CM in current cluster.
        """
        # Use common_execute_cmd to safely execute command, avoiding injection risk
        safe_env_file = shlex.quote(self.envFile)
        cmd_str = "source %s; gs_om -t status --detail | grep 'CMServer State' > /dev/null" % safe_env_file
        cmd_args = ['sh', '-c', cmd_str]
        status, output = common_execute_cmd(cmd_args)
        if status == 0:
            self.logger.logExit("CM exists in current cluster.")

    def checkCluster(self):
        """
        check the status of the current cluster
        """
        # Use common_execute_cmd to safely execute command, avoiding injection risk
        safe_env_file = shlex.quote(self.envFile)
        cmd_str = "source %s; gs_om -t status --detail" % safe_env_file
        cmd_args = ['sh', '-c', cmd_str]
        status, output = common_execute_cmd(cmd_args)
        if status != 0:
            cmd_str = "source %s; gs_om -t status --detail" % safe_env_file
            cmd_args = ['sh', '-c', cmd_str]
            errorDetail = "Detail:\nCommand:\n" + str(cmd_args) + "\noutput:" + output
            self.logger.logExit(ErrorCode.GAUSS_516["GAUSS_51600"] + errorDetail)
        if "cluster_state   : Unavailable" in output:
            # It’s permitted to deploy CM tool when cluster is stopped,
            # but not permitted when cluster is unavailable.
            if output.count("Manually stopped") == len(self.hostnames):
                self.clusterStopped = True
                return
            self.logger.logExit("The cluster is unavailable currently.")
        if "cluster_state   : Normal" not in output:
            self.logger.logExit("Cluster is running but its status is abnormal.")
        # check whether term of primary is invalid and biggest.
        primaryCount = 0
        primaryTerm = 0
        sqlCmd = "select term from pg_catalog.pg_last_xlog_replay_location();"
        for host in self.hostnames:
            isLocal = False
            if host == self.localhostName:
                isLocal = True
            findPrimaryCmd = "source %s; gs_ctl query -D %s | grep -i 'local_role.*Primary' > /dev/null" % \
                (self.envFile, self.nodesInfo[host]["dataPath"])
            notPrimary, output = executeCmdOnHost(host, findPrimaryCmd, isLocal)
            if notPrimary == 0:
                primaryCount += 1
            getTermLsnCmd = "source %s; gsql -d postgres -p %s -tA -c '%s'" % \
                (self.envFile, self.nodesInfo[host]["port"], sqlCmd)
            status, term = executeCmdOnHost(host, getTermLsnCmd, isLocal)
            if status != 0:
                self.logger.logExit("Failed to get term of host %s." % host)
            if notPrimary == 0:
                primaryTerm = int(term)
            if self.maxTerm < int(term):
                self.maxTerm = int(term)

        if primaryCount != 1:
            self.logger.logExit("The number of primary is invalid.")
        if primaryTerm == 0 or primaryTerm < self.maxTerm:
            self.primaryTermAbnormal = True
            self.logger.warn("Term of primary is invalid or not maximal.\n"
                "Hint: it seems that the cluster is newly installed, so it's "
                "recommended to deploy CM tool while installing the cluster.")

    def run(self):
        self.checkExeUser()
        self.parseCommandLine()
        self.checkParam()
        self.getEnvParams()
        self.initLogger()
        self.checkOm()
        self.checkCM()
        self.getInfoListOfAllNodes()
        self.getLocalhostName()
        self.checkHostTrust()
        self.checkCluster()
        installImpl = InstallImpl(self)
        installImpl.run()

if __name__ == "__main__":
    install = Install()
    install.run()