import os
import platform
import json
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils import common_info, path_manager, common_utils, compatibility_config
from ansible.module_utils.safe_file_handler import SafeFileHandler
class LabelNode:
common_master_labels = {'masterselector': 'dls-master-node'}
common_worker_labels = {
'node-role.kubernetes.io/worker': 'worker',
'workerselector': 'dls-worker-node'
}
def __init__(self):
self.module = AnsibleModule(
argument_spec=dict(
step=dict(type='str', required=True),
ansible_run_tags=dict(type='list'),
node_name=dict(type='str'),
master_node=dict(type='bool'),
worker_node=dict(type='bool'),
nodes_label=dict(type='dict'),
group_count=dict(type='int'),
noded_label=dict(type='str')
)
)
self.step = self.module.params['step']
self.tags = self.module.params['ansible_run_tags']
self.node_name = self.module.params['node_name']
self.master_node = self.module.params['master_node']
self.worker_node = self.module.params['worker_node']
self.nodes_label = self.module.params['nodes_label']
self.sub_groups = self.module.params['group_count']
self.noded_label = self.module.params['noded_label']
self.arch = platform.machine()
self.facts = dict()
self.label_yaml_dir = os.path.join(path_manager.TmpPath.DL_YAML_DIR, "label")
if not os.path.exists(self.label_yaml_dir):
os.makedirs(self.label_yaml_dir, mode=0o750)
def _get_noded_label(self):
if self.noded_label == "on" or "noded" in self.tags or "dl" in self.tags:
return {'nodeDEnable': 'on'}
return {}
def _get_device_plugin_label(self):
labels = {}
if self.arch == 'x86_64':
labels.update({'host-arch': 'huawei-x86'})
else:
labels.update({'host-arch': 'huawei-arm'})
for line in self.iter_cmd_output('lspci'):
if 'Processing accelerators' in line:
if 'Device d100' in line:
labels.update({'accelerator': 'huawei-Ascend310'})
if 'Device d500' in line:
labels.update({'accelerator': 'huawei-Ascend310P'})
if any(f'Device d{num}' in line for num in (801, 802, 803, 806)):
labels.update({'accelerator': 'huawei-Ascend910'})
card_nums = 0
npu_id = '0'
chip_id = '0'
for line in self.iter_cmd_output('npu-smi info -m'):
if 'Ascend' in line and len(line.split(None, 2)) == 3:
card_nums += 1
if card_nums == 1:
npu_id, chip_id, _ = line.split(None, 2)
board_id = ''
for line in self.iter_cmd_output('npu-smi info -t board -i {} -c {}'.format(npu_id, chip_id)):
if 'Board' in line and ':' in line:
board_id = line.strip().split(':')[1].strip().lower()
break
if board_id in common_info.Atlas_800:
if card_nums == 8:
labels.update({'accelerator-type': 'module'})
elif card_nums == 4:
labels.update({'accelerator-type': 'half'})
elif board_id in common_info.Atlas_800_A2 + common_info.Atlas_900_A2_PoD:
labels.update({'accelerator-type': 'module-910b-8'})
elif board_id in common_info.Atlas_200T_A2_Box16:
labels.update({'accelerator-type':'module-910b-16'})
elif board_id in common_info.Atlas_300T:
labels.update({'accelerator-type': 'card'})
elif board_id in common_info.Atlas_300T_A2:
labels.update({'accelerator-type': 'card-910b-2'})
elif board_id in common_info.Atlas_200T_A3_Box8 + common_info.Atlas_800I_A3:
labels.update({'accelerator-type': 'module-a3-16'})
elif board_id in common_info.Atlas_900_A3_PoD:
labels.update({'accelerator-type': 'module-a3-16-super-pod'})
elif board_id in common_info.Atlas_350_A5:
labels.update({'accelerator-type': '300I-A5-8'})
npu_info = common_info.get_npu_info()
if npu_info.get('card') == compatibility_config.Hardware.ATLAS_800I_A2:
labels.update({'server-usage': 'infer'})
return labels
def iter_cmd_output(self, cmd):
if not self.module.get_bin_path(cmd.split()[0]):
return
rc, out, err = self.module.run_command(cmd)
if out:
for line in out.splitlines():
yield line
def get_labels(self):
node_label = dict()
if self.master_node:
node_label.update(self.common_master_labels)
if self.worker_node:
node_label.update(self.common_worker_labels)
node_label.update(self._get_device_plugin_label())
node_label.update(self._get_noded_label())
if not node_label:
node_label.update(self.common_worker_labels)
node_label.update(self._get_device_plugin_label())
node_label.update(self._get_noded_label())
self.facts['node_label'] = {self.node_name: node_label}
self.module.exit_json(changed=True, msg='{} successfully'.format(self.step), ansible_facts=self.facts)
def save_labels(self):
"""
将标签信息转换成yaml,保存到worker[0]
"""
label_yaml = os.path.join(self.label_yaml_dir, "label_node.json")
nodes = []
for node_name, node_label in self.nodes_label.items():
node_data = {
"apiVersion": "v1",
"kind": "Node",
"metadata": {
"name": node_name,
"labels": node_label
}
}
nodes.append(node_data)
try:
SafeFileHandler.safe_write(label_yaml, json.dumps(nodes), 'w')
except Exception as e:
self.module.fail_json(msg='failed to write label file: {}, error: {}'.format(label_yaml, str(e)),
changed=False, rc=1)
self.module.exit_json(changed=True, msg='{} successfully'.format(self.step), ansible_facts=self.facts)
def label_node(self):
"""
将各个集群的标签信息的yaml文件合并成一个文件label_nodes.yaml,
通过kubectl apply -f label_nodes.yaml给节点打标签
"""
groups_json_dir = os.path.join(self.label_yaml_dir, "groups")
groups_json = os.listdir(groups_json_dir)
if 0 < len(groups_json) < self.sub_groups:
self.module.exit_json(changed=False, rc=0, msg="Did not get all labels, skipped.")
label_yaml = os.path.join(self.label_yaml_dir, "label_nodes.yaml")
try:
with open(label_yaml, 'w') as write_file:
for groups_name in groups_json:
group_json_path = os.path.join(groups_json_dir, groups_name, "label_node.json")
try:
with open(group_json_path, 'r') as read_file:
data = json.load(read_file)
common_utils.dump_all_to_yaml(data, write_file)
except IOError as e:
self.module.fail_json(msg="Failed to read or parse {}: {}".format(group_json_path, str(e)),
changed=False, rc=1)
except IOError as e:
self.module.fail_json(msg="Failed to write to {}: {}".format(label_yaml, str(e)), changed=False, rc=1)
cmd = 'kubectl apply -f {}'.format(label_yaml)
self.module.run_command(cmd, check_rc=True)
self.module.log(msg='apply yaml: {} for label nodes'.format(label_yaml))
self.module.exit_json(changed=True, msg='{} successfully'.format(self.step), ansible_facts=self.facts)
def run(self):
steps = {
'get_label': self.get_labels,
'save_label': self.save_labels,
'label': self.label_node
}
if self.step not in steps:
self.module.fail_json(msg='invalid step: {}, choose from {}'.format(self.step, list(steps)))
steps.get(self.step)()
if __name__ == '__main__':
LabelNode().run()