import os
import platform
import re
from ansible.module_utils.check_output_manager import check_event
from ansible.module_utils.check_utils import CheckUtil as util
from ansible.module_utils.common_utils import result_handler, compare_version
from ansible.module_utils.deepseek_pd.config_info import MAX_SEQ_LEN_DICT, STATIC_CONFIG_DICT_A910_93, \
STATIC_CONFIG_DICT_A910B
A910_93_COMBINATIONS = [
(4, 1, 2, 2),
(8, 1, 2, 4),
(16, 1, 4, 4),
(24, 1, 6, 4),
]
A910B_COMBINATIONS = [
(2, 2, 1, 4),
(4, 2, 2, 4),
(4, 2, 1, 8),
]
A910_93_BASE_COMBINATION = (24, 1, 6, 4)
A910B_BASE_COMBINATION = (4, 2, 1, 8)
KUBERNETES_NAMESPACE_MAX_LENGTH = 63
_DOCKER = "docker"
_CONTAINERD = "containerd"
_IMAGES = "images"
_CTR = "ctr"
class DeepseekDpCheck:
def __init__(self, module, error_messages):
self.module = module
self.tags = list(filter(bool, self.module.params.get('tags', [])))
self.model_name = self.module.params["model_name"]
self.weight_mount_path = self.module.params["weight_mount_path"]
self.model_weight_path = self.module.params["model_weight_path"]
self.messages = []
self.mindie_image_name = self.module.params["mindie_image_name"]
self.mindie_image_file = self.module.params["mindie_image_file"]
self.p_instances_num = self.module.params["p_instances_num"]
self.d_instances_num = self.module.params["d_instances_num"]
self.single_p_instance_pod_num = self.module.params["single_p_instance_pod_num"]
self.single_d_instance_pod_num = self.module.params["single_d_instance_pod_num"]
self.expert_map_file = self.module.params["expert_map_file"]
self.job_id = self.module.params["job_id"]
self.max_seq_len = self.module.params["max_seq_len"]
self.mindie_host_log_path = self.module.params["mindie_host_log_path"]
self.job_id = self.module.params["job_id"]
self.python_version = module.params.get("python_version")
self.arch = platform.machine()
self.tls_config = self.module.params["tls_config"]
self.npu_info = module.params["npu_info"]
self.cluster_info = module.params.get("cluster_info")
self.container_runtime_type = self.module.params['container_runtime_type']
self.worker_groups = self.module.params.get("worker_groups")
self.master_groups = self.module.params.get("master_groups")
self.facts = dict()
self.error_messages = error_messages
seq_len_scene = MAX_SEQ_LEN_DICT.get(self.max_seq_len)
if self.npu_info.get("scene") == "a910_93":
self.static_config = STATIC_CONFIG_DICT_A910_93.get(seq_len_scene)
else:
self.static_config = STATIC_CONFIG_DICT_A910B.get(seq_len_scene)
@staticmethod
def match_image_name(line, target_image_name):
"""
匹配镜像名称,支持带前缀和不带前缀的情况
Args:
line (str): ctr images list 输出的一行
target_image_name (str): 目标镜像名称
Returns:
bool: 是否匹配
"""
if not line.strip():
return False
image_name_in_line = line.split()[0]
if image_name_in_line == target_image_name:
return True
if '/' in image_name_in_line and image_name_in_line.split('/')[-1] == target_image_name:
return True
return False
def check_mount_path(self):
"""
检查是否提供了挂载路径和路径是否实际存在
"""
paths_to_check = {
'weight_mount_path': self.weight_mount_path,
'model_weight_path': self.model_weight_path
}
for path_name, path_value in paths_to_check.items():
if not path_value:
util.record_error('[ASCEND][ERROR] Please provide a value for the {} parameter.'.format(path_name),
self.error_messages)
return
if not os.path.exists(path_value):
util.record_error('[ASCEND][ERROR] {}: {} is not existed.'.format(path_name, path_value),
self.error_messages)
return
if os.path.islink(path_value):
util.record_error("[ASCEND][ERROR] The specified {} '{}' should not be a symbolic link.".format(
path_name, path_value), self.error_messages)
return
mount_path = os.path.abspath(self.weight_mount_path)
model_path = os.path.abspath(self.model_weight_path)
if not model_path.startswith(mount_path):
util.record_error(
'[ASCEND][ERROR] The model_weight_path must be under the weight_mount_path directory.',
self.error_messages)
def check_image_name_and_file(self):
if not self.mindie_image_name and not self.mindie_image_file:
return
if self.mindie_image_file:
if not os.path.exists(self.mindie_image_file):
util.record_error("[ASCEND][ERROR] The specified mindie_image_file '{}' does not exist.".format(
self.mindie_image_file), self.error_messages)
if os.path.islink(self.mindie_image_file):
util.record_error(
"[ASCEND][ERROR] The specified mindie_image_file '{}' should not be a symbolic link.".format(
self.mindie_image_file), self.error_messages)
if self.mindie_image_name:
if not isinstance(self.mindie_image_name, str):
util.record_error("[ASCEND][ERROR] The mindie_image_name parameter must be a string, got {}.".format(
type(self.mindie_image_name).__name__), self.error_messages)
return
if ':' not in self.mindie_image_name:
util.record_error(
"[ASCEND][ERROR] The mindie_image_name '{}' must include a tag. "
"Valid format example: mindie:2.1.RC1-xx-xx-py311-ubuntu22.04-aarch64".format(
self.mindie_image_name), self.error_messages)
return
self.check_image_exists_in_registry()
def check_image_exists_in_registry(self):
"""
通过查询镜像是否存在
"""
if not self.container_runtime_type:
util.record_error(
"[ASCEND][ERROR] Query container_runtime_type failed.",
self.error_messages)
return
if _DOCKER in list(self.container_runtime_type.values())[0]:
container_runtime = _DOCKER
elif _CONTAINERD in list(self.container_runtime_type.values())[0]:
container_runtime = _CONTAINERD
else:
util.record_error(
"[ASCEND][ERROR] Invalid container runtime type. ", self.error_messages)
return
if container_runtime == _DOCKER:
docker_exists = self.module.get_bin_path(_DOCKER)
if not docker_exists:
util.record_error(
"[ASCEND][ERROR] Docker not found. Image '{}' not verified locally.".format(
self.mindie_image_name), self.error_messages)
return
check_cmd = ["docker", _IMAGES, "--format", "{{.Repository}}:{{.Tag}}"]
rc, out, err = self.module.run_command(check_cmd)
if rc != 0:
util.record_error(
"[ASCEND][ERROR] Failed to list {} images. Command '{}' failed with return code {}. "
"Error message: {}".format(container_runtime, " ".join(check_cmd), rc, err.strip()),
self.error_messages)
return
image_lines = out.splitlines()
if self.mindie_image_name in image_lines:
return
else:
ctr_exists = self.module.get_bin_path(_CTR)
if not ctr_exists:
util.record_error(
"[ASCEND][ERROR] Containerd not found. Image '{}' not verified locally.".format(
self.mindie_image_name), self.error_messages)
return
check_cmd_default = [_CTR, _IMAGES, "list"]
rc, out, err = self.module.run_command(check_cmd_default)
if rc == 0:
image_lines = out.splitlines()
if any(self.match_image_name(line, self.mindie_image_name) for line in image_lines if line.strip()):
return
check_cmd_k8s = [_CTR, "-n", "k8s.io", _IMAGES, "list"]
rc, out, err = self.module.run_command(check_cmd_k8s)
if rc != 0:
util.record_error(
"[ASCEND][ERROR] Failed to list {} images. Command '{}' failed with return code {}. "
"Error message: {}".format(container_runtime, " ".join(check_cmd_k8s), rc, err.strip()),
self.error_messages)
return
image_lines = out.splitlines()
if any(self.match_image_name(line, self.mindie_image_name) for line in image_lines if line.strip()):
return
util.record_error(
"[ASCEND][ERROR] mindie_image_name: '{}' not found locally.".format(self.mindie_image_name),
self.error_messages)
def check_expert_map_file(self):
if not self.expert_map_file:
return
if not os.path.exists(self.expert_map_file):
util.record_error("[ASCEND][ERROR] expert_map_file: {} not existed.".format(self.expert_map_file),
self.error_messages)
if os.path.islink(self.expert_map_file):
util.record_error(
"[ASCEND][ERROR] The specified expert_map_file '{}' should not be a symbolic link.".format(
self.expert_map_file), self.error_messages)
def check_job_id(self):
"""
job_id用来创建kubernetes的namespace, 因此要符合kubernetes命名规范
"""
if not self.job_id:
util.record_error("[ASCEND][ERROR] Please provide a value for the job_id parameter.",
self.error_messages)
if len(self.job_id) > KUBERNETES_NAMESPACE_MAX_LENGTH:
util.record_error("[ASCEND][ERROR] {} length should not exceed 63 characters.".format(self.job_id),
self.error_messages)
return
pattern = r'^[a-z0-9]([a-z0-9\-]{0,61}[a-z0-9])?$'
if not re.match(pattern, self.job_id):
util.record_error("[ASCEND][ERROR] job_id must follow Kubernetes naming convention: "
"contain only lowercase letters, numbers, and hyphens, start and end with "
"alphanumeric characters.", self.error_messages)
def check_model_name(self):
if not self.model_name:
return
if not isinstance(self.model_name, str):
util.record_error("[ASCEND][ERROR] The model_name parameter must be a string, got {}.".format(
type(self.model_name).__name__), self.error_messages)
def check_positive_integer(self, param_name, param_value):
"""
检查参数是否为正整数
Args:
param_name (str): 参数名称
param_value: 参数值
Returns:
bool: 检查是否通过
"""
if not param_value:
util.record_error("[ASCEND][ERROR] Please provide a value for the {} parameter.".format(param_name),
self.error_messages)
return False
param_int_value = None
if isinstance(param_value, int):
param_int_value = param_value
elif isinstance(param_value, str):
try:
param_int_value = int(param_value)
except (ValueError, TypeError) as e:
util.record_error("[ASCEND][ERROR] The {} parameter must be an integer, got '{}'.".format(
param_name, param_value), self.error_messages)
return False
else:
util.record_error("[ASCEND][ERROR] The {} parameter must be an integer, got {}.".format(
param_name, type(param_value).__name__), self.error_messages)
return False
if param_int_value is not None and param_int_value <= 0:
util.record_error("[ASCEND][ERROR] The {} parameter must be a positive integer, got {}.".format(
param_name, param_int_value), self.error_messages)
return False
return True
def check_p_and_d_params(self):
"""
检查 P 和 D 实例相关参数
需要确保 p_instances_num, d_instances_num, single_p_instance_pod_num, single_d_instance_pod_num
都存在且为正整数类型
"""
params_to_check = {
'p_instances_num': self.p_instances_num,
'd_instances_num': self.d_instances_num,
'single_p_instance_pod_num': self.single_p_instance_pod_num,
'single_d_instance_pod_num': self.single_d_instance_pod_num
}
for param_name, param_value in params_to_check.items():
self.check_positive_integer(param_name, param_value)
if not self._check_single_d_instance_pod_num():
return
scene = self.npu_info.get("scene", "")
if scene == "a910b":
self._check_a910b_combinations()
elif scene == "a910_93":
self._check_a910_93_combinations()
def check_cluster_node_ready_count(self):
"""
检查集群中处于Ready状态的节点数量是否满足P和D实例的总需求
需要满足条件:Ready节点数 >= (p_instances_num * single_p_instance_pod_num) +
(d_instances_num * single_d_instance_pod_num)
"""
has_all_required_params = (self.p_instances_num and self.single_p_instance_pod_num and
self.d_instances_num and self.single_d_instance_pod_num)
if has_all_required_params:
required_nodes = (self.p_instances_num * self.single_p_instance_pod_num +
self.d_instances_num * self.single_d_instance_pod_num)
lines = self.cluster_info.get('stdout_lines', [])
if not lines:
util.record_error(
"[ASCEND][ERROR] Cannot get cluster node information, skip node count check.",
self.error_messages)
return
ready_node_count = 0
for line in lines[1:] if len(lines) > 1 else lines:
fields = line.split()
if len(fields) >= 2 and fields[1].lower() == 'ready':
ready_node_count += 1
if ready_node_count < required_nodes:
util.record_error(
"[ASCEND][ERROR] Not enough ready nodes in the cluster. "
"Required: {} nodes, Available: {} ready nodes.".format(required_nodes, ready_node_count),
self.error_messages)
def _check_single_d_instance_pod_num(self):
"""检查 single_d_instance_pod_num 参数是否有效"""
if not self.static_config:
util.record_error("[ASCEND][ERROR] The static_config for max_seq_len:{} is none".format(self.max_seq_len),
self.error_messages)
return False
if self.single_d_instance_pod_num not in self.static_config["decode"].moe_ep:
valid_keys = list(self.static_config["decode"].moe_ep.keys())
util.record_error(
"[ASCEND][ERROR] The single_d_instance_pod_num:{} parameter must be one of {}".format(
self.single_d_instance_pod_num, sorted(valid_keys)), self.error_messages)
return False
return True
def _check_a910b_combinations(self):
"""检查 a910b 场景的参数组合"""
combination = (self.p_instances_num, self.single_p_instance_pod_num,
self.d_instances_num, self.single_d_instance_pod_num)
if combination in A910B_COMBINATIONS:
return
util.record_error(
"[ASCEND][ERROR] On this device, the combination of parameters must be one of: "
"2*2P+1*4D, 4*2P+2*4D, 4*2P+1*8D", self.error_messages)
def _check_a910_93_combinations(self):
"""检查 a910_93 场景的参数组合"""
combination = (self.p_instances_num, self.single_p_instance_pod_num,
self.d_instances_num, self.single_d_instance_pod_num)
if combination in A910_93_COMBINATIONS:
return
util.record_error(
"[ASCEND][ERROR] On this device, the combination of parameters must be one of: "
"4*1P+2*2D, 8*1P+2*4D, 16*1P+4*4D, 24*1P+6*4D", self.error_messages)
def check_max_seq_len(self):
if not self.max_seq_len:
util.record_error("[ASCEND][ERROR] Please provide a value for the max_seq_len parameter.")
return
self.check_positive_integer('max_seq_len', self.max_seq_len)
if self.max_seq_len not in MAX_SEQ_LEN_DICT:
util.record_error(
"[ASCEND][ERROR] The max_seq_len:{} parameter must be one of {}".format(self.max_seq_len, sorted(
MAX_SEQ_LEN_DICT.keys())), self.error_messages)
def check_mindie_host_log_path(self):
if not self.mindie_host_log_path:
return
if not os.path.exists(self.mindie_host_log_path):
util.record_error('[ASCEND][ERROR] mindie_host_log_path: {} not existed.'.format(self.mindie_host_log_path),
self.error_messages)
return
if not os.path.isdir(self.mindie_host_log_path):
util.record_error(
'[ASCEND][ERROR] mindie_host_log_path: {} is not a directory.'.format(self.mindie_host_log_path),
self.error_messages)
if os.path.islink(self.mindie_host_log_path):
util.record_error(
"[ASCEND][ERROR] The specified mindie_host_log_path '{}' should not be a symbolic link.".format(
self.mindie_host_log_path), self.error_messages)
def check_groups(self):
if not self.worker_groups or len(self.worker_groups) == 0:
util.record_error(
"[ASCEND][ERROR] Please provide at least one worker node.", self.error_messages)
if not self.master_groups or len(self.master_groups) == 0:
util.record_error(
"[ASCEND][ERROR] Please provide at least one master node.", self.error_messages)
@check_event
def check_deepseek_pd(self):
self.check_mount_path()
self.check_image_name_and_file()
self.check_expert_map_file()
self.check_job_id()
self.check_model_name()
self.check_p_and_d_params()
self.check_max_seq_len()
self.check_mindie_host_log_path()
self.check_cluster_node_ready_count()
self.check_groups()