import os
import re
import shutil
from utils import ROOT_PATH
class ConfigrationError(Exception):
pass
class Mark:
EQUAL = "="
SPACE = " "
EMPTY = ""
class StrTool:
_NON_WORD_PATTERN = re.compile(r"[^a-zA-Z0-9]")
_FURMULA_PATTERN = r'^[\w\s\.\+\-\*\/\(\)\'"]+$'
_EXCEPTION = ["()"]
_SAFE_EVAL_SCOPE = {
'__builtins__': None,
'int': int,
'str': str
}
@classmethod
def to_py_field(cls, src_field):
return cls._NON_WORD_PATTERN.sub("_", src_field)
@classmethod
def safe_eval(cls, expr):
if not re.fullmatch(cls._FURMULA_PATTERN, expr):
raise ValueError("unsafe expression: {}".format(expr))
for k in cls._EXCEPTION:
if k in expr:
raise ValueError("unsafe expression: {}".format(expr))
return str(eval(expr, cls._SAFE_EVAL_SCOPE))
class Var:
def __init__(self, key, value):
self.key = key
self.value = value
def to_string(self):
return self.key + Mark.EQUAL + self.value
class Host:
def __init__(self, ip, params):
self.ip = ip
self.params = params
def to_string(self):
if self.params:
return self.ip + Mark.SPACE + self.params
return self.ip
class InventoryFilePath:
OldFilePath = "inventory_file"
ParsedFilePath = "parsed_inventory_file"
class IPRange:
def __init__(self, ip_range, step_len):
self.ip_range = ip_range
self.step_len = step_len
def expand_ip_range(self):
import ipaddress
start_ip, end_ip = self.ip_range.split('-')
try:
start = ipaddress.ip_address(start_ip)
end = ipaddress.ip_address(end_ip)
except ValueError as e:
raise ConfigrationError("Parse ip range {} failed: {}.".format(self.ip_range, str(e)))
if start >= end:
raise ConfigrationError("Start IP {} must be less than to end IP {}.".format(start, end))
ip_list = []
current = start
while current <= end:
ip_list.append(str(current))
current += self.step_len
if ip_list[-1] != str(end):
ip_list.append(str(end))
return ip_list
class HostParams:
_STEP_LEN_PARAM_KEY = "step_len"
_INDEX_KEY = "index"
_IP_KEY = "ip"
_PARAM_PATTERN = re.compile(r"\{.+?}")
def __init__(self, params):
self.params = params
self.params_dict = self._convert_to_dict(params)
@staticmethod
def _convert_to_dict(params):
"""
Args:
params: str, example: 'ansible_ssh_user="root" ansible_ssh_pass="test1234" step_len=3'
Response:
res: dict, example:{"ansible_ssh_user": "root", "ansible_ssh_pass": "test1234", "step_len": 3}
"""
param_list = re.split(r"\s+", params)
res = {}
for param in param_list:
if Mark.EQUAL in param:
parts = param.split(Mark.EQUAL)
res[parts[0]] = parts[1]
else:
res[param] = True
return res
def get_step_len(self):
step_len = self.params_dict.get(self._STEP_LEN_PARAM_KEY, 1)
try:
step_len = int(step_len)
except (ValueError, TypeError) as e:
raise ConfigrationError("step_len {} must be int.".format(step_len))
if step_len <= 0:
raise ConfigrationError("step_len {} must bigger than 0.".format(step_len))
return step_len
def remove_step_len(self):
if self._STEP_LEN_PARAM_KEY not in self.params:
return
self.params = re.sub(r"{}\S*\s*".format(self._STEP_LEN_PARAM_KEY), "", self.params)
self.params = self.params.strip()
self.params_dict.pop(self._STEP_LEN_PARAM_KEY)
def generate_new_params_str_list(self, ips):
new_params_str_list = []
self.remove_step_len()
for index, ip in enumerate(ips):
new_params_str = self.params
for key, value in self.params_dict.items():
if not isinstance(value, str):
continue
search_res_list = self._PARAM_PATTERN.findall(value)
if not search_res_list:
continue
for search_str in search_res_list:
replaced_str = search_str.replace(self._IP_KEY, repr(ip)) \
.replace(self._INDEX_KEY, repr(index + 1))
parse_str = StrTool.safe_eval(replaced_str[1:-1])
new_params_str = new_params_str.replace(search_str, parse_str)
new_params_str_list.append(Host(ip, new_params_str).to_string())
return new_params_str_list
class InventoryFile:
"""
This class is mainly to convert the human-friendly inventory_file to ansible-friendly inventory_file.
read the old inventory file, parse it and generate a new inventory file.
An example:
User could input the ip like this:
10.10.10.1-10.10.10.3 ansible_ssh_user="root" ansible_ssh_pass="test1234"
this class will parse it as following:
10.10.10.1 ansible_ssh_user="root" ansible_ssh_pass="test1234"
10.10.10.2 ansible_ssh_user="root" ansible_ssh_pass="test1234"
10.10.10.3 ansible_ssh_user="root" ansible_ssh_pass="test1234"
"""
_MASTER_SEC = "master"
_WORKER_SEC = "worker"
_NPU_NODE_SEC = "npu_node"
_APPLY_NODE_SEC = "apply"
_ALL_VARS_SEC = "all:vars"
_HCCN_SEC = "hccn"
_HCCN_VARS = "hccn:vars"
_OTHER_BUILD_IMAGE_SEC = "other_build_image"
_HOST_SECS = [_MASTER_SEC, _WORKER_SEC, _APPLY_NODE_SEC, _HCCN_SEC, _OTHER_BUILD_IMAGE_SEC, _NPU_NODE_SEC]
_VAR_SECS = [_HCCN_VARS, _ALL_VARS_SEC]
_ALL_SECS = _HOST_SECS + _VAR_SECS
_REQUIRED_SECS = [_MASTER_SEC, _WORKER_SEC]
OLD_FILE_PATH = os.path.join(ROOT_PATH, InventoryFilePath.OldFilePath)
PARSED_FILE_PATH = os.path.join(ROOT_PATH, InventoryFilePath.ParsedFilePath)
config = None
is_python2 = False
is_parsed = False
def __init__(self):
try:
from configparser import ConfigParser
except ImportError:
self.is_python2 = True
return
self.config = ConfigParser(delimiters=(Mark.SPACE, Mark.EQUAL), allow_no_value=True, interpolation=None)
self.config.optionxform = str
self.new_config = ConfigParser(delimiters=[Mark.EMPTY], allow_no_value=True)
self.new_config.optionxform = str
try:
self.config.read(self.OLD_FILE_PATH)
except (IOError, OSError, ConfigParser.Error) as e:
raise ConfigrationError(str(e))
def _copy(self):
try:
shutil.copy(self.OLD_FILE_PATH, self.PARSED_FILE_PATH)
except OSError as e:
raise ConfigrationError(
"Copy {} to {} failed: {}".format(self.OLD_FILE_PATH, self.PARSED_FILE_PATH, str(e)))
def _parse_hosts(self):
res = {}
for host_sec in self._HOST_SECS:
if not self.config.has_section(host_sec):
continue
hosts = []
for host_item in self.config.items(host_sec):
if "-" in host_item[0]:
host_params = HostParams(host_item[1])
ip_range = host_item[0]
ips = IPRange(ip_range, host_params.get_step_len()).expand_ip_range()
hosts.extend(host_params.generate_new_params_str_list(ips))
else:
hosts.append(Host(*host_item).to_string())
res[StrTool.to_py_field(host_sec)] = hosts
return res
def _parse_vars(self):
res = {}
for var_sec in self._VAR_SECS:
if not self.config.has_section(var_sec):
continue
res[StrTool.to_py_field(var_sec)] = [Var(*var).to_string() for var in self.config.items(var_sec)]
return res
def _generate_parsed_inventory_file(self, sec_dict):
for section, _ in self.config.items():
if section not in self._ALL_SECS:
continue
key = StrTool.to_py_field(section)
sec_values = sec_dict.get(key, [])
self.new_config.add_section(section)
for value in sec_values:
self.new_config.set(section, value)
try:
with open(self.PARSED_FILE_PATH, "w") as f:
self.new_config.write(f)
except (IOError, OSError) as e:
raise ConfigrationError(
"Failed to write parsed inventory file {}".format(self.PARSED_FILE_PATH)) from e
def parse(self):
if self.is_parsed:
return
if self.is_python2:
self._copy()
else:
if not self.config.has_section(self._MASTER_SEC) and not self.config.has_section(self._WORKER_SEC):
raise ConfigrationError("Either a worker group or a master group of host nodes must exist!")
sec_dict = self._parse_hosts()
var_sec_dict = self._parse_vars()
sec_dict.update(var_sec_dict)
self._generate_parsed_inventory_file(sec_dict)
self.is_parsed = True
def get_parsed_inventory_file_path(self):
return self.PARSED_FILE_PATH
inventory_file = InventoryFile()