import getpass
import logging
import shlex
import socket
import subprocess
import threading
import time
import os
import sys
import paramiko
from dbmind.common.utils.checking import check_ssh_version
n_stdin = 0
n_stdout = 1
n_stderr = 2
class shlex_py38(shlex.shlex):
"""
Use the read_token function of shlex in py38 to replace that in py37 to
fix the incompatibility of punctuation_chars and whitespace_split in py37.
"""
def read_token(self):
quoted = False
escapedstate = ' '
while True:
if self.punctuation_chars and self._pushback_chars:
nextchar = self._pushback_chars.pop()
else:
nextchar = self.instream.read(1)
if nextchar == '\n':
self.lineno += 1
if self.state is None:
self.token = ''
break
elif self.state == ' ':
if not nextchar:
self.state = None
break
elif nextchar in self.whitespace:
if self.token or (self.posix and quoted):
break
else:
continue
elif nextchar in self.commenters:
self.instream.readline()
self.lineno += 1
elif self.posix and nextchar in self.escape:
escapedstate = 'a'
self.state = nextchar
elif nextchar in self.wordchars:
self.token = nextchar
self.state = 'a'
elif nextchar in self.punctuation_chars:
self.token = nextchar
self.state = 'c'
elif nextchar in self.quotes:
if not self.posix:
self.token = nextchar
self.state = nextchar
elif self.whitespace_split:
self.token = nextchar
self.state = 'a'
else:
self.token = nextchar
if self.token or (self.posix and quoted):
break
else:
continue
elif self.state in self.quotes:
quoted = True
if not nextchar:
raise ValueError("No closing quotation")
if nextchar == self.state:
if not self.posix:
self.token += nextchar
self.state = ' '
break
else:
self.state = 'a'
elif (self.posix and nextchar in self.escape and self.state
in self.escapedquotes):
escapedstate = self.state
self.state = nextchar
else:
self.token += nextchar
elif self.state in self.escape:
if not nextchar:
raise ValueError("No escaped character")
if (escapedstate in self.quotes and
nextchar != self.state and nextchar != escapedstate):
self.token += self.state
self.token += nextchar
self.state = escapedstate
elif self.state in ('a', 'c'):
if not nextchar:
self.state = None
break
elif nextchar in self.whitespace:
self.state = ' '
if self.token or (self.posix and quoted):
break
else:
continue
elif nextchar in self.commenters:
self.instream.readline()
self.lineno += 1
if self.posix:
self.state = ' '
if self.token or (self.posix and quoted):
break
else:
continue
elif self.state == 'c':
if nextchar in self.punctuation_chars:
self.token += nextchar
else:
if nextchar not in self.whitespace:
self._pushback_chars.append(nextchar)
self.state = ' '
break
elif self.posix and nextchar in self.quotes:
self.state = nextchar
elif self.posix and nextchar in self.escape:
escapedstate = 'a'
self.state = nextchar
elif (nextchar in self.wordchars or nextchar in self.quotes
or (self.whitespace_split and
nextchar not in self.punctuation_chars)):
self.token += nextchar
else:
if self.punctuation_chars:
self._pushback_chars.append(nextchar)
else:
self.pushback.appendleft(nextchar)
self.state = ' '
if self.token or (self.posix and quoted):
break
else:
continue
result = self.token
self.token = ''
if self.posix and not quoted and result == '':
result = None
return result
def bytes2text(bs):
"""
Converts bytes (or array-like of bytes) to text.
:param bs: Bytes or array-like of bytes.
:return: Converted text.
"""
if type(bs) in (list, tuple) and len(bs) > 0:
if isinstance(bs[0], bytes):
return b''.join(bs).decode(errors='ignore').strip()
if isinstance(bs[0], str):
return ''.join(bs).strip()
else:
raise TypeError
elif isinstance(bs, bytes):
return bs.decode(errors='ignore').strip()
else:
return ''
class ExecutorFactory:
def __init__(self):
"""
A factory class is used to produce executors.
Here are two types of executors.
One is implemented through Popen (generally used for local command execution)
and the other is implemented through SSH (generally used for remote command execution).
"""
self.host = None
self.pwd = None
self.port = 22
self.me = getpass.getuser()
self.user = self.me
def set_host(self, host):
self.host = host
return self
def set_user(self, user):
self.user = user
return self
def set_pwd(self, pwd):
self.pwd = pwd
return self
def set_port(self, port):
self.port = port
return self
def get_executor(self):
if self._is_remote() or self.user != self.me:
if None in (self.user, self.pwd, self.port, self.host):
raise AssertionError
return SSH(host=self.host,
user=self.user,
pwd=self.pwd,
port=self.port)
else:
return LocalExec()
def _is_remote(self):
if not self.host:
return False
hostname = socket.gethostname()
_, _, ip_address_list = socket.gethostbyname_ex(hostname)
if self.host in ('127.0.0.1', 'localhost') or self.host in ip_address_list:
return False
else:
return True
class Executor:
"""Executor is an abstract class."""
class Wrapper:
"""inner abstract class for asynchronous execution."""
def __init__(self, stream):
self.stream = stream
def read(self):
pass
def exec_command_sync(self, command, *args, **kwargs):
pass
class SSH(Executor):
def __init__(self, host, user, pwd, port=22, max_retry_times=5):
"""
Use the paramiko library to establish an SSH connection with the remote server.
You can run one or more commands.
In addition, the `gsql` password information is not exposed.
:param host: String type.
:param user: String type.
:param pwd: String type.
:param port: Int type.
:param max_retry_times: Int type. Maximum number of retries if the connection fails.
"""
check_ssh_version()
self.host = host
self.user = user
self.pwd = pwd
self.port = port
self.max_retry_times = max_retry_times
self.retry_cnt = 0
self.client = SSH._connect_ssh(host, user, pwd, port)
self._exit_status = threading.local()
self._exit_status.value = 0
@staticmethod
def _connect_ssh(host, user, pwd, port):
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(host, port, user, pwd)
return client
@property
def exit_status(self):
return self._exit_status.value
def _exec_command(self, command, **kwargs):
if self.client is None:
self.client = SSH._connect_ssh(self.host, self.user,
self.pwd, self.port)
try:
if type(command) in (list, tuple):
chan = self.client.get_transport().open_session()
chan.invoke_shell()
buff_size = 32768
timeout = kwargs.get('timeout', None)
stdout = list()
stderr = list()
cmds = list(command)
cmds.append('exit $?')
for line in cmds:
chan.send(line + '\n')
while not chan.send_ready():
time.sleep(0.1)
start_time = time.monotonic()
while not chan.exit_status_ready():
if chan.recv_ready():
stdout.append(chan.recv(buff_size))
if chan.recv_stderr_ready():
stderr.append(chan.recv_stderr(buff_size))
if timeout and (time.monotonic() - start_time) > timeout:
break
time.sleep(0.1)
chan.close()
self._exit_status.value = chan.recv_exit_status()
result_tup = (bytes2text(stdout), bytes2text(stderr))
else:
blocking_fd = kwargs.pop('fd')
bin_paths = '/usr/local/bin:/bin:/usr/bin:/usr/local/sbin:/usr/sbin'
path_prefix = 'PATH=$PATH:%s && ' % bin_paths
command = path_prefix + command
chan = self.client.exec_command(command=command, **kwargs)
while not chan[blocking_fd].channel.exit_status_ready():
time.sleep(0.1)
self._exit_status.value = chan[blocking_fd].channel.recv_exit_status()
result_tup = (bytes2text(chan[n_stdout].read()), bytes2text(chan[n_stderr].read()))
self.retry_cnt = 0
return result_tup
except paramiko.SSHException as e:
self.client.close()
self.client = SSH._connect_ssh(self.host, self.user,
self.pwd, self.port)
if self.retry_cnt >= self.max_retry_times:
raise ConnectionError("Can not connect to remote host.")
logging.warning("SSH: %s, so try to reconnect.", e)
self.retry_cnt += 1
return self._exec_command(command)
def exec_command_sync(self, command, *args, **kwargs):
"""
You can run one or more commands.
:param command: Type: tuple, list or string.
:param kwargs: blocking_fd means blocking and waiting for which standard streams.
:return: Execution result.
"""
blocking_fd = kwargs.pop('blocking_fd', n_stdout)
if not isinstance(blocking_fd, int) or blocking_fd > n_stderr or blocking_fd < n_stdin:
raise ValueError
return self._exec_command(command, fd=blocking_fd, **kwargs)
def close(self):
if self.client:
self.client.close()
self.client = None
class LocalExec(Executor):
_exit_status = None
def __init__(self):
"""
Use the subprocess. Popen library to open a pipe.
You can run one or more commands.
In addition, the `gsql` password information is not exposed.
"""
LocalExec._exit_status = threading.local()
LocalExec._exit_status.value = 0
@staticmethod
def exec_command_sync(command, *args, **kwargs):
if type(command) in (list, tuple):
stdout = list()
stderr = list()
cwd = None
for line in command:
if line.strip().startswith('cd '):
cwd = line.strip()[len('cd'):]
continue
proc = subprocess.Popen(shlex.split(line),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=False,
cwd=cwd)
outs, errs = proc.communicate(timeout=kwargs.get('timeout', None))
LocalExec._exit_status.value = proc.returncode
if outs:
stdout.append(outs)
if errs:
stderr.append(errs)
return [bytes2text(stdout), bytes2text(stderr)]
else:
returncode, outs, errs = multiple_cmd_exec(command, timeout=kwargs.get('timeout', None))
LocalExec._exit_status.value = returncode
return [bytes2text(stream) for stream in [outs, errs]]
@property
def exit_status(self):
return LocalExec._exit_status.value
def dequote(text):
"""Strip quotation marks."""
return shlex.split(text)[0]
def to_cmds(cmdline):
separators = {'|', '||', '&&', ';'}
escaped = '\\'
def get_separators(s):
if sys.version_info < (3, 8):
lex = shlex_py38(s, punctuation_chars=True)
else:
lex = shlex.shlex(s, punctuation_chars=True)
lex.whitespace_split = True
tokens = list(lex)
real_tokens = []
separator_indexes = []
escape_count = 0
for token in tokens:
if token == escaped:
escape_count += 1
continue
real_tokens.append(token)
if token in separators:
if escape_count == 0 and len(real_tokens) > 0:
separator_indexes.append(len(real_tokens) - 1)
if escape_count > 0:
escape_count -= 1
for _ in range(escape_count):
real_tokens.append(escaped)
while separator_indexes and separator_indexes[-1] == len(real_tokens) - 1:
separator_indexes.pop()
real_tokens.pop()
if len(separator_indexes) == 0:
real_tokens = shlex.split(s)
return real_tokens, separator_indexes
cmd_words, seps = get_separators(cmdline)
if len(seps) == 0:
return [cmd_words], [False]
cmds = []
require_stdin = [False]
cmd_start = 0
while seps:
sep_index = seps.pop(0)
cmds.append(list(map(dequote, cmd_words[cmd_start:sep_index])))
require_stdin.append(cmd_words[sep_index] == '|')
cmd_start = sep_index + 1
last_one = list(map(dequote, cmd_words[cmd_start:]))
if len(last_one) > 0:
cmds.append(last_one)
return cmds, require_stdin[:len(cmds)]
def multiple_cmd_exec(cmdline, **communicate_kwargs):
"""This function only returns the execution result of the last
command. And only support the basic scenarios.
Notice: this function is only a simple wrapper of popen,
which doesn't support complicated scenarios. e.g.,
`echo $?`, `dirname $(pwd)`, `PATH=$PATH; echo $PATH`, `echo abc >&2`
"""
cmds, require_stdin = to_cmds(cmdline)
if not communicate_kwargs.get('input'):
stdin = None
else:
stdin = subprocess.PIPE
process_list = []
cwd = os.getcwd()
env = os.environ.copy()
for index, cmd in enumerate(cmds):
dollar_index = -1
for i, word in enumerate(cmd):
if word[0] == '$' and word.count("$") == 1 and env.get(word[1:]):
dollar_index = i
break
if 0 < dollar_index < len(cmd) and cmd[dollar_index - 1] != '\\':
cmd[dollar_index] = env.get(cmd[dollar_index][1:], '')
if cmd[0] == 'cd':
cwd = cmd[1]
continue
if cmd[0] == 'export' and '=' in cmd[1]:
k, v = cmd[1].split('=')
env[k] = v
continue
if index == 0:
_p = subprocess.Popen(
cmd, stdin=stdin, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, shell=False
)
if communicate_kwargs.get('input'):
_p.stdin.write(communicate_kwargs.pop('input'))
_p.stdin.close()
else:
if require_stdin[index] and len(process_list):
prev_process = process_list[-1]
stdin = prev_process.stdout
else:
stdin = None
_p = subprocess.Popen(
cmd, stdin=stdin,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=False,
env=env,
cwd=cwd
)
process_list.append(_p)
last_process = process_list[-1]
if last_process.stdin and last_process.stdin.closed:
last_process.stdin = None
try:
outs, errs = last_process.communicate(**communicate_kwargs)
finally:
for _p in process_list:
_p.terminate()
return last_process.returncode, outs, errs