import os
import sys
import threading
import subprocess
import re
from itertools import combinations
from logging import getLogger
from megatron.training import get_args
from mindspeed.core.multi_modal.dist_train.dist_parallel_state import get_tensor_model_parallel_world_size
from mindspeed.log_config import log_rank_0
from mindspeed.core.qos.domain_info import domains, generate_masked_orthogonal_rank_groups, \
get_tensor_parallel_comm_domain, get_pipeline_parallel_comm_domain, get_data_parallel_comm_domain, \
get_context_parallel_comm_domain, get_expert_parallel_comm_domain, get_overlap_time_dict, get_overlap_space_dict, \
is_cross_boundary, is_a3_version
LOG = getLogger()
_DEFAULT_QOS = 4
_DEFAULT_QOS_SDMA_LOW = os.environ.get('QOS_SDMA_LOW', 2)
_DEFAULT_QOS_SDMA_MIDDLE = os.environ.get('QOS_SDMA_MIDDLE', 4)
_DEFAULT_QOS_SDMA_HIGH = os.environ.get('QOS_SDMA_HIGH', 6)
_DEFAULT_QOS_ROCE_LOW = os.environ.get('QOS_ROCE_LOW', 3)
_DEFAULT_QOS_ROCE_MIDDLE = os.environ.get('QOS_ROCE_MIDDLE', 4)
_DEFAULT_QOS_ROCE_HIGH = os.environ.get('QOS_ROCE_HIGH', 5)
sdma_qos_str_to_value = {
'low': _DEFAULT_QOS_SDMA_LOW,
'middle': _DEFAULT_QOS_SDMA_MIDDLE,
'high': _DEFAULT_QOS_SDMA_HIGH
}
roce_qos_str_to_value = {
'low': _DEFAULT_QOS_ROCE_LOW,
'middle': _DEFAULT_QOS_ROCE_MIDDLE,
'high': _DEFAULT_QOS_ROCE_HIGH
}
_PARALLEL_TYPES = [
'dp',
'dp-cp',
'intra-dp-cp',
'inter-dp-cp',
'cp',
'mp',
'tp',
'pp',
'embd',
'pos-embd',
'tp-dp-cp',
'tp-dp',
'tp-cp',
'ep',
'ep-tp',
'tp-ep-mp',
'tp-ep-pp',
'ep-dp',
'hcp'
]
domains = ('tp', 'dp', 'pp', 'ep', 'cp')
class Qos:
_instance = None
_lock = threading.Lock()
_initialize = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
with cls._lock:
cls._instance = super(Qos, cls).__new__(cls)
return cls._instance
def __init__(self, sdma_queue_list=None, roce_queue_list=None):
if Qos._initialize:
return
if sdma_queue_list is None:
self.sdma_queue_list = [_DEFAULT_QOS_SDMA_LOW, _DEFAULT_QOS_SDMA_MIDDLE,
_DEFAULT_QOS_SDMA_HIGH]
if roce_queue_list is None:
self.roce_queue_list = [_DEFAULT_QOS_ROCE_LOW, _DEFAULT_QOS_ROCE_MIDDLE,
_DEFAULT_QOS_ROCE_HIGH]
else:
self.sdma_queue_list = sdma_queue_list
self.roce_queue_list = roce_queue_list
self.args = get_args()
self.aiqos_mode = self.args.aiqos_mode if hasattr(self.args,
'aiqos_mode') and self.args.aiqos_mode is not None else "auto"
if self.aiqos_mode.lower() not in ['auto', 'manual']:
raise ValueError('aiqos mode must be "auto or manual"')
self.roce_aiqos_schedule = {}
self.sdma_aiqos_schedule = {}
self.init_qos()
Qos._initialize = True
def init_qos(self):
def parse_args(arg_str, target_dict):
if arg_str is None or arg_str.strip() == "":
raise ValueError("aiqos-schedule parameter cannot be empty")
clean_str = arg_str.strip()
valid_values = r'(high|low|middle)'
pattern = fr'^\{{\s*([a-zA-Z0-9_-]+:\s*{valid_values}\s*(,\s*[a-zA-Z0-9_-]+:\s*{valid_values}\s*)*)?\}}$'
if not re.match(pattern, clean_str, re.IGNORECASE):
raise ValueError(
f"Invalid aiqos-schedule format. Only {{key:value,key:value,...}} is supported (value must be high/low/middle). Current input: {arg_str}\n"
"Correct example: {tp:high,pp:low,dp:middle}"
)
clean_str = clean_str.lstrip('{').rstrip('}')
items = [item.strip() for item in clean_str.split(',') if item.strip()]
allowed_values = {'high', 'low', 'middle'}
for item in items:
if item.count(':') != 1:
raise ValueError(f"Invalid key-value pair format. Must be key:value. Current: {item}")
parallel_type, value = item.split(':', 1)
parallel_type = parallel_type.strip().lower()
value_str = value.strip().lower()
if value_str not in allowed_values:
raise ValueError(
f"Invalid value for {parallel_type}. Only 'high', 'low', 'middle' are allowed. Current: {value_str}"
)
target_dict[parallel_type] = value_str
if self.aiqos_mode.lower() == 'manual':
parse_args(self.args.aiqos_schedule, self.roce_aiqos_schedule)
parse_args(self.args.aiqos_schedule, self.sdma_aiqos_schedule)
for key, priority_str in self.roce_aiqos_schedule.items():
priority_str_lower = priority_str.strip().lower()
if priority_str_lower not in roce_qos_str_to_value:
raise ValueError(
f"Invalid QoS priority string: {priority_str}, only 'high'/'low'/'middle' are allowed")
self.roce_aiqos_schedule[key] = roce_qos_str_to_value[priority_str_lower]
for key, priority_str in self.sdma_aiqos_schedule.items():
priority_str_lower = priority_str.strip().lower()
if priority_str_lower not in sdma_qos_str_to_value:
raise ValueError(
f"Invalid QoS priority string: {priority_str}, only 'high'/'low'/'middle' are allowed")
self.sdma_aiqos_schedule[key] = sdma_qos_str_to_value[priority_str_lower]
elif self.aiqos_mode.lower() == 'auto':
self.cal_auto_qos()
self.init_domain_qos_schedule_rules()
log_rank_0(LOG.info, f'qos roce schedule: {self.roce_aiqos_schedule}')
if is_a3_version:
log_rank_0(LOG.info, f'qos sdma schedule: {self.sdma_aiqos_schedule}')
def set_parallel_roce_qos(self, parallel_type):
if parallel_type is None:
return _DEFAULT_QOS
if not self.args.aiqos_enable_roce:
return _DEFAULT_QOS
if parallel_type.lower() not in _PARALLEL_TYPES or parallel_type.lower() not in self.roce_aiqos_schedule:
return _DEFAULT_QOS
return self.roce_aiqos_schedule[parallel_type.lower()]
def set_parallel_sdma_qos(self, parallel_type):
if parallel_type is None:
return _DEFAULT_QOS
if parallel_type.lower() not in _PARALLEL_TYPES or parallel_type.lower() not in self.sdma_aiqos_schedule:
return _DEFAULT_QOS
return self.sdma_aiqos_schedule[parallel_type.lower()]
def cal_auto_qos(self):
parallel_comm_domain_list = [get_tensor_parallel_comm_domain(), get_data_parallel_comm_domain(),
get_pipeline_parallel_comm_domain(), get_expert_parallel_comm_domain(),
get_context_parallel_comm_domain(),
]
domain_partition_information = {
key: domain.rank_list
for key, domain in zip(domains, parallel_comm_domain_list)
}
if self.args.aiqos_enable_roce:
roce_qos_res = self.combination(parallel_comm_domain_list, domain_partition_information, link_type="ROCE")
for parallel_type, qos in roce_qos_res.items():
self.roce_aiqos_schedule[parallel_type] = qos
sdma_qos_res = self.combination(parallel_comm_domain_list, domain_partition_information, link_type="SDMA")
for parallel_type, qos in sdma_qos_res.items():
self.sdma_aiqos_schedule[parallel_type] = qos
return
def combination(self, parallel_comm_domain_list=None, domain_partition_information=None, link_type="SDMA"):
if parallel_comm_domain_list is None or domain_partition_information is None:
raise ValueError("parallel_comm_domain_list or domain_partition_information is None")
domain_nums = len(parallel_comm_domain_list)
if link_type == "ROCE":
queue_nums = len(self.roce_queue_list)
elif link_type == "SDMA":
queue_nums = len(self.sdma_queue_list)
time_overlap = get_overlap_time_dict()
space_overlap = get_overlap_space_dict(domain_partition_information, link_type=link_type)
comb = generate_distributions(domain_nums, queue_nums)
min_single_comb = comb[0]
degree = sys.maxsize
for each_comb in comb:
cur_degree = cal_conflict_degree(each_comb, parallel_comm_domain_list, time_overlap, space_overlap)
if cur_degree < degree:
degree = cur_degree
min_single_comb = each_comb
min_single_comb_log_info = [[domains[idx] for idx in num_list] for num_list in min_single_comb]
log_rank_0(LOG.info, f'{link_type} min_single_comb: {min_single_comb_log_info}')
return self.auto_qos_priority(min_single_comb, parallel_comm_domain_list, link_type=link_type)
def auto_qos_priority(self, min_single_comb, parallel_comm_domain_list, link_type="SDMA"):
rate = [0] * len(min_single_comb)
for each_queue in min_single_comb:
sum_comm_amount = 0
sum_comm_amount_no_overlap = 0
for each_flow in each_queue:
sum_comm_amount += parallel_comm_domain_list[each_flow].comm_amount
sum_comm_amount_no_overlap += parallel_comm_domain_list[each_flow].comm_amount_no_overlap
if sum_comm_amount == 0:
continue
else:
cur_rate = sum_comm_amount_no_overlap / sum_comm_amount
rate[min_single_comb.index(each_queue)] = cur_rate
sorted_rate = dict(sorted(zip(rate, min_single_comb), key=lambda x: x[0], reverse=True))
length = len(sorted_rate)
qos_res = {}
queue_index = 1 if length <= 2 else 0
q_list = self.sdma_queue_list if link_type == "SDMA" else self.roce_queue_list
if not q_list:
raise ValueError(f"Queue list is empty for link_type '{link_type}'. "
+ f"Please ensure self.{link_type.lower()}_queue_list is initialized.")
for value in sorted_rate.values():
for flow in value:
qos_res[domains[flow]] = q_list[queue_index]
queue_index += 1
return qos_res
def init_domain_qos_schedule_rules(self):
if self.args.num_experts is None:
self.sdma_aiqos_schedule['dp-cp'] = self.sdma_aiqos_schedule['dp']
self.sdma_aiqos_schedule['mp'] = self.sdma_aiqos_schedule['pp']
else:
self.sdma_aiqos_schedule['dp-cp'] = self.sdma_aiqos_schedule['dp']
self.sdma_aiqos_schedule['ep-dp'] = self.sdma_aiqos_schedule['dp']
self.sdma_aiqos_schedule['mp'] = self.sdma_aiqos_schedule['pp']
self.sdma_aiqos_schedule['tp-ep-pp'] = self.sdma_aiqos_schedule['pp']
self.sdma_aiqos_schedule['tp-ep-mp'] = max(self.sdma_aiqos_schedule['pp'], self.sdma_aiqos_schedule['tp'])
if self.args.aiqos_enable_roce:
if self.args.num_experts is None:
self.roce_aiqos_schedule['dp-cp'] = self.roce_aiqos_schedule['dp']
self.roce_aiqos_schedule['mp'] = self.roce_aiqos_schedule['pp']
else:
self.roce_aiqos_schedule['dp-cp'] = self.roce_aiqos_schedule['dp']
self.roce_aiqos_schedule['ep-dp'] = self.roce_aiqos_schedule['dp']
self.roce_aiqos_schedule['mp'] = self.roce_aiqos_schedule['pp']
self.roce_aiqos_schedule['tp-ep-pp'] = self.roce_aiqos_schedule['pp']
self.roce_aiqos_schedule['tp-ep-mp'] = max(self.roce_aiqos_schedule['pp'],
self.roce_aiqos_schedule['tp'])
def generate_distributions(m, n):
if m < 0 or n < 0:
return []
results = []
groups = []
def dfs(idx):
if idx == m:
results.append([group[:] for group in groups])
return
for group in groups:
if group and group[-1] < idx:
group.append(idx)
dfs(idx + 1)
group.pop()
if len(groups) < n:
groups.append([idx])
dfs(idx + 1)
groups.pop()
dfs(0)
return results
def cal_conflict_degree(single_comb_info, parallel_comm_domain_list, time_overlap, space_overlap):
degree = 0
for each_queue in single_comb_info:
for i, elem_i in enumerate(each_queue):
for j, elem_j in enumerate(each_queue[i + 1:], start=i + 1):
key = (domains[elem_i], domains[elem_j])
conflict_state = time_overlap[key] * space_overlap[key]
if conflict_state == 1:
degree += min(
parallel_comm_domain_list[i].comm_amount,
parallel_comm_domain_list[j].comm_amount
)
return degree