import json
import shlex
import socket
import stat
import argparse
import getpass
import logging
import logging.handlers
import platform
import shutil
import re
import os
import sys
import string
from subprocess import PIPE, Popen
from module_utils.path_manager import get_validated_env
from module_utils.safe_file_handler import SafeFileHandler
ROOT_PATH = SRC_PATH = os.path.dirname(__file__)
NEXUS_SENTINEL_FILE = os.path.expanduser('~/.local/nexus.sentinel')
MODE_700 = stat.S_IRWXU
MODE_600 = stat.S_IRUSR | stat.S_IWUSR
MODE_400 = stat.S_IRUSR
LOG = logging.getLogger('ascend_deployer.utils')
MAX_LEN = 120
dir_list = ['downloader', 'playbooks', 'tools', 'ansible_plugin', 'group_vars', 'patch', 'scripts', 'yamls',
'library', 'module_utils', 'templates', 'large_scale_deploy']
file_list = ['install.sh', 'inventory_file', 'ansible.cfg',
'__init__.py', 'start_deploy.py', 'jobs.py', 'utils.py',
'version.json', 'large_scale_install.sh', 'large_scale_inventory.ini', 'large_scale_deployer.py']
VERSION_PATTERN = re.compile(r"(\d+)")
def compare_version(src_version, target_version):
use_version_parts = VERSION_PATTERN.split(src_version)
new_version_parts = VERSION_PATTERN.split(target_version)
for cur_ver_part, new_ver_part in zip(use_version_parts, new_version_parts):
if cur_ver_part.isdigit() and new_ver_part.isdigit():
result = int(cur_ver_part) - int(new_ver_part)
else:
result = (cur_ver_part > new_ver_part) - (cur_ver_part < new_ver_part)
if result != 0:
return result
return len(use_version_parts) - len(new_version_parts)
def copy_scripts():
"""
copy scripts from library to ASCEND_DEPLOY_HOME
the default ASCEND_DEPLOYER_HOME is HOME
"""
if SRC_PATH == ROOT_PATH:
return
if not os.path.exists(ROOT_PATH):
os.makedirs(ROOT_PATH, mode=0o750)
for dir_name in dir_list:
src = os.path.join(SRC_PATH, dir_name)
dst = os.path.join(ROOT_PATH, dir_name)
if os.path.exists(src) and not os.path.exists(dst):
shutil.copytree(src, dst)
for filename in file_list:
src = os.path.join(SRC_PATH, filename)
dst = os.path.join(ROOT_PATH, filename)
if not os.path.exists(dst) and os.path.exists(src):
shutil.copy(src, dst)
if 'site-packages' in ROOT_PATH or 'dist-packages' in ROOT_PATH:
deployer_home = os.getcwd()
if platform.system() == 'Linux':
deployer_home = get_validated_env('ASCEND_DEPLOYER_HOME') or get_validated_env('HOME')
ROOT_PATH = os.path.join(deployer_home, 'ascend-deployer')
copy_scripts()
class ValidChoices(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, list(set(values)))
class SkipCheck(argparse.Action):
def __call__(self, parser, namespace, value, option_string=None):
if value.lower() == "true":
setattr(namespace, self.dest, True)
return
setattr(namespace, self.dest, False)
def pretty_format(text):
results = []
loc = text.index(':') + 1
results.append(text[:loc])
results.extend(text[loc:].split(','))
return results
class HelpFormatter(argparse.HelpFormatter):
def _split_lines(self, text, width):
if ':' in text:
return pretty_format(text)
import textwrap
return textwrap.wrap(text, width, break_on_hyphens=False)
def args_with_comma(args):
new_args = []
for arg in args:
sep_loc = arg.find('=')
ver_loc = arg.find('==')
if sep_loc > 0 and sep_loc != ver_loc:
new_args.append(arg[:sep_loc])
arg = arg[sep_loc + 1:]
for sub_arg in arg.split(','):
if sub_arg:
new_args.append(sub_arg)
return new_args
def get_python_version_list():
origin_py_version_file = os.path.join(ROOT_PATH, 'downloader', 'python_version.json')
update_py_version = os.path.join(ROOT_PATH, 'downloader', 'obs_downloader_config', 'python_version.json')
python_version_json = update_py_version if os.path.exists(update_py_version) else origin_py_version_file
data = SafeFileHandler.safe_read_json(python_version_json)
available_python_list = [item['filename'].rstrip('.tar.xz') for item in data]
return available_python_list
def get_name_list(dir_path, prefix, suffix):
items = []
for file_name in os.listdir(dir_path):
if file_name.startswith(prefix) and file_name.endswith(suffix):
item = file_name.replace(prefix, '').replace(suffix, '')
items.append(item)
return sorted(items)
dl_items = ['ascend-device-plugin', 'ascend-docker-runtime', 'ascend-operator', 'noded',
'npu-exporter', 'resilience-controller', 'volcano', 'clusterd', 'dl', 'deepseek_pd']
install_items = get_name_list(os.path.join(ROOT_PATH, "playbooks", "install"), 'install_', '.yml')
scene_items = get_name_list(os.path.join(ROOT_PATH, "playbooks", "scene"), 'scene_', '.yml')
patch_items = get_name_list(os.path.join(ROOT_PATH, "playbooks", "install", "patch"), "install_", ".yml")
upgrade_items = get_name_list(os.path.join(ROOT_PATH, "playbooks", "install", "upgrade"), "upgrade_", ".yml")
test_items = ['all', 'firmware', 'driver', 'nnrt', 'nnae', 'toolkit', 'toolbox', 'mindspore', 'pytorch',
'tensorflow', 'fault-diag', 'ascend-docker-runtime', 'ascend-device-plugin', 'volcano',
'noded', 'clusterd', 'ascend-operator', 'npu-exporter', 'resilience-controller',
'mindie_image', 'mcu', "ubengine"]
check_items = ['full', 'fast']
stdout_callbacks = [
"ansible.posix.cgroup_perf_recap",
"ansible.posix.debug",
"ansible.posix.json",
"ansible.posix.jsonl",
"ansible.posix.profile_roles",
"ansible.posix.profile_tasks",
"ansible.posix.skippy",
"ansible.posix.timer",
"ansible_log",
"community.general.cgroup_memory_recap",
"community.general.context_demo",
"community.general.counter_enabled",
"community.general.dense",
"community.general.diy",
"community.general.elastic",
"community.general.hipchat",
"community.general.jabber",
"community.general.log_plays",
"community.general.loganalytics",
"community.general.logdna",
"community.general.logentries",
"community.general.logstash",
"community.general.mail",
"community.general.nrdp",
"community.general.null",
"community.general.opentelemetry",
"community.general.say",
"community.general.selective",
"community.general.slack",
"community.general.splunk",
"community.general.sumologic",
"community.general.syslog_json",
"community.general.unixy",
"community.general.yaml",
"default",
"deploy_info_output_plugin",
"junit",
"minimal",
"oneline",
"standard",
"tree"
]
LOG_MAX_BACKUP_COUNT = 5
LOG_MAX_SIZE = 20 * 1024 * 1024
LOG_FILE = os.path.join(ROOT_PATH, 'install.log')
LOG_OPERATION_FILE = os.path.join(ROOT_PATH, 'install_operation.log')
class UserHostFilter(logging.Filter):
user = getpass.getuser()
ssh_client_whitelist = string.digits + string.ascii_letters + '~-+_./ :[]'
host = (get_validated_env('SSH_CLIENT', whitelist=ssh_client_whitelist, check_symlink=False,
check_owner=False) or 'localhost').split()[0]
def filter(self, record):
record.user = self.user
record.host = self.host
return True
class RotatingFileHandler(logging.handlers.RotatingFileHandler):
def doRollover(self):
try:
os.chmod(self.baseFilename, 0o400)
except OSError:
os.chmod('{}.{}'.format(self.baseFilename, LOG_MAX_BACKUP_COUNT), 0o600)
finally:
logging.handlers.RotatingFileHandler.doRollover(self)
os.chmod(self.baseFilename, 0o600)
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
'extra': {
'format': "%(asctime)s %(user)s@%(host)s [%(levelname)s] "
"[%(filename)s:%(lineno)d %(funcName)s] %(message)s"
}
},
"filters": {
"user_host": {
'()': UserHostFilter
}
},
"handlers": {
"install": {
"level": "DEBUG",
"formatter": "extra",
"class": 'utils.RotatingFileHandler',
"filename": LOG_FILE,
'maxBytes': LOG_MAX_SIZE,
'backupCount': LOG_MAX_BACKUP_COUNT,
'encoding': "UTF-8",
"filters": ["user_host"],
},
"install_operation": {
"level": "INFO",
"formatter": "extra",
"class": 'utils.RotatingFileHandler',
"filename": LOG_OPERATION_FILE,
'maxBytes': LOG_MAX_SIZE,
'backupCount': LOG_MAX_BACKUP_COUNT,
'encoding': "UTF-8",
"filters": ["user_host"],
},
},
"loggers": {
"ascend_deployer": {
"handlers": ["install"],
"level": "INFO",
"propagate": True,
},
"install_operation": {
"handlers": ["install_operation"],
"level": "INFO",
"propagate": True,
},
}
}
def run_cmd(args, oneline=False, **kwargs):
if not kwargs.get('shell') and isinstance(args, str):
args = shlex.split(args, posix=platform.system() == 'Linux')
cmd = args if isinstance(args, str) else ' '.join(args)
LOG.info(cmd.center(MAX_LEN, '-'))
stdout = kwargs.pop('stdout', PIPE if oneline else None)
stderr = kwargs.pop('stderr', PIPE)
text = kwargs.pop('universal_newlines', True)
output = []
process = Popen(args, stdout=stdout, stderr=stderr, universal_newlines=text, **kwargs)
if oneline:
for line in iter(process.stdout.readline, ''):
line = line.strip()
output.append(line)
LOG.info(line)
if len(line) <= MAX_LEN:
line += (MAX_LEN - len(line)) * " "
else:
line = line[0:MAX_LEN - 4] + "..."
sys.stdout.write("\r{}".format(line))
err = process.stderr.read()
process.wait()
else:
out, err = process.communicate()
if isinstance(out, str):
output = out.splitlines()
for line in output:
LOG.info(line)
if process.returncode:
if err and '[ASCEND][WARNING]' not in str(err):
raise Exception(err)
raise Exception("returned non-zero exit status {}".format(process.returncode))
elif err and '[ASCEND][WARNING]' in str(err):
print(err)
return output
def install_pkg(name, *paths):
from distutils.spawn import find_executable
if find_executable(name):
LOG.info('{} is already installed, skip'.format(name))
return
if find_executable('dpkg'):
prefix_cmd = "dpkg --force-all -i"
suffix_cmd = '.deb'
else:
prefix_cmd = "rpm -ivUh --force --nodeps --replacepkgs"
suffix_cmd = '.rpm'
pkg_path = os.path.join(ROOT_PATH, 'resources', *paths)
if not pkg_path.endswith(('.deb', '.rpm')):
pkg_path += suffix_cmd
cmd = "{} {}".format(prefix_cmd, pkg_path)
if getpass.getuser() != 'root':
raise Exception('no permission to run cmd: {}, please run command with root user firstly'.format(cmd))
return run_cmd(cmd, oneline=True, shell=True)
def get_hosts_name(tags):
if (isinstance(tags, str) and tags in dl_items) or (isinstance(tags, list) and set(tags) & set(dl_items)):
return 'master,worker'
return 'worker'
class Validator:
"""
This class is mainly to validate some value like ip address
"""
@staticmethod
def is_valid_ipv4(ip):
"""
return True if the ip is ipv4 else False
:param ip: the string of ip address
:return: bool, true if ipv4 otherwise false
"""
if not isinstance(ip, str):
return False
try:
socket.inet_pton(socket.AF_INET, ip)
return True
except (socket.error, ValueError, AttributeError):
return False
@staticmethod
def is_valid_ipv6(ip):
"""
return True if the ip is ipv6 else False
:param ip: the string of ip address
:return: bool, true if ipv6 otherwise false
"""
try:
socket.inet_pton(socket.AF_INET6, ip)
return True
except (socket.error, ValueError, AttributeError):
return False
def is_valid_ip(self, ip):
"""
:param ip: the string of ip address
:return: bool: true is validate otherwise false
"""
if not isinstance(ip, str):
return False
if ip.lower() == "localhost":
return True
return self.is_valid_ipv4(ip) or self.is_valid_ipv6(ip)